blob: 92aa5066c0d98a4505cf87827a1788045e63d027 [file] [log] [blame]
Francis Murtaghbee4bc92019-06-18 12:30:37 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#include <armnn/ArmNN.hpp>
6#include <armnn/TypesUtils.hpp>
7
8#if defined(ARMNN_SERIALIZER)
9#include "armnnDeserializer/IDeserializer.hpp"
10#endif
11#if defined(ARMNN_CAFFE_PARSER)
12#include "armnnCaffeParser/ICaffeParser.hpp"
13#endif
14#if defined(ARMNN_TF_PARSER)
15#include "armnnTfParser/ITfParser.hpp"
16#endif
17#if defined(ARMNN_TF_LITE_PARSER)
18#include "armnnTfLiteParser/ITfLiteParser.hpp"
19#endif
20#if defined(ARMNN_ONNX_PARSER)
21#include "armnnOnnxParser/IOnnxParser.hpp"
22#endif
23#include "CsvReader.hpp"
24#include "../InferenceTest.hpp"
25
26#include <Logging.hpp>
27#include <Profiling.hpp>
28
29#include <boost/algorithm/string/trim.hpp>
30#include <boost/algorithm/string/split.hpp>
31#include <boost/algorithm/string/classification.hpp>
32#include <boost/program_options.hpp>
33#include <boost/variant.hpp>
34
35#include <iostream>
36#include <fstream>
37#include <functional>
38#include <future>
39#include <algorithm>
40#include <iterator>
41
42namespace
43{
44
45// Configure boost::program_options for command-line parsing and validation.
46namespace po = boost::program_options;
47
48template<typename T, typename TParseElementFunc>
49std::vector<T> ParseArrayImpl(std::istream& stream, TParseElementFunc parseElementFunc, const char * chars = "\t ,:")
50{
51 std::vector<T> result;
52 // Processes line-by-line.
53 std::string line;
54 while (std::getline(stream, line))
55 {
56 std::vector<std::string> tokens;
57 try
58 {
59 // Coverity fix: boost::split() may throw an exception of type boost::bad_function_call.
60 boost::split(tokens, line, boost::algorithm::is_any_of(chars), boost::token_compress_on);
61 }
62 catch (const std::exception& e)
63 {
64 BOOST_LOG_TRIVIAL(error) << "An error occurred when splitting tokens: " << e.what();
65 continue;
66 }
67 for (const std::string& token : tokens)
68 {
69 if (!token.empty()) // See https://stackoverflow.com/questions/10437406/
70 {
71 try
72 {
73 result.push_back(parseElementFunc(token));
74 }
75 catch (const std::exception&)
76 {
77 BOOST_LOG_TRIVIAL(error) << "'" << token << "' is not a valid number. It has been ignored.";
78 }
79 }
80 }
81 }
82
83 return result;
84}
85
86bool CheckOption(const po::variables_map& vm,
87 const char* option)
88{
89 // Check that the given option is valid.
90 if (option == nullptr)
91 {
92 return false;
93 }
94
95 // Check whether 'option' is provided.
96 return vm.find(option) != vm.end();
97}
98
99void CheckOptionDependency(const po::variables_map& vm,
100 const char* option,
101 const char* required)
102{
103 // Check that the given options are valid.
104 if (option == nullptr || required == nullptr)
105 {
106 throw po::error("Invalid option to check dependency for");
107 }
108
109 // Check that if 'option' is provided, 'required' is also provided.
110 if (CheckOption(vm, option) && !vm[option].defaulted())
111 {
112 if (CheckOption(vm, required) == 0 || vm[required].defaulted())
113 {
114 throw po::error(std::string("Option '") + option + "' requires option '" + required + "'.");
115 }
116 }
117}
118
119void CheckOptionDependencies(const po::variables_map& vm)
120{
121 CheckOptionDependency(vm, "model-path", "model-format");
122 CheckOptionDependency(vm, "model-path", "input-name");
123 CheckOptionDependency(vm, "model-path", "input-tensor-data");
124 CheckOptionDependency(vm, "model-path", "output-name");
125 CheckOptionDependency(vm, "input-tensor-shape", "model-path");
126}
127
128template<armnn::DataType NonQuantizedType>
129auto ParseDataArray(std::istream & stream);
130
131template<armnn::DataType QuantizedType>
132auto ParseDataArray(std::istream& stream,
133 const float& quantizationScale,
134 const int32_t& quantizationOffset);
135
136template<>
137auto ParseDataArray<armnn::DataType::Float32>(std::istream & stream)
138{
139 return ParseArrayImpl<float>(stream, [](const std::string& s) { return std::stof(s); });
140}
141
142template<>
143auto ParseDataArray<armnn::DataType::Signed32>(std::istream & stream)
144{
145 return ParseArrayImpl<int>(stream, [](const std::string & s) { return std::stoi(s); });
146}
147
148template<>
Narumol Prangnawarat610256f2019-06-26 15:10:46 +0100149auto ParseDataArray<armnn::DataType::QuantisedAsymm8>(std::istream& stream)
150{
151 return ParseArrayImpl<uint8_t>(stream,
152 [](const std::string& s) { return boost::numeric_cast<uint8_t>(std::stoi(s)); });
153}
154
155template<>
Francis Murtaghbee4bc92019-06-18 12:30:37 +0100156auto ParseDataArray<armnn::DataType::QuantisedAsymm8>(std::istream& stream,
157 const float& quantizationScale,
158 const int32_t& quantizationOffset)
159{
160 return ParseArrayImpl<uint8_t>(stream,
161 [&quantizationScale, &quantizationOffset](const std::string & s)
162 {
163 return boost::numeric_cast<uint8_t>(
164 armnn::Quantize<u_int8_t>(std::stof(s),
165 quantizationScale,
166 quantizationOffset));
167 });
168}
Francis Murtaghbee4bc92019-06-18 12:30:37 +0100169std::vector<unsigned int> ParseArray(std::istream& stream)
170{
171 return ParseArrayImpl<unsigned int>(stream,
172 [](const std::string& s) { return boost::numeric_cast<unsigned int>(std::stoi(s)); });
173}
174
175std::vector<std::string> ParseStringList(const std::string & inputString, const char * delimiter)
176{
177 std::stringstream stream(inputString);
178 return ParseArrayImpl<std::string>(stream, [](const std::string& s) { return boost::trim_copy(s); }, delimiter);
179}
180
181void RemoveDuplicateDevices(std::vector<armnn::BackendId>& computeDevices)
182{
183 // Mark the duplicate devices as 'Undefined'.
184 for (auto i = computeDevices.begin(); i != computeDevices.end(); ++i)
185 {
186 for (auto j = std::next(i); j != computeDevices.end(); ++j)
187 {
188 if (*j == *i)
189 {
190 *j = armnn::Compute::Undefined;
191 }
192 }
193 }
194
195 // Remove 'Undefined' devices.
196 computeDevices.erase(std::remove(computeDevices.begin(), computeDevices.end(), armnn::Compute::Undefined),
197 computeDevices.end());
198}
199
200struct TensorPrinter : public boost::static_visitor<>
201{
202 TensorPrinter(const std::string& binding, const armnn::TensorInfo& info)
203 : m_OutputBinding(binding)
204 , m_Scale(info.GetQuantizationScale())
205 , m_Offset(info.GetQuantizationOffset())
206 {}
207
208 void operator()(const std::vector<float>& values)
209 {
210 ForEachValue(values, [](float value){
211 printf("%f ", value);
212 });
213 }
214
215 void operator()(const std::vector<uint8_t>& values)
216 {
217 auto& scale = m_Scale;
218 auto& offset = m_Offset;
219 ForEachValue(values, [&scale, &offset](uint8_t value)
220 {
221 printf("%f ", armnn::Dequantize(value, scale, offset));
222 });
223 }
224
225 void operator()(const std::vector<int>& values)
226 {
227 ForEachValue(values, [](int value)
228 {
229 printf("%d ", value);
230 });
231 }
232
233private:
234 template<typename Container, typename Delegate>
235 void ForEachValue(const Container& c, Delegate delegate)
236 {
237 std::cout << m_OutputBinding << ": ";
238 for (const auto& value : c)
239 {
240 delegate(value);
241 }
242 printf("\n");
243 }
244
245 std::string m_OutputBinding;
246 float m_Scale=0.0f;
247 int m_Offset=0;
248};
249
250
251} // namespace
252
253template<typename TParser, typename TDataType>
254int MainImpl(const char* modelPath,
255 bool isModelBinary,
256 const std::vector<armnn::BackendId>& computeDevices,
Matteo Martincigh00dda4a2019-08-14 11:42:30 +0100257 const std::string& dynamicBackendsPath,
Francis Murtaghbee4bc92019-06-18 12:30:37 +0100258 const std::vector<string>& inputNames,
259 const std::vector<std::unique_ptr<armnn::TensorShape>>& inputTensorShapes,
260 const std::vector<string>& inputTensorDataFilePaths,
261 const std::vector<string>& inputTypes,
Narumol Prangnawarat610256f2019-06-26 15:10:46 +0100262 bool quantizeInput,
Francis Murtaghbee4bc92019-06-18 12:30:37 +0100263 const std::vector<string>& outputTypes,
264 const std::vector<string>& outputNames,
265 bool enableProfiling,
266 bool enableFp16TurboMode,
267 const double& thresholdTime,
Matthew Jackson54658b92019-08-27 15:35:59 +0100268 bool printIntermediate,
Francis Murtaghbee4bc92019-06-18 12:30:37 +0100269 const size_t subgraphId,
270 const std::shared_ptr<armnn::IRuntime>& runtime = nullptr)
271{
272 using TContainer = boost::variant<std::vector<float>, std::vector<int>, std::vector<unsigned char>>;
273
274 std::vector<TContainer> inputDataContainers;
275
276 try
277 {
278 // Creates an InferenceModel, which will parse the model and load it into an IRuntime.
279 typename InferenceModel<TParser, TDataType>::Params params;
280 params.m_ModelPath = modelPath;
281 params.m_IsModelBinary = isModelBinary;
282 params.m_ComputeDevices = computeDevices;
Matteo Martincigh00dda4a2019-08-14 11:42:30 +0100283 params.m_DynamicBackendsPath = dynamicBackendsPath;
Matthew Jackson54658b92019-08-27 15:35:59 +0100284 params.m_PrintIntermediateLayers = printIntermediate;
Francis Murtaghbee4bc92019-06-18 12:30:37 +0100285
286 for(const std::string& inputName: inputNames)
287 {
288 params.m_InputBindings.push_back(inputName);
289 }
290
291 for(unsigned int i = 0; i < inputTensorShapes.size(); ++i)
292 {
293 params.m_InputShapes.push_back(*inputTensorShapes[i]);
294 }
295
296 for(const std::string& outputName: outputNames)
297 {
298 params.m_OutputBindings.push_back(outputName);
299 }
300
301 params.m_SubgraphId = subgraphId;
302 params.m_EnableFp16TurboMode = enableFp16TurboMode;
Matteo Martincigh00dda4a2019-08-14 11:42:30 +0100303 InferenceModel<TParser, TDataType> model(params, enableProfiling, dynamicBackendsPath, runtime);
Francis Murtaghbee4bc92019-06-18 12:30:37 +0100304
305 for(unsigned int i = 0; i < inputTensorDataFilePaths.size(); ++i)
306 {
307 std::ifstream inputTensorFile(inputTensorDataFilePaths[i]);
308
309 if (inputTypes[i].compare("float") == 0)
310 {
Narumol Prangnawarat610256f2019-06-26 15:10:46 +0100311 if (quantizeInput)
312 {
313 auto inputBinding = model.GetInputBindingInfo();
314 inputDataContainers.push_back(
315 ParseDataArray<armnn::DataType::QuantisedAsymm8>(inputTensorFile,
316 inputBinding.second.GetQuantizationScale(),
317 inputBinding.second.GetQuantizationOffset()));
318 }
319 else
320 {
321 inputDataContainers.push_back(
322 ParseDataArray<armnn::DataType::Float32>(inputTensorFile));
323 }
Francis Murtaghbee4bc92019-06-18 12:30:37 +0100324 }
325 else if (inputTypes[i].compare("int") == 0)
326 {
327 inputDataContainers.push_back(
328 ParseDataArray<armnn::DataType::Signed32>(inputTensorFile));
329 }
330 else if (inputTypes[i].compare("qasymm8") == 0)
331 {
Francis Murtaghbee4bc92019-06-18 12:30:37 +0100332 inputDataContainers.push_back(
Narumol Prangnawarat610256f2019-06-26 15:10:46 +0100333 ParseDataArray<armnn::DataType::QuantisedAsymm8>(inputTensorFile));
Francis Murtaghbee4bc92019-06-18 12:30:37 +0100334 }
335 else
336 {
337 BOOST_LOG_TRIVIAL(fatal) << "Unsupported tensor data type \"" << inputTypes[i] << "\". ";
338 return EXIT_FAILURE;
339 }
340
341 inputTensorFile.close();
342 }
343
344 const size_t numOutputs = params.m_OutputBindings.size();
345 std::vector<TContainer> outputDataContainers;
346
347 for (unsigned int i = 0; i < numOutputs; ++i)
348 {
349 if (outputTypes[i].compare("float") == 0)
350 {
351 outputDataContainers.push_back(std::vector<float>(model.GetOutputSize(i)));
352 }
353 else if (outputTypes[i].compare("int") == 0)
354 {
355 outputDataContainers.push_back(std::vector<int>(model.GetOutputSize(i)));
356 }
357 else if (outputTypes[i].compare("qasymm8") == 0)
358 {
359 outputDataContainers.push_back(std::vector<uint8_t>(model.GetOutputSize(i)));
360 }
361 else
362 {
363 BOOST_LOG_TRIVIAL(fatal) << "Unsupported tensor data type \"" << outputTypes[i] << "\". ";
364 return EXIT_FAILURE;
365 }
366 }
367
368 // model.Run returns the inference time elapsed in EnqueueWorkload (in milliseconds)
369 auto inference_duration = model.Run(inputDataContainers, outputDataContainers);
370
371 // Print output tensors
372 const auto& infosOut = model.GetOutputBindingInfos();
373 for (size_t i = 0; i < numOutputs; i++)
374 {
375 const armnn::TensorInfo& infoOut = infosOut[i].second;
376 TensorPrinter printer(params.m_OutputBindings[i], infoOut);
377 boost::apply_visitor(printer, outputDataContainers[i]);
378 }
379
380 BOOST_LOG_TRIVIAL(info) << "\nInference time: " << std::setprecision(2)
381 << std::fixed << inference_duration.count() << " ms";
382
383 // If thresholdTime == 0.0 (default), then it hasn't been supplied at command line
384 if (thresholdTime != 0.0)
385 {
386 BOOST_LOG_TRIVIAL(info) << "Threshold time: " << std::setprecision(2)
387 << std::fixed << thresholdTime << " ms";
388 auto thresholdMinusInference = thresholdTime - inference_duration.count();
389 BOOST_LOG_TRIVIAL(info) << "Threshold time - Inference time: " << std::setprecision(2)
390 << std::fixed << thresholdMinusInference << " ms" << "\n";
391
392 if (thresholdMinusInference < 0)
393 {
394 BOOST_LOG_TRIVIAL(fatal) << "Elapsed inference time is greater than provided threshold time.\n";
395 return EXIT_FAILURE;
396 }
397 }
398
399
400 }
401 catch (armnn::Exception const& e)
402 {
403 BOOST_LOG_TRIVIAL(fatal) << "Armnn Error: " << e.what();
404 return EXIT_FAILURE;
405 }
406
407 return EXIT_SUCCESS;
408}
409
410// This will run a test
411int RunTest(const std::string& format,
412 const std::string& inputTensorShapesStr,
413 const vector<armnn::BackendId>& computeDevice,
Matteo Martincigh00dda4a2019-08-14 11:42:30 +0100414 const std::string& dynamicBackendsPath,
Francis Murtaghbee4bc92019-06-18 12:30:37 +0100415 const std::string& path,
416 const std::string& inputNames,
417 const std::string& inputTensorDataFilePaths,
418 const std::string& inputTypes,
Narumol Prangnawarat610256f2019-06-26 15:10:46 +0100419 bool quantizeInput,
Francis Murtaghbee4bc92019-06-18 12:30:37 +0100420 const std::string& outputTypes,
421 const std::string& outputNames,
422 bool enableProfiling,
423 bool enableFp16TurboMode,
424 const double& thresholdTime,
Matthew Jackson54658b92019-08-27 15:35:59 +0100425 bool printIntermediate,
Francis Murtaghbee4bc92019-06-18 12:30:37 +0100426 const size_t subgraphId,
427 const std::shared_ptr<armnn::IRuntime>& runtime = nullptr)
428{
429 std::string modelFormat = boost::trim_copy(format);
430 std::string modelPath = boost::trim_copy(path);
431 std::vector<std::string> inputNamesVector = ParseStringList(inputNames, ",");
432 std::vector<std::string> inputTensorShapesVector = ParseStringList(inputTensorShapesStr, ";");
433 std::vector<std::string> inputTensorDataFilePathsVector = ParseStringList(
434 inputTensorDataFilePaths, ",");
435 std::vector<std::string> outputNamesVector = ParseStringList(outputNames, ",");
436 std::vector<std::string> inputTypesVector = ParseStringList(inputTypes, ",");
437 std::vector<std::string> outputTypesVector = ParseStringList(outputTypes, ",");
438
439 // Parse model binary flag from the model-format string we got from the command-line
440 bool isModelBinary;
441 if (modelFormat.find("bin") != std::string::npos)
442 {
443 isModelBinary = true;
444 }
445 else if (modelFormat.find("txt") != std::string::npos || modelFormat.find("text") != std::string::npos)
446 {
447 isModelBinary = false;
448 }
449 else
450 {
451 BOOST_LOG_TRIVIAL(fatal) << "Unknown model format: '" << modelFormat << "'. Please include 'binary' or 'text'";
452 return EXIT_FAILURE;
453 }
454
455 if ((inputTensorShapesVector.size() != 0) && (inputTensorShapesVector.size() != inputNamesVector.size()))
456 {
457 BOOST_LOG_TRIVIAL(fatal) << "input-name and input-tensor-shape must have the same amount of elements.";
458 return EXIT_FAILURE;
459 }
460
461 if ((inputTensorDataFilePathsVector.size() != 0) &&
462 (inputTensorDataFilePathsVector.size() != inputNamesVector.size()))
463 {
464 BOOST_LOG_TRIVIAL(fatal) << "input-name and input-tensor-data must have the same amount of elements.";
465 return EXIT_FAILURE;
466 }
467
468 if (inputTypesVector.size() == 0)
469 {
470 //Defaults the value of all inputs to "float"
471 inputTypesVector.assign(inputNamesVector.size(), "float");
472 }
Matteo Martincigh08b51862019-08-29 16:26:10 +0100473 else if ((inputTypesVector.size() != 0) && (inputTypesVector.size() != inputNamesVector.size()))
474 {
475 BOOST_LOG_TRIVIAL(fatal) << "input-name and input-type must have the same amount of elements.";
476 return EXIT_FAILURE;
477 }
478
Francis Murtaghbee4bc92019-06-18 12:30:37 +0100479 if (outputTypesVector.size() == 0)
480 {
481 //Defaults the value of all outputs to "float"
482 outputTypesVector.assign(outputNamesVector.size(), "float");
483 }
Matteo Martincigh08b51862019-08-29 16:26:10 +0100484 else if ((outputTypesVector.size() != 0) && (outputTypesVector.size() != outputNamesVector.size()))
Francis Murtaghbee4bc92019-06-18 12:30:37 +0100485 {
Matteo Martincigh08b51862019-08-29 16:26:10 +0100486 BOOST_LOG_TRIVIAL(fatal) << "output-name and output-type must have the same amount of elements.";
Francis Murtaghbee4bc92019-06-18 12:30:37 +0100487 return EXIT_FAILURE;
488 }
489
490 // Parse input tensor shape from the string we got from the command-line.
491 std::vector<std::unique_ptr<armnn::TensorShape>> inputTensorShapes;
492
493 if (!inputTensorShapesVector.empty())
494 {
495 inputTensorShapes.reserve(inputTensorShapesVector.size());
496
497 for(const std::string& shape : inputTensorShapesVector)
498 {
499 std::stringstream ss(shape);
500 std::vector<unsigned int> dims = ParseArray(ss);
501
502 try
503 {
504 // Coverity fix: An exception of type armnn::InvalidArgumentException is thrown and never caught.
505 inputTensorShapes.push_back(std::make_unique<armnn::TensorShape>(dims.size(), dims.data()));
506 }
507 catch (const armnn::InvalidArgumentException& e)
508 {
509 BOOST_LOG_TRIVIAL(fatal) << "Cannot create tensor shape: " << e.what();
510 return EXIT_FAILURE;
511 }
512 }
513 }
514
515 // Check that threshold time is not less than zero
516 if (thresholdTime < 0)
517 {
518 BOOST_LOG_TRIVIAL(fatal) << "Threshold time supplied as a commoand line argument is less than zero.";
519 return EXIT_FAILURE;
520 }
521
522 // Forward to implementation based on the parser type
523 if (modelFormat.find("armnn") != std::string::npos)
524 {
525#if defined(ARMNN_SERIALIZER)
526 return MainImpl<armnnDeserializer::IDeserializer, float>(
527 modelPath.c_str(), isModelBinary, computeDevice,
Matteo Martincigh00dda4a2019-08-14 11:42:30 +0100528 dynamicBackendsPath, inputNamesVector, inputTensorShapes,
Narumol Prangnawarat610256f2019-06-26 15:10:46 +0100529 inputTensorDataFilePathsVector, inputTypesVector, quantizeInput,
Francis Murtaghbee4bc92019-06-18 12:30:37 +0100530 outputTypesVector, outputNamesVector, enableProfiling,
Matthew Jackson54658b92019-08-27 15:35:59 +0100531 enableFp16TurboMode, thresholdTime, printIntermediate, subgraphId, runtime);
Francis Murtaghbee4bc92019-06-18 12:30:37 +0100532#else
533 BOOST_LOG_TRIVIAL(fatal) << "Not built with serialization support.";
534 return EXIT_FAILURE;
535#endif
536 }
537 else if (modelFormat.find("caffe") != std::string::npos)
538 {
539#if defined(ARMNN_CAFFE_PARSER)
540 return MainImpl<armnnCaffeParser::ICaffeParser, float>(modelPath.c_str(), isModelBinary, computeDevice,
Matteo Martincigh00dda4a2019-08-14 11:42:30 +0100541 dynamicBackendsPath,
Francis Murtaghbee4bc92019-06-18 12:30:37 +0100542 inputNamesVector, inputTensorShapes,
543 inputTensorDataFilePathsVector, inputTypesVector,
Narumol Prangnawarat610256f2019-06-26 15:10:46 +0100544 quantizeInput, outputTypesVector, outputNamesVector,
545 enableProfiling, enableFp16TurboMode, thresholdTime,
Matthew Jackson54658b92019-08-27 15:35:59 +0100546 printIntermediate, subgraphId, runtime);
Francis Murtaghbee4bc92019-06-18 12:30:37 +0100547#else
548 BOOST_LOG_TRIVIAL(fatal) << "Not built with Caffe parser support.";
549 return EXIT_FAILURE;
550#endif
551 }
552 else if (modelFormat.find("onnx") != std::string::npos)
553{
554#if defined(ARMNN_ONNX_PARSER)
555 return MainImpl<armnnOnnxParser::IOnnxParser, float>(modelPath.c_str(), isModelBinary, computeDevice,
Matteo Martincigh00dda4a2019-08-14 11:42:30 +0100556 dynamicBackendsPath,
Francis Murtaghbee4bc92019-06-18 12:30:37 +0100557 inputNamesVector, inputTensorShapes,
558 inputTensorDataFilePathsVector, inputTypesVector,
Narumol Prangnawarat610256f2019-06-26 15:10:46 +0100559 quantizeInput, outputTypesVector, outputNamesVector,
560 enableProfiling, enableFp16TurboMode, thresholdTime,
Matthew Jackson54658b92019-08-27 15:35:59 +0100561 printIntermediate, subgraphId, runtime);
Francis Murtaghbee4bc92019-06-18 12:30:37 +0100562#else
563 BOOST_LOG_TRIVIAL(fatal) << "Not built with Onnx parser support.";
564 return EXIT_FAILURE;
565#endif
566 }
567 else if (modelFormat.find("tensorflow") != std::string::npos)
568 {
569#if defined(ARMNN_TF_PARSER)
570 return MainImpl<armnnTfParser::ITfParser, float>(modelPath.c_str(), isModelBinary, computeDevice,
Matteo Martincigh00dda4a2019-08-14 11:42:30 +0100571 dynamicBackendsPath,
Francis Murtaghbee4bc92019-06-18 12:30:37 +0100572 inputNamesVector, inputTensorShapes,
573 inputTensorDataFilePathsVector, inputTypesVector,
Narumol Prangnawarat610256f2019-06-26 15:10:46 +0100574 quantizeInput, outputTypesVector, outputNamesVector,
575 enableProfiling, enableFp16TurboMode, thresholdTime,
Matthew Jackson54658b92019-08-27 15:35:59 +0100576 printIntermediate, subgraphId, runtime);
Francis Murtaghbee4bc92019-06-18 12:30:37 +0100577#else
578 BOOST_LOG_TRIVIAL(fatal) << "Not built with Tensorflow parser support.";
579 return EXIT_FAILURE;
580#endif
581 }
582 else if(modelFormat.find("tflite") != std::string::npos)
583 {
584#if defined(ARMNN_TF_LITE_PARSER)
585 if (! isModelBinary)
586 {
587 BOOST_LOG_TRIVIAL(fatal) << "Unknown model format: '" << modelFormat << "'. Only 'binary' format supported \
588 for tflite files";
589 return EXIT_FAILURE;
590 }
591 return MainImpl<armnnTfLiteParser::ITfLiteParser, float>(modelPath.c_str(), isModelBinary, computeDevice,
Matteo Martincigh00dda4a2019-08-14 11:42:30 +0100592 dynamicBackendsPath,
Francis Murtaghbee4bc92019-06-18 12:30:37 +0100593 inputNamesVector, inputTensorShapes,
594 inputTensorDataFilePathsVector, inputTypesVector,
Narumol Prangnawarat610256f2019-06-26 15:10:46 +0100595 quantizeInput, outputTypesVector, outputNamesVector,
596 enableProfiling, enableFp16TurboMode, thresholdTime,
Matthew Jackson54658b92019-08-27 15:35:59 +0100597 printIntermediate, subgraphId, runtime);
Francis Murtaghbee4bc92019-06-18 12:30:37 +0100598#else
599 BOOST_LOG_TRIVIAL(fatal) << "Unknown model format: '" << modelFormat <<
600 "'. Please include 'caffe', 'tensorflow', 'tflite' or 'onnx'";
601 return EXIT_FAILURE;
602#endif
603 }
604 else
605 {
606 BOOST_LOG_TRIVIAL(fatal) << "Unknown model format: '" << modelFormat <<
607 "'. Please include 'caffe', 'tensorflow', 'tflite' or 'onnx'";
608 return EXIT_FAILURE;
609 }
610}
611
612int RunCsvTest(const armnnUtils::CsvRow &csvRow, const std::shared_ptr<armnn::IRuntime>& runtime,
Matthew Jackson54658b92019-08-27 15:35:59 +0100613 const bool enableProfiling, const bool enableFp16TurboMode, const double& thresholdTime,
614 const bool printIntermediate)
Francis Murtaghbee4bc92019-06-18 12:30:37 +0100615{
616 std::string modelFormat;
617 std::string modelPath;
618 std::string inputNames;
619 std::string inputTensorShapes;
620 std::string inputTensorDataFilePaths;
621 std::string outputNames;
622 std::string inputTypes;
623 std::string outputTypes;
Matteo Martincigh00dda4a2019-08-14 11:42:30 +0100624 std::string dynamicBackendsPath;
Francis Murtaghbee4bc92019-06-18 12:30:37 +0100625
626 size_t subgraphId = 0;
627
628 const std::string backendsMessage = std::string("The preferred order of devices to run layers on by default. ")
629 + std::string("Possible choices: ")
630 + armnn::BackendRegistryInstance().GetBackendIdsAsString();
631
632 po::options_description desc("Options");
633 try
634 {
635 desc.add_options()
636 ("model-format,f", po::value(&modelFormat),
637 "armnn-binary, caffe-binary, caffe-text, tflite-binary, onnx-binary, onnx-text, tensorflow-binary or "
638 "tensorflow-text.")
639 ("model-path,m", po::value(&modelPath), "Path to model file, e.g. .armnn, .caffemodel, .prototxt, "
640 ".tflite, .onnx")
641 ("compute,c", po::value<std::vector<armnn::BackendId>>()->multitoken(),
642 backendsMessage.c_str())
Matteo Martincigh00dda4a2019-08-14 11:42:30 +0100643 ("dynamic-backends-path,b", po::value(&dynamicBackendsPath),
644 "Path where to load any available dynamic backend from. "
645 "If left empty (the default), dynamic backends will not be used.")
Francis Murtaghbee4bc92019-06-18 12:30:37 +0100646 ("input-name,i", po::value(&inputNames), "Identifier of the input tensors in the network separated by comma.")
647 ("subgraph-number,n", po::value<size_t>(&subgraphId)->default_value(0), "Id of the subgraph to be "
648 "executed. Defaults to 0.")
649 ("input-tensor-shape,s", po::value(&inputTensorShapes),
650 "The shape of the input tensors in the network as a flat array of integers separated by comma. "
651 "Several shapes can be passed separating them by semicolon. "
652 "This parameter is optional, depending on the network.")
653 ("input-tensor-data,d", po::value(&inputTensorDataFilePaths),
654 "Path to files containing the input data as a flat array separated by whitespace. "
655 "Several paths can be passed separating them by comma.")
656 ("input-type,y",po::value(&inputTypes), "The type of the input tensors in the network separated by comma. "
657 "If unset, defaults to \"float\" for all defined inputs. "
658 "Accepted values (float, int or qasymm8).")
Narumol Prangnawarat610256f2019-06-26 15:10:46 +0100659 ("quantize-input,q",po::bool_switch()->default_value(false),
660 "If this option is enabled, all float inputs will be quantized to qasymm8. "
661 "If unset, default to not quantized. "
662 "Accepted values (true or false)")
Francis Murtaghbee4bc92019-06-18 12:30:37 +0100663 ("output-type,z",po::value(&outputTypes), "The type of the output tensors in the network separated by comma. "
664 "If unset, defaults to \"float\" for all defined outputs. "
665 "Accepted values (float, int or qasymm8).")
666 ("output-name,o", po::value(&outputNames),
667 "Identifier of the output tensors in the network separated by comma.");
668 }
669 catch (const std::exception& e)
670 {
671 // Coverity points out that default_value(...) can throw a bad_lexical_cast,
672 // and that desc.add_options() can throw boost::io::too_few_args.
673 // They really won't in any of these cases.
674 BOOST_ASSERT_MSG(false, "Caught unexpected exception");
675 BOOST_LOG_TRIVIAL(fatal) << "Fatal internal error: " << e.what();
676 return EXIT_FAILURE;
677 }
678
679 std::vector<const char*> clOptions;
680 clOptions.reserve(csvRow.values.size());
681 for (const std::string& value : csvRow.values)
682 {
683 clOptions.push_back(value.c_str());
684 }
685
686 po::variables_map vm;
687 try
688 {
689 po::store(po::parse_command_line(static_cast<int>(clOptions.size()), clOptions.data(), desc), vm);
690
691 po::notify(vm);
692
693 CheckOptionDependencies(vm);
694 }
695 catch (const po::error& e)
696 {
697 std::cerr << e.what() << std::endl << std::endl;
698 std::cerr << desc << std::endl;
699 return EXIT_FAILURE;
700 }
701
Narumol Prangnawarat610256f2019-06-26 15:10:46 +0100702 // Get the value of the switch arguments.
703 bool quantizeInput = vm["quantize-input"].as<bool>();
704
Francis Murtaghbee4bc92019-06-18 12:30:37 +0100705 // Get the preferred order of compute devices.
706 std::vector<armnn::BackendId> computeDevices = vm["compute"].as<std::vector<armnn::BackendId>>();
707
708 // Remove duplicates from the list of compute devices.
709 RemoveDuplicateDevices(computeDevices);
710
711 // Check that the specified compute devices are valid.
712 std::string invalidBackends;
713 if (!CheckRequestedBackendsAreValid(computeDevices, armnn::Optional<std::string&>(invalidBackends)))
714 {
715 BOOST_LOG_TRIVIAL(fatal) << "The list of preferred devices contains invalid backend IDs: "
716 << invalidBackends;
717 return EXIT_FAILURE;
718 }
719
Matteo Martincigh00dda4a2019-08-14 11:42:30 +0100720 return RunTest(modelFormat, inputTensorShapes, computeDevices, dynamicBackendsPath, modelPath, inputNames,
Narumol Prangnawarat610256f2019-06-26 15:10:46 +0100721 inputTensorDataFilePaths, inputTypes, quantizeInput, outputTypes, outputNames,
Matthew Jackson54658b92019-08-27 15:35:59 +0100722 enableProfiling, enableFp16TurboMode, thresholdTime, printIntermediate, subgraphId);
Matteo Martincigh00dda4a2019-08-14 11:42:30 +0100723}