blob: ee207472d00c00db507dc261a9da7145f53273e7 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#include "armnn/ArmNN.hpp"
telsoa01c577f2c2018-08-31 09:22:23 +01006
7#include <armnn/TypesUtils.hpp>
8
telsoa014fcda012018-03-09 14:13:49 +00009#if defined(ARMNN_CAFFE_PARSER)
10#include "armnnCaffeParser/ICaffeParser.hpp"
11#endif
surmeh01bceff2f2018-03-29 16:29:27 +010012#if defined(ARMNN_TF_PARSER)
13#include "armnnTfParser/ITfParser.hpp"
14#endif
telsoa01c577f2c2018-08-31 09:22:23 +010015#if defined(ARMNN_TF_LITE_PARSER)
16#include "armnnTfLiteParser/ITfLiteParser.hpp"
17#endif
18#if defined(ARMNN_ONNX_PARSER)
19#include "armnnOnnxParser/IOnnxParser.hpp"
20#endif
21#include "CsvReader.hpp"
telsoa014fcda012018-03-09 14:13:49 +000022#include "../InferenceTest.hpp"
23
telsoa01c577f2c2018-08-31 09:22:23 +010024#include <Logging.hpp>
25#include <Profiling.hpp>
26
27#include <boost/algorithm/string/trim.hpp>
telsoa014fcda012018-03-09 14:13:49 +000028#include <boost/algorithm/string/split.hpp>
29#include <boost/algorithm/string/classification.hpp>
telsoa01c577f2c2018-08-31 09:22:23 +010030#include <boost/program_options.hpp>
telsoa014fcda012018-03-09 14:13:49 +000031
32#include <iostream>
33#include <fstream>
telsoa01c577f2c2018-08-31 09:22:23 +010034#include <functional>
35#include <future>
36#include <algorithm>
37#include <iterator>
telsoa014fcda012018-03-09 14:13:49 +000038
39namespace
40{
41
telsoa01c577f2c2018-08-31 09:22:23 +010042// Configure boost::program_options for command-line parsing and validation.
43namespace po = boost::program_options;
44
telsoa014fcda012018-03-09 14:13:49 +000045template<typename T, typename TParseElementFunc>
46std::vector<T> ParseArrayImpl(std::istream& stream, TParseElementFunc parseElementFunc)
47{
48 std::vector<T> result;
telsoa01c577f2c2018-08-31 09:22:23 +010049 // Processes line-by-line.
telsoa014fcda012018-03-09 14:13:49 +000050 std::string line;
51 while (std::getline(stream, line))
52 {
53 std::vector<std::string> tokens;
surmeh013537c2c2018-05-18 16:31:43 +010054 try
55 {
56 // Coverity fix: boost::split() may throw an exception of type boost::bad_function_call.
57 boost::split(tokens, line, boost::algorithm::is_any_of("\t ,;:"), boost::token_compress_on);
58 }
59 catch (const std::exception& e)
60 {
61 BOOST_LOG_TRIVIAL(error) << "An error occurred when splitting tokens: " << e.what();
62 continue;
63 }
telsoa014fcda012018-03-09 14:13:49 +000064 for (const std::string& token : tokens)
65 {
66 if (!token.empty()) // See https://stackoverflow.com/questions/10437406/
67 {
68 try
69 {
70 result.push_back(parseElementFunc(token));
71 }
72 catch (const std::exception&)
73 {
74 BOOST_LOG_TRIVIAL(error) << "'" << token << "' is not a valid number. It has been ignored.";
75 }
76 }
77 }
78 }
79
80 return result;
81}
82
telsoa01c577f2c2018-08-31 09:22:23 +010083bool CheckOption(const po::variables_map& vm,
84 const char* option)
85{
86 // Check that the given option is valid.
87 if (option == nullptr)
88 {
89 return false;
90 }
91
92 // Check whether 'option' is provided.
93 return vm.find(option) != vm.end();
94}
95
96void CheckOptionDependency(const po::variables_map& vm,
97 const char* option,
98 const char* required)
99{
100 // Check that the given options are valid.
101 if (option == nullptr || required == nullptr)
102 {
103 throw po::error("Invalid option to check dependency for");
104 }
105
106 // Check that if 'option' is provided, 'required' is also provided.
107 if (CheckOption(vm, option) && !vm[option].defaulted())
108 {
109 if (CheckOption(vm, required) == 0 || vm[required].defaulted())
110 {
111 throw po::error(std::string("Option '") + option + "' requires option '" + required + "'.");
112 }
113 }
114}
115
116void CheckOptionDependencies(const po::variables_map& vm)
117{
118 CheckOptionDependency(vm, "model-path", "model-format");
119 CheckOptionDependency(vm, "model-path", "input-name");
120 CheckOptionDependency(vm, "model-path", "input-tensor-data");
121 CheckOptionDependency(vm, "model-path", "output-name");
122 CheckOptionDependency(vm, "input-tensor-shape", "model-path");
telsoa014fcda012018-03-09 14:13:49 +0000123}
124
125template<typename T>
126std::vector<T> ParseArray(std::istream& stream);
127
128template<>
129std::vector<float> ParseArray(std::istream& stream)
130{
131 return ParseArrayImpl<float>(stream, [](const std::string& s) { return std::stof(s); });
132}
133
134template<>
135std::vector<unsigned int> ParseArray(std::istream& stream)
136{
137 return ParseArrayImpl<unsigned int>(stream,
138 [](const std::string& s) { return boost::numeric_cast<unsigned int>(std::stoi(s)); });
139}
140
141void PrintArray(const std::vector<float>& v)
142{
143 for (size_t i = 0; i < v.size(); i++)
144 {
145 printf("%f ", v[i]);
146 }
147 printf("\n");
148}
149
telsoa01c577f2c2018-08-31 09:22:23 +0100150void RemoveDuplicateDevices(std::vector<armnn::Compute>& computeDevices)
telsoa014fcda012018-03-09 14:13:49 +0000151{
telsoa01c577f2c2018-08-31 09:22:23 +0100152 // Mark the duplicate devices as 'Undefined'.
153 for (auto i = computeDevices.begin(); i != computeDevices.end(); ++i)
154 {
155 for (auto j = std::next(i); j != computeDevices.end(); ++j)
156 {
157 if (*j == *i)
158 {
159 *j = armnn::Compute::Undefined;
160 }
161 }
162 }
163
164 // Remove 'Undefined' devices.
165 computeDevices.erase(std::remove(computeDevices.begin(), computeDevices.end(), armnn::Compute::Undefined),
166 computeDevices.end());
167}
168
169bool CheckDevicesAreValid(const std::vector<armnn::Compute>& computeDevices)
170{
171 return (!computeDevices.empty()
172 && std::none_of(computeDevices.begin(), computeDevices.end(),
173 [](armnn::Compute c){ return c == armnn::Compute::Undefined; }));
174}
175
176} // namespace
177
178template<typename TParser, typename TDataType>
179int MainImpl(const char* modelPath,
180 bool isModelBinary,
181 const std::vector<armnn::Compute>& computeDevice,
182 const char* inputName,
183 const armnn::TensorShape* inputTensorShape,
184 const char* inputTensorDataFilePath,
185 const char* outputName,
186 bool enableProfiling,
187 const size_t subgraphId,
188 const std::shared_ptr<armnn::IRuntime>& runtime = nullptr)
189{
190 // Loads input tensor.
telsoa014fcda012018-03-09 14:13:49 +0000191 std::vector<TDataType> input;
192 {
193 std::ifstream inputTensorFile(inputTensorDataFilePath);
194 if (!inputTensorFile.good())
195 {
196 BOOST_LOG_TRIVIAL(fatal) << "Failed to load input tensor data file from " << inputTensorDataFilePath;
telsoa01c577f2c2018-08-31 09:22:23 +0100197 return EXIT_FAILURE;
telsoa014fcda012018-03-09 14:13:49 +0000198 }
199 input = ParseArray<TDataType>(inputTensorFile);
200 }
201
202 try
203 {
telsoa01c577f2c2018-08-31 09:22:23 +0100204 // Creates an InferenceModel, which will parse the model and load it into an IRuntime.
telsoa014fcda012018-03-09 14:13:49 +0000205 typename InferenceModel<TParser, TDataType>::Params params;
206 params.m_ModelPath = modelPath;
207 params.m_IsModelBinary = isModelBinary;
208 params.m_ComputeDevice = computeDevice;
209 params.m_InputBinding = inputName;
210 params.m_InputTensorShape = inputTensorShape;
211 params.m_OutputBinding = outputName;
telsoa01c577f2c2018-08-31 09:22:23 +0100212 params.m_EnableProfiling = enableProfiling;
213 params.m_SubgraphId = subgraphId;
214 InferenceModel<TParser, TDataType> model(params, runtime);
telsoa014fcda012018-03-09 14:13:49 +0000215
telsoa01c577f2c2018-08-31 09:22:23 +0100216 // Executes the model.
telsoa014fcda012018-03-09 14:13:49 +0000217 std::vector<TDataType> output(model.GetOutputSize());
218 model.Run(input, output);
219
telsoa01c577f2c2018-08-31 09:22:23 +0100220 // Prints the output tensor.
telsoa014fcda012018-03-09 14:13:49 +0000221 PrintArray(output);
222 }
223 catch (armnn::Exception const& e)
224 {
225 BOOST_LOG_TRIVIAL(fatal) << "Armnn Error: " << e.what();
telsoa01c577f2c2018-08-31 09:22:23 +0100226 return EXIT_FAILURE;
telsoa014fcda012018-03-09 14:13:49 +0000227 }
228
telsoa01c577f2c2018-08-31 09:22:23 +0100229 return EXIT_SUCCESS;
telsoa014fcda012018-03-09 14:13:49 +0000230}
231
telsoa01c577f2c2018-08-31 09:22:23 +0100232// This will run a test
233int RunTest(const std::string& modelFormat,
234 const std::string& inputTensorShapeStr,
235 const vector<armnn::Compute>& computeDevice,
236 const std::string& modelPath,
237 const std::string& inputName,
238 const std::string& inputTensorDataFilePath,
239 const std::string& outputName,
240 bool enableProfiling,
241 const size_t subgraphId,
242 const std::shared_ptr<armnn::IRuntime>& runtime = nullptr)
telsoa014fcda012018-03-09 14:13:49 +0000243{
telsoa014fcda012018-03-09 14:13:49 +0000244 // Parse model binary flag from the model-format string we got from the command-line
245 bool isModelBinary;
246 if (modelFormat.find("bin") != std::string::npos)
247 {
248 isModelBinary = true;
249 }
250 else if (modelFormat.find("txt") != std::string::npos || modelFormat.find("text") != std::string::npos)
251 {
252 isModelBinary = false;
253 }
254 else
255 {
256 BOOST_LOG_TRIVIAL(fatal) << "Unknown model format: '" << modelFormat << "'. Please include 'binary' or 'text'";
telsoa01c577f2c2018-08-31 09:22:23 +0100257 return EXIT_FAILURE;
telsoa014fcda012018-03-09 14:13:49 +0000258 }
259
260 // Parse input tensor shape from the string we got from the command-line.
261 std::unique_ptr<armnn::TensorShape> inputTensorShape;
262 if (!inputTensorShapeStr.empty())
263 {
264 std::stringstream ss(inputTensorShapeStr);
265 std::vector<unsigned int> dims = ParseArray<unsigned int>(ss);
surmeh013537c2c2018-05-18 16:31:43 +0100266
267 try
268 {
269 // Coverity fix: An exception of type armnn::InvalidArgumentException is thrown and never caught.
270 inputTensorShape = std::make_unique<armnn::TensorShape>(dims.size(), dims.data());
271 }
272 catch (const armnn::InvalidArgumentException& e)
273 {
274 BOOST_LOG_TRIVIAL(fatal) << "Cannot create tensor shape: " << e.what();
telsoa01c577f2c2018-08-31 09:22:23 +0100275 return EXIT_FAILURE;
surmeh013537c2c2018-05-18 16:31:43 +0100276 }
telsoa014fcda012018-03-09 14:13:49 +0000277 }
278
279 // Forward to implementation based on the parser type
280 if (modelFormat.find("caffe") != std::string::npos)
281 {
282#if defined(ARMNN_CAFFE_PARSER)
283 return MainImpl<armnnCaffeParser::ICaffeParser, float>(modelPath.c_str(), isModelBinary, computeDevice,
telsoa01c577f2c2018-08-31 09:22:23 +0100284 inputName.c_str(), inputTensorShape.get(),
285 inputTensorDataFilePath.c_str(), outputName.c_str(),
286 enableProfiling, subgraphId, runtime);
telsoa014fcda012018-03-09 14:13:49 +0000287#else
288 BOOST_LOG_TRIVIAL(fatal) << "Not built with Caffe parser support.";
telsoa01c577f2c2018-08-31 09:22:23 +0100289 return EXIT_FAILURE;
290#endif
291 }
292 else if (modelFormat.find("onnx") != std::string::npos)
293{
294#if defined(ARMNN_ONNX_PARSER)
295 return MainImpl<armnnOnnxParser::IOnnxParser, float>(modelPath.c_str(), isModelBinary, computeDevice,
296 inputName.c_str(), inputTensorShape.get(),
297 inputTensorDataFilePath.c_str(), outputName.c_str(),
298 enableProfiling, subgraphId, runtime);
299#else
300 BOOST_LOG_TRIVIAL(fatal) << "Not built with Onnx parser support.";
301 return EXIT_FAILURE;
telsoa014fcda012018-03-09 14:13:49 +0000302#endif
303 }
304 else if (modelFormat.find("tensorflow") != std::string::npos)
305 {
surmeh01bceff2f2018-03-29 16:29:27 +0100306#if defined(ARMNN_TF_PARSER)
307 return MainImpl<armnnTfParser::ITfParser, float>(modelPath.c_str(), isModelBinary, computeDevice,
telsoa01c577f2c2018-08-31 09:22:23 +0100308 inputName.c_str(), inputTensorShape.get(),
309 inputTensorDataFilePath.c_str(), outputName.c_str(),
310 enableProfiling, subgraphId, runtime);
surmeh01bceff2f2018-03-29 16:29:27 +0100311#else
telsoa014fcda012018-03-09 14:13:49 +0000312 BOOST_LOG_TRIVIAL(fatal) << "Not built with Tensorflow parser support.";
telsoa01c577f2c2018-08-31 09:22:23 +0100313 return EXIT_FAILURE;
314#endif
315 }
316 else if(modelFormat.find("tflite") != std::string::npos)
317 {
318#if defined(ARMNN_TF_LITE_PARSER)
319 if (! isModelBinary)
320 {
321 BOOST_LOG_TRIVIAL(fatal) << "Unknown model format: '" << modelFormat << "'. Only 'binary' format supported \
322 for tflite files";
323 return EXIT_FAILURE;
324 }
325 return MainImpl<armnnTfLiteParser::ITfLiteParser, float>(modelPath.c_str(), isModelBinary, computeDevice,
326 inputName.c_str(), inputTensorShape.get(),
327 inputTensorDataFilePath.c_str(), outputName.c_str(),
328 enableProfiling, subgraphId, runtime);
329#else
330 BOOST_LOG_TRIVIAL(fatal) << "Unknown model format: '" << modelFormat <<
331 "'. Please include 'caffe', 'tensorflow', 'tflite' or 'onnx'";
332 return EXIT_FAILURE;
surmeh01bceff2f2018-03-29 16:29:27 +0100333#endif
telsoa014fcda012018-03-09 14:13:49 +0000334 }
335 else
336 {
337 BOOST_LOG_TRIVIAL(fatal) << "Unknown model format: '" << modelFormat <<
telsoa01c577f2c2018-08-31 09:22:23 +0100338 "'. Please include 'caffe', 'tensorflow', 'tflite' or 'onnx'";
339 return EXIT_FAILURE;
340 }
341}
342
343int RunCsvTest(const armnnUtils::CsvRow &csvRow,
Nina Drozd549ae372018-09-10 14:26:44 +0100344 const std::shared_ptr<armnn::IRuntime>& runtime, const bool enableProfiling)
telsoa01c577f2c2018-08-31 09:22:23 +0100345{
346 std::string modelFormat;
347 std::string modelPath;
348 std::string inputName;
349 std::string inputTensorShapeStr;
350 std::string inputTensorDataFilePath;
351 std::string outputName;
352
353 size_t subgraphId = 0;
354
355 po::options_description desc("Options");
356 try
357 {
358 desc.add_options()
359 ("model-format,f", po::value(&modelFormat),
360 "caffe-binary, caffe-text, tflite-binary, onnx-binary, onnx-text, tensorflow-binary or tensorflow-text.")
361 ("model-path,m", po::value(&modelPath), "Path to model file, e.g. .caffemodel, .prototxt, .tflite,"
362 " .onnx")
363 ("compute,c", po::value<std::vector<armnn::Compute>>()->multitoken(),
364 "The preferred order of devices to run layers on by default. Possible choices: CpuAcc, CpuRef, GpuAcc")
365 ("input-name,i", po::value(&inputName), "Identifier of the input tensor in the network.")
366 ("subgraph-number,n", po::value<size_t>(&subgraphId)->default_value(0), "Id of the subgraph to be "
367 "executed. Defaults to 0")
368 ("input-tensor-shape,s", po::value(&inputTensorShapeStr),
369 "The shape of the input tensor in the network as a flat array of integers separated by whitespace. "
370 "This parameter is optional, depending on the network.")
371 ("input-tensor-data,d", po::value(&inputTensorDataFilePath),
372 "Path to a file containing the input data as a flat array separated by whitespace.")
Nina Drozd549ae372018-09-10 14:26:44 +0100373 ("output-name,o", po::value(&outputName), "Identifier of the output tensor in the network.");
telsoa01c577f2c2018-08-31 09:22:23 +0100374 }
375 catch (const std::exception& e)
376 {
377 // Coverity points out that default_value(...) can throw a bad_lexical_cast,
378 // and that desc.add_options() can throw boost::io::too_few_args.
379 // They really won't in any of these cases.
380 BOOST_ASSERT_MSG(false, "Caught unexpected exception");
381 BOOST_LOG_TRIVIAL(fatal) << "Fatal internal error: " << e.what();
382 return EXIT_FAILURE;
383 }
384
385 std::vector<const char*> clOptions;
386 clOptions.reserve(csvRow.values.size());
387 for (const std::string& value : csvRow.values)
388 {
389 clOptions.push_back(value.c_str());
390 }
391
392 po::variables_map vm;
393 try
394 {
395 po::store(po::parse_command_line(static_cast<int>(clOptions.size()), clOptions.data(), desc), vm);
396
397 po::notify(vm);
398
399 CheckOptionDependencies(vm);
400 }
401 catch (const po::error& e)
402 {
403 std::cerr << e.what() << std::endl << std::endl;
404 std::cerr << desc << std::endl;
405 return EXIT_FAILURE;
406 }
407
408 // Remove leading and trailing whitespaces from the parsed arguments.
409 boost::trim(modelFormat);
410 boost::trim(modelPath);
411 boost::trim(inputName);
412 boost::trim(inputTensorShapeStr);
413 boost::trim(inputTensorDataFilePath);
414 boost::trim(outputName);
415
telsoa01c577f2c2018-08-31 09:22:23 +0100416 // Get the preferred order of compute devices.
417 std::vector<armnn::Compute> computeDevices = vm["compute"].as<std::vector<armnn::Compute>>();
418
419 // Remove duplicates from the list of compute devices.
420 RemoveDuplicateDevices(computeDevices);
421
422 // Check that the specified compute devices are valid.
423 if (!CheckDevicesAreValid(computeDevices))
424 {
425 BOOST_LOG_TRIVIAL(fatal) << "The list of preferred devices contains an invalid compute";
426 return EXIT_FAILURE;
427 }
428
429 return RunTest(modelFormat, inputTensorShapeStr, computeDevices,
430 modelPath, inputName, inputTensorDataFilePath, outputName, enableProfiling, subgraphId, runtime);
431}
432
433int main(int argc, const char* argv[])
434{
435 // Configures logging for both the ARMNN library and this test program.
436#ifdef NDEBUG
437 armnn::LogSeverity level = armnn::LogSeverity::Info;
438#else
439 armnn::LogSeverity level = armnn::LogSeverity::Debug;
440#endif
441 armnn::ConfigureLogging(true, true, level);
442 armnnUtils::ConfigureLogging(boost::log::core::get().get(), true, true, level);
443
444 std::string testCasesFile;
445
446 std::string modelFormat;
447 std::string modelPath;
448 std::string inputName;
449 std::string inputTensorShapeStr;
450 std::string inputTensorDataFilePath;
451 std::string outputName;
452
453 size_t subgraphId = 0;
454
455 po::options_description desc("Options");
456 try
457 {
458 desc.add_options()
459 ("help", "Display usage information")
460 ("test-cases,t", po::value(&testCasesFile), "Path to a CSV file containing test cases to run. "
461 "If set, further parameters -- with the exception of compute device and concurrency -- will be ignored, "
462 "as they are expected to be defined in the file for each test in particular.")
463 ("concurrent,n", po::bool_switch()->default_value(false),
464 "Whether or not the test cases should be executed in parallel")
465 ("model-format,f", po::value(&modelFormat),
466 "caffe-binary, caffe-text, onnx-binary, onnx-text, tflite-binary, tensorflow-binary or tensorflow-text.")
467 ("model-path,m", po::value(&modelPath), "Path to model file, e.g. .caffemodel, .prototxt,"
468 " .tflite, .onnx")
469 ("compute,c", po::value<std::vector<armnn::Compute>>()->multitoken(),
470 "The preferred order of devices to run layers on by default. Possible choices: CpuAcc, CpuRef, GpuAcc")
471 ("input-name,i", po::value(&inputName), "Identifier of the input tensor in the network.")
472 ("subgraph-number,x", po::value<size_t>(&subgraphId)->default_value(0), "Id of the subgraph to be executed."
473 "Defaults to 0")
474 ("input-tensor-shape,s", po::value(&inputTensorShapeStr),
475 "The shape of the input tensor in the network as a flat array of integers separated by whitespace. "
476 "This parameter is optional, depending on the network.")
477 ("input-tensor-data,d", po::value(&inputTensorDataFilePath),
478 "Path to a file containing the input data as a flat array separated by whitespace.")
479 ("output-name,o", po::value(&outputName), "Identifier of the output tensor in the network.")
480 ("event-based-profiling,e", po::bool_switch()->default_value(false),
481 "Enables built in profiler. If unset, defaults to off.");
482 }
483 catch (const std::exception& e)
484 {
485 // Coverity points out that default_value(...) can throw a bad_lexical_cast,
486 // and that desc.add_options() can throw boost::io::too_few_args.
487 // They really won't in any of these cases.
488 BOOST_ASSERT_MSG(false, "Caught unexpected exception");
489 BOOST_LOG_TRIVIAL(fatal) << "Fatal internal error: " << e.what();
490 return EXIT_FAILURE;
491 }
492
493 // Parses the command-line.
494 po::variables_map vm;
495 try
496 {
497 po::store(po::parse_command_line(argc, argv, desc), vm);
498
499 if (CheckOption(vm, "help") || argc <= 1)
500 {
501 std::cout << "Executes a neural network model using the provided input tensor. " << std::endl;
502 std::cout << "Prints the resulting output tensor." << std::endl;
503 std::cout << std::endl;
504 std::cout << desc << std::endl;
505 return EXIT_SUCCESS;
506 }
507
508 po::notify(vm);
509 }
510 catch (const po::error& e)
511 {
512 std::cerr << e.what() << std::endl << std::endl;
513 std::cerr << desc << std::endl;
514 return EXIT_FAILURE;
515 }
516
517 // Get the value of the switch arguments.
518 bool concurrent = vm["concurrent"].as<bool>();
519 bool enableProfiling = vm["event-based-profiling"].as<bool>();
520
521 // Check whether we have to load test cases from a file.
522 if (CheckOption(vm, "test-cases"))
523 {
524 // Check that the file exists.
525 if (!boost::filesystem::exists(testCasesFile))
526 {
527 BOOST_LOG_TRIVIAL(fatal) << "Given file \"" << testCasesFile << "\" does not exist";
528 return EXIT_FAILURE;
529 }
530
531 // Parse CSV file and extract test cases
532 armnnUtils::CsvReader reader;
533 std::vector<armnnUtils::CsvRow> testCases = reader.ParseFile(testCasesFile);
534
535 // Check that there is at least one test case to run
536 if (testCases.empty())
537 {
538 BOOST_LOG_TRIVIAL(fatal) << "Given file \"" << testCasesFile << "\" has no test cases";
539 return EXIT_FAILURE;
540 }
541
542 // Create runtime
543 armnn::IRuntime::CreationOptions options;
Nina Drozd549ae372018-09-10 14:26:44 +0100544 options.m_EnableGpuProfiling = enableProfiling;
545
telsoa01c577f2c2018-08-31 09:22:23 +0100546 std::shared_ptr<armnn::IRuntime> runtime(armnn::IRuntime::Create(options));
547
548 const std::string executableName("ExecuteNetwork");
549
550 // Check whether we need to run the test cases concurrently
551 if (concurrent)
552 {
553 std::vector<std::future<int>> results;
554 results.reserve(testCases.size());
555
556 // Run each test case in its own thread
557 for (auto& testCase : testCases)
558 {
559 testCase.values.insert(testCase.values.begin(), executableName);
Nina Drozd549ae372018-09-10 14:26:44 +0100560 results.push_back(std::async(std::launch::async, RunCsvTest, std::cref(testCase), std::cref(runtime),
561 enableProfiling));
telsoa01c577f2c2018-08-31 09:22:23 +0100562 }
563
564 // Check results
565 for (auto& result : results)
566 {
567 if (result.get() != EXIT_SUCCESS)
568 {
569 return EXIT_FAILURE;
570 }
571 }
572 }
573 else
574 {
575 // Run tests sequentially
576 for (auto& testCase : testCases)
577 {
578 testCase.values.insert(testCase.values.begin(), executableName);
Nina Drozd549ae372018-09-10 14:26:44 +0100579 if (RunCsvTest(testCase, runtime, enableProfiling) != EXIT_SUCCESS)
telsoa01c577f2c2018-08-31 09:22:23 +0100580 {
581 return EXIT_FAILURE;
582 }
583 }
584 }
585
586 return EXIT_SUCCESS;
587 }
588 else // Run single test
589 {
590 // Get the preferred order of compute devices.
591 std::vector<armnn::Compute> computeDevices = vm["compute"].as<std::vector<armnn::Compute>>();
592
593 // Remove duplicates from the list of compute devices.
594 RemoveDuplicateDevices(computeDevices);
595
596 // Check that the specified compute devices are valid.
597 if (!CheckDevicesAreValid(computeDevices))
598 {
599 BOOST_LOG_TRIVIAL(fatal) << "The list of preferred devices contains an invalid compute";
600 return EXIT_FAILURE;
601 }
602
603 try
604 {
605 CheckOptionDependencies(vm);
606 }
607 catch (const po::error& e)
608 {
609 std::cerr << e.what() << std::endl << std::endl;
610 std::cerr << desc << std::endl;
611 return EXIT_FAILURE;
612 }
613
614 return RunTest(modelFormat, inputTensorShapeStr, computeDevices,
615 modelPath, inputName, inputTensorDataFilePath, outputName, enableProfiling, subgraphId);
telsoa014fcda012018-03-09 14:13:49 +0000616 }
617}