blob: 29780104c29b97f2644b49cbc0f05677c2000ef1 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Aron Virginas-Tar5cc8e562018-10-23 15:14:46 +01005#include <armnn/ArmNN.hpp>
telsoa01c577f2c2018-08-31 09:22:23 +01006#include <armnn/TypesUtils.hpp>
7
telsoa014fcda012018-03-09 14:13:49 +00008#if defined(ARMNN_CAFFE_PARSER)
9#include "armnnCaffeParser/ICaffeParser.hpp"
10#endif
surmeh01bceff2f2018-03-29 16:29:27 +010011#if defined(ARMNN_TF_PARSER)
12#include "armnnTfParser/ITfParser.hpp"
13#endif
telsoa01c577f2c2018-08-31 09:22:23 +010014#if defined(ARMNN_TF_LITE_PARSER)
15#include "armnnTfLiteParser/ITfLiteParser.hpp"
16#endif
17#if defined(ARMNN_ONNX_PARSER)
18#include "armnnOnnxParser/IOnnxParser.hpp"
19#endif
20#include "CsvReader.hpp"
telsoa014fcda012018-03-09 14:13:49 +000021#include "../InferenceTest.hpp"
22
telsoa01c577f2c2018-08-31 09:22:23 +010023#include <Logging.hpp>
24#include <Profiling.hpp>
25
26#include <boost/algorithm/string/trim.hpp>
telsoa014fcda012018-03-09 14:13:49 +000027#include <boost/algorithm/string/split.hpp>
28#include <boost/algorithm/string/classification.hpp>
telsoa01c577f2c2018-08-31 09:22:23 +010029#include <boost/program_options.hpp>
telsoa014fcda012018-03-09 14:13:49 +000030
31#include <iostream>
32#include <fstream>
telsoa01c577f2c2018-08-31 09:22:23 +010033#include <functional>
34#include <future>
35#include <algorithm>
36#include <iterator>
telsoa014fcda012018-03-09 14:13:49 +000037
38namespace
39{
40
telsoa01c577f2c2018-08-31 09:22:23 +010041// Configure boost::program_options for command-line parsing and validation.
42namespace po = boost::program_options;
43
telsoa014fcda012018-03-09 14:13:49 +000044template<typename T, typename TParseElementFunc>
45std::vector<T> ParseArrayImpl(std::istream& stream, TParseElementFunc parseElementFunc)
46{
47 std::vector<T> result;
telsoa01c577f2c2018-08-31 09:22:23 +010048 // Processes line-by-line.
telsoa014fcda012018-03-09 14:13:49 +000049 std::string line;
50 while (std::getline(stream, line))
51 {
52 std::vector<std::string> tokens;
surmeh013537c2c2018-05-18 16:31:43 +010053 try
54 {
55 // Coverity fix: boost::split() may throw an exception of type boost::bad_function_call.
56 boost::split(tokens, line, boost::algorithm::is_any_of("\t ,;:"), boost::token_compress_on);
57 }
58 catch (const std::exception& e)
59 {
60 BOOST_LOG_TRIVIAL(error) << "An error occurred when splitting tokens: " << e.what();
61 continue;
62 }
telsoa014fcda012018-03-09 14:13:49 +000063 for (const std::string& token : tokens)
64 {
65 if (!token.empty()) // See https://stackoverflow.com/questions/10437406/
66 {
67 try
68 {
69 result.push_back(parseElementFunc(token));
70 }
71 catch (const std::exception&)
72 {
73 BOOST_LOG_TRIVIAL(error) << "'" << token << "' is not a valid number. It has been ignored.";
74 }
75 }
76 }
77 }
78
79 return result;
80}
81
telsoa01c577f2c2018-08-31 09:22:23 +010082bool CheckOption(const po::variables_map& vm,
83 const char* option)
84{
85 // Check that the given option is valid.
86 if (option == nullptr)
87 {
88 return false;
89 }
90
91 // Check whether 'option' is provided.
92 return vm.find(option) != vm.end();
93}
94
95void CheckOptionDependency(const po::variables_map& vm,
96 const char* option,
97 const char* required)
98{
99 // Check that the given options are valid.
100 if (option == nullptr || required == nullptr)
101 {
102 throw po::error("Invalid option to check dependency for");
103 }
104
105 // Check that if 'option' is provided, 'required' is also provided.
106 if (CheckOption(vm, option) && !vm[option].defaulted())
107 {
108 if (CheckOption(vm, required) == 0 || vm[required].defaulted())
109 {
110 throw po::error(std::string("Option '") + option + "' requires option '" + required + "'.");
111 }
112 }
113}
114
115void CheckOptionDependencies(const po::variables_map& vm)
116{
117 CheckOptionDependency(vm, "model-path", "model-format");
118 CheckOptionDependency(vm, "model-path", "input-name");
119 CheckOptionDependency(vm, "model-path", "input-tensor-data");
120 CheckOptionDependency(vm, "model-path", "output-name");
121 CheckOptionDependency(vm, "input-tensor-shape", "model-path");
telsoa014fcda012018-03-09 14:13:49 +0000122}
123
124template<typename T>
125std::vector<T> ParseArray(std::istream& stream);
126
127template<>
128std::vector<float> ParseArray(std::istream& stream)
129{
130 return ParseArrayImpl<float>(stream, [](const std::string& s) { return std::stof(s); });
131}
132
133template<>
134std::vector<unsigned int> ParseArray(std::istream& stream)
135{
136 return ParseArrayImpl<unsigned int>(stream,
137 [](const std::string& s) { return boost::numeric_cast<unsigned int>(std::stoi(s)); });
138}
139
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000140void PrintOutputData(const std::string& outputLayerName, const std::vector<float>& data)
telsoa014fcda012018-03-09 14:13:49 +0000141{
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000142 std::cout << outputLayerName << ": ";
Aron Virginas-Tar93f5f972019-01-31 13:12:34 +0000143 for (size_t i = 0; i < data.size(); i++)
144 {
145 printf("%f ", data[i]);
146 }
147 printf("\n");
telsoa014fcda012018-03-09 14:13:49 +0000148}
149
David Beckf0b48452018-10-19 15:20:56 +0100150void RemoveDuplicateDevices(std::vector<armnn::BackendId>& computeDevices)
telsoa014fcda012018-03-09 14:13:49 +0000151{
telsoa01c577f2c2018-08-31 09:22:23 +0100152 // Mark the duplicate devices as 'Undefined'.
153 for (auto i = computeDevices.begin(); i != computeDevices.end(); ++i)
154 {
155 for (auto j = std::next(i); j != computeDevices.end(); ++j)
156 {
157 if (*j == *i)
158 {
159 *j = armnn::Compute::Undefined;
160 }
161 }
162 }
163
164 // Remove 'Undefined' devices.
165 computeDevices.erase(std::remove(computeDevices.begin(), computeDevices.end(), armnn::Compute::Undefined),
166 computeDevices.end());
167}
168
telsoa01c577f2c2018-08-31 09:22:23 +0100169} // namespace
170
171template<typename TParser, typename TDataType>
172int MainImpl(const char* modelPath,
173 bool isModelBinary,
David Beckf0b48452018-10-19 15:20:56 +0100174 const std::vector<armnn::BackendId>& computeDevice,
telsoa01c577f2c2018-08-31 09:22:23 +0100175 const char* inputName,
176 const armnn::TensorShape* inputTensorShape,
177 const char* inputTensorDataFilePath,
178 const char* outputName,
179 bool enableProfiling,
180 const size_t subgraphId,
181 const std::shared_ptr<armnn::IRuntime>& runtime = nullptr)
182{
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000183 using TContainer = std::vector<TDataType>;
184
telsoa01c577f2c2018-08-31 09:22:23 +0100185 // Loads input tensor.
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000186 TContainer inputDataContainer;
telsoa014fcda012018-03-09 14:13:49 +0000187 {
188 std::ifstream inputTensorFile(inputTensorDataFilePath);
189 if (!inputTensorFile.good())
190 {
191 BOOST_LOG_TRIVIAL(fatal) << "Failed to load input tensor data file from " << inputTensorDataFilePath;
telsoa01c577f2c2018-08-31 09:22:23 +0100192 return EXIT_FAILURE;
telsoa014fcda012018-03-09 14:13:49 +0000193 }
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000194 inputDataContainer = ParseArray<TDataType>(inputTensorFile);
telsoa014fcda012018-03-09 14:13:49 +0000195 }
196
197 try
198 {
telsoa01c577f2c2018-08-31 09:22:23 +0100199 // Creates an InferenceModel, which will parse the model and load it into an IRuntime.
telsoa014fcda012018-03-09 14:13:49 +0000200 typename InferenceModel<TParser, TDataType>::Params params;
201 params.m_ModelPath = modelPath;
202 params.m_IsModelBinary = isModelBinary;
203 params.m_ComputeDevice = computeDevice;
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000204 params.m_InputBindings = { inputName };
205 params.m_InputShapes = { *inputTensorShape };
206 params.m_OutputBindings = { outputName };
telsoa01c577f2c2018-08-31 09:22:23 +0100207 params.m_EnableProfiling = enableProfiling;
208 params.m_SubgraphId = subgraphId;
209 InferenceModel<TParser, TDataType> model(params, runtime);
telsoa014fcda012018-03-09 14:13:49 +0000210
Aron Virginas-Tar9b937472019-01-30 17:41:47 +0000211 const size_t numOutputs = params.m_OutputBindings.size();
212 const size_t containerSize = model.GetOutputSize();
213
Aron Virginas-Tar93f5f972019-01-31 13:12:34 +0000214 // Set up input data container
215 std::vector<TContainer> inputData(1, std::move(inputDataContainer));
216
217 // Set up output data container
218 std::vector<TContainer> outputData(numOutputs, TContainer(containerSize));
Aron Virginas-Tar9b937472019-01-30 17:41:47 +0000219
220 // Execute model
Aron Virginas-Tar93f5f972019-01-31 13:12:34 +0000221 model.Run(inputData, outputData);
telsoa014fcda012018-03-09 14:13:49 +0000222
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000223 // Print output tensors
224 for (size_t i = 0; i < numOutputs; i++)
225 {
Aron Virginas-Tar93f5f972019-01-31 13:12:34 +0000226 PrintOutputData(params.m_OutputBindings[i], outputData[i]);
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000227 }
telsoa014fcda012018-03-09 14:13:49 +0000228 }
229 catch (armnn::Exception const& e)
230 {
231 BOOST_LOG_TRIVIAL(fatal) << "Armnn Error: " << e.what();
telsoa01c577f2c2018-08-31 09:22:23 +0100232 return EXIT_FAILURE;
telsoa014fcda012018-03-09 14:13:49 +0000233 }
234
telsoa01c577f2c2018-08-31 09:22:23 +0100235 return EXIT_SUCCESS;
telsoa014fcda012018-03-09 14:13:49 +0000236}
237
telsoa01c577f2c2018-08-31 09:22:23 +0100238// This will run a test
239int RunTest(const std::string& modelFormat,
240 const std::string& inputTensorShapeStr,
David Beckf0b48452018-10-19 15:20:56 +0100241 const vector<armnn::BackendId>& computeDevice,
telsoa01c577f2c2018-08-31 09:22:23 +0100242 const std::string& modelPath,
243 const std::string& inputName,
244 const std::string& inputTensorDataFilePath,
245 const std::string& outputName,
246 bool enableProfiling,
247 const size_t subgraphId,
248 const std::shared_ptr<armnn::IRuntime>& runtime = nullptr)
telsoa014fcda012018-03-09 14:13:49 +0000249{
telsoa014fcda012018-03-09 14:13:49 +0000250 // Parse model binary flag from the model-format string we got from the command-line
251 bool isModelBinary;
252 if (modelFormat.find("bin") != std::string::npos)
253 {
254 isModelBinary = true;
255 }
256 else if (modelFormat.find("txt") != std::string::npos || modelFormat.find("text") != std::string::npos)
257 {
258 isModelBinary = false;
259 }
260 else
261 {
262 BOOST_LOG_TRIVIAL(fatal) << "Unknown model format: '" << modelFormat << "'. Please include 'binary' or 'text'";
telsoa01c577f2c2018-08-31 09:22:23 +0100263 return EXIT_FAILURE;
telsoa014fcda012018-03-09 14:13:49 +0000264 }
265
266 // Parse input tensor shape from the string we got from the command-line.
267 std::unique_ptr<armnn::TensorShape> inputTensorShape;
268 if (!inputTensorShapeStr.empty())
269 {
270 std::stringstream ss(inputTensorShapeStr);
271 std::vector<unsigned int> dims = ParseArray<unsigned int>(ss);
surmeh013537c2c2018-05-18 16:31:43 +0100272
273 try
274 {
275 // Coverity fix: An exception of type armnn::InvalidArgumentException is thrown and never caught.
276 inputTensorShape = std::make_unique<armnn::TensorShape>(dims.size(), dims.data());
277 }
278 catch (const armnn::InvalidArgumentException& e)
279 {
280 BOOST_LOG_TRIVIAL(fatal) << "Cannot create tensor shape: " << e.what();
telsoa01c577f2c2018-08-31 09:22:23 +0100281 return EXIT_FAILURE;
surmeh013537c2c2018-05-18 16:31:43 +0100282 }
telsoa014fcda012018-03-09 14:13:49 +0000283 }
284
285 // Forward to implementation based on the parser type
286 if (modelFormat.find("caffe") != std::string::npos)
287 {
288#if defined(ARMNN_CAFFE_PARSER)
289 return MainImpl<armnnCaffeParser::ICaffeParser, float>(modelPath.c_str(), isModelBinary, computeDevice,
telsoa01c577f2c2018-08-31 09:22:23 +0100290 inputName.c_str(), inputTensorShape.get(),
291 inputTensorDataFilePath.c_str(), outputName.c_str(),
292 enableProfiling, subgraphId, runtime);
telsoa014fcda012018-03-09 14:13:49 +0000293#else
294 BOOST_LOG_TRIVIAL(fatal) << "Not built with Caffe parser support.";
telsoa01c577f2c2018-08-31 09:22:23 +0100295 return EXIT_FAILURE;
296#endif
297 }
298 else if (modelFormat.find("onnx") != std::string::npos)
299{
300#if defined(ARMNN_ONNX_PARSER)
301 return MainImpl<armnnOnnxParser::IOnnxParser, float>(modelPath.c_str(), isModelBinary, computeDevice,
302 inputName.c_str(), inputTensorShape.get(),
303 inputTensorDataFilePath.c_str(), outputName.c_str(),
304 enableProfiling, subgraphId, runtime);
305#else
306 BOOST_LOG_TRIVIAL(fatal) << "Not built with Onnx parser support.";
307 return EXIT_FAILURE;
telsoa014fcda012018-03-09 14:13:49 +0000308#endif
309 }
310 else if (modelFormat.find("tensorflow") != std::string::npos)
311 {
surmeh01bceff2f2018-03-29 16:29:27 +0100312#if defined(ARMNN_TF_PARSER)
313 return MainImpl<armnnTfParser::ITfParser, float>(modelPath.c_str(), isModelBinary, computeDevice,
telsoa01c577f2c2018-08-31 09:22:23 +0100314 inputName.c_str(), inputTensorShape.get(),
315 inputTensorDataFilePath.c_str(), outputName.c_str(),
316 enableProfiling, subgraphId, runtime);
surmeh01bceff2f2018-03-29 16:29:27 +0100317#else
telsoa014fcda012018-03-09 14:13:49 +0000318 BOOST_LOG_TRIVIAL(fatal) << "Not built with Tensorflow parser support.";
telsoa01c577f2c2018-08-31 09:22:23 +0100319 return EXIT_FAILURE;
320#endif
321 }
322 else if(modelFormat.find("tflite") != std::string::npos)
323 {
324#if defined(ARMNN_TF_LITE_PARSER)
325 if (! isModelBinary)
326 {
327 BOOST_LOG_TRIVIAL(fatal) << "Unknown model format: '" << modelFormat << "'. Only 'binary' format supported \
328 for tflite files";
329 return EXIT_FAILURE;
330 }
331 return MainImpl<armnnTfLiteParser::ITfLiteParser, float>(modelPath.c_str(), isModelBinary, computeDevice,
332 inputName.c_str(), inputTensorShape.get(),
333 inputTensorDataFilePath.c_str(), outputName.c_str(),
334 enableProfiling, subgraphId, runtime);
335#else
336 BOOST_LOG_TRIVIAL(fatal) << "Unknown model format: '" << modelFormat <<
337 "'. Please include 'caffe', 'tensorflow', 'tflite' or 'onnx'";
338 return EXIT_FAILURE;
surmeh01bceff2f2018-03-29 16:29:27 +0100339#endif
telsoa014fcda012018-03-09 14:13:49 +0000340 }
341 else
342 {
343 BOOST_LOG_TRIVIAL(fatal) << "Unknown model format: '" << modelFormat <<
telsoa01c577f2c2018-08-31 09:22:23 +0100344 "'. Please include 'caffe', 'tensorflow', 'tflite' or 'onnx'";
345 return EXIT_FAILURE;
346 }
347}
348
349int RunCsvTest(const armnnUtils::CsvRow &csvRow,
Nina Drozd549ae372018-09-10 14:26:44 +0100350 const std::shared_ptr<armnn::IRuntime>& runtime, const bool enableProfiling)
telsoa01c577f2c2018-08-31 09:22:23 +0100351{
352 std::string modelFormat;
353 std::string modelPath;
354 std::string inputName;
355 std::string inputTensorShapeStr;
356 std::string inputTensorDataFilePath;
357 std::string outputName;
358
359 size_t subgraphId = 0;
360
Aron Virginas-Tar5cc8e562018-10-23 15:14:46 +0100361 const std::string backendsMessage = std::string("The preferred order of devices to run layers on by default. ")
362 + std::string("Possible choices: ")
363 + armnn::BackendRegistryInstance().GetBackendIdsAsString();
364
telsoa01c577f2c2018-08-31 09:22:23 +0100365 po::options_description desc("Options");
366 try
367 {
368 desc.add_options()
369 ("model-format,f", po::value(&modelFormat),
370 "caffe-binary, caffe-text, tflite-binary, onnx-binary, onnx-text, tensorflow-binary or tensorflow-text.")
371 ("model-path,m", po::value(&modelPath), "Path to model file, e.g. .caffemodel, .prototxt, .tflite,"
372 " .onnx")
David Beckf0b48452018-10-19 15:20:56 +0100373 ("compute,c", po::value<std::vector<armnn::BackendId>>()->multitoken(),
Aron Virginas-Tar5cc8e562018-10-23 15:14:46 +0100374 backendsMessage.c_str())
telsoa01c577f2c2018-08-31 09:22:23 +0100375 ("input-name,i", po::value(&inputName), "Identifier of the input tensor in the network.")
376 ("subgraph-number,n", po::value<size_t>(&subgraphId)->default_value(0), "Id of the subgraph to be "
377 "executed. Defaults to 0")
378 ("input-tensor-shape,s", po::value(&inputTensorShapeStr),
379 "The shape of the input tensor in the network as a flat array of integers separated by whitespace. "
380 "This parameter is optional, depending on the network.")
381 ("input-tensor-data,d", po::value(&inputTensorDataFilePath),
382 "Path to a file containing the input data as a flat array separated by whitespace.")
Nina Drozd549ae372018-09-10 14:26:44 +0100383 ("output-name,o", po::value(&outputName), "Identifier of the output tensor in the network.");
telsoa01c577f2c2018-08-31 09:22:23 +0100384 }
385 catch (const std::exception& e)
386 {
387 // Coverity points out that default_value(...) can throw a bad_lexical_cast,
388 // and that desc.add_options() can throw boost::io::too_few_args.
389 // They really won't in any of these cases.
390 BOOST_ASSERT_MSG(false, "Caught unexpected exception");
391 BOOST_LOG_TRIVIAL(fatal) << "Fatal internal error: " << e.what();
392 return EXIT_FAILURE;
393 }
394
395 std::vector<const char*> clOptions;
396 clOptions.reserve(csvRow.values.size());
397 for (const std::string& value : csvRow.values)
398 {
399 clOptions.push_back(value.c_str());
400 }
401
402 po::variables_map vm;
403 try
404 {
405 po::store(po::parse_command_line(static_cast<int>(clOptions.size()), clOptions.data(), desc), vm);
406
407 po::notify(vm);
408
409 CheckOptionDependencies(vm);
410 }
411 catch (const po::error& e)
412 {
413 std::cerr << e.what() << std::endl << std::endl;
414 std::cerr << desc << std::endl;
415 return EXIT_FAILURE;
416 }
417
418 // Remove leading and trailing whitespaces from the parsed arguments.
419 boost::trim(modelFormat);
420 boost::trim(modelPath);
421 boost::trim(inputName);
422 boost::trim(inputTensorShapeStr);
423 boost::trim(inputTensorDataFilePath);
424 boost::trim(outputName);
425
telsoa01c577f2c2018-08-31 09:22:23 +0100426 // Get the preferred order of compute devices.
David Beckf0b48452018-10-19 15:20:56 +0100427 std::vector<armnn::BackendId> computeDevices = vm["compute"].as<std::vector<armnn::BackendId>>();
telsoa01c577f2c2018-08-31 09:22:23 +0100428
429 // Remove duplicates from the list of compute devices.
430 RemoveDuplicateDevices(computeDevices);
431
432 // Check that the specified compute devices are valid.
Aron Virginas-Tar5cc8e562018-10-23 15:14:46 +0100433 std::string invalidBackends;
434 if (!CheckRequestedBackendsAreValid(computeDevices, armnn::Optional<std::string&>(invalidBackends)))
telsoa01c577f2c2018-08-31 09:22:23 +0100435 {
Aron Virginas-Tar5cc8e562018-10-23 15:14:46 +0100436 BOOST_LOG_TRIVIAL(fatal) << "The list of preferred devices contains invalid backend IDs: "
437 << invalidBackends;
telsoa01c577f2c2018-08-31 09:22:23 +0100438 return EXIT_FAILURE;
439 }
440
441 return RunTest(modelFormat, inputTensorShapeStr, computeDevices,
442 modelPath, inputName, inputTensorDataFilePath, outputName, enableProfiling, subgraphId, runtime);
443}
444
445int main(int argc, const char* argv[])
446{
447 // Configures logging for both the ARMNN library and this test program.
448#ifdef NDEBUG
449 armnn::LogSeverity level = armnn::LogSeverity::Info;
450#else
451 armnn::LogSeverity level = armnn::LogSeverity::Debug;
452#endif
453 armnn::ConfigureLogging(true, true, level);
454 armnnUtils::ConfigureLogging(boost::log::core::get().get(), true, true, level);
455
456 std::string testCasesFile;
457
458 std::string modelFormat;
459 std::string modelPath;
460 std::string inputName;
461 std::string inputTensorShapeStr;
462 std::string inputTensorDataFilePath;
463 std::string outputName;
464
465 size_t subgraphId = 0;
466
Aron Virginas-Tar5cc8e562018-10-23 15:14:46 +0100467 const std::string backendsMessage = "Which device to run layers on by default. Possible choices: "
468 + armnn::BackendRegistryInstance().GetBackendIdsAsString();
469
telsoa01c577f2c2018-08-31 09:22:23 +0100470 po::options_description desc("Options");
471 try
472 {
473 desc.add_options()
474 ("help", "Display usage information")
475 ("test-cases,t", po::value(&testCasesFile), "Path to a CSV file containing test cases to run. "
476 "If set, further parameters -- with the exception of compute device and concurrency -- will be ignored, "
477 "as they are expected to be defined in the file for each test in particular.")
478 ("concurrent,n", po::bool_switch()->default_value(false),
479 "Whether or not the test cases should be executed in parallel")
Matteo Martincigh49124022019-01-11 13:25:59 +0000480 ("model-format,f", po::value(&modelFormat)->required(),
telsoa01c577f2c2018-08-31 09:22:23 +0100481 "caffe-binary, caffe-text, onnx-binary, onnx-text, tflite-binary, tensorflow-binary or tensorflow-text.")
Matteo Martincigh49124022019-01-11 13:25:59 +0000482 ("model-path,m", po::value(&modelPath)->required(), "Path to model file, e.g. .caffemodel, .prototxt,"
telsoa01c577f2c2018-08-31 09:22:23 +0100483 " .tflite, .onnx")
David Beckf0b48452018-10-19 15:20:56 +0100484 ("compute,c", po::value<std::vector<std::string>>()->multitoken(),
Aron Virginas-Tar5cc8e562018-10-23 15:14:46 +0100485 backendsMessage.c_str())
telsoa01c577f2c2018-08-31 09:22:23 +0100486 ("input-name,i", po::value(&inputName), "Identifier of the input tensor in the network.")
487 ("subgraph-number,x", po::value<size_t>(&subgraphId)->default_value(0), "Id of the subgraph to be executed."
488 "Defaults to 0")
489 ("input-tensor-shape,s", po::value(&inputTensorShapeStr),
490 "The shape of the input tensor in the network as a flat array of integers separated by whitespace. "
491 "This parameter is optional, depending on the network.")
492 ("input-tensor-data,d", po::value(&inputTensorDataFilePath),
493 "Path to a file containing the input data as a flat array separated by whitespace.")
494 ("output-name,o", po::value(&outputName), "Identifier of the output tensor in the network.")
495 ("event-based-profiling,e", po::bool_switch()->default_value(false),
496 "Enables built in profiler. If unset, defaults to off.");
497 }
498 catch (const std::exception& e)
499 {
500 // Coverity points out that default_value(...) can throw a bad_lexical_cast,
501 // and that desc.add_options() can throw boost::io::too_few_args.
502 // They really won't in any of these cases.
503 BOOST_ASSERT_MSG(false, "Caught unexpected exception");
504 BOOST_LOG_TRIVIAL(fatal) << "Fatal internal error: " << e.what();
505 return EXIT_FAILURE;
506 }
507
508 // Parses the command-line.
509 po::variables_map vm;
510 try
511 {
512 po::store(po::parse_command_line(argc, argv, desc), vm);
513
514 if (CheckOption(vm, "help") || argc <= 1)
515 {
516 std::cout << "Executes a neural network model using the provided input tensor. " << std::endl;
517 std::cout << "Prints the resulting output tensor." << std::endl;
518 std::cout << std::endl;
519 std::cout << desc << std::endl;
520 return EXIT_SUCCESS;
521 }
522
523 po::notify(vm);
524 }
525 catch (const po::error& e)
526 {
527 std::cerr << e.what() << std::endl << std::endl;
528 std::cerr << desc << std::endl;
529 return EXIT_FAILURE;
530 }
531
532 // Get the value of the switch arguments.
533 bool concurrent = vm["concurrent"].as<bool>();
534 bool enableProfiling = vm["event-based-profiling"].as<bool>();
535
536 // Check whether we have to load test cases from a file.
537 if (CheckOption(vm, "test-cases"))
538 {
539 // Check that the file exists.
540 if (!boost::filesystem::exists(testCasesFile))
541 {
542 BOOST_LOG_TRIVIAL(fatal) << "Given file \"" << testCasesFile << "\" does not exist";
543 return EXIT_FAILURE;
544 }
545
546 // Parse CSV file and extract test cases
547 armnnUtils::CsvReader reader;
548 std::vector<armnnUtils::CsvRow> testCases = reader.ParseFile(testCasesFile);
549
550 // Check that there is at least one test case to run
551 if (testCases.empty())
552 {
553 BOOST_LOG_TRIVIAL(fatal) << "Given file \"" << testCasesFile << "\" has no test cases";
554 return EXIT_FAILURE;
555 }
556
557 // Create runtime
558 armnn::IRuntime::CreationOptions options;
Nina Drozd549ae372018-09-10 14:26:44 +0100559 options.m_EnableGpuProfiling = enableProfiling;
560
telsoa01c577f2c2018-08-31 09:22:23 +0100561 std::shared_ptr<armnn::IRuntime> runtime(armnn::IRuntime::Create(options));
562
563 const std::string executableName("ExecuteNetwork");
564
565 // Check whether we need to run the test cases concurrently
566 if (concurrent)
567 {
568 std::vector<std::future<int>> results;
569 results.reserve(testCases.size());
570
571 // Run each test case in its own thread
572 for (auto& testCase : testCases)
573 {
574 testCase.values.insert(testCase.values.begin(), executableName);
Nina Drozd549ae372018-09-10 14:26:44 +0100575 results.push_back(std::async(std::launch::async, RunCsvTest, std::cref(testCase), std::cref(runtime),
576 enableProfiling));
telsoa01c577f2c2018-08-31 09:22:23 +0100577 }
578
579 // Check results
580 for (auto& result : results)
581 {
582 if (result.get() != EXIT_SUCCESS)
583 {
584 return EXIT_FAILURE;
585 }
586 }
587 }
588 else
589 {
590 // Run tests sequentially
591 for (auto& testCase : testCases)
592 {
593 testCase.values.insert(testCase.values.begin(), executableName);
Nina Drozd549ae372018-09-10 14:26:44 +0100594 if (RunCsvTest(testCase, runtime, enableProfiling) != EXIT_SUCCESS)
telsoa01c577f2c2018-08-31 09:22:23 +0100595 {
596 return EXIT_FAILURE;
597 }
598 }
599 }
600
601 return EXIT_SUCCESS;
602 }
603 else // Run single test
604 {
Aron Virginas-Tar382e21c2019-01-22 14:10:39 +0000605 // Get the preferred order of compute devices. If none are specified, default to using CpuRef
606 const std::string computeOption("compute");
607 std::vector<std::string> computeDevicesAsStrings = CheckOption(vm, computeOption.c_str()) ?
608 vm[computeOption].as<std::vector<std::string>>() :
609 std::vector<std::string>({ "CpuRef" });
Matteo Martincigh067112f2018-10-29 11:01:09 +0000610 std::vector<armnn::BackendId> computeDevices(computeDevicesAsStrings.begin(), computeDevicesAsStrings.end());
telsoa01c577f2c2018-08-31 09:22:23 +0100611
612 // Remove duplicates from the list of compute devices.
613 RemoveDuplicateDevices(computeDevices);
614
615 // Check that the specified compute devices are valid.
Aron Virginas-Tar5cc8e562018-10-23 15:14:46 +0100616 std::string invalidBackends;
617 if (!CheckRequestedBackendsAreValid(computeDevices, armnn::Optional<std::string&>(invalidBackends)))
telsoa01c577f2c2018-08-31 09:22:23 +0100618 {
Aron Virginas-Tar5cc8e562018-10-23 15:14:46 +0100619 BOOST_LOG_TRIVIAL(fatal) << "The list of preferred devices contains invalid backend IDs: "
620 << invalidBackends;
telsoa01c577f2c2018-08-31 09:22:23 +0100621 return EXIT_FAILURE;
622 }
623
624 try
625 {
626 CheckOptionDependencies(vm);
627 }
628 catch (const po::error& e)
629 {
630 std::cerr << e.what() << std::endl << std::endl;
631 std::cerr << desc << std::endl;
632 return EXIT_FAILURE;
633 }
634
635 return RunTest(modelFormat, inputTensorShapeStr, computeDevices,
636 modelPath, inputName, inputTensorDataFilePath, outputName, enableProfiling, subgraphId);
telsoa014fcda012018-03-09 14:13:49 +0000637 }
638}