blob: d783a0e2cfc829a548127eb467190197c13268fd [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Aron Virginas-Tar5cc8e562018-10-23 15:14:46 +01005#include <armnn/ArmNN.hpp>
telsoa01c577f2c2018-08-31 09:22:23 +01006#include <armnn/TypesUtils.hpp>
7
telsoa014fcda012018-03-09 14:13:49 +00008#if defined(ARMNN_CAFFE_PARSER)
9#include "armnnCaffeParser/ICaffeParser.hpp"
10#endif
surmeh01bceff2f2018-03-29 16:29:27 +010011#if defined(ARMNN_TF_PARSER)
12#include "armnnTfParser/ITfParser.hpp"
13#endif
telsoa01c577f2c2018-08-31 09:22:23 +010014#if defined(ARMNN_TF_LITE_PARSER)
15#include "armnnTfLiteParser/ITfLiteParser.hpp"
16#endif
17#if defined(ARMNN_ONNX_PARSER)
18#include "armnnOnnxParser/IOnnxParser.hpp"
19#endif
20#include "CsvReader.hpp"
telsoa014fcda012018-03-09 14:13:49 +000021#include "../InferenceTest.hpp"
22
telsoa01c577f2c2018-08-31 09:22:23 +010023#include <Logging.hpp>
24#include <Profiling.hpp>
25
26#include <boost/algorithm/string/trim.hpp>
telsoa014fcda012018-03-09 14:13:49 +000027#include <boost/algorithm/string/split.hpp>
28#include <boost/algorithm/string/classification.hpp>
telsoa01c577f2c2018-08-31 09:22:23 +010029#include <boost/program_options.hpp>
telsoa014fcda012018-03-09 14:13:49 +000030
31#include <iostream>
32#include <fstream>
telsoa01c577f2c2018-08-31 09:22:23 +010033#include <functional>
34#include <future>
35#include <algorithm>
36#include <iterator>
telsoa014fcda012018-03-09 14:13:49 +000037
38namespace
39{
40
telsoa01c577f2c2018-08-31 09:22:23 +010041// Configure boost::program_options for command-line parsing and validation.
42namespace po = boost::program_options;
43
telsoa014fcda012018-03-09 14:13:49 +000044template<typename T, typename TParseElementFunc>
45std::vector<T> ParseArrayImpl(std::istream& stream, TParseElementFunc parseElementFunc)
46{
47 std::vector<T> result;
telsoa01c577f2c2018-08-31 09:22:23 +010048 // Processes line-by-line.
telsoa014fcda012018-03-09 14:13:49 +000049 std::string line;
50 while (std::getline(stream, line))
51 {
52 std::vector<std::string> tokens;
surmeh013537c2c2018-05-18 16:31:43 +010053 try
54 {
55 // Coverity fix: boost::split() may throw an exception of type boost::bad_function_call.
56 boost::split(tokens, line, boost::algorithm::is_any_of("\t ,;:"), boost::token_compress_on);
57 }
58 catch (const std::exception& e)
59 {
60 BOOST_LOG_TRIVIAL(error) << "An error occurred when splitting tokens: " << e.what();
61 continue;
62 }
telsoa014fcda012018-03-09 14:13:49 +000063 for (const std::string& token : tokens)
64 {
65 if (!token.empty()) // See https://stackoverflow.com/questions/10437406/
66 {
67 try
68 {
69 result.push_back(parseElementFunc(token));
70 }
71 catch (const std::exception&)
72 {
73 BOOST_LOG_TRIVIAL(error) << "'" << token << "' is not a valid number. It has been ignored.";
74 }
75 }
76 }
77 }
78
79 return result;
80}
81
telsoa01c577f2c2018-08-31 09:22:23 +010082bool CheckOption(const po::variables_map& vm,
83 const char* option)
84{
85 // Check that the given option is valid.
86 if (option == nullptr)
87 {
88 return false;
89 }
90
91 // Check whether 'option' is provided.
92 return vm.find(option) != vm.end();
93}
94
95void CheckOptionDependency(const po::variables_map& vm,
96 const char* option,
97 const char* required)
98{
99 // Check that the given options are valid.
100 if (option == nullptr || required == nullptr)
101 {
102 throw po::error("Invalid option to check dependency for");
103 }
104
105 // Check that if 'option' is provided, 'required' is also provided.
106 if (CheckOption(vm, option) && !vm[option].defaulted())
107 {
108 if (CheckOption(vm, required) == 0 || vm[required].defaulted())
109 {
110 throw po::error(std::string("Option '") + option + "' requires option '" + required + "'.");
111 }
112 }
113}
114
115void CheckOptionDependencies(const po::variables_map& vm)
116{
117 CheckOptionDependency(vm, "model-path", "model-format");
118 CheckOptionDependency(vm, "model-path", "input-name");
119 CheckOptionDependency(vm, "model-path", "input-tensor-data");
120 CheckOptionDependency(vm, "model-path", "output-name");
121 CheckOptionDependency(vm, "input-tensor-shape", "model-path");
telsoa014fcda012018-03-09 14:13:49 +0000122}
123
124template<typename T>
125std::vector<T> ParseArray(std::istream& stream);
126
127template<>
128std::vector<float> ParseArray(std::istream& stream)
129{
130 return ParseArrayImpl<float>(stream, [](const std::string& s) { return std::stof(s); });
131}
132
133template<>
134std::vector<unsigned int> ParseArray(std::istream& stream)
135{
136 return ParseArrayImpl<unsigned int>(stream,
137 [](const std::string& s) { return boost::numeric_cast<unsigned int>(std::stoi(s)); });
138}
139
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000140void PrintOutputData(const std::string& outputLayerName, const std::vector<float>& data)
telsoa014fcda012018-03-09 14:13:49 +0000141{
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000142 std::cout << outputLayerName << ": ";
143 std::copy(data.begin(), data.end(),
144 std::ostream_iterator<float>(std::cout, " "));
145 std::cout << std::endl;
telsoa014fcda012018-03-09 14:13:49 +0000146}
147
David Beckf0b48452018-10-19 15:20:56 +0100148void RemoveDuplicateDevices(std::vector<armnn::BackendId>& computeDevices)
telsoa014fcda012018-03-09 14:13:49 +0000149{
telsoa01c577f2c2018-08-31 09:22:23 +0100150 // Mark the duplicate devices as 'Undefined'.
151 for (auto i = computeDevices.begin(); i != computeDevices.end(); ++i)
152 {
153 for (auto j = std::next(i); j != computeDevices.end(); ++j)
154 {
155 if (*j == *i)
156 {
157 *j = armnn::Compute::Undefined;
158 }
159 }
160 }
161
162 // Remove 'Undefined' devices.
163 computeDevices.erase(std::remove(computeDevices.begin(), computeDevices.end(), armnn::Compute::Undefined),
164 computeDevices.end());
165}
166
telsoa01c577f2c2018-08-31 09:22:23 +0100167} // namespace
168
169template<typename TParser, typename TDataType>
170int MainImpl(const char* modelPath,
171 bool isModelBinary,
David Beckf0b48452018-10-19 15:20:56 +0100172 const std::vector<armnn::BackendId>& computeDevice,
telsoa01c577f2c2018-08-31 09:22:23 +0100173 const char* inputName,
174 const armnn::TensorShape* inputTensorShape,
175 const char* inputTensorDataFilePath,
176 const char* outputName,
177 bool enableProfiling,
178 const size_t subgraphId,
179 const std::shared_ptr<armnn::IRuntime>& runtime = nullptr)
180{
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000181 using TContainer = std::vector<TDataType>;
182
telsoa01c577f2c2018-08-31 09:22:23 +0100183 // Loads input tensor.
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000184 TContainer inputDataContainer;
telsoa014fcda012018-03-09 14:13:49 +0000185 {
186 std::ifstream inputTensorFile(inputTensorDataFilePath);
187 if (!inputTensorFile.good())
188 {
189 BOOST_LOG_TRIVIAL(fatal) << "Failed to load input tensor data file from " << inputTensorDataFilePath;
telsoa01c577f2c2018-08-31 09:22:23 +0100190 return EXIT_FAILURE;
telsoa014fcda012018-03-09 14:13:49 +0000191 }
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000192 inputDataContainer = ParseArray<TDataType>(inputTensorFile);
telsoa014fcda012018-03-09 14:13:49 +0000193 }
194
195 try
196 {
telsoa01c577f2c2018-08-31 09:22:23 +0100197 // Creates an InferenceModel, which will parse the model and load it into an IRuntime.
telsoa014fcda012018-03-09 14:13:49 +0000198 typename InferenceModel<TParser, TDataType>::Params params;
199 params.m_ModelPath = modelPath;
200 params.m_IsModelBinary = isModelBinary;
201 params.m_ComputeDevice = computeDevice;
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000202 params.m_InputBindings = { inputName };
203 params.m_InputShapes = { *inputTensorShape };
204 params.m_OutputBindings = { outputName };
telsoa01c577f2c2018-08-31 09:22:23 +0100205 params.m_EnableProfiling = enableProfiling;
206 params.m_SubgraphId = subgraphId;
207 InferenceModel<TParser, TDataType> model(params, runtime);
telsoa014fcda012018-03-09 14:13:49 +0000208
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000209 // Executes the model
210 const size_t numOutputs = params.m_OutputBindings.size();
211 std::vector<TContainer> outputDataContainers(numOutputs);
212 model.Run({ inputDataContainer }, outputDataContainers);
telsoa014fcda012018-03-09 14:13:49 +0000213
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000214 // Print output tensors
215 for (size_t i = 0; i < numOutputs; i++)
216 {
217 PrintOutputData(params.m_OutputBindings[i], outputDataContainers[i]);
218 }
telsoa014fcda012018-03-09 14:13:49 +0000219 }
220 catch (armnn::Exception const& e)
221 {
222 BOOST_LOG_TRIVIAL(fatal) << "Armnn Error: " << e.what();
telsoa01c577f2c2018-08-31 09:22:23 +0100223 return EXIT_FAILURE;
telsoa014fcda012018-03-09 14:13:49 +0000224 }
225
telsoa01c577f2c2018-08-31 09:22:23 +0100226 return EXIT_SUCCESS;
telsoa014fcda012018-03-09 14:13:49 +0000227}
228
telsoa01c577f2c2018-08-31 09:22:23 +0100229// This will run a test
230int RunTest(const std::string& modelFormat,
231 const std::string& inputTensorShapeStr,
David Beckf0b48452018-10-19 15:20:56 +0100232 const vector<armnn::BackendId>& computeDevice,
telsoa01c577f2c2018-08-31 09:22:23 +0100233 const std::string& modelPath,
234 const std::string& inputName,
235 const std::string& inputTensorDataFilePath,
236 const std::string& outputName,
237 bool enableProfiling,
238 const size_t subgraphId,
239 const std::shared_ptr<armnn::IRuntime>& runtime = nullptr)
telsoa014fcda012018-03-09 14:13:49 +0000240{
telsoa014fcda012018-03-09 14:13:49 +0000241 // Parse model binary flag from the model-format string we got from the command-line
242 bool isModelBinary;
243 if (modelFormat.find("bin") != std::string::npos)
244 {
245 isModelBinary = true;
246 }
247 else if (modelFormat.find("txt") != std::string::npos || modelFormat.find("text") != std::string::npos)
248 {
249 isModelBinary = false;
250 }
251 else
252 {
253 BOOST_LOG_TRIVIAL(fatal) << "Unknown model format: '" << modelFormat << "'. Please include 'binary' or 'text'";
telsoa01c577f2c2018-08-31 09:22:23 +0100254 return EXIT_FAILURE;
telsoa014fcda012018-03-09 14:13:49 +0000255 }
256
257 // Parse input tensor shape from the string we got from the command-line.
258 std::unique_ptr<armnn::TensorShape> inputTensorShape;
259 if (!inputTensorShapeStr.empty())
260 {
261 std::stringstream ss(inputTensorShapeStr);
262 std::vector<unsigned int> dims = ParseArray<unsigned int>(ss);
surmeh013537c2c2018-05-18 16:31:43 +0100263
264 try
265 {
266 // Coverity fix: An exception of type armnn::InvalidArgumentException is thrown and never caught.
267 inputTensorShape = std::make_unique<armnn::TensorShape>(dims.size(), dims.data());
268 }
269 catch (const armnn::InvalidArgumentException& e)
270 {
271 BOOST_LOG_TRIVIAL(fatal) << "Cannot create tensor shape: " << e.what();
telsoa01c577f2c2018-08-31 09:22:23 +0100272 return EXIT_FAILURE;
surmeh013537c2c2018-05-18 16:31:43 +0100273 }
telsoa014fcda012018-03-09 14:13:49 +0000274 }
275
276 // Forward to implementation based on the parser type
277 if (modelFormat.find("caffe") != std::string::npos)
278 {
279#if defined(ARMNN_CAFFE_PARSER)
280 return MainImpl<armnnCaffeParser::ICaffeParser, float>(modelPath.c_str(), isModelBinary, computeDevice,
telsoa01c577f2c2018-08-31 09:22:23 +0100281 inputName.c_str(), inputTensorShape.get(),
282 inputTensorDataFilePath.c_str(), outputName.c_str(),
283 enableProfiling, subgraphId, runtime);
telsoa014fcda012018-03-09 14:13:49 +0000284#else
285 BOOST_LOG_TRIVIAL(fatal) << "Not built with Caffe parser support.";
telsoa01c577f2c2018-08-31 09:22:23 +0100286 return EXIT_FAILURE;
287#endif
288 }
289 else if (modelFormat.find("onnx") != std::string::npos)
290{
291#if defined(ARMNN_ONNX_PARSER)
292 return MainImpl<armnnOnnxParser::IOnnxParser, float>(modelPath.c_str(), isModelBinary, computeDevice,
293 inputName.c_str(), inputTensorShape.get(),
294 inputTensorDataFilePath.c_str(), outputName.c_str(),
295 enableProfiling, subgraphId, runtime);
296#else
297 BOOST_LOG_TRIVIAL(fatal) << "Not built with Onnx parser support.";
298 return EXIT_FAILURE;
telsoa014fcda012018-03-09 14:13:49 +0000299#endif
300 }
301 else if (modelFormat.find("tensorflow") != std::string::npos)
302 {
surmeh01bceff2f2018-03-29 16:29:27 +0100303#if defined(ARMNN_TF_PARSER)
304 return MainImpl<armnnTfParser::ITfParser, float>(modelPath.c_str(), isModelBinary, computeDevice,
telsoa01c577f2c2018-08-31 09:22:23 +0100305 inputName.c_str(), inputTensorShape.get(),
306 inputTensorDataFilePath.c_str(), outputName.c_str(),
307 enableProfiling, subgraphId, runtime);
surmeh01bceff2f2018-03-29 16:29:27 +0100308#else
telsoa014fcda012018-03-09 14:13:49 +0000309 BOOST_LOG_TRIVIAL(fatal) << "Not built with Tensorflow parser support.";
telsoa01c577f2c2018-08-31 09:22:23 +0100310 return EXIT_FAILURE;
311#endif
312 }
313 else if(modelFormat.find("tflite") != std::string::npos)
314 {
315#if defined(ARMNN_TF_LITE_PARSER)
316 if (! isModelBinary)
317 {
318 BOOST_LOG_TRIVIAL(fatal) << "Unknown model format: '" << modelFormat << "'. Only 'binary' format supported \
319 for tflite files";
320 return EXIT_FAILURE;
321 }
322 return MainImpl<armnnTfLiteParser::ITfLiteParser, float>(modelPath.c_str(), isModelBinary, computeDevice,
323 inputName.c_str(), inputTensorShape.get(),
324 inputTensorDataFilePath.c_str(), outputName.c_str(),
325 enableProfiling, subgraphId, runtime);
326#else
327 BOOST_LOG_TRIVIAL(fatal) << "Unknown model format: '" << modelFormat <<
328 "'. Please include 'caffe', 'tensorflow', 'tflite' or 'onnx'";
329 return EXIT_FAILURE;
surmeh01bceff2f2018-03-29 16:29:27 +0100330#endif
telsoa014fcda012018-03-09 14:13:49 +0000331 }
332 else
333 {
334 BOOST_LOG_TRIVIAL(fatal) << "Unknown model format: '" << modelFormat <<
telsoa01c577f2c2018-08-31 09:22:23 +0100335 "'. Please include 'caffe', 'tensorflow', 'tflite' or 'onnx'";
336 return EXIT_FAILURE;
337 }
338}
339
340int RunCsvTest(const armnnUtils::CsvRow &csvRow,
Nina Drozd549ae372018-09-10 14:26:44 +0100341 const std::shared_ptr<armnn::IRuntime>& runtime, const bool enableProfiling)
telsoa01c577f2c2018-08-31 09:22:23 +0100342{
343 std::string modelFormat;
344 std::string modelPath;
345 std::string inputName;
346 std::string inputTensorShapeStr;
347 std::string inputTensorDataFilePath;
348 std::string outputName;
349
350 size_t subgraphId = 0;
351
Aron Virginas-Tar5cc8e562018-10-23 15:14:46 +0100352 const std::string backendsMessage = std::string("The preferred order of devices to run layers on by default. ")
353 + std::string("Possible choices: ")
354 + armnn::BackendRegistryInstance().GetBackendIdsAsString();
355
telsoa01c577f2c2018-08-31 09:22:23 +0100356 po::options_description desc("Options");
357 try
358 {
359 desc.add_options()
360 ("model-format,f", po::value(&modelFormat),
361 "caffe-binary, caffe-text, tflite-binary, onnx-binary, onnx-text, tensorflow-binary or tensorflow-text.")
362 ("model-path,m", po::value(&modelPath), "Path to model file, e.g. .caffemodel, .prototxt, .tflite,"
363 " .onnx")
David Beckf0b48452018-10-19 15:20:56 +0100364 ("compute,c", po::value<std::vector<armnn::BackendId>>()->multitoken(),
Aron Virginas-Tar5cc8e562018-10-23 15:14:46 +0100365 backendsMessage.c_str())
telsoa01c577f2c2018-08-31 09:22:23 +0100366 ("input-name,i", po::value(&inputName), "Identifier of the input tensor in the network.")
367 ("subgraph-number,n", po::value<size_t>(&subgraphId)->default_value(0), "Id of the subgraph to be "
368 "executed. Defaults to 0")
369 ("input-tensor-shape,s", po::value(&inputTensorShapeStr),
370 "The shape of the input tensor in the network as a flat array of integers separated by whitespace. "
371 "This parameter is optional, depending on the network.")
372 ("input-tensor-data,d", po::value(&inputTensorDataFilePath),
373 "Path to a file containing the input data as a flat array separated by whitespace.")
Nina Drozd549ae372018-09-10 14:26:44 +0100374 ("output-name,o", po::value(&outputName), "Identifier of the output tensor in the network.");
telsoa01c577f2c2018-08-31 09:22:23 +0100375 }
376 catch (const std::exception& e)
377 {
378 // Coverity points out that default_value(...) can throw a bad_lexical_cast,
379 // and that desc.add_options() can throw boost::io::too_few_args.
380 // They really won't in any of these cases.
381 BOOST_ASSERT_MSG(false, "Caught unexpected exception");
382 BOOST_LOG_TRIVIAL(fatal) << "Fatal internal error: " << e.what();
383 return EXIT_FAILURE;
384 }
385
386 std::vector<const char*> clOptions;
387 clOptions.reserve(csvRow.values.size());
388 for (const std::string& value : csvRow.values)
389 {
390 clOptions.push_back(value.c_str());
391 }
392
393 po::variables_map vm;
394 try
395 {
396 po::store(po::parse_command_line(static_cast<int>(clOptions.size()), clOptions.data(), desc), vm);
397
398 po::notify(vm);
399
400 CheckOptionDependencies(vm);
401 }
402 catch (const po::error& e)
403 {
404 std::cerr << e.what() << std::endl << std::endl;
405 std::cerr << desc << std::endl;
406 return EXIT_FAILURE;
407 }
408
409 // Remove leading and trailing whitespaces from the parsed arguments.
410 boost::trim(modelFormat);
411 boost::trim(modelPath);
412 boost::trim(inputName);
413 boost::trim(inputTensorShapeStr);
414 boost::trim(inputTensorDataFilePath);
415 boost::trim(outputName);
416
telsoa01c577f2c2018-08-31 09:22:23 +0100417 // Get the preferred order of compute devices.
David Beckf0b48452018-10-19 15:20:56 +0100418 std::vector<armnn::BackendId> computeDevices = vm["compute"].as<std::vector<armnn::BackendId>>();
telsoa01c577f2c2018-08-31 09:22:23 +0100419
420 // Remove duplicates from the list of compute devices.
421 RemoveDuplicateDevices(computeDevices);
422
423 // Check that the specified compute devices are valid.
Aron Virginas-Tar5cc8e562018-10-23 15:14:46 +0100424 std::string invalidBackends;
425 if (!CheckRequestedBackendsAreValid(computeDevices, armnn::Optional<std::string&>(invalidBackends)))
telsoa01c577f2c2018-08-31 09:22:23 +0100426 {
Aron Virginas-Tar5cc8e562018-10-23 15:14:46 +0100427 BOOST_LOG_TRIVIAL(fatal) << "The list of preferred devices contains invalid backend IDs: "
428 << invalidBackends;
telsoa01c577f2c2018-08-31 09:22:23 +0100429 return EXIT_FAILURE;
430 }
431
432 return RunTest(modelFormat, inputTensorShapeStr, computeDevices,
433 modelPath, inputName, inputTensorDataFilePath, outputName, enableProfiling, subgraphId, runtime);
434}
435
436int main(int argc, const char* argv[])
437{
438 // Configures logging for both the ARMNN library and this test program.
439#ifdef NDEBUG
440 armnn::LogSeverity level = armnn::LogSeverity::Info;
441#else
442 armnn::LogSeverity level = armnn::LogSeverity::Debug;
443#endif
444 armnn::ConfigureLogging(true, true, level);
445 armnnUtils::ConfigureLogging(boost::log::core::get().get(), true, true, level);
446
447 std::string testCasesFile;
448
449 std::string modelFormat;
450 std::string modelPath;
451 std::string inputName;
452 std::string inputTensorShapeStr;
453 std::string inputTensorDataFilePath;
454 std::string outputName;
455
456 size_t subgraphId = 0;
457
Aron Virginas-Tar5cc8e562018-10-23 15:14:46 +0100458 const std::string backendsMessage = "Which device to run layers on by default. Possible choices: "
459 + armnn::BackendRegistryInstance().GetBackendIdsAsString();
460
telsoa01c577f2c2018-08-31 09:22:23 +0100461 po::options_description desc("Options");
462 try
463 {
464 desc.add_options()
465 ("help", "Display usage information")
466 ("test-cases,t", po::value(&testCasesFile), "Path to a CSV file containing test cases to run. "
467 "If set, further parameters -- with the exception of compute device and concurrency -- will be ignored, "
468 "as they are expected to be defined in the file for each test in particular.")
469 ("concurrent,n", po::bool_switch()->default_value(false),
470 "Whether or not the test cases should be executed in parallel")
Matteo Martincigh49124022019-01-11 13:25:59 +0000471 ("model-format,f", po::value(&modelFormat)->required(),
telsoa01c577f2c2018-08-31 09:22:23 +0100472 "caffe-binary, caffe-text, onnx-binary, onnx-text, tflite-binary, tensorflow-binary or tensorflow-text.")
Matteo Martincigh49124022019-01-11 13:25:59 +0000473 ("model-path,m", po::value(&modelPath)->required(), "Path to model file, e.g. .caffemodel, .prototxt,"
telsoa01c577f2c2018-08-31 09:22:23 +0100474 " .tflite, .onnx")
David Beckf0b48452018-10-19 15:20:56 +0100475 ("compute,c", po::value<std::vector<std::string>>()->multitoken(),
Aron Virginas-Tar5cc8e562018-10-23 15:14:46 +0100476 backendsMessage.c_str())
telsoa01c577f2c2018-08-31 09:22:23 +0100477 ("input-name,i", po::value(&inputName), "Identifier of the input tensor in the network.")
478 ("subgraph-number,x", po::value<size_t>(&subgraphId)->default_value(0), "Id of the subgraph to be executed."
479 "Defaults to 0")
480 ("input-tensor-shape,s", po::value(&inputTensorShapeStr),
481 "The shape of the input tensor in the network as a flat array of integers separated by whitespace. "
482 "This parameter is optional, depending on the network.")
483 ("input-tensor-data,d", po::value(&inputTensorDataFilePath),
484 "Path to a file containing the input data as a flat array separated by whitespace.")
485 ("output-name,o", po::value(&outputName), "Identifier of the output tensor in the network.")
486 ("event-based-profiling,e", po::bool_switch()->default_value(false),
487 "Enables built in profiler. If unset, defaults to off.");
488 }
489 catch (const std::exception& e)
490 {
491 // Coverity points out that default_value(...) can throw a bad_lexical_cast,
492 // and that desc.add_options() can throw boost::io::too_few_args.
493 // They really won't in any of these cases.
494 BOOST_ASSERT_MSG(false, "Caught unexpected exception");
495 BOOST_LOG_TRIVIAL(fatal) << "Fatal internal error: " << e.what();
496 return EXIT_FAILURE;
497 }
498
499 // Parses the command-line.
500 po::variables_map vm;
501 try
502 {
503 po::store(po::parse_command_line(argc, argv, desc), vm);
504
505 if (CheckOption(vm, "help") || argc <= 1)
506 {
507 std::cout << "Executes a neural network model using the provided input tensor. " << std::endl;
508 std::cout << "Prints the resulting output tensor." << std::endl;
509 std::cout << std::endl;
510 std::cout << desc << std::endl;
511 return EXIT_SUCCESS;
512 }
513
514 po::notify(vm);
515 }
516 catch (const po::error& e)
517 {
518 std::cerr << e.what() << std::endl << std::endl;
519 std::cerr << desc << std::endl;
520 return EXIT_FAILURE;
521 }
522
523 // Get the value of the switch arguments.
524 bool concurrent = vm["concurrent"].as<bool>();
525 bool enableProfiling = vm["event-based-profiling"].as<bool>();
526
527 // Check whether we have to load test cases from a file.
528 if (CheckOption(vm, "test-cases"))
529 {
530 // Check that the file exists.
531 if (!boost::filesystem::exists(testCasesFile))
532 {
533 BOOST_LOG_TRIVIAL(fatal) << "Given file \"" << testCasesFile << "\" does not exist";
534 return EXIT_FAILURE;
535 }
536
537 // Parse CSV file and extract test cases
538 armnnUtils::CsvReader reader;
539 std::vector<armnnUtils::CsvRow> testCases = reader.ParseFile(testCasesFile);
540
541 // Check that there is at least one test case to run
542 if (testCases.empty())
543 {
544 BOOST_LOG_TRIVIAL(fatal) << "Given file \"" << testCasesFile << "\" has no test cases";
545 return EXIT_FAILURE;
546 }
547
548 // Create runtime
549 armnn::IRuntime::CreationOptions options;
Nina Drozd549ae372018-09-10 14:26:44 +0100550 options.m_EnableGpuProfiling = enableProfiling;
551
telsoa01c577f2c2018-08-31 09:22:23 +0100552 std::shared_ptr<armnn::IRuntime> runtime(armnn::IRuntime::Create(options));
553
554 const std::string executableName("ExecuteNetwork");
555
556 // Check whether we need to run the test cases concurrently
557 if (concurrent)
558 {
559 std::vector<std::future<int>> results;
560 results.reserve(testCases.size());
561
562 // Run each test case in its own thread
563 for (auto& testCase : testCases)
564 {
565 testCase.values.insert(testCase.values.begin(), executableName);
Nina Drozd549ae372018-09-10 14:26:44 +0100566 results.push_back(std::async(std::launch::async, RunCsvTest, std::cref(testCase), std::cref(runtime),
567 enableProfiling));
telsoa01c577f2c2018-08-31 09:22:23 +0100568 }
569
570 // Check results
571 for (auto& result : results)
572 {
573 if (result.get() != EXIT_SUCCESS)
574 {
575 return EXIT_FAILURE;
576 }
577 }
578 }
579 else
580 {
581 // Run tests sequentially
582 for (auto& testCase : testCases)
583 {
584 testCase.values.insert(testCase.values.begin(), executableName);
Nina Drozd549ae372018-09-10 14:26:44 +0100585 if (RunCsvTest(testCase, runtime, enableProfiling) != EXIT_SUCCESS)
telsoa01c577f2c2018-08-31 09:22:23 +0100586 {
587 return EXIT_FAILURE;
588 }
589 }
590 }
591
592 return EXIT_SUCCESS;
593 }
594 else // Run single test
595 {
Aron Virginas-Tar382e21c2019-01-22 14:10:39 +0000596 // Get the preferred order of compute devices. If none are specified, default to using CpuRef
597 const std::string computeOption("compute");
598 std::vector<std::string> computeDevicesAsStrings = CheckOption(vm, computeOption.c_str()) ?
599 vm[computeOption].as<std::vector<std::string>>() :
600 std::vector<std::string>({ "CpuRef" });
Matteo Martincigh067112f2018-10-29 11:01:09 +0000601 std::vector<armnn::BackendId> computeDevices(computeDevicesAsStrings.begin(), computeDevicesAsStrings.end());
telsoa01c577f2c2018-08-31 09:22:23 +0100602
603 // Remove duplicates from the list of compute devices.
604 RemoveDuplicateDevices(computeDevices);
605
606 // Check that the specified compute devices are valid.
Aron Virginas-Tar5cc8e562018-10-23 15:14:46 +0100607 std::string invalidBackends;
608 if (!CheckRequestedBackendsAreValid(computeDevices, armnn::Optional<std::string&>(invalidBackends)))
telsoa01c577f2c2018-08-31 09:22:23 +0100609 {
Aron Virginas-Tar5cc8e562018-10-23 15:14:46 +0100610 BOOST_LOG_TRIVIAL(fatal) << "The list of preferred devices contains invalid backend IDs: "
611 << invalidBackends;
telsoa01c577f2c2018-08-31 09:22:23 +0100612 return EXIT_FAILURE;
613 }
614
615 try
616 {
617 CheckOptionDependencies(vm);
618 }
619 catch (const po::error& e)
620 {
621 std::cerr << e.what() << std::endl << std::endl;
622 std::cerr << desc << std::endl;
623 return EXIT_FAILURE;
624 }
625
626 return RunTest(modelFormat, inputTensorShapeStr, computeDevices,
627 modelPath, inputName, inputTensorDataFilePath, outputName, enableProfiling, subgraphId);
telsoa014fcda012018-03-09 14:13:49 +0000628 }
629}