blob: c75b68bbe1e51454da3f4756b1536ca9e7a19120 [file] [log] [blame]
Éanna Ó Catháin919c14e2020-09-14 17:36:49 +01001//
2// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "Types.hpp"
9
10#include "armnn/ArmNN.hpp"
11#include "armnnTfLiteParser/ITfLiteParser.hpp"
12#include "armnnUtils/DataLayoutIndexed.hpp"
13#include <armnn/Logging.hpp>
14
15#include <string>
16#include <vector>
17
18namespace od
19{
20/**
21* @brief Used to load in a network through ArmNN and run inference on it against a given backend.
22*
23*/
24class ArmnnNetworkExecutor
25{
26private:
27 armnn::IRuntimePtr m_Runtime;
28 armnn::NetworkId m_NetId{};
29 mutable InferenceResults m_OutputBuffer;
30 armnn::InputTensors m_InputTensors;
31 armnn::OutputTensors m_OutputTensors;
32 std::vector<armnnTfLiteParser::BindingPointInfo> m_outputBindingInfo;
33
34 std::vector<std::string> m_outputLayerNamesList;
35
36 armnnTfLiteParser::BindingPointInfo m_inputBindingInfo;
37
38 void PrepareTensors(const void* inputData, const size_t dataBytes);
39
40 template <typename Enumeration>
41 auto log_as_int(Enumeration value)
42 -> typename std::underlying_type<Enumeration>::type
43 {
44 return static_cast<typename std::underlying_type<Enumeration>::type>(value);
45 }
46
47public:
48 ArmnnNetworkExecutor() = delete;
49
50 /**
51 * @brief Initializes the network with the given input data. Parsed through TfLiteParser and optimized for a
52 * given backend.
53 *
54 * Note that the output layers names order in m_outputLayerNamesList affects the order of the feature vectors
55 * in output of the Run method.
56 *
57 * * @param[in] modelPath - Relative path to the model file
58 * * @param[in] backends - The list of preferred backends to run inference on
59 */
60 ArmnnNetworkExecutor(std::string& modelPath,
61 std::vector<armnn::BackendId>& backends);
62
63 /**
64 * @brief Returns the aspect ratio of the associated model in the order of width, height.
65 */
66 Size GetImageAspectRatio();
67
68 armnn::DataType GetInputDataType() const;
69
70 /**
71 * @brief Runs inference on the provided input data, and stores the results in the provided InferenceResults object.
72 *
73 * @param[in] inputData - input frame data
74 * @param[in] dataBytes - input data size in bytes
75 * @param[out] results - Vector of DetectionResult objects used to store the output result.
76 */
77 bool Run(const void* inputData, const size_t dataBytes, InferenceResults& outResults);
78
79};
80}// namespace od