blob: 7e338669c79dbc01fa8bcb41097ed81f12a898c2 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
David Beckf0b48452018-10-19 15:20:56 +01006#include <armnn/ArmNN.hpp>
telsoa01c577f2c2018-08-31 09:22:23 +01007
8#if defined(ARMNN_TF_LITE_PARSER)
David Beckf0b48452018-10-19 15:20:56 +01009#include <armnnTfLiteParser/ITfLiteParser.hpp>
telsoa01c577f2c2018-08-31 09:22:23 +010010#endif
11
12#include <HeapProfiling.hpp>
13#if defined(ARMNN_ONNX_PARSER)
David Beckf0b48452018-10-19 15:20:56 +010014#include <armnnOnnxParser/IOnnxParser.hpp>
telsoa01c577f2c2018-08-31 09:22:23 +010015#endif
telsoa014fcda012018-03-09 14:13:49 +000016
David Beck1b61be52018-11-08 09:19:14 +000017#include <backendsCommon/BackendRegistry.hpp>
Aron Virginas-Tar5cc8e562018-10-23 15:14:46 +010018
Aron Virginas-Tar339bcae2019-01-31 16:44:26 +000019#include <boost/algorithm/string/join.hpp>
surmeh013537c2c2018-05-18 16:31:43 +010020#include <boost/exception/exception.hpp>
21#include <boost/exception/diagnostic_information.hpp>
telsoa014fcda012018-03-09 14:13:49 +000022#include <boost/log/trivial.hpp>
23#include <boost/format.hpp>
24#include <boost/program_options.hpp>
surmeh013537c2c2018-05-18 16:31:43 +010025#include <boost/filesystem.hpp>
David Beckf0b48452018-10-19 15:20:56 +010026#include <boost/lexical_cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000027
Aron Virginas-Tar339bcae2019-01-31 16:44:26 +000028#include <algorithm>
29#include <iterator>
Aron Virginas-Tar5cc8e562018-10-23 15:14:46 +010030#include <fstream>
telsoa014fcda012018-03-09 14:13:49 +000031#include <map>
32#include <string>
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +000033#include <vector>
telsoa01c577f2c2018-08-31 09:22:23 +010034#include <type_traits>
35
Aron Virginas-Tar5cc8e562018-10-23 15:14:46 +010036namespace
37{
38
39inline bool CheckRequestedBackendsAreValid(const std::vector<armnn::BackendId>& backendIds,
40 armnn::Optional<std::string&> invalidBackendIds = armnn::EmptyOptional())
41{
42 if (backendIds.empty())
43 {
44 return false;
45 }
46
47 armnn::BackendIdSet validBackendIds = armnn::BackendRegistryInstance().GetBackendIds();
48
49 bool allValid = true;
50 for (const auto& backendId : backendIds)
51 {
52 if (std::find(validBackendIds.begin(), validBackendIds.end(), backendId) == validBackendIds.end())
53 {
54 allValid = false;
55 if (invalidBackendIds)
56 {
57 if (!invalidBackendIds.value().empty())
58 {
59 invalidBackendIds.value() += ", ";
60 }
61 invalidBackendIds.value() += backendId;
62 }
63 }
64 }
65 return allValid;
66}
67
68} // anonymous namespace
69
telsoa01c577f2c2018-08-31 09:22:23 +010070namespace InferenceModelInternal
71{
72// This needs to go when the armnnCaffeParser, armnnTfParser and armnnTfLiteParser
73// definitions of BindingPointInfo gets consolidated.
74using BindingPointInfo = std::pair<armnn::LayerBindingId, armnn::TensorInfo>;
75
76using QuantizationParams = std::pair<float,int32_t>;
77
78struct Params
79{
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +000080 std::string m_ModelPath;
81 std::vector<std::string> m_InputBindings;
82 std::vector<armnn::TensorShape> m_InputShapes;
83 std::vector<std::string> m_OutputBindings;
Aron Virginas-Tar339bcae2019-01-31 16:44:26 +000084 std::vector<armnn::BackendId> m_ComputeDevices;
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +000085 bool m_EnableProfiling;
86 size_t m_SubgraphId;
87 bool m_IsModelBinary;
88 bool m_VisualizePostOptimizationModel;
89 bool m_EnableFp16TurboMode;
telsoa01c577f2c2018-08-31 09:22:23 +010090
91 Params()
Aron Virginas-Tar339bcae2019-01-31 16:44:26 +000092 : m_ComputeDevices{"CpuRef"}
telsoa01c577f2c2018-08-31 09:22:23 +010093 , m_EnableProfiling(false)
94 , m_SubgraphId(0)
95 , m_IsModelBinary(true)
96 , m_VisualizePostOptimizationModel(false)
97 , m_EnableFp16TurboMode(false)
98 {}
99};
100
101} // namespace InferenceModelInternal
102
103template <typename IParser>
104struct CreateNetworkImpl
105{
106public:
107 using Params = InferenceModelInternal::Params;
108 using BindingPointInfo = InferenceModelInternal::BindingPointInfo;
109
110 static armnn::INetworkPtr Create(const Params& params,
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000111 std::vector<BindingPointInfo>& inputBindings,
112 std::vector<BindingPointInfo>& outputBindings)
telsoa01c577f2c2018-08-31 09:22:23 +0100113 {
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000114 const std::string& modelPath = params.m_ModelPath;
telsoa01c577f2c2018-08-31 09:22:23 +0100115
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000116 // Create a network from a file on disk
117 auto parser(IParser::Create());
telsoa01c577f2c2018-08-31 09:22:23 +0100118
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000119 std::map<std::string, armnn::TensorShape> inputShapes;
120 if (!params.m_InputShapes.empty())
121 {
122 const size_t numInputShapes = params.m_InputShapes.size();
123 const size_t numInputBindings = params.m_InputBindings.size();
124 if (numInputShapes < numInputBindings)
125 {
126 throw armnn::Exception(boost::str(boost::format(
127 "Not every input has its tensor shape specified: expected=%1%, got=%2%")
128 % numInputBindings % numInputShapes));
129 }
telsoa01c577f2c2018-08-31 09:22:23 +0100130
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000131 for (size_t i = 0; i < numInputShapes; i++)
132 {
133 inputShapes[params.m_InputBindings[i]] = params.m_InputShapes[i];
134 }
135 }
telsoa01c577f2c2018-08-31 09:22:23 +0100136
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000137 std::vector<std::string> requestedOutputs = params.m_OutputBindings;
138 armnn::INetworkPtr network{nullptr, [](armnn::INetwork *){}};
139
140 {
141 ARMNN_SCOPED_HEAP_PROFILING("Parsing");
142 // Handle text and binary input differently by calling the corresponding parser function
143 network = (params.m_IsModelBinary ?
144 parser->CreateNetworkFromBinaryFile(modelPath.c_str(), inputShapes, requestedOutputs) :
145 parser->CreateNetworkFromTextFile(modelPath.c_str(), inputShapes, requestedOutputs));
146 }
147
148 for (const std::string& inputLayerName : params.m_InputBindings)
149 {
150 inputBindings.push_back(parser->GetNetworkInputBindingInfo(inputLayerName));
151 }
152
153 for (const std::string& outputLayerName : params.m_OutputBindings)
154 {
155 outputBindings.push_back(parser->GetNetworkOutputBindingInfo(outputLayerName));
156 }
157
158 return network;
telsoa01c577f2c2018-08-31 09:22:23 +0100159 }
160};
161
162#if defined(ARMNN_TF_LITE_PARSER)
163template <>
164struct CreateNetworkImpl<armnnTfLiteParser::ITfLiteParser>
165{
166public:
167 using IParser = armnnTfLiteParser::ITfLiteParser;
168 using Params = InferenceModelInternal::Params;
169 using BindingPointInfo = InferenceModelInternal::BindingPointInfo;
170
171 static armnn::INetworkPtr Create(const Params& params,
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000172 std::vector<BindingPointInfo>& inputBindings,
173 std::vector<BindingPointInfo>& outputBindings)
telsoa01c577f2c2018-08-31 09:22:23 +0100174 {
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000175 const std::string& modelPath = params.m_ModelPath;
telsoa01c577f2c2018-08-31 09:22:23 +0100176
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000177 // Create a network from a file on disk
178 auto parser(IParser::Create());
telsoa01c577f2c2018-08-31 09:22:23 +0100179
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000180 armnn::INetworkPtr network{nullptr, [](armnn::INetwork *){}};
telsoa01c577f2c2018-08-31 09:22:23 +0100181
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000182 {
183 ARMNN_SCOPED_HEAP_PROFILING("Parsing");
184 network = parser->CreateNetworkFromBinaryFile(modelPath.c_str());
185 }
telsoa01c577f2c2018-08-31 09:22:23 +0100186
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000187 for (const std::string& inputLayerName : params.m_InputBindings)
188 {
189 BindingPointInfo inputBinding =
190 parser->GetNetworkInputBindingInfo(params.m_SubgraphId, inputLayerName);
191 inputBindings.push_back(inputBinding);
192 }
193
194 for (const std::string& outputLayerName : params.m_OutputBindings)
195 {
196 BindingPointInfo outputBinding =
197 parser->GetNetworkOutputBindingInfo(params.m_SubgraphId, outputLayerName);
198 outputBindings.push_back(outputBinding);
199 }
200
201 return network;
telsoa01c577f2c2018-08-31 09:22:23 +0100202 }
203};
204#endif
205
206#if defined(ARMNN_ONNX_PARSER)
207template <>
208struct CreateNetworkImpl<armnnOnnxParser::IOnnxParser>
209{
210public:
211 using IParser = armnnOnnxParser::IOnnxParser;
212 using Params = InferenceModelInternal::Params;
213 using BindingPointInfo = InferenceModelInternal::BindingPointInfo;
214
215 static armnn::INetworkPtr Create(const Params& params,
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000216 std::vector<BindingPointInfo>& inputBindings,
217 std::vector<BindingPointInfo>& outputBindings)
telsoa01c577f2c2018-08-31 09:22:23 +0100218 {
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000219 const std::string& modelPath = params.m_ModelPath;
telsoa01c577f2c2018-08-31 09:22:23 +0100220
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000221 // Create a network from a file on disk
222 auto parser(IParser::Create());
telsoa01c577f2c2018-08-31 09:22:23 +0100223
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000224 armnn::INetworkPtr network{nullptr, [](armnn::INetwork *){}};
telsoa01c577f2c2018-08-31 09:22:23 +0100225
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000226 {
227 ARMNN_SCOPED_HEAP_PROFILING("Parsing");
228 network = (params.m_IsModelBinary ?
229 parser->CreateNetworkFromBinaryFile(modelPath.c_str()) :
230 parser->CreateNetworkFromTextFile(modelPath.c_str()));
231 }
telsoa01c577f2c2018-08-31 09:22:23 +0100232
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000233 for (const std::string& inputLayerName : params.m_InputBindings)
234 {
235 BindingPointInfo inputBinding = parser->GetNetworkInputBindingInfo(inputLayerName);
236 inputBindings.push_back(inputBinding);
237 }
238
239 for (const std::string& outputLayerName : params.m_OutputBindings)
240 {
241 BindingPointInfo outputBinding = parser->GetNetworkOutputBindingInfo(outputLayerName);
242 outputBindings.push_back(outputBinding);
243 }
244
245 return network;
telsoa01c577f2c2018-08-31 09:22:23 +0100246 }
247};
248#endif
telsoa014fcda012018-03-09 14:13:49 +0000249
250template<typename TContainer>
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000251inline armnn::InputTensors MakeInputTensors(
252 const std::vector<InferenceModelInternal::BindingPointInfo>& inputBindings,
253 const std::vector<TContainer>& inputDataContainers)
telsoa014fcda012018-03-09 14:13:49 +0000254{
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000255 armnn::InputTensors inputTensors;
256
257 const size_t numInputs = inputBindings.size();
258 if (numInputs != inputDataContainers.size())
telsoa014fcda012018-03-09 14:13:49 +0000259 {
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000260 throw armnn::Exception(boost::str(boost::format("Number of inputs does not match number of "
261 "tensor data containers: %1% != %2%") % numInputs % inputDataContainers.size()));
telsoa014fcda012018-03-09 14:13:49 +0000262 }
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000263
264 for (size_t i = 0; i < numInputs; i++)
265 {
266 const InferenceModelInternal::BindingPointInfo& inputBinding = inputBindings[i];
267 const TContainer& inputData = inputDataContainers[i];
268
269 if (inputData.size() != inputBinding.second.GetNumElements())
270 {
271 throw armnn::Exception("Input tensor has incorrect size");
272 }
273
274 armnn::ConstTensor inputTensor(inputBinding.second, inputData.data());
275 inputTensors.push_back(std::make_pair(inputBinding.first, inputTensor));
276 }
277
278 return inputTensors;
telsoa014fcda012018-03-09 14:13:49 +0000279}
280
281template<typename TContainer>
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000282inline armnn::OutputTensors MakeOutputTensors(
283 const std::vector<InferenceModelInternal::BindingPointInfo>& outputBindings,
284 std::vector<TContainer>& outputDataContainers)
telsoa014fcda012018-03-09 14:13:49 +0000285{
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000286 armnn::OutputTensors outputTensors;
287
288 const size_t numOutputs = outputBindings.size();
289 if (numOutputs != outputDataContainers.size())
telsoa014fcda012018-03-09 14:13:49 +0000290 {
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000291 throw armnn::Exception(boost::str(boost::format("Number of outputs does not match number of "
292 "tensor data containers: %1% != %2%") % numOutputs % outputDataContainers.size()));
telsoa014fcda012018-03-09 14:13:49 +0000293 }
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000294
295 for (size_t i = 0; i < numOutputs; i++)
296 {
297 const InferenceModelInternal::BindingPointInfo& outputBinding = outputBindings[i];
298 TContainer& outputData = outputDataContainers[i];
299
300 if (outputData.size() != outputBinding.second.GetNumElements())
301 {
302 throw armnn::Exception("Output tensor has incorrect size");
303 }
304
305 armnn::Tensor outputTensor(outputBinding.second, outputData.data());
306 outputTensors.push_back(std::make_pair(outputBinding.first, outputTensor));
307 }
308
309 return outputTensors;
telsoa014fcda012018-03-09 14:13:49 +0000310}
311
312template <typename IParser, typename TDataType>
313class InferenceModel
314{
315public:
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000316 using DataType = TDataType;
317 using Params = InferenceModelInternal::Params;
318 using BindingPointInfo = InferenceModelInternal::BindingPointInfo;
319 using QuantizationParams = InferenceModelInternal::QuantizationParams;
320 using TContainer = std::vector<TDataType>;
telsoa014fcda012018-03-09 14:13:49 +0000321
322 struct CommandLineOptions
323 {
324 std::string m_ModelDir;
Aron Virginas-Tar339bcae2019-01-31 16:44:26 +0000325 std::vector<std::string> m_ComputeDevices;
surmeh013537c2c2018-05-18 16:31:43 +0100326 bool m_VisualizePostOptimizationModel;
telsoa01c577f2c2018-08-31 09:22:23 +0100327 bool m_EnableFp16TurboMode;
Aron Virginas-Tar339bcae2019-01-31 16:44:26 +0000328
329 std::vector<armnn::BackendId> GetComputeDevicesAsBackendIds()
330 {
331 std::vector<armnn::BackendId> backendIds;
332 std::copy(m_ComputeDevices.begin(), m_ComputeDevices.end(), std::back_inserter(backendIds));
333 return backendIds;
334 }
telsoa014fcda012018-03-09 14:13:49 +0000335 };
336
337 static void AddCommandLineOptions(boost::program_options::options_description& desc, CommandLineOptions& options)
338 {
339 namespace po = boost::program_options;
340
Aron Virginas-Tar339bcae2019-01-31 16:44:26 +0000341 const std::vector<std::string> defaultComputes = { "CpuAcc", "CpuRef" };
David Beckf0b48452018-10-19 15:20:56 +0100342
Aron Virginas-Tar5cc8e562018-10-23 15:14:46 +0100343 const std::string backendsMessage = "Which device to run layers on by default. Possible choices: "
344 + armnn::BackendRegistryInstance().GetBackendIdsAsString();
345
telsoa014fcda012018-03-09 14:13:49 +0000346 desc.add_options()
347 ("model-dir,m", po::value<std::string>(&options.m_ModelDir)->required(),
telsoa01c577f2c2018-08-31 09:22:23 +0100348 "Path to directory containing model files (.caffemodel/.prototxt/.tflite)")
Aron Virginas-Tar339bcae2019-01-31 16:44:26 +0000349 ("compute,c", po::value<std::vector<std::string>>(&options.m_ComputeDevices)->
350 default_value(defaultComputes, boost::algorithm::join(defaultComputes, ", "))->
351 multitoken(), backendsMessage.c_str())
surmeh013537c2c2018-05-18 16:31:43 +0100352 ("visualize-optimized-model,v",
353 po::value<bool>(&options.m_VisualizePostOptimizationModel)->default_value(false),
354 "Produce a dot file useful for visualizing the graph post optimization."
telsoa01c577f2c2018-08-31 09:22:23 +0100355 "The file will have the same name as the model with the .dot extention.")
356 ("fp16-turbo-mode", po::value<bool>(&options.m_EnableFp16TurboMode)->default_value(false),
357 "If this option is enabled FP32 layers, weights and biases will be converted "
358 "to FP16 where the backend supports it.");
telsoa014fcda012018-03-09 14:13:49 +0000359 }
360
telsoa01c577f2c2018-08-31 09:22:23 +0100361 InferenceModel(const Params& params, const std::shared_ptr<armnn::IRuntime>& runtime = nullptr)
362 : m_EnableProfiling(params.m_EnableProfiling)
telsoa014fcda012018-03-09 14:13:49 +0000363 {
telsoa01c577f2c2018-08-31 09:22:23 +0100364 if (runtime)
telsoa014fcda012018-03-09 14:13:49 +0000365 {
telsoa01c577f2c2018-08-31 09:22:23 +0100366 m_Runtime = runtime;
telsoa014fcda012018-03-09 14:13:49 +0000367 }
telsoa01c577f2c2018-08-31 09:22:23 +0100368 else
telsoa014fcda012018-03-09 14:13:49 +0000369 {
telsoa01c577f2c2018-08-31 09:22:23 +0100370 armnn::IRuntime::CreationOptions options;
Nina Drozd549ae372018-09-10 14:26:44 +0100371 options.m_EnableGpuProfiling = m_EnableProfiling;
telsoa01c577f2c2018-08-31 09:22:23 +0100372 m_Runtime = std::move(armnn::IRuntime::Create(options));
surmeh013537c2c2018-05-18 16:31:43 +0100373 }
telsoa014fcda012018-03-09 14:13:49 +0000374
Aron Virginas-Tar5cc8e562018-10-23 15:14:46 +0100375 std::string invalidBackends;
Aron Virginas-Tar339bcae2019-01-31 16:44:26 +0000376 if (!CheckRequestedBackendsAreValid(params.m_ComputeDevices, armnn::Optional<std::string&>(invalidBackends)))
Aron Virginas-Tar5cc8e562018-10-23 15:14:46 +0100377 {
378 throw armnn::Exception("Some backend IDs are invalid: " + invalidBackends);
379 }
380
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000381 armnn::INetworkPtr network =
382 CreateNetworkImpl<IParser>::Create(params, m_InputBindings, m_OutputBindings);
telsoa014fcda012018-03-09 14:13:49 +0000383
surmeh013537c2c2018-05-18 16:31:43 +0100384 armnn::IOptimizedNetworkPtr optNet{nullptr, [](armnn::IOptimizedNetwork *){}};
385 {
386 ARMNN_SCOPED_HEAP_PROFILING("Optimizing");
telsoa01c577f2c2018-08-31 09:22:23 +0100387
388 armnn::OptimizerOptions options;
389 options.m_ReduceFp32ToFp16 = params.m_EnableFp16TurboMode;
390
Aron Virginas-Tar339bcae2019-01-31 16:44:26 +0000391 optNet = armnn::Optimize(*network, params.m_ComputeDevices, m_Runtime->GetDeviceSpec(), options);
telsoa01c577f2c2018-08-31 09:22:23 +0100392 if (!optNet)
393 {
394 throw armnn::Exception("Optimize returned nullptr");
395 }
surmeh013537c2c2018-05-18 16:31:43 +0100396 }
telsoa014fcda012018-03-09 14:13:49 +0000397
surmeh013537c2c2018-05-18 16:31:43 +0100398 if (params.m_VisualizePostOptimizationModel)
399 {
400 boost::filesystem::path filename = params.m_ModelPath;
401 filename.replace_extension("dot");
402 std::fstream file(filename.c_str(),file.out);
403 optNet->SerializeToDot(file);
404 }
405
406 armnn::Status ret;
407 {
408 ARMNN_SCOPED_HEAP_PROFILING("LoadNetwork");
409 ret = m_Runtime->LoadNetwork(m_NetworkIdentifier, std::move(optNet));
410 }
411
telsoa014fcda012018-03-09 14:13:49 +0000412 if (ret == armnn::Status::Failure)
413 {
414 throw armnn::Exception("IRuntime::LoadNetwork failed");
415 }
416 }
417
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000418 void CheckInputIndexIsValid(unsigned int inputIndex) const
telsoa014fcda012018-03-09 14:13:49 +0000419 {
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000420 if (m_InputBindings.size() < inputIndex + 1)
421 {
422 throw armnn::Exception(boost::str(boost::format("Input index out of range: %1%") % inputIndex));
423 }
telsoa014fcda012018-03-09 14:13:49 +0000424 }
425
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000426 void CheckOutputIndexIsValid(unsigned int outputIndex) const
telsoa014fcda012018-03-09 14:13:49 +0000427 {
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000428 if (m_OutputBindings.size() < outputIndex + 1)
429 {
430 throw armnn::Exception(boost::str(boost::format("Output index out of range: %1%") % outputIndex));
431 }
432 }
433
434 unsigned int GetOutputSize(unsigned int outputIndex = 0u) const
435 {
436 CheckOutputIndexIsValid(outputIndex);
437 return m_OutputBindings[outputIndex].second.GetNumElements();
438 }
439
440 void Run(const std::vector<TContainer>& inputContainers, std::vector<TContainer>& outputContainers)
441 {
442 for (unsigned int i = 0; i < outputContainers.size(); i++)
443 {
444 const unsigned int expectedOutputDataSize = GetOutputSize(i);
445 const unsigned int actualOutputDataSize = boost::numeric_cast<unsigned int>(outputContainers[i].size());
446 if (actualOutputDataSize < expectedOutputDataSize)
447 {
448 unsigned int outputIndex = boost::numeric_cast<unsigned int>(i);
449 throw armnn::Exception(boost::str(boost::format("Not enough data for output #%1%: expected "
450 "%2% elements, got %3%") % outputIndex % expectedOutputDataSize % actualOutputDataSize));
451 }
452 }
telsoa01c577f2c2018-08-31 09:22:23 +0100453
454 std::shared_ptr<armnn::IProfiler> profiler = m_Runtime->GetProfiler(m_NetworkIdentifier);
455 if (profiler)
456 {
457 profiler->EnableProfiling(m_EnableProfiling);
458 }
459
telsoa014fcda012018-03-09 14:13:49 +0000460 armnn::Status ret = m_Runtime->EnqueueWorkload(m_NetworkIdentifier,
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000461 MakeInputTensors(inputContainers),
462 MakeOutputTensors(outputContainers));
Sadik Armagan2b7a1582018-09-05 16:33:58 +0100463
464 // if profiling is enabled print out the results
465 if (profiler && profiler->IsProfilingEnabled())
466 {
467 profiler->Print(std::cout);
468 }
469
telsoa014fcda012018-03-09 14:13:49 +0000470 if (ret == armnn::Status::Failure)
471 {
472 throw armnn::Exception("IRuntime::EnqueueWorkload failed");
473 }
474 }
475
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000476 const BindingPointInfo& GetInputBindingInfo(unsigned int inputIndex = 0u) const
telsoa01c577f2c2018-08-31 09:22:23 +0100477 {
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000478 CheckInputIndexIsValid(inputIndex);
479 return m_InputBindings[inputIndex];
telsoa01c577f2c2018-08-31 09:22:23 +0100480 }
481
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000482 const std::vector<BindingPointInfo>& GetInputBindingInfos() const
telsoa01c577f2c2018-08-31 09:22:23 +0100483 {
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000484 return m_InputBindings;
telsoa01c577f2c2018-08-31 09:22:23 +0100485 }
486
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000487 const BindingPointInfo& GetOutputBindingInfo(unsigned int outputIndex = 0u) const
telsoa01c577f2c2018-08-31 09:22:23 +0100488 {
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000489 CheckOutputIndexIsValid(outputIndex);
490 return m_OutputBindings[outputIndex];
491 }
492
493 const std::vector<BindingPointInfo>& GetOutputBindingInfos() const
494 {
495 return m_OutputBindings;
496 }
497
498 QuantizationParams GetQuantizationParams(unsigned int outputIndex = 0u) const
499 {
500 CheckOutputIndexIsValid(outputIndex);
501 return std::make_pair(m_OutputBindings[outputIndex].second.GetQuantizationScale(),
502 m_OutputBindings[outputIndex].second.GetQuantizationOffset());
503 }
504
505 std::vector<QuantizationParams> GetAllQuantizationParams() const
506 {
507 std::vector<QuantizationParams> quantizationParams;
508 for (unsigned int i = 0u; i < m_OutputBindings.size(); i++)
509 {
510 quantizationParams.push_back(GetQuantizationParams(i));
511 }
512 return quantizationParams;
telsoa01c577f2c2018-08-31 09:22:23 +0100513 }
514
telsoa014fcda012018-03-09 14:13:49 +0000515private:
telsoa01c577f2c2018-08-31 09:22:23 +0100516 armnn::NetworkId m_NetworkIdentifier;
517 std::shared_ptr<armnn::IRuntime> m_Runtime;
518
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000519 std::vector<InferenceModelInternal::BindingPointInfo> m_InputBindings;
520 std::vector<InferenceModelInternal::BindingPointInfo> m_OutputBindings;
telsoa01c577f2c2018-08-31 09:22:23 +0100521 bool m_EnableProfiling;
522
telsoa014fcda012018-03-09 14:13:49 +0000523 template<typename TContainer>
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000524 armnn::InputTensors MakeInputTensors(const std::vector<TContainer>& inputDataContainers)
telsoa014fcda012018-03-09 14:13:49 +0000525 {
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000526 return ::MakeInputTensors(m_InputBindings, inputDataContainers);
telsoa014fcda012018-03-09 14:13:49 +0000527 }
528
529 template<typename TContainer>
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000530 armnn::OutputTensors MakeOutputTensors(std::vector<TContainer>& outputDataContainers)
telsoa014fcda012018-03-09 14:13:49 +0000531 {
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000532 return ::MakeOutputTensors(m_OutputBindings, outputDataContainers);
telsoa014fcda012018-03-09 14:13:49 +0000533 }
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000534};