blob: cfddc38a99e3e1430c398353050302a25b2e1aea [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Aron Virginas-Tar5cc8e562018-10-23 15:14:46 +01005#include <armnn/ArmNN.hpp>
telsoa01c577f2c2018-08-31 09:22:23 +01006#include <armnn/TypesUtils.hpp>
7
telsoa014fcda012018-03-09 14:13:49 +00008#if defined(ARMNN_CAFFE_PARSER)
9#include "armnnCaffeParser/ICaffeParser.hpp"
10#endif
surmeh01bceff2f2018-03-29 16:29:27 +010011#if defined(ARMNN_TF_PARSER)
12#include "armnnTfParser/ITfParser.hpp"
13#endif
telsoa01c577f2c2018-08-31 09:22:23 +010014#if defined(ARMNN_TF_LITE_PARSER)
15#include "armnnTfLiteParser/ITfLiteParser.hpp"
16#endif
17#if defined(ARMNN_ONNX_PARSER)
18#include "armnnOnnxParser/IOnnxParser.hpp"
19#endif
20#include "CsvReader.hpp"
telsoa014fcda012018-03-09 14:13:49 +000021#include "../InferenceTest.hpp"
22
telsoa01c577f2c2018-08-31 09:22:23 +010023#include <Logging.hpp>
24#include <Profiling.hpp>
25
26#include <boost/algorithm/string/trim.hpp>
telsoa014fcda012018-03-09 14:13:49 +000027#include <boost/algorithm/string/split.hpp>
28#include <boost/algorithm/string/classification.hpp>
telsoa01c577f2c2018-08-31 09:22:23 +010029#include <boost/program_options.hpp>
telsoa014fcda012018-03-09 14:13:49 +000030
31#include <iostream>
32#include <fstream>
telsoa01c577f2c2018-08-31 09:22:23 +010033#include <functional>
34#include <future>
35#include <algorithm>
36#include <iterator>
telsoa014fcda012018-03-09 14:13:49 +000037
38namespace
39{
40
telsoa01c577f2c2018-08-31 09:22:23 +010041// Configure boost::program_options for command-line parsing and validation.
42namespace po = boost::program_options;
43
telsoa014fcda012018-03-09 14:13:49 +000044template<typename T, typename TParseElementFunc>
45std::vector<T> ParseArrayImpl(std::istream& stream, TParseElementFunc parseElementFunc)
46{
47 std::vector<T> result;
telsoa01c577f2c2018-08-31 09:22:23 +010048 // Processes line-by-line.
telsoa014fcda012018-03-09 14:13:49 +000049 std::string line;
50 while (std::getline(stream, line))
51 {
52 std::vector<std::string> tokens;
surmeh013537c2c2018-05-18 16:31:43 +010053 try
54 {
55 // Coverity fix: boost::split() may throw an exception of type boost::bad_function_call.
56 boost::split(tokens, line, boost::algorithm::is_any_of("\t ,;:"), boost::token_compress_on);
57 }
58 catch (const std::exception& e)
59 {
60 BOOST_LOG_TRIVIAL(error) << "An error occurred when splitting tokens: " << e.what();
61 continue;
62 }
telsoa014fcda012018-03-09 14:13:49 +000063 for (const std::string& token : tokens)
64 {
65 if (!token.empty()) // See https://stackoverflow.com/questions/10437406/
66 {
67 try
68 {
69 result.push_back(parseElementFunc(token));
70 }
71 catch (const std::exception&)
72 {
73 BOOST_LOG_TRIVIAL(error) << "'" << token << "' is not a valid number. It has been ignored.";
74 }
75 }
76 }
77 }
78
79 return result;
80}
81
telsoa01c577f2c2018-08-31 09:22:23 +010082bool CheckOption(const po::variables_map& vm,
83 const char* option)
84{
85 // Check that the given option is valid.
86 if (option == nullptr)
87 {
88 return false;
89 }
90
91 // Check whether 'option' is provided.
92 return vm.find(option) != vm.end();
93}
94
95void CheckOptionDependency(const po::variables_map& vm,
96 const char* option,
97 const char* required)
98{
99 // Check that the given options are valid.
100 if (option == nullptr || required == nullptr)
101 {
102 throw po::error("Invalid option to check dependency for");
103 }
104
105 // Check that if 'option' is provided, 'required' is also provided.
106 if (CheckOption(vm, option) && !vm[option].defaulted())
107 {
108 if (CheckOption(vm, required) == 0 || vm[required].defaulted())
109 {
110 throw po::error(std::string("Option '") + option + "' requires option '" + required + "'.");
111 }
112 }
113}
114
115void CheckOptionDependencies(const po::variables_map& vm)
116{
117 CheckOptionDependency(vm, "model-path", "model-format");
118 CheckOptionDependency(vm, "model-path", "input-name");
119 CheckOptionDependency(vm, "model-path", "input-tensor-data");
120 CheckOptionDependency(vm, "model-path", "output-name");
121 CheckOptionDependency(vm, "input-tensor-shape", "model-path");
telsoa014fcda012018-03-09 14:13:49 +0000122}
123
124template<typename T>
125std::vector<T> ParseArray(std::istream& stream);
126
127template<>
128std::vector<float> ParseArray(std::istream& stream)
129{
130 return ParseArrayImpl<float>(stream, [](const std::string& s) { return std::stof(s); });
131}
132
133template<>
134std::vector<unsigned int> ParseArray(std::istream& stream)
135{
136 return ParseArrayImpl<unsigned int>(stream,
137 [](const std::string& s) { return boost::numeric_cast<unsigned int>(std::stoi(s)); });
138}
139
140void PrintArray(const std::vector<float>& v)
141{
142 for (size_t i = 0; i < v.size(); i++)
143 {
144 printf("%f ", v[i]);
145 }
146 printf("\n");
147}
148
David Beckf0b48452018-10-19 15:20:56 +0100149void RemoveDuplicateDevices(std::vector<armnn::BackendId>& computeDevices)
telsoa014fcda012018-03-09 14:13:49 +0000150{
telsoa01c577f2c2018-08-31 09:22:23 +0100151 // Mark the duplicate devices as 'Undefined'.
152 for (auto i = computeDevices.begin(); i != computeDevices.end(); ++i)
153 {
154 for (auto j = std::next(i); j != computeDevices.end(); ++j)
155 {
156 if (*j == *i)
157 {
158 *j = armnn::Compute::Undefined;
159 }
160 }
161 }
162
163 // Remove 'Undefined' devices.
164 computeDevices.erase(std::remove(computeDevices.begin(), computeDevices.end(), armnn::Compute::Undefined),
165 computeDevices.end());
166}
167
telsoa01c577f2c2018-08-31 09:22:23 +0100168} // namespace
169
170template<typename TParser, typename TDataType>
171int MainImpl(const char* modelPath,
172 bool isModelBinary,
David Beckf0b48452018-10-19 15:20:56 +0100173 const std::vector<armnn::BackendId>& computeDevice,
telsoa01c577f2c2018-08-31 09:22:23 +0100174 const char* inputName,
175 const armnn::TensorShape* inputTensorShape,
176 const char* inputTensorDataFilePath,
177 const char* outputName,
178 bool enableProfiling,
179 const size_t subgraphId,
180 const std::shared_ptr<armnn::IRuntime>& runtime = nullptr)
181{
182 // Loads input tensor.
telsoa014fcda012018-03-09 14:13:49 +0000183 std::vector<TDataType> input;
184 {
185 std::ifstream inputTensorFile(inputTensorDataFilePath);
186 if (!inputTensorFile.good())
187 {
188 BOOST_LOG_TRIVIAL(fatal) << "Failed to load input tensor data file from " << inputTensorDataFilePath;
telsoa01c577f2c2018-08-31 09:22:23 +0100189 return EXIT_FAILURE;
telsoa014fcda012018-03-09 14:13:49 +0000190 }
191 input = ParseArray<TDataType>(inputTensorFile);
192 }
193
194 try
195 {
telsoa01c577f2c2018-08-31 09:22:23 +0100196 // Creates an InferenceModel, which will parse the model and load it into an IRuntime.
telsoa014fcda012018-03-09 14:13:49 +0000197 typename InferenceModel<TParser, TDataType>::Params params;
198 params.m_ModelPath = modelPath;
199 params.m_IsModelBinary = isModelBinary;
200 params.m_ComputeDevice = computeDevice;
201 params.m_InputBinding = inputName;
202 params.m_InputTensorShape = inputTensorShape;
203 params.m_OutputBinding = outputName;
telsoa01c577f2c2018-08-31 09:22:23 +0100204 params.m_EnableProfiling = enableProfiling;
205 params.m_SubgraphId = subgraphId;
206 InferenceModel<TParser, TDataType> model(params, runtime);
telsoa014fcda012018-03-09 14:13:49 +0000207
telsoa01c577f2c2018-08-31 09:22:23 +0100208 // Executes the model.
telsoa014fcda012018-03-09 14:13:49 +0000209 std::vector<TDataType> output(model.GetOutputSize());
210 model.Run(input, output);
211
telsoa01c577f2c2018-08-31 09:22:23 +0100212 // Prints the output tensor.
telsoa014fcda012018-03-09 14:13:49 +0000213 PrintArray(output);
214 }
215 catch (armnn::Exception const& e)
216 {
217 BOOST_LOG_TRIVIAL(fatal) << "Armnn Error: " << e.what();
telsoa01c577f2c2018-08-31 09:22:23 +0100218 return EXIT_FAILURE;
telsoa014fcda012018-03-09 14:13:49 +0000219 }
220
telsoa01c577f2c2018-08-31 09:22:23 +0100221 return EXIT_SUCCESS;
telsoa014fcda012018-03-09 14:13:49 +0000222}
223
telsoa01c577f2c2018-08-31 09:22:23 +0100224// This will run a test
225int RunTest(const std::string& modelFormat,
226 const std::string& inputTensorShapeStr,
David Beckf0b48452018-10-19 15:20:56 +0100227 const vector<armnn::BackendId>& computeDevice,
telsoa01c577f2c2018-08-31 09:22:23 +0100228 const std::string& modelPath,
229 const std::string& inputName,
230 const std::string& inputTensorDataFilePath,
231 const std::string& outputName,
232 bool enableProfiling,
233 const size_t subgraphId,
234 const std::shared_ptr<armnn::IRuntime>& runtime = nullptr)
telsoa014fcda012018-03-09 14:13:49 +0000235{
telsoa014fcda012018-03-09 14:13:49 +0000236 // Parse model binary flag from the model-format string we got from the command-line
237 bool isModelBinary;
238 if (modelFormat.find("bin") != std::string::npos)
239 {
240 isModelBinary = true;
241 }
242 else if (modelFormat.find("txt") != std::string::npos || modelFormat.find("text") != std::string::npos)
243 {
244 isModelBinary = false;
245 }
246 else
247 {
248 BOOST_LOG_TRIVIAL(fatal) << "Unknown model format: '" << modelFormat << "'. Please include 'binary' or 'text'";
telsoa01c577f2c2018-08-31 09:22:23 +0100249 return EXIT_FAILURE;
telsoa014fcda012018-03-09 14:13:49 +0000250 }
251
252 // Parse input tensor shape from the string we got from the command-line.
253 std::unique_ptr<armnn::TensorShape> inputTensorShape;
254 if (!inputTensorShapeStr.empty())
255 {
256 std::stringstream ss(inputTensorShapeStr);
257 std::vector<unsigned int> dims = ParseArray<unsigned int>(ss);
surmeh013537c2c2018-05-18 16:31:43 +0100258
259 try
260 {
261 // Coverity fix: An exception of type armnn::InvalidArgumentException is thrown and never caught.
262 inputTensorShape = std::make_unique<armnn::TensorShape>(dims.size(), dims.data());
263 }
264 catch (const armnn::InvalidArgumentException& e)
265 {
266 BOOST_LOG_TRIVIAL(fatal) << "Cannot create tensor shape: " << e.what();
telsoa01c577f2c2018-08-31 09:22:23 +0100267 return EXIT_FAILURE;
surmeh013537c2c2018-05-18 16:31:43 +0100268 }
telsoa014fcda012018-03-09 14:13:49 +0000269 }
270
271 // Forward to implementation based on the parser type
272 if (modelFormat.find("caffe") != std::string::npos)
273 {
274#if defined(ARMNN_CAFFE_PARSER)
275 return MainImpl<armnnCaffeParser::ICaffeParser, float>(modelPath.c_str(), isModelBinary, computeDevice,
telsoa01c577f2c2018-08-31 09:22:23 +0100276 inputName.c_str(), inputTensorShape.get(),
277 inputTensorDataFilePath.c_str(), outputName.c_str(),
278 enableProfiling, subgraphId, runtime);
telsoa014fcda012018-03-09 14:13:49 +0000279#else
280 BOOST_LOG_TRIVIAL(fatal) << "Not built with Caffe parser support.";
telsoa01c577f2c2018-08-31 09:22:23 +0100281 return EXIT_FAILURE;
282#endif
283 }
284 else if (modelFormat.find("onnx") != std::string::npos)
285{
286#if defined(ARMNN_ONNX_PARSER)
287 return MainImpl<armnnOnnxParser::IOnnxParser, float>(modelPath.c_str(), isModelBinary, computeDevice,
288 inputName.c_str(), inputTensorShape.get(),
289 inputTensorDataFilePath.c_str(), outputName.c_str(),
290 enableProfiling, subgraphId, runtime);
291#else
292 BOOST_LOG_TRIVIAL(fatal) << "Not built with Onnx parser support.";
293 return EXIT_FAILURE;
telsoa014fcda012018-03-09 14:13:49 +0000294#endif
295 }
296 else if (modelFormat.find("tensorflow") != std::string::npos)
297 {
surmeh01bceff2f2018-03-29 16:29:27 +0100298#if defined(ARMNN_TF_PARSER)
299 return MainImpl<armnnTfParser::ITfParser, float>(modelPath.c_str(), isModelBinary, computeDevice,
telsoa01c577f2c2018-08-31 09:22:23 +0100300 inputName.c_str(), inputTensorShape.get(),
301 inputTensorDataFilePath.c_str(), outputName.c_str(),
302 enableProfiling, subgraphId, runtime);
surmeh01bceff2f2018-03-29 16:29:27 +0100303#else
telsoa014fcda012018-03-09 14:13:49 +0000304 BOOST_LOG_TRIVIAL(fatal) << "Not built with Tensorflow parser support.";
telsoa01c577f2c2018-08-31 09:22:23 +0100305 return EXIT_FAILURE;
306#endif
307 }
308 else if(modelFormat.find("tflite") != std::string::npos)
309 {
310#if defined(ARMNN_TF_LITE_PARSER)
311 if (! isModelBinary)
312 {
313 BOOST_LOG_TRIVIAL(fatal) << "Unknown model format: '" << modelFormat << "'. Only 'binary' format supported \
314 for tflite files";
315 return EXIT_FAILURE;
316 }
317 return MainImpl<armnnTfLiteParser::ITfLiteParser, float>(modelPath.c_str(), isModelBinary, computeDevice,
318 inputName.c_str(), inputTensorShape.get(),
319 inputTensorDataFilePath.c_str(), outputName.c_str(),
320 enableProfiling, subgraphId, runtime);
321#else
322 BOOST_LOG_TRIVIAL(fatal) << "Unknown model format: '" << modelFormat <<
323 "'. Please include 'caffe', 'tensorflow', 'tflite' or 'onnx'";
324 return EXIT_FAILURE;
surmeh01bceff2f2018-03-29 16:29:27 +0100325#endif
telsoa014fcda012018-03-09 14:13:49 +0000326 }
327 else
328 {
329 BOOST_LOG_TRIVIAL(fatal) << "Unknown model format: '" << modelFormat <<
telsoa01c577f2c2018-08-31 09:22:23 +0100330 "'. Please include 'caffe', 'tensorflow', 'tflite' or 'onnx'";
331 return EXIT_FAILURE;
332 }
333}
334
335int RunCsvTest(const armnnUtils::CsvRow &csvRow,
Nina Drozd549ae372018-09-10 14:26:44 +0100336 const std::shared_ptr<armnn::IRuntime>& runtime, const bool enableProfiling)
telsoa01c577f2c2018-08-31 09:22:23 +0100337{
338 std::string modelFormat;
339 std::string modelPath;
340 std::string inputName;
341 std::string inputTensorShapeStr;
342 std::string inputTensorDataFilePath;
343 std::string outputName;
344
345 size_t subgraphId = 0;
346
Aron Virginas-Tar5cc8e562018-10-23 15:14:46 +0100347 const std::string backendsMessage = std::string("The preferred order of devices to run layers on by default. ")
348 + std::string("Possible choices: ")
349 + armnn::BackendRegistryInstance().GetBackendIdsAsString();
350
telsoa01c577f2c2018-08-31 09:22:23 +0100351 po::options_description desc("Options");
352 try
353 {
354 desc.add_options()
355 ("model-format,f", po::value(&modelFormat),
356 "caffe-binary, caffe-text, tflite-binary, onnx-binary, onnx-text, tensorflow-binary or tensorflow-text.")
357 ("model-path,m", po::value(&modelPath), "Path to model file, e.g. .caffemodel, .prototxt, .tflite,"
358 " .onnx")
David Beckf0b48452018-10-19 15:20:56 +0100359 ("compute,c", po::value<std::vector<armnn::BackendId>>()->multitoken(),
Aron Virginas-Tar5cc8e562018-10-23 15:14:46 +0100360 backendsMessage.c_str())
telsoa01c577f2c2018-08-31 09:22:23 +0100361 ("input-name,i", po::value(&inputName), "Identifier of the input tensor in the network.")
362 ("subgraph-number,n", po::value<size_t>(&subgraphId)->default_value(0), "Id of the subgraph to be "
363 "executed. Defaults to 0")
364 ("input-tensor-shape,s", po::value(&inputTensorShapeStr),
365 "The shape of the input tensor in the network as a flat array of integers separated by whitespace. "
366 "This parameter is optional, depending on the network.")
367 ("input-tensor-data,d", po::value(&inputTensorDataFilePath),
368 "Path to a file containing the input data as a flat array separated by whitespace.")
Nina Drozd549ae372018-09-10 14:26:44 +0100369 ("output-name,o", po::value(&outputName), "Identifier of the output tensor in the network.");
telsoa01c577f2c2018-08-31 09:22:23 +0100370 }
371 catch (const std::exception& e)
372 {
373 // Coverity points out that default_value(...) can throw a bad_lexical_cast,
374 // and that desc.add_options() can throw boost::io::too_few_args.
375 // They really won't in any of these cases.
376 BOOST_ASSERT_MSG(false, "Caught unexpected exception");
377 BOOST_LOG_TRIVIAL(fatal) << "Fatal internal error: " << e.what();
378 return EXIT_FAILURE;
379 }
380
381 std::vector<const char*> clOptions;
382 clOptions.reserve(csvRow.values.size());
383 for (const std::string& value : csvRow.values)
384 {
385 clOptions.push_back(value.c_str());
386 }
387
388 po::variables_map vm;
389 try
390 {
391 po::store(po::parse_command_line(static_cast<int>(clOptions.size()), clOptions.data(), desc), vm);
392
393 po::notify(vm);
394
395 CheckOptionDependencies(vm);
396 }
397 catch (const po::error& e)
398 {
399 std::cerr << e.what() << std::endl << std::endl;
400 std::cerr << desc << std::endl;
401 return EXIT_FAILURE;
402 }
403
404 // Remove leading and trailing whitespaces from the parsed arguments.
405 boost::trim(modelFormat);
406 boost::trim(modelPath);
407 boost::trim(inputName);
408 boost::trim(inputTensorShapeStr);
409 boost::trim(inputTensorDataFilePath);
410 boost::trim(outputName);
411
telsoa01c577f2c2018-08-31 09:22:23 +0100412 // Get the preferred order of compute devices.
David Beckf0b48452018-10-19 15:20:56 +0100413 std::vector<armnn::BackendId> computeDevices = vm["compute"].as<std::vector<armnn::BackendId>>();
telsoa01c577f2c2018-08-31 09:22:23 +0100414
415 // Remove duplicates from the list of compute devices.
416 RemoveDuplicateDevices(computeDevices);
417
418 // Check that the specified compute devices are valid.
Aron Virginas-Tar5cc8e562018-10-23 15:14:46 +0100419 std::string invalidBackends;
420 if (!CheckRequestedBackendsAreValid(computeDevices, armnn::Optional<std::string&>(invalidBackends)))
telsoa01c577f2c2018-08-31 09:22:23 +0100421 {
Aron Virginas-Tar5cc8e562018-10-23 15:14:46 +0100422 BOOST_LOG_TRIVIAL(fatal) << "The list of preferred devices contains invalid backend IDs: "
423 << invalidBackends;
telsoa01c577f2c2018-08-31 09:22:23 +0100424 return EXIT_FAILURE;
425 }
426
427 return RunTest(modelFormat, inputTensorShapeStr, computeDevices,
428 modelPath, inputName, inputTensorDataFilePath, outputName, enableProfiling, subgraphId, runtime);
429}
430
431int main(int argc, const char* argv[])
432{
433 // Configures logging for both the ARMNN library and this test program.
434#ifdef NDEBUG
435 armnn::LogSeverity level = armnn::LogSeverity::Info;
436#else
437 armnn::LogSeverity level = armnn::LogSeverity::Debug;
438#endif
439 armnn::ConfigureLogging(true, true, level);
440 armnnUtils::ConfigureLogging(boost::log::core::get().get(), true, true, level);
441
442 std::string testCasesFile;
443
444 std::string modelFormat;
445 std::string modelPath;
446 std::string inputName;
447 std::string inputTensorShapeStr;
448 std::string inputTensorDataFilePath;
449 std::string outputName;
450
451 size_t subgraphId = 0;
452
Aron Virginas-Tar5cc8e562018-10-23 15:14:46 +0100453 const std::string backendsMessage = "Which device to run layers on by default. Possible choices: "
454 + armnn::BackendRegistryInstance().GetBackendIdsAsString();
455
telsoa01c577f2c2018-08-31 09:22:23 +0100456 po::options_description desc("Options");
457 try
458 {
459 desc.add_options()
460 ("help", "Display usage information")
461 ("test-cases,t", po::value(&testCasesFile), "Path to a CSV file containing test cases to run. "
462 "If set, further parameters -- with the exception of compute device and concurrency -- will be ignored, "
463 "as they are expected to be defined in the file for each test in particular.")
464 ("concurrent,n", po::bool_switch()->default_value(false),
465 "Whether or not the test cases should be executed in parallel")
466 ("model-format,f", po::value(&modelFormat),
467 "caffe-binary, caffe-text, onnx-binary, onnx-text, tflite-binary, tensorflow-binary or tensorflow-text.")
468 ("model-path,m", po::value(&modelPath), "Path to model file, e.g. .caffemodel, .prototxt,"
469 " .tflite, .onnx")
David Beckf0b48452018-10-19 15:20:56 +0100470 ("compute,c", po::value<std::vector<std::string>>()->multitoken(),
Aron Virginas-Tar5cc8e562018-10-23 15:14:46 +0100471 backendsMessage.c_str())
telsoa01c577f2c2018-08-31 09:22:23 +0100472 ("input-name,i", po::value(&inputName), "Identifier of the input tensor in the network.")
473 ("subgraph-number,x", po::value<size_t>(&subgraphId)->default_value(0), "Id of the subgraph to be executed."
474 "Defaults to 0")
475 ("input-tensor-shape,s", po::value(&inputTensorShapeStr),
476 "The shape of the input tensor in the network as a flat array of integers separated by whitespace. "
477 "This parameter is optional, depending on the network.")
478 ("input-tensor-data,d", po::value(&inputTensorDataFilePath),
479 "Path to a file containing the input data as a flat array separated by whitespace.")
480 ("output-name,o", po::value(&outputName), "Identifier of the output tensor in the network.")
481 ("event-based-profiling,e", po::bool_switch()->default_value(false),
482 "Enables built in profiler. If unset, defaults to off.");
483 }
484 catch (const std::exception& e)
485 {
486 // Coverity points out that default_value(...) can throw a bad_lexical_cast,
487 // and that desc.add_options() can throw boost::io::too_few_args.
488 // They really won't in any of these cases.
489 BOOST_ASSERT_MSG(false, "Caught unexpected exception");
490 BOOST_LOG_TRIVIAL(fatal) << "Fatal internal error: " << e.what();
491 return EXIT_FAILURE;
492 }
493
494 // Parses the command-line.
495 po::variables_map vm;
496 try
497 {
498 po::store(po::parse_command_line(argc, argv, desc), vm);
499
500 if (CheckOption(vm, "help") || argc <= 1)
501 {
502 std::cout << "Executes a neural network model using the provided input tensor. " << std::endl;
503 std::cout << "Prints the resulting output tensor." << std::endl;
504 std::cout << std::endl;
505 std::cout << desc << std::endl;
506 return EXIT_SUCCESS;
507 }
508
509 po::notify(vm);
510 }
511 catch (const po::error& e)
512 {
513 std::cerr << e.what() << std::endl << std::endl;
514 std::cerr << desc << std::endl;
515 return EXIT_FAILURE;
516 }
517
518 // Get the value of the switch arguments.
519 bool concurrent = vm["concurrent"].as<bool>();
520 bool enableProfiling = vm["event-based-profiling"].as<bool>();
521
522 // Check whether we have to load test cases from a file.
523 if (CheckOption(vm, "test-cases"))
524 {
525 // Check that the file exists.
526 if (!boost::filesystem::exists(testCasesFile))
527 {
528 BOOST_LOG_TRIVIAL(fatal) << "Given file \"" << testCasesFile << "\" does not exist";
529 return EXIT_FAILURE;
530 }
531
532 // Parse CSV file and extract test cases
533 armnnUtils::CsvReader reader;
534 std::vector<armnnUtils::CsvRow> testCases = reader.ParseFile(testCasesFile);
535
536 // Check that there is at least one test case to run
537 if (testCases.empty())
538 {
539 BOOST_LOG_TRIVIAL(fatal) << "Given file \"" << testCasesFile << "\" has no test cases";
540 return EXIT_FAILURE;
541 }
542
543 // Create runtime
544 armnn::IRuntime::CreationOptions options;
Nina Drozd549ae372018-09-10 14:26:44 +0100545 options.m_EnableGpuProfiling = enableProfiling;
546
telsoa01c577f2c2018-08-31 09:22:23 +0100547 std::shared_ptr<armnn::IRuntime> runtime(armnn::IRuntime::Create(options));
548
549 const std::string executableName("ExecuteNetwork");
550
551 // Check whether we need to run the test cases concurrently
552 if (concurrent)
553 {
554 std::vector<std::future<int>> results;
555 results.reserve(testCases.size());
556
557 // Run each test case in its own thread
558 for (auto& testCase : testCases)
559 {
560 testCase.values.insert(testCase.values.begin(), executableName);
Nina Drozd549ae372018-09-10 14:26:44 +0100561 results.push_back(std::async(std::launch::async, RunCsvTest, std::cref(testCase), std::cref(runtime),
562 enableProfiling));
telsoa01c577f2c2018-08-31 09:22:23 +0100563 }
564
565 // Check results
566 for (auto& result : results)
567 {
568 if (result.get() != EXIT_SUCCESS)
569 {
570 return EXIT_FAILURE;
571 }
572 }
573 }
574 else
575 {
576 // Run tests sequentially
577 for (auto& testCase : testCases)
578 {
579 testCase.values.insert(testCase.values.begin(), executableName);
Nina Drozd549ae372018-09-10 14:26:44 +0100580 if (RunCsvTest(testCase, runtime, enableProfiling) != EXIT_SUCCESS)
telsoa01c577f2c2018-08-31 09:22:23 +0100581 {
582 return EXIT_FAILURE;
583 }
584 }
585 }
586
587 return EXIT_SUCCESS;
588 }
589 else // Run single test
590 {
591 // Get the preferred order of compute devices.
Matteo Martincigh067112f2018-10-29 11:01:09 +0000592 std::vector<std::string> computeDevicesAsStrings = vm["compute"].as<std::vector<std::string>>();
593 std::vector<armnn::BackendId> computeDevices(computeDevicesAsStrings.begin(), computeDevicesAsStrings.end());
telsoa01c577f2c2018-08-31 09:22:23 +0100594
595 // Remove duplicates from the list of compute devices.
596 RemoveDuplicateDevices(computeDevices);
597
598 // Check that the specified compute devices are valid.
Aron Virginas-Tar5cc8e562018-10-23 15:14:46 +0100599 std::string invalidBackends;
600 if (!CheckRequestedBackendsAreValid(computeDevices, armnn::Optional<std::string&>(invalidBackends)))
telsoa01c577f2c2018-08-31 09:22:23 +0100601 {
Aron Virginas-Tar5cc8e562018-10-23 15:14:46 +0100602 BOOST_LOG_TRIVIAL(fatal) << "The list of preferred devices contains invalid backend IDs: "
603 << invalidBackends;
telsoa01c577f2c2018-08-31 09:22:23 +0100604 return EXIT_FAILURE;
605 }
606
607 try
608 {
609 CheckOptionDependencies(vm);
610 }
611 catch (const po::error& e)
612 {
613 std::cerr << e.what() << std::endl << std::endl;
614 std::cerr << desc << std::endl;
615 return EXIT_FAILURE;
616 }
617
618 return RunTest(modelFormat, inputTensorShapeStr, computeDevices,
619 modelPath, inputName, inputTensorDataFilePath, outputName, enableProfiling, subgraphId);
telsoa014fcda012018-03-09 14:13:49 +0000620 }
621}