blob: cb4c0c9f84dbb2e4dc3c78ab97602eed3d818fbc [file] [log] [blame]
Éanna Ó Catháin919c14e2020-09-14 17:36:49 +01001//
2// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "ArmnnNetworkExecutor.hpp"
7#include "Types.hpp"
8
9#include <random>
10#include <string>
11
12namespace od
13{
14
15armnn::DataType ArmnnNetworkExecutor::GetInputDataType() const
16{
17 return m_inputBindingInfo.second.GetDataType();
18}
19
20ArmnnNetworkExecutor::ArmnnNetworkExecutor(std::string& modelPath,
21 std::vector<armnn::BackendId>& preferredBackends)
22: m_Runtime(armnn::IRuntime::Create(armnn::IRuntime::CreationOptions()))
23{
24 // Import the TensorFlow lite model.
25 armnnTfLiteParser::ITfLiteParserPtr parser = armnnTfLiteParser::ITfLiteParser::Create();
26 armnn::INetworkPtr network = parser->CreateNetworkFromBinaryFile(modelPath.c_str());
27
28 std::vector<std::string> inputNames = parser->GetSubgraphInputTensorNames(0);
29
30 m_inputBindingInfo = parser->GetNetworkInputBindingInfo(0, inputNames[0]);
31
32 m_outputLayerNamesList = parser->GetSubgraphOutputTensorNames(0);
33
34 std::vector<armnn::BindingPointInfo> outputBindings;
35 for(const std::string& name : m_outputLayerNamesList)
36 {
37 m_outputBindingInfo.push_back(std::move(parser->GetNetworkOutputBindingInfo(0, name)));
38 }
39
40 std::vector<std::string> errorMessages;
41 // optimize the network.
42 armnn::IOptimizedNetworkPtr optNet = Optimize(*network,
43 preferredBackends,
44 m_Runtime->GetDeviceSpec(),
45 armnn::OptimizerOptions(),
46 armnn::Optional<std::vector<std::string>&>(errorMessages));
47
48 if (!optNet)
49 {
50 const std::string errorMessage{"ArmnnNetworkExecutor: Failed to optimize network"};
51 ARMNN_LOG(error) << errorMessage;
52 throw armnn::Exception(errorMessage);
53 }
54
55 // Load the optimized network onto the m_Runtime device
56 std::string errorMessage;
57 if (armnn::Status::Success != m_Runtime->LoadNetwork(m_NetId, std::move(optNet), errorMessage))
58 {
59 ARMNN_LOG(error) << errorMessage;
60 }
61
62 //pre-allocate memory for output (the size of it never changes)
63 for (int it = 0; it < m_outputLayerNamesList.size(); ++it)
64 {
65 const armnn::DataType dataType = m_outputBindingInfo[it].second.GetDataType();
66 const armnn::TensorShape& tensorShape = m_outputBindingInfo[it].second.GetShape();
67
68 InferenceResult oneLayerOutResult;
69 switch (dataType)
70 {
71 case armnn::DataType::Float32:
72 {
73 oneLayerOutResult.resize(tensorShape.GetNumElements(), 0);
74 break;
75 }
76 default:
77 {
78 errorMessage = "ArmnnNetworkExecutor: unsupported output tensor data type";
79 ARMNN_LOG(error) << errorMessage << " " << log_as_int(dataType);
80 throw armnn::Exception(errorMessage);
81 }
82 }
83
84 m_OutputBuffer.emplace_back(oneLayerOutResult);
85
86 // Make ArmNN output tensors
87 m_OutputTensors.reserve(m_OutputBuffer.size());
88 for (size_t it = 0; it < m_OutputBuffer.size(); ++it)
89 {
90 m_OutputTensors.emplace_back(std::make_pair(
91 m_outputBindingInfo[it].first,
92 armnn::Tensor(m_outputBindingInfo[it].second,
93 m_OutputBuffer.at(it).data())
94 ));
95 }
96 }
97
98}
99
100void ArmnnNetworkExecutor::PrepareTensors(const void* inputData, const size_t dataBytes)
101{
102 assert(m_inputBindingInfo.second.GetNumBytes() >= dataBytes);
103 m_InputTensors.clear();
104 m_InputTensors = {{ m_inputBindingInfo.first, armnn::ConstTensor(m_inputBindingInfo.second, inputData)}};
105}
106
107bool ArmnnNetworkExecutor::Run(const void* inputData, const size_t dataBytes, InferenceResults& outResults)
108{
109 /* Prepare tensors if they are not ready */
110 ARMNN_LOG(debug) << "Preparing tensors...";
111 this->PrepareTensors(inputData, dataBytes);
112 ARMNN_LOG(trace) << "Running inference...";
113
114 armnn::Status ret = m_Runtime->EnqueueWorkload(m_NetId, m_InputTensors, m_OutputTensors);
115
116 std::stringstream inferenceFinished;
117 inferenceFinished << "Inference finished with code {" << log_as_int(ret) << "}\n";
118
119 ARMNN_LOG(trace) << inferenceFinished.str();
120
121 if (ret == armnn::Status::Failure)
122 {
123 ARMNN_LOG(error) << "Failed to perform inference.";
124 }
125
126 outResults.reserve(m_outputLayerNamesList.size());
127 outResults = m_OutputBuffer;
128
129 return (armnn::Status::Success == ret);
130}
131
132Size ArmnnNetworkExecutor::GetImageAspectRatio()
133{
134 const auto shape = m_inputBindingInfo.second.GetShape();
135 assert(shape.GetNumDimensions() == 4);
136 armnnUtils::DataLayoutIndexed nhwc(armnn::DataLayout::NHWC);
137 return Size(shape[nhwc.GetWidthIndex()],
138 shape[nhwc.GetHeightIndex()]);
139}
140}// namespace od