blob: e168923048cb889cb567a6883e2dadac980071f1 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
David Beckf0b48452018-10-19 15:20:56 +01006#include <armnn/ArmNN.hpp>
telsoa01c577f2c2018-08-31 09:22:23 +01007
Aron Virginas-Tar64e4ccb2019-02-12 11:27:53 +00008#if defined(ARMNN_SERIALIZER)
Derek Lamberti0028d1b2019-02-20 13:57:42 +00009#include "armnnDeserializer/IDeserializer.hpp"
Aron Virginas-Tar64e4ccb2019-02-12 11:27:53 +000010#endif
telsoa01c577f2c2018-08-31 09:22:23 +010011#if defined(ARMNN_TF_LITE_PARSER)
David Beckf0b48452018-10-19 15:20:56 +010012#include <armnnTfLiteParser/ITfLiteParser.hpp>
telsoa01c577f2c2018-08-31 09:22:23 +010013#endif
telsoa01c577f2c2018-08-31 09:22:23 +010014#if defined(ARMNN_ONNX_PARSER)
David Beckf0b48452018-10-19 15:20:56 +010015#include <armnnOnnxParser/IOnnxParser.hpp>
telsoa01c577f2c2018-08-31 09:22:23 +010016#endif
telsoa014fcda012018-03-09 14:13:49 +000017
Aron Virginas-Tar64e4ccb2019-02-12 11:27:53 +000018#include <HeapProfiling.hpp>
19
David Beck1b61be52018-11-08 09:19:14 +000020#include <backendsCommon/BackendRegistry.hpp>
Aron Virginas-Tar5cc8e562018-10-23 15:14:46 +010021
Aron Virginas-Tar339bcae2019-01-31 16:44:26 +000022#include <boost/algorithm/string/join.hpp>
surmeh013537c2c2018-05-18 16:31:43 +010023#include <boost/exception/exception.hpp>
24#include <boost/exception/diagnostic_information.hpp>
telsoa014fcda012018-03-09 14:13:49 +000025#include <boost/log/trivial.hpp>
26#include <boost/format.hpp>
27#include <boost/program_options.hpp>
surmeh013537c2c2018-05-18 16:31:43 +010028#include <boost/filesystem.hpp>
David Beckf0b48452018-10-19 15:20:56 +010029#include <boost/lexical_cast.hpp>
Ferran Balaguerc602f292019-02-08 17:09:55 +000030#include <boost/variant.hpp>
telsoa014fcda012018-03-09 14:13:49 +000031
Aron Virginas-Tar339bcae2019-01-31 16:44:26 +000032#include <algorithm>
James Conroy7b4886f2019-04-11 10:23:58 +010033#include <chrono>
Aron Virginas-Tar339bcae2019-01-31 16:44:26 +000034#include <iterator>
Aron Virginas-Tar5cc8e562018-10-23 15:14:46 +010035#include <fstream>
telsoa014fcda012018-03-09 14:13:49 +000036#include <map>
37#include <string>
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +000038#include <vector>
telsoa01c577f2c2018-08-31 09:22:23 +010039#include <type_traits>
40
Aron Virginas-Tar5cc8e562018-10-23 15:14:46 +010041namespace
42{
43
44inline bool CheckRequestedBackendsAreValid(const std::vector<armnn::BackendId>& backendIds,
45 armnn::Optional<std::string&> invalidBackendIds = armnn::EmptyOptional())
46{
47 if (backendIds.empty())
48 {
49 return false;
50 }
51
52 armnn::BackendIdSet validBackendIds = armnn::BackendRegistryInstance().GetBackendIds();
53
54 bool allValid = true;
55 for (const auto& backendId : backendIds)
56 {
57 if (std::find(validBackendIds.begin(), validBackendIds.end(), backendId) == validBackendIds.end())
58 {
59 allValid = false;
60 if (invalidBackendIds)
61 {
62 if (!invalidBackendIds.value().empty())
63 {
64 invalidBackendIds.value() += ", ";
65 }
66 invalidBackendIds.value() += backendId;
67 }
68 }
69 }
70 return allValid;
71}
72
73} // anonymous namespace
74
telsoa01c577f2c2018-08-31 09:22:23 +010075namespace InferenceModelInternal
76{
77// This needs to go when the armnnCaffeParser, armnnTfParser and armnnTfLiteParser
78// definitions of BindingPointInfo gets consolidated.
79using BindingPointInfo = std::pair<armnn::LayerBindingId, armnn::TensorInfo>;
80
81using QuantizationParams = std::pair<float,int32_t>;
82
83struct Params
84{
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +000085 std::string m_ModelPath;
86 std::vector<std::string> m_InputBindings;
87 std::vector<armnn::TensorShape> m_InputShapes;
88 std::vector<std::string> m_OutputBindings;
Aron Virginas-Tar339bcae2019-01-31 16:44:26 +000089 std::vector<armnn::BackendId> m_ComputeDevices;
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +000090 bool m_EnableProfiling;
91 size_t m_SubgraphId;
92 bool m_IsModelBinary;
93 bool m_VisualizePostOptimizationModel;
94 bool m_EnableFp16TurboMode;
telsoa01c577f2c2018-08-31 09:22:23 +010095
96 Params()
Aron Virginas-Tar339bcae2019-01-31 16:44:26 +000097 : m_ComputeDevices{"CpuRef"}
telsoa01c577f2c2018-08-31 09:22:23 +010098 , m_EnableProfiling(false)
99 , m_SubgraphId(0)
100 , m_IsModelBinary(true)
101 , m_VisualizePostOptimizationModel(false)
102 , m_EnableFp16TurboMode(false)
103 {}
104};
105
106} // namespace InferenceModelInternal
107
108template <typename IParser>
109struct CreateNetworkImpl
110{
111public:
112 using Params = InferenceModelInternal::Params;
113 using BindingPointInfo = InferenceModelInternal::BindingPointInfo;
114
115 static armnn::INetworkPtr Create(const Params& params,
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000116 std::vector<BindingPointInfo>& inputBindings,
117 std::vector<BindingPointInfo>& outputBindings)
telsoa01c577f2c2018-08-31 09:22:23 +0100118 {
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000119 const std::string& modelPath = params.m_ModelPath;
telsoa01c577f2c2018-08-31 09:22:23 +0100120
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000121 // Create a network from a file on disk
122 auto parser(IParser::Create());
telsoa01c577f2c2018-08-31 09:22:23 +0100123
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000124 std::map<std::string, armnn::TensorShape> inputShapes;
125 if (!params.m_InputShapes.empty())
126 {
127 const size_t numInputShapes = params.m_InputShapes.size();
128 const size_t numInputBindings = params.m_InputBindings.size();
129 if (numInputShapes < numInputBindings)
130 {
131 throw armnn::Exception(boost::str(boost::format(
132 "Not every input has its tensor shape specified: expected=%1%, got=%2%")
133 % numInputBindings % numInputShapes));
134 }
telsoa01c577f2c2018-08-31 09:22:23 +0100135
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000136 for (size_t i = 0; i < numInputShapes; i++)
137 {
138 inputShapes[params.m_InputBindings[i]] = params.m_InputShapes[i];
139 }
140 }
telsoa01c577f2c2018-08-31 09:22:23 +0100141
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000142 std::vector<std::string> requestedOutputs = params.m_OutputBindings;
143 armnn::INetworkPtr network{nullptr, [](armnn::INetwork *){}};
144
145 {
146 ARMNN_SCOPED_HEAP_PROFILING("Parsing");
147 // Handle text and binary input differently by calling the corresponding parser function
148 network = (params.m_IsModelBinary ?
149 parser->CreateNetworkFromBinaryFile(modelPath.c_str(), inputShapes, requestedOutputs) :
150 parser->CreateNetworkFromTextFile(modelPath.c_str(), inputShapes, requestedOutputs));
151 }
152
153 for (const std::string& inputLayerName : params.m_InputBindings)
154 {
155 inputBindings.push_back(parser->GetNetworkInputBindingInfo(inputLayerName));
156 }
157
158 for (const std::string& outputLayerName : params.m_OutputBindings)
159 {
160 outputBindings.push_back(parser->GetNetworkOutputBindingInfo(outputLayerName));
161 }
162
163 return network;
telsoa01c577f2c2018-08-31 09:22:23 +0100164 }
165};
166
Aron Virginas-Tar64e4ccb2019-02-12 11:27:53 +0000167#if defined(ARMNN_SERIALIZER)
168template <>
Derek Lamberti0028d1b2019-02-20 13:57:42 +0000169struct CreateNetworkImpl<armnnDeserializer::IDeserializer>
Aron Virginas-Tar64e4ccb2019-02-12 11:27:53 +0000170{
171public:
Derek Lamberti0028d1b2019-02-20 13:57:42 +0000172 using IParser = armnnDeserializer::IDeserializer;
Aron Virginas-Tar64e4ccb2019-02-12 11:27:53 +0000173 using Params = InferenceModelInternal::Params;
174 using BindingPointInfo = InferenceModelInternal::BindingPointInfo;
175
176 static armnn::INetworkPtr Create(const Params& params,
177 std::vector<BindingPointInfo>& inputBindings,
178 std::vector<BindingPointInfo>& outputBindings)
179 {
180 auto parser(IParser::Create());
181 BOOST_ASSERT(parser);
182
183 armnn::INetworkPtr network{nullptr, [](armnn::INetwork *){}};
184
185 {
186 ARMNN_SCOPED_HEAP_PROFILING("Parsing");
Derek Lamberti2b183fb2019-02-18 16:36:57 +0000187
188 boost::system::error_code errorCode;
189 boost::filesystem::path pathToFile(params.m_ModelPath);
190 if (!boost::filesystem::exists(pathToFile, errorCode))
191 {
192 throw armnn::FileNotFoundException(boost::str(
193 boost::format("Cannot find the file (%1%) errorCode: %2% %3%") %
194 params.m_ModelPath %
195 errorCode %
196 CHECK_LOCATION().AsString()));
197 }
198 std::ifstream file(params.m_ModelPath, std::ios::binary);
199
200 network = parser->CreateNetworkFromBinary(file);
Aron Virginas-Tar64e4ccb2019-02-12 11:27:53 +0000201 }
202
203 unsigned int subGraphId = boost::numeric_cast<unsigned int>(params.m_SubgraphId);
204
205 for (const std::string& inputLayerName : params.m_InputBindings)
206 {
Derek Lamberti8ddae332019-02-21 16:29:43 +0000207 armnnDeserializer::BindingPointInfo inputBinding =
208 parser->GetNetworkInputBindingInfo(subGraphId, inputLayerName);
209 inputBindings.push_back(std::make_pair(inputBinding.m_BindingId, inputBinding.m_TensorInfo));
Aron Virginas-Tar64e4ccb2019-02-12 11:27:53 +0000210 }
211
212 for (const std::string& outputLayerName : params.m_OutputBindings)
213 {
Derek Lamberti8ddae332019-02-21 16:29:43 +0000214 armnnDeserializer::BindingPointInfo outputBinding =
215 parser->GetNetworkOutputBindingInfo(subGraphId, outputLayerName);
216 outputBindings.push_back(std::make_pair(outputBinding.m_BindingId, outputBinding.m_TensorInfo));
Aron Virginas-Tar64e4ccb2019-02-12 11:27:53 +0000217 }
218
219 return network;
220 }
221};
222#endif
223
telsoa01c577f2c2018-08-31 09:22:23 +0100224#if defined(ARMNN_TF_LITE_PARSER)
225template <>
226struct CreateNetworkImpl<armnnTfLiteParser::ITfLiteParser>
227{
228public:
229 using IParser = armnnTfLiteParser::ITfLiteParser;
230 using Params = InferenceModelInternal::Params;
231 using BindingPointInfo = InferenceModelInternal::BindingPointInfo;
232
233 static armnn::INetworkPtr Create(const Params& params,
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000234 std::vector<BindingPointInfo>& inputBindings,
235 std::vector<BindingPointInfo>& outputBindings)
telsoa01c577f2c2018-08-31 09:22:23 +0100236 {
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000237 const std::string& modelPath = params.m_ModelPath;
telsoa01c577f2c2018-08-31 09:22:23 +0100238
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000239 // Create a network from a file on disk
240 auto parser(IParser::Create());
telsoa01c577f2c2018-08-31 09:22:23 +0100241
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000242 armnn::INetworkPtr network{nullptr, [](armnn::INetwork *){}};
telsoa01c577f2c2018-08-31 09:22:23 +0100243
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000244 {
245 ARMNN_SCOPED_HEAP_PROFILING("Parsing");
246 network = parser->CreateNetworkFromBinaryFile(modelPath.c_str());
247 }
telsoa01c577f2c2018-08-31 09:22:23 +0100248
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000249 for (const std::string& inputLayerName : params.m_InputBindings)
250 {
251 BindingPointInfo inputBinding =
252 parser->GetNetworkInputBindingInfo(params.m_SubgraphId, inputLayerName);
253 inputBindings.push_back(inputBinding);
254 }
255
256 for (const std::string& outputLayerName : params.m_OutputBindings)
257 {
258 BindingPointInfo outputBinding =
259 parser->GetNetworkOutputBindingInfo(params.m_SubgraphId, outputLayerName);
260 outputBindings.push_back(outputBinding);
261 }
262
263 return network;
telsoa01c577f2c2018-08-31 09:22:23 +0100264 }
265};
266#endif
267
268#if defined(ARMNN_ONNX_PARSER)
269template <>
270struct CreateNetworkImpl<armnnOnnxParser::IOnnxParser>
271{
272public:
273 using IParser = armnnOnnxParser::IOnnxParser;
274 using Params = InferenceModelInternal::Params;
275 using BindingPointInfo = InferenceModelInternal::BindingPointInfo;
276
277 static armnn::INetworkPtr Create(const Params& params,
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000278 std::vector<BindingPointInfo>& inputBindings,
279 std::vector<BindingPointInfo>& outputBindings)
telsoa01c577f2c2018-08-31 09:22:23 +0100280 {
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000281 const std::string& modelPath = params.m_ModelPath;
telsoa01c577f2c2018-08-31 09:22:23 +0100282
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000283 // Create a network from a file on disk
284 auto parser(IParser::Create());
telsoa01c577f2c2018-08-31 09:22:23 +0100285
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000286 armnn::INetworkPtr network{nullptr, [](armnn::INetwork *){}};
telsoa01c577f2c2018-08-31 09:22:23 +0100287
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000288 {
289 ARMNN_SCOPED_HEAP_PROFILING("Parsing");
290 network = (params.m_IsModelBinary ?
291 parser->CreateNetworkFromBinaryFile(modelPath.c_str()) :
292 parser->CreateNetworkFromTextFile(modelPath.c_str()));
293 }
telsoa01c577f2c2018-08-31 09:22:23 +0100294
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000295 for (const std::string& inputLayerName : params.m_InputBindings)
296 {
297 BindingPointInfo inputBinding = parser->GetNetworkInputBindingInfo(inputLayerName);
298 inputBindings.push_back(inputBinding);
299 }
300
301 for (const std::string& outputLayerName : params.m_OutputBindings)
302 {
303 BindingPointInfo outputBinding = parser->GetNetworkOutputBindingInfo(outputLayerName);
304 outputBindings.push_back(outputBinding);
305 }
306
307 return network;
telsoa01c577f2c2018-08-31 09:22:23 +0100308 }
309};
310#endif
telsoa014fcda012018-03-09 14:13:49 +0000311
312template<typename TContainer>
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000313inline armnn::InputTensors MakeInputTensors(
314 const std::vector<InferenceModelInternal::BindingPointInfo>& inputBindings,
315 const std::vector<TContainer>& inputDataContainers)
telsoa014fcda012018-03-09 14:13:49 +0000316{
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000317 armnn::InputTensors inputTensors;
318
319 const size_t numInputs = inputBindings.size();
320 if (numInputs != inputDataContainers.size())
telsoa014fcda012018-03-09 14:13:49 +0000321 {
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000322 throw armnn::Exception(boost::str(boost::format("Number of inputs does not match number of "
323 "tensor data containers: %1% != %2%") % numInputs % inputDataContainers.size()));
telsoa014fcda012018-03-09 14:13:49 +0000324 }
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000325
326 for (size_t i = 0; i < numInputs; i++)
327 {
328 const InferenceModelInternal::BindingPointInfo& inputBinding = inputBindings[i];
329 const TContainer& inputData = inputDataContainers[i];
330
Ferran Balaguerc602f292019-02-08 17:09:55 +0000331 boost::apply_visitor([&](auto&& value)
332 {
333 if (value.size() != inputBinding.second.GetNumElements())
334 {
335 throw armnn::Exception("Input tensor has incorrect size");
336 }
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000337
Ferran Balaguerc602f292019-02-08 17:09:55 +0000338 armnn::ConstTensor inputTensor(inputBinding.second, value.data());
339 inputTensors.push_back(std::make_pair(inputBinding.first, inputTensor));
340 },
341 inputData);
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000342 }
343
344 return inputTensors;
telsoa014fcda012018-03-09 14:13:49 +0000345}
346
347template<typename TContainer>
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000348inline armnn::OutputTensors MakeOutputTensors(
349 const std::vector<InferenceModelInternal::BindingPointInfo>& outputBindings,
350 std::vector<TContainer>& outputDataContainers)
telsoa014fcda012018-03-09 14:13:49 +0000351{
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000352 armnn::OutputTensors outputTensors;
353
354 const size_t numOutputs = outputBindings.size();
355 if (numOutputs != outputDataContainers.size())
telsoa014fcda012018-03-09 14:13:49 +0000356 {
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000357 throw armnn::Exception(boost::str(boost::format("Number of outputs does not match number of "
358 "tensor data containers: %1% != %2%") % numOutputs % outputDataContainers.size()));
telsoa014fcda012018-03-09 14:13:49 +0000359 }
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000360
361 for (size_t i = 0; i < numOutputs; i++)
362 {
363 const InferenceModelInternal::BindingPointInfo& outputBinding = outputBindings[i];
364 TContainer& outputData = outputDataContainers[i];
365
Ferran Balaguerc602f292019-02-08 17:09:55 +0000366 boost::apply_visitor([&](auto&& value)
367 {
368 if (value.size() != outputBinding.second.GetNumElements())
369 {
370 throw armnn::Exception("Output tensor has incorrect size");
371 }
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000372
Ferran Balaguerc602f292019-02-08 17:09:55 +0000373 armnn::Tensor outputTensor(outputBinding.second, value.data());
374 outputTensors.push_back(std::make_pair(outputBinding.first, outputTensor));
375 },
376 outputData);
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000377 }
378
379 return outputTensors;
telsoa014fcda012018-03-09 14:13:49 +0000380}
381
382template <typename IParser, typename TDataType>
383class InferenceModel
384{
385public:
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000386 using DataType = TDataType;
387 using Params = InferenceModelInternal::Params;
388 using BindingPointInfo = InferenceModelInternal::BindingPointInfo;
389 using QuantizationParams = InferenceModelInternal::QuantizationParams;
Ferran Balaguerc602f292019-02-08 17:09:55 +0000390 using TContainer = boost::variant<std::vector<float>, std::vector<int>, std::vector<unsigned char>>;
telsoa014fcda012018-03-09 14:13:49 +0000391
392 struct CommandLineOptions
393 {
394 std::string m_ModelDir;
Aron Virginas-Tar339bcae2019-01-31 16:44:26 +0000395 std::vector<std::string> m_ComputeDevices;
surmeh013537c2c2018-05-18 16:31:43 +0100396 bool m_VisualizePostOptimizationModel;
telsoa01c577f2c2018-08-31 09:22:23 +0100397 bool m_EnableFp16TurboMode;
Aron Virginas-Tar339bcae2019-01-31 16:44:26 +0000398
399 std::vector<armnn::BackendId> GetComputeDevicesAsBackendIds()
400 {
401 std::vector<armnn::BackendId> backendIds;
402 std::copy(m_ComputeDevices.begin(), m_ComputeDevices.end(), std::back_inserter(backendIds));
403 return backendIds;
404 }
telsoa014fcda012018-03-09 14:13:49 +0000405 };
406
407 static void AddCommandLineOptions(boost::program_options::options_description& desc, CommandLineOptions& options)
408 {
409 namespace po = boost::program_options;
410
Aron Virginas-Tar339bcae2019-01-31 16:44:26 +0000411 const std::vector<std::string> defaultComputes = { "CpuAcc", "CpuRef" };
David Beckf0b48452018-10-19 15:20:56 +0100412
Aron Virginas-Tar5cc8e562018-10-23 15:14:46 +0100413 const std::string backendsMessage = "Which device to run layers on by default. Possible choices: "
414 + armnn::BackendRegistryInstance().GetBackendIdsAsString();
415
telsoa014fcda012018-03-09 14:13:49 +0000416 desc.add_options()
417 ("model-dir,m", po::value<std::string>(&options.m_ModelDir)->required(),
telsoa01c577f2c2018-08-31 09:22:23 +0100418 "Path to directory containing model files (.caffemodel/.prototxt/.tflite)")
Aron Virginas-Tar339bcae2019-01-31 16:44:26 +0000419 ("compute,c", po::value<std::vector<std::string>>(&options.m_ComputeDevices)->
420 default_value(defaultComputes, boost::algorithm::join(defaultComputes, ", "))->
421 multitoken(), backendsMessage.c_str())
surmeh013537c2c2018-05-18 16:31:43 +0100422 ("visualize-optimized-model,v",
423 po::value<bool>(&options.m_VisualizePostOptimizationModel)->default_value(false),
424 "Produce a dot file useful for visualizing the graph post optimization."
telsoa01c577f2c2018-08-31 09:22:23 +0100425 "The file will have the same name as the model with the .dot extention.")
426 ("fp16-turbo-mode", po::value<bool>(&options.m_EnableFp16TurboMode)->default_value(false),
427 "If this option is enabled FP32 layers, weights and biases will be converted "
428 "to FP16 where the backend supports it.");
telsoa014fcda012018-03-09 14:13:49 +0000429 }
430
telsoa01c577f2c2018-08-31 09:22:23 +0100431 InferenceModel(const Params& params, const std::shared_ptr<armnn::IRuntime>& runtime = nullptr)
432 : m_EnableProfiling(params.m_EnableProfiling)
telsoa014fcda012018-03-09 14:13:49 +0000433 {
telsoa01c577f2c2018-08-31 09:22:23 +0100434 if (runtime)
telsoa014fcda012018-03-09 14:13:49 +0000435 {
telsoa01c577f2c2018-08-31 09:22:23 +0100436 m_Runtime = runtime;
telsoa014fcda012018-03-09 14:13:49 +0000437 }
telsoa01c577f2c2018-08-31 09:22:23 +0100438 else
telsoa014fcda012018-03-09 14:13:49 +0000439 {
telsoa01c577f2c2018-08-31 09:22:23 +0100440 armnn::IRuntime::CreationOptions options;
Nina Drozd549ae372018-09-10 14:26:44 +0100441 options.m_EnableGpuProfiling = m_EnableProfiling;
telsoa01c577f2c2018-08-31 09:22:23 +0100442 m_Runtime = std::move(armnn::IRuntime::Create(options));
surmeh013537c2c2018-05-18 16:31:43 +0100443 }
telsoa014fcda012018-03-09 14:13:49 +0000444
Aron Virginas-Tar5cc8e562018-10-23 15:14:46 +0100445 std::string invalidBackends;
Aron Virginas-Tar339bcae2019-01-31 16:44:26 +0000446 if (!CheckRequestedBackendsAreValid(params.m_ComputeDevices, armnn::Optional<std::string&>(invalidBackends)))
Aron Virginas-Tar5cc8e562018-10-23 15:14:46 +0100447 {
448 throw armnn::Exception("Some backend IDs are invalid: " + invalidBackends);
449 }
450
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000451 armnn::INetworkPtr network =
452 CreateNetworkImpl<IParser>::Create(params, m_InputBindings, m_OutputBindings);
telsoa014fcda012018-03-09 14:13:49 +0000453
surmeh013537c2c2018-05-18 16:31:43 +0100454 armnn::IOptimizedNetworkPtr optNet{nullptr, [](armnn::IOptimizedNetwork *){}};
455 {
456 ARMNN_SCOPED_HEAP_PROFILING("Optimizing");
telsoa01c577f2c2018-08-31 09:22:23 +0100457
458 armnn::OptimizerOptions options;
459 options.m_ReduceFp32ToFp16 = params.m_EnableFp16TurboMode;
460
Aron Virginas-Tar339bcae2019-01-31 16:44:26 +0000461 optNet = armnn::Optimize(*network, params.m_ComputeDevices, m_Runtime->GetDeviceSpec(), options);
telsoa01c577f2c2018-08-31 09:22:23 +0100462 if (!optNet)
463 {
464 throw armnn::Exception("Optimize returned nullptr");
465 }
surmeh013537c2c2018-05-18 16:31:43 +0100466 }
telsoa014fcda012018-03-09 14:13:49 +0000467
surmeh013537c2c2018-05-18 16:31:43 +0100468 if (params.m_VisualizePostOptimizationModel)
469 {
470 boost::filesystem::path filename = params.m_ModelPath;
471 filename.replace_extension("dot");
472 std::fstream file(filename.c_str(),file.out);
473 optNet->SerializeToDot(file);
474 }
475
476 armnn::Status ret;
477 {
478 ARMNN_SCOPED_HEAP_PROFILING("LoadNetwork");
479 ret = m_Runtime->LoadNetwork(m_NetworkIdentifier, std::move(optNet));
480 }
481
telsoa014fcda012018-03-09 14:13:49 +0000482 if (ret == armnn::Status::Failure)
483 {
484 throw armnn::Exception("IRuntime::LoadNetwork failed");
485 }
486 }
487
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000488 void CheckInputIndexIsValid(unsigned int inputIndex) const
telsoa014fcda012018-03-09 14:13:49 +0000489 {
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000490 if (m_InputBindings.size() < inputIndex + 1)
491 {
492 throw armnn::Exception(boost::str(boost::format("Input index out of range: %1%") % inputIndex));
493 }
telsoa014fcda012018-03-09 14:13:49 +0000494 }
495
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000496 void CheckOutputIndexIsValid(unsigned int outputIndex) const
telsoa014fcda012018-03-09 14:13:49 +0000497 {
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000498 if (m_OutputBindings.size() < outputIndex + 1)
499 {
500 throw armnn::Exception(boost::str(boost::format("Output index out of range: %1%") % outputIndex));
501 }
502 }
503
504 unsigned int GetOutputSize(unsigned int outputIndex = 0u) const
505 {
506 CheckOutputIndexIsValid(outputIndex);
507 return m_OutputBindings[outputIndex].second.GetNumElements();
508 }
509
James Conroy7b4886f2019-04-11 10:23:58 +0100510 std::chrono::duration<double, std::milli> Run(
511 const std::vector<TContainer>& inputContainers,
512 std::vector<TContainer>& outputContainers)
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000513 {
Ferran Balaguerc602f292019-02-08 17:09:55 +0000514 for (unsigned int i = 0; i < outputContainers.size(); ++i)
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000515 {
516 const unsigned int expectedOutputDataSize = GetOutputSize(i);
Ferran Balaguerc602f292019-02-08 17:09:55 +0000517
518 boost::apply_visitor([expectedOutputDataSize, i](auto&& value)
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000519 {
Ferran Balaguerc602f292019-02-08 17:09:55 +0000520 const unsigned int actualOutputDataSize = boost::numeric_cast<unsigned int>(value.size());
521 if (actualOutputDataSize < expectedOutputDataSize)
522 {
523 unsigned int outputIndex = boost::numeric_cast<unsigned int>(i);
524 throw armnn::Exception(
525 boost::str(boost::format("Not enough data for output #%1%: expected "
526 "%2% elements, got %3%") % outputIndex % expectedOutputDataSize % actualOutputDataSize));
527 }
528 },
529 outputContainers[i]);
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000530 }
telsoa01c577f2c2018-08-31 09:22:23 +0100531
532 std::shared_ptr<armnn::IProfiler> profiler = m_Runtime->GetProfiler(m_NetworkIdentifier);
533 if (profiler)
534 {
535 profiler->EnableProfiling(m_EnableProfiling);
536 }
537
James Conroy7b4886f2019-04-11 10:23:58 +0100538 // Start timer to record inference time in EnqueueWorkload (in milliseconds)
539 const auto start_time = GetCurrentTime();
540
telsoa014fcda012018-03-09 14:13:49 +0000541 armnn::Status ret = m_Runtime->EnqueueWorkload(m_NetworkIdentifier,
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000542 MakeInputTensors(inputContainers),
543 MakeOutputTensors(outputContainers));
Sadik Armagan2b7a1582018-09-05 16:33:58 +0100544
James Conroy7b4886f2019-04-11 10:23:58 +0100545 const auto end_time = GetCurrentTime();
546
Sadik Armagan2b7a1582018-09-05 16:33:58 +0100547 // if profiling is enabled print out the results
548 if (profiler && profiler->IsProfilingEnabled())
549 {
550 profiler->Print(std::cout);
551 }
552
telsoa014fcda012018-03-09 14:13:49 +0000553 if (ret == armnn::Status::Failure)
554 {
555 throw armnn::Exception("IRuntime::EnqueueWorkload failed");
556 }
James Conroy7b4886f2019-04-11 10:23:58 +0100557 else
558 {
559 return std::chrono::duration<double, std::milli>(end_time - start_time);
560 }
telsoa014fcda012018-03-09 14:13:49 +0000561 }
562
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000563 const BindingPointInfo& GetInputBindingInfo(unsigned int inputIndex = 0u) const
telsoa01c577f2c2018-08-31 09:22:23 +0100564 {
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000565 CheckInputIndexIsValid(inputIndex);
566 return m_InputBindings[inputIndex];
telsoa01c577f2c2018-08-31 09:22:23 +0100567 }
568
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000569 const std::vector<BindingPointInfo>& GetInputBindingInfos() const
telsoa01c577f2c2018-08-31 09:22:23 +0100570 {
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000571 return m_InputBindings;
telsoa01c577f2c2018-08-31 09:22:23 +0100572 }
573
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000574 const BindingPointInfo& GetOutputBindingInfo(unsigned int outputIndex = 0u) const
telsoa01c577f2c2018-08-31 09:22:23 +0100575 {
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000576 CheckOutputIndexIsValid(outputIndex);
577 return m_OutputBindings[outputIndex];
578 }
579
580 const std::vector<BindingPointInfo>& GetOutputBindingInfos() const
581 {
582 return m_OutputBindings;
583 }
584
585 QuantizationParams GetQuantizationParams(unsigned int outputIndex = 0u) const
586 {
587 CheckOutputIndexIsValid(outputIndex);
588 return std::make_pair(m_OutputBindings[outputIndex].second.GetQuantizationScale(),
589 m_OutputBindings[outputIndex].second.GetQuantizationOffset());
590 }
591
Narumol Prangnawarat4628d052019-02-25 17:26:05 +0000592 QuantizationParams GetInputQuantizationParams(unsigned int inputIndex = 0u) const
593 {
594 CheckInputIndexIsValid(inputIndex);
595 return std::make_pair(m_InputBindings[inputIndex].second.GetQuantizationScale(),
596 m_InputBindings[inputIndex].second.GetQuantizationOffset());
597 }
598
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000599 std::vector<QuantizationParams> GetAllQuantizationParams() const
600 {
601 std::vector<QuantizationParams> quantizationParams;
602 for (unsigned int i = 0u; i < m_OutputBindings.size(); i++)
603 {
604 quantizationParams.push_back(GetQuantizationParams(i));
605 }
606 return quantizationParams;
telsoa01c577f2c2018-08-31 09:22:23 +0100607 }
608
telsoa014fcda012018-03-09 14:13:49 +0000609private:
telsoa01c577f2c2018-08-31 09:22:23 +0100610 armnn::NetworkId m_NetworkIdentifier;
611 std::shared_ptr<armnn::IRuntime> m_Runtime;
612
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000613 std::vector<InferenceModelInternal::BindingPointInfo> m_InputBindings;
614 std::vector<InferenceModelInternal::BindingPointInfo> m_OutputBindings;
telsoa01c577f2c2018-08-31 09:22:23 +0100615 bool m_EnableProfiling;
616
telsoa014fcda012018-03-09 14:13:49 +0000617 template<typename TContainer>
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000618 armnn::InputTensors MakeInputTensors(const std::vector<TContainer>& inputDataContainers)
telsoa014fcda012018-03-09 14:13:49 +0000619 {
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000620 return ::MakeInputTensors(m_InputBindings, inputDataContainers);
telsoa014fcda012018-03-09 14:13:49 +0000621 }
622
623 template<typename TContainer>
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000624 armnn::OutputTensors MakeOutputTensors(std::vector<TContainer>& outputDataContainers)
telsoa014fcda012018-03-09 14:13:49 +0000625 {
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000626 return ::MakeOutputTensors(m_OutputBindings, outputDataContainers);
telsoa014fcda012018-03-09 14:13:49 +0000627 }
James Conroy7b4886f2019-04-11 10:23:58 +0100628
629 std::chrono::high_resolution_clock::time_point GetCurrentTime()
630 {
631 return std::chrono::high_resolution_clock::now();
632 }
633
634 std::chrono::duration<double, std::milli> GetTimeDuration(
635 std::chrono::high_resolution_clock::time_point& start_time,
636 std::chrono::high_resolution_clock::time_point& end_time)
637 {
638 return std::chrono::duration<double, std::milli>(end_time - start_time);
639 }
640
Ferran Balaguerc602f292019-02-08 17:09:55 +0000641};