blob: e6e6cdf80de235a89fa2060382ff40f18aabe783 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
Aron Virginas-Tar5cc8e562018-10-23 15:14:46 +01005#include <armnn/ArmNN.hpp>
telsoa01c577f2c2018-08-31 09:22:23 +01006#include <armnn/TypesUtils.hpp>
7
telsoa014fcda012018-03-09 14:13:49 +00008#if defined(ARMNN_CAFFE_PARSER)
9#include "armnnCaffeParser/ICaffeParser.hpp"
10#endif
surmeh01bceff2f2018-03-29 16:29:27 +010011#if defined(ARMNN_TF_PARSER)
12#include "armnnTfParser/ITfParser.hpp"
13#endif
telsoa01c577f2c2018-08-31 09:22:23 +010014#if defined(ARMNN_TF_LITE_PARSER)
15#include "armnnTfLiteParser/ITfLiteParser.hpp"
16#endif
17#if defined(ARMNN_ONNX_PARSER)
18#include "armnnOnnxParser/IOnnxParser.hpp"
19#endif
20#include "CsvReader.hpp"
telsoa014fcda012018-03-09 14:13:49 +000021#include "../InferenceTest.hpp"
22
telsoa01c577f2c2018-08-31 09:22:23 +010023#include <Logging.hpp>
24#include <Profiling.hpp>
25
26#include <boost/algorithm/string/trim.hpp>
telsoa014fcda012018-03-09 14:13:49 +000027#include <boost/algorithm/string/split.hpp>
28#include <boost/algorithm/string/classification.hpp>
telsoa01c577f2c2018-08-31 09:22:23 +010029#include <boost/program_options.hpp>
telsoa014fcda012018-03-09 14:13:49 +000030
31#include <iostream>
32#include <fstream>
telsoa01c577f2c2018-08-31 09:22:23 +010033#include <functional>
34#include <future>
35#include <algorithm>
36#include <iterator>
telsoa014fcda012018-03-09 14:13:49 +000037
38namespace
39{
40
telsoa01c577f2c2018-08-31 09:22:23 +010041// Configure boost::program_options for command-line parsing and validation.
42namespace po = boost::program_options;
43
telsoa014fcda012018-03-09 14:13:49 +000044template<typename T, typename TParseElementFunc>
45std::vector<T> ParseArrayImpl(std::istream& stream, TParseElementFunc parseElementFunc)
46{
47 std::vector<T> result;
telsoa01c577f2c2018-08-31 09:22:23 +010048 // Processes line-by-line.
telsoa014fcda012018-03-09 14:13:49 +000049 std::string line;
50 while (std::getline(stream, line))
51 {
52 std::vector<std::string> tokens;
surmeh013537c2c2018-05-18 16:31:43 +010053 try
54 {
55 // Coverity fix: boost::split() may throw an exception of type boost::bad_function_call.
56 boost::split(tokens, line, boost::algorithm::is_any_of("\t ,;:"), boost::token_compress_on);
57 }
58 catch (const std::exception& e)
59 {
60 BOOST_LOG_TRIVIAL(error) << "An error occurred when splitting tokens: " << e.what();
61 continue;
62 }
telsoa014fcda012018-03-09 14:13:49 +000063 for (const std::string& token : tokens)
64 {
65 if (!token.empty()) // See https://stackoverflow.com/questions/10437406/
66 {
67 try
68 {
69 result.push_back(parseElementFunc(token));
70 }
71 catch (const std::exception&)
72 {
73 BOOST_LOG_TRIVIAL(error) << "'" << token << "' is not a valid number. It has been ignored.";
74 }
75 }
76 }
77 }
78
79 return result;
80}
81
telsoa01c577f2c2018-08-31 09:22:23 +010082bool CheckOption(const po::variables_map& vm,
83 const char* option)
84{
85 // Check that the given option is valid.
86 if (option == nullptr)
87 {
88 return false;
89 }
90
91 // Check whether 'option' is provided.
92 return vm.find(option) != vm.end();
93}
94
95void CheckOptionDependency(const po::variables_map& vm,
96 const char* option,
97 const char* required)
98{
99 // Check that the given options are valid.
100 if (option == nullptr || required == nullptr)
101 {
102 throw po::error("Invalid option to check dependency for");
103 }
104
105 // Check that if 'option' is provided, 'required' is also provided.
106 if (CheckOption(vm, option) && !vm[option].defaulted())
107 {
108 if (CheckOption(vm, required) == 0 || vm[required].defaulted())
109 {
110 throw po::error(std::string("Option '") + option + "' requires option '" + required + "'.");
111 }
112 }
113}
114
115void CheckOptionDependencies(const po::variables_map& vm)
116{
117 CheckOptionDependency(vm, "model-path", "model-format");
118 CheckOptionDependency(vm, "model-path", "input-name");
119 CheckOptionDependency(vm, "model-path", "input-tensor-data");
120 CheckOptionDependency(vm, "model-path", "output-name");
121 CheckOptionDependency(vm, "input-tensor-shape", "model-path");
telsoa014fcda012018-03-09 14:13:49 +0000122}
123
124template<typename T>
125std::vector<T> ParseArray(std::istream& stream);
126
127template<>
128std::vector<float> ParseArray(std::istream& stream)
129{
130 return ParseArrayImpl<float>(stream, [](const std::string& s) { return std::stof(s); });
131}
132
133template<>
134std::vector<unsigned int> ParseArray(std::istream& stream)
135{
136 return ParseArrayImpl<unsigned int>(stream,
137 [](const std::string& s) { return boost::numeric_cast<unsigned int>(std::stoi(s)); });
138}
139
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000140void PrintOutputData(const std::string& outputLayerName, const std::vector<float>& data)
telsoa014fcda012018-03-09 14:13:49 +0000141{
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000142 std::cout << outputLayerName << ": ";
143 std::copy(data.begin(), data.end(),
144 std::ostream_iterator<float>(std::cout, " "));
145 std::cout << std::endl;
telsoa014fcda012018-03-09 14:13:49 +0000146}
147
David Beckf0b48452018-10-19 15:20:56 +0100148void RemoveDuplicateDevices(std::vector<armnn::BackendId>& computeDevices)
telsoa014fcda012018-03-09 14:13:49 +0000149{
telsoa01c577f2c2018-08-31 09:22:23 +0100150 // Mark the duplicate devices as 'Undefined'.
151 for (auto i = computeDevices.begin(); i != computeDevices.end(); ++i)
152 {
153 for (auto j = std::next(i); j != computeDevices.end(); ++j)
154 {
155 if (*j == *i)
156 {
157 *j = armnn::Compute::Undefined;
158 }
159 }
160 }
161
162 // Remove 'Undefined' devices.
163 computeDevices.erase(std::remove(computeDevices.begin(), computeDevices.end(), armnn::Compute::Undefined),
164 computeDevices.end());
165}
166
telsoa01c577f2c2018-08-31 09:22:23 +0100167} // namespace
168
169template<typename TParser, typename TDataType>
170int MainImpl(const char* modelPath,
171 bool isModelBinary,
David Beckf0b48452018-10-19 15:20:56 +0100172 const std::vector<armnn::BackendId>& computeDevice,
telsoa01c577f2c2018-08-31 09:22:23 +0100173 const char* inputName,
174 const armnn::TensorShape* inputTensorShape,
175 const char* inputTensorDataFilePath,
176 const char* outputName,
177 bool enableProfiling,
178 const size_t subgraphId,
179 const std::shared_ptr<armnn::IRuntime>& runtime = nullptr)
180{
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000181 using TContainer = std::vector<TDataType>;
182
telsoa01c577f2c2018-08-31 09:22:23 +0100183 // Loads input tensor.
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000184 TContainer inputDataContainer;
telsoa014fcda012018-03-09 14:13:49 +0000185 {
186 std::ifstream inputTensorFile(inputTensorDataFilePath);
187 if (!inputTensorFile.good())
188 {
189 BOOST_LOG_TRIVIAL(fatal) << "Failed to load input tensor data file from " << inputTensorDataFilePath;
telsoa01c577f2c2018-08-31 09:22:23 +0100190 return EXIT_FAILURE;
telsoa014fcda012018-03-09 14:13:49 +0000191 }
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000192 inputDataContainer = ParseArray<TDataType>(inputTensorFile);
telsoa014fcda012018-03-09 14:13:49 +0000193 }
194
195 try
196 {
telsoa01c577f2c2018-08-31 09:22:23 +0100197 // Creates an InferenceModel, which will parse the model and load it into an IRuntime.
telsoa014fcda012018-03-09 14:13:49 +0000198 typename InferenceModel<TParser, TDataType>::Params params;
199 params.m_ModelPath = modelPath;
200 params.m_IsModelBinary = isModelBinary;
201 params.m_ComputeDevice = computeDevice;
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000202 params.m_InputBindings = { inputName };
203 params.m_InputShapes = { *inputTensorShape };
204 params.m_OutputBindings = { outputName };
telsoa01c577f2c2018-08-31 09:22:23 +0100205 params.m_EnableProfiling = enableProfiling;
206 params.m_SubgraphId = subgraphId;
207 InferenceModel<TParser, TDataType> model(params, runtime);
telsoa014fcda012018-03-09 14:13:49 +0000208
Aron Virginas-Tar9b937472019-01-30 17:41:47 +0000209 const size_t numOutputs = params.m_OutputBindings.size();
210 const size_t containerSize = model.GetOutputSize();
211
212 std::vector<TContainer> outputDataContainers(numOutputs, TContainer(containerSize));
213
214 // Execute model
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000215 model.Run({ inputDataContainer }, outputDataContainers);
telsoa014fcda012018-03-09 14:13:49 +0000216
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000217 // Print output tensors
218 for (size_t i = 0; i < numOutputs; i++)
219 {
220 PrintOutputData(params.m_OutputBindings[i], outputDataContainers[i]);
221 }
telsoa014fcda012018-03-09 14:13:49 +0000222 }
223 catch (armnn::Exception const& e)
224 {
225 BOOST_LOG_TRIVIAL(fatal) << "Armnn Error: " << e.what();
telsoa01c577f2c2018-08-31 09:22:23 +0100226 return EXIT_FAILURE;
telsoa014fcda012018-03-09 14:13:49 +0000227 }
228
telsoa01c577f2c2018-08-31 09:22:23 +0100229 return EXIT_SUCCESS;
telsoa014fcda012018-03-09 14:13:49 +0000230}
231
telsoa01c577f2c2018-08-31 09:22:23 +0100232// This will run a test
233int RunTest(const std::string& modelFormat,
234 const std::string& inputTensorShapeStr,
David Beckf0b48452018-10-19 15:20:56 +0100235 const vector<armnn::BackendId>& computeDevice,
telsoa01c577f2c2018-08-31 09:22:23 +0100236 const std::string& modelPath,
237 const std::string& inputName,
238 const std::string& inputTensorDataFilePath,
239 const std::string& outputName,
240 bool enableProfiling,
241 const size_t subgraphId,
242 const std::shared_ptr<armnn::IRuntime>& runtime = nullptr)
telsoa014fcda012018-03-09 14:13:49 +0000243{
telsoa014fcda012018-03-09 14:13:49 +0000244 // Parse model binary flag from the model-format string we got from the command-line
245 bool isModelBinary;
246 if (modelFormat.find("bin") != std::string::npos)
247 {
248 isModelBinary = true;
249 }
250 else if (modelFormat.find("txt") != std::string::npos || modelFormat.find("text") != std::string::npos)
251 {
252 isModelBinary = false;
253 }
254 else
255 {
256 BOOST_LOG_TRIVIAL(fatal) << "Unknown model format: '" << modelFormat << "'. Please include 'binary' or 'text'";
telsoa01c577f2c2018-08-31 09:22:23 +0100257 return EXIT_FAILURE;
telsoa014fcda012018-03-09 14:13:49 +0000258 }
259
260 // Parse input tensor shape from the string we got from the command-line.
261 std::unique_ptr<armnn::TensorShape> inputTensorShape;
262 if (!inputTensorShapeStr.empty())
263 {
264 std::stringstream ss(inputTensorShapeStr);
265 std::vector<unsigned int> dims = ParseArray<unsigned int>(ss);
surmeh013537c2c2018-05-18 16:31:43 +0100266
267 try
268 {
269 // Coverity fix: An exception of type armnn::InvalidArgumentException is thrown and never caught.
270 inputTensorShape = std::make_unique<armnn::TensorShape>(dims.size(), dims.data());
271 }
272 catch (const armnn::InvalidArgumentException& e)
273 {
274 BOOST_LOG_TRIVIAL(fatal) << "Cannot create tensor shape: " << e.what();
telsoa01c577f2c2018-08-31 09:22:23 +0100275 return EXIT_FAILURE;
surmeh013537c2c2018-05-18 16:31:43 +0100276 }
telsoa014fcda012018-03-09 14:13:49 +0000277 }
278
279 // Forward to implementation based on the parser type
280 if (modelFormat.find("caffe") != std::string::npos)
281 {
282#if defined(ARMNN_CAFFE_PARSER)
283 return MainImpl<armnnCaffeParser::ICaffeParser, float>(modelPath.c_str(), isModelBinary, computeDevice,
telsoa01c577f2c2018-08-31 09:22:23 +0100284 inputName.c_str(), inputTensorShape.get(),
285 inputTensorDataFilePath.c_str(), outputName.c_str(),
286 enableProfiling, subgraphId, runtime);
telsoa014fcda012018-03-09 14:13:49 +0000287#else
288 BOOST_LOG_TRIVIAL(fatal) << "Not built with Caffe parser support.";
telsoa01c577f2c2018-08-31 09:22:23 +0100289 return EXIT_FAILURE;
290#endif
291 }
292 else if (modelFormat.find("onnx") != std::string::npos)
293{
294#if defined(ARMNN_ONNX_PARSER)
295 return MainImpl<armnnOnnxParser::IOnnxParser, float>(modelPath.c_str(), isModelBinary, computeDevice,
296 inputName.c_str(), inputTensorShape.get(),
297 inputTensorDataFilePath.c_str(), outputName.c_str(),
298 enableProfiling, subgraphId, runtime);
299#else
300 BOOST_LOG_TRIVIAL(fatal) << "Not built with Onnx parser support.";
301 return EXIT_FAILURE;
telsoa014fcda012018-03-09 14:13:49 +0000302#endif
303 }
304 else if (modelFormat.find("tensorflow") != std::string::npos)
305 {
surmeh01bceff2f2018-03-29 16:29:27 +0100306#if defined(ARMNN_TF_PARSER)
307 return MainImpl<armnnTfParser::ITfParser, float>(modelPath.c_str(), isModelBinary, computeDevice,
telsoa01c577f2c2018-08-31 09:22:23 +0100308 inputName.c_str(), inputTensorShape.get(),
309 inputTensorDataFilePath.c_str(), outputName.c_str(),
310 enableProfiling, subgraphId, runtime);
surmeh01bceff2f2018-03-29 16:29:27 +0100311#else
telsoa014fcda012018-03-09 14:13:49 +0000312 BOOST_LOG_TRIVIAL(fatal) << "Not built with Tensorflow parser support.";
telsoa01c577f2c2018-08-31 09:22:23 +0100313 return EXIT_FAILURE;
314#endif
315 }
316 else if(modelFormat.find("tflite") != std::string::npos)
317 {
318#if defined(ARMNN_TF_LITE_PARSER)
319 if (! isModelBinary)
320 {
321 BOOST_LOG_TRIVIAL(fatal) << "Unknown model format: '" << modelFormat << "'. Only 'binary' format supported \
322 for tflite files";
323 return EXIT_FAILURE;
324 }
325 return MainImpl<armnnTfLiteParser::ITfLiteParser, float>(modelPath.c_str(), isModelBinary, computeDevice,
326 inputName.c_str(), inputTensorShape.get(),
327 inputTensorDataFilePath.c_str(), outputName.c_str(),
328 enableProfiling, subgraphId, runtime);
329#else
330 BOOST_LOG_TRIVIAL(fatal) << "Unknown model format: '" << modelFormat <<
331 "'. Please include 'caffe', 'tensorflow', 'tflite' or 'onnx'";
332 return EXIT_FAILURE;
surmeh01bceff2f2018-03-29 16:29:27 +0100333#endif
telsoa014fcda012018-03-09 14:13:49 +0000334 }
335 else
336 {
337 BOOST_LOG_TRIVIAL(fatal) << "Unknown model format: '" << modelFormat <<
telsoa01c577f2c2018-08-31 09:22:23 +0100338 "'. Please include 'caffe', 'tensorflow', 'tflite' or 'onnx'";
339 return EXIT_FAILURE;
340 }
341}
342
343int RunCsvTest(const armnnUtils::CsvRow &csvRow,
Nina Drozd549ae372018-09-10 14:26:44 +0100344 const std::shared_ptr<armnn::IRuntime>& runtime, const bool enableProfiling)
telsoa01c577f2c2018-08-31 09:22:23 +0100345{
346 std::string modelFormat;
347 std::string modelPath;
348 std::string inputName;
349 std::string inputTensorShapeStr;
350 std::string inputTensorDataFilePath;
351 std::string outputName;
352
353 size_t subgraphId = 0;
354
Aron Virginas-Tar5cc8e562018-10-23 15:14:46 +0100355 const std::string backendsMessage = std::string("The preferred order of devices to run layers on by default. ")
356 + std::string("Possible choices: ")
357 + armnn::BackendRegistryInstance().GetBackendIdsAsString();
358
telsoa01c577f2c2018-08-31 09:22:23 +0100359 po::options_description desc("Options");
360 try
361 {
362 desc.add_options()
363 ("model-format,f", po::value(&modelFormat),
364 "caffe-binary, caffe-text, tflite-binary, onnx-binary, onnx-text, tensorflow-binary or tensorflow-text.")
365 ("model-path,m", po::value(&modelPath), "Path to model file, e.g. .caffemodel, .prototxt, .tflite,"
366 " .onnx")
David Beckf0b48452018-10-19 15:20:56 +0100367 ("compute,c", po::value<std::vector<armnn::BackendId>>()->multitoken(),
Aron Virginas-Tar5cc8e562018-10-23 15:14:46 +0100368 backendsMessage.c_str())
telsoa01c577f2c2018-08-31 09:22:23 +0100369 ("input-name,i", po::value(&inputName), "Identifier of the input tensor in the network.")
370 ("subgraph-number,n", po::value<size_t>(&subgraphId)->default_value(0), "Id of the subgraph to be "
371 "executed. Defaults to 0")
372 ("input-tensor-shape,s", po::value(&inputTensorShapeStr),
373 "The shape of the input tensor in the network as a flat array of integers separated by whitespace. "
374 "This parameter is optional, depending on the network.")
375 ("input-tensor-data,d", po::value(&inputTensorDataFilePath),
376 "Path to a file containing the input data as a flat array separated by whitespace.")
Nina Drozd549ae372018-09-10 14:26:44 +0100377 ("output-name,o", po::value(&outputName), "Identifier of the output tensor in the network.");
telsoa01c577f2c2018-08-31 09:22:23 +0100378 }
379 catch (const std::exception& e)
380 {
381 // Coverity points out that default_value(...) can throw a bad_lexical_cast,
382 // and that desc.add_options() can throw boost::io::too_few_args.
383 // They really won't in any of these cases.
384 BOOST_ASSERT_MSG(false, "Caught unexpected exception");
385 BOOST_LOG_TRIVIAL(fatal) << "Fatal internal error: " << e.what();
386 return EXIT_FAILURE;
387 }
388
389 std::vector<const char*> clOptions;
390 clOptions.reserve(csvRow.values.size());
391 for (const std::string& value : csvRow.values)
392 {
393 clOptions.push_back(value.c_str());
394 }
395
396 po::variables_map vm;
397 try
398 {
399 po::store(po::parse_command_line(static_cast<int>(clOptions.size()), clOptions.data(), desc), vm);
400
401 po::notify(vm);
402
403 CheckOptionDependencies(vm);
404 }
405 catch (const po::error& e)
406 {
407 std::cerr << e.what() << std::endl << std::endl;
408 std::cerr << desc << std::endl;
409 return EXIT_FAILURE;
410 }
411
412 // Remove leading and trailing whitespaces from the parsed arguments.
413 boost::trim(modelFormat);
414 boost::trim(modelPath);
415 boost::trim(inputName);
416 boost::trim(inputTensorShapeStr);
417 boost::trim(inputTensorDataFilePath);
418 boost::trim(outputName);
419
telsoa01c577f2c2018-08-31 09:22:23 +0100420 // Get the preferred order of compute devices.
David Beckf0b48452018-10-19 15:20:56 +0100421 std::vector<armnn::BackendId> computeDevices = vm["compute"].as<std::vector<armnn::BackendId>>();
telsoa01c577f2c2018-08-31 09:22:23 +0100422
423 // Remove duplicates from the list of compute devices.
424 RemoveDuplicateDevices(computeDevices);
425
426 // Check that the specified compute devices are valid.
Aron Virginas-Tar5cc8e562018-10-23 15:14:46 +0100427 std::string invalidBackends;
428 if (!CheckRequestedBackendsAreValid(computeDevices, armnn::Optional<std::string&>(invalidBackends)))
telsoa01c577f2c2018-08-31 09:22:23 +0100429 {
Aron Virginas-Tar5cc8e562018-10-23 15:14:46 +0100430 BOOST_LOG_TRIVIAL(fatal) << "The list of preferred devices contains invalid backend IDs: "
431 << invalidBackends;
telsoa01c577f2c2018-08-31 09:22:23 +0100432 return EXIT_FAILURE;
433 }
434
435 return RunTest(modelFormat, inputTensorShapeStr, computeDevices,
436 modelPath, inputName, inputTensorDataFilePath, outputName, enableProfiling, subgraphId, runtime);
437}
438
439int main(int argc, const char* argv[])
440{
441 // Configures logging for both the ARMNN library and this test program.
442#ifdef NDEBUG
443 armnn::LogSeverity level = armnn::LogSeverity::Info;
444#else
445 armnn::LogSeverity level = armnn::LogSeverity::Debug;
446#endif
447 armnn::ConfigureLogging(true, true, level);
448 armnnUtils::ConfigureLogging(boost::log::core::get().get(), true, true, level);
449
450 std::string testCasesFile;
451
452 std::string modelFormat;
453 std::string modelPath;
454 std::string inputName;
455 std::string inputTensorShapeStr;
456 std::string inputTensorDataFilePath;
457 std::string outputName;
458
459 size_t subgraphId = 0;
460
Aron Virginas-Tar5cc8e562018-10-23 15:14:46 +0100461 const std::string backendsMessage = "Which device to run layers on by default. Possible choices: "
462 + armnn::BackendRegistryInstance().GetBackendIdsAsString();
463
telsoa01c577f2c2018-08-31 09:22:23 +0100464 po::options_description desc("Options");
465 try
466 {
467 desc.add_options()
468 ("help", "Display usage information")
469 ("test-cases,t", po::value(&testCasesFile), "Path to a CSV file containing test cases to run. "
470 "If set, further parameters -- with the exception of compute device and concurrency -- will be ignored, "
471 "as they are expected to be defined in the file for each test in particular.")
472 ("concurrent,n", po::bool_switch()->default_value(false),
473 "Whether or not the test cases should be executed in parallel")
Matteo Martincigh49124022019-01-11 13:25:59 +0000474 ("model-format,f", po::value(&modelFormat)->required(),
telsoa01c577f2c2018-08-31 09:22:23 +0100475 "caffe-binary, caffe-text, onnx-binary, onnx-text, tflite-binary, tensorflow-binary or tensorflow-text.")
Matteo Martincigh49124022019-01-11 13:25:59 +0000476 ("model-path,m", po::value(&modelPath)->required(), "Path to model file, e.g. .caffemodel, .prototxt,"
telsoa01c577f2c2018-08-31 09:22:23 +0100477 " .tflite, .onnx")
David Beckf0b48452018-10-19 15:20:56 +0100478 ("compute,c", po::value<std::vector<std::string>>()->multitoken(),
Aron Virginas-Tar5cc8e562018-10-23 15:14:46 +0100479 backendsMessage.c_str())
telsoa01c577f2c2018-08-31 09:22:23 +0100480 ("input-name,i", po::value(&inputName), "Identifier of the input tensor in the network.")
481 ("subgraph-number,x", po::value<size_t>(&subgraphId)->default_value(0), "Id of the subgraph to be executed."
482 "Defaults to 0")
483 ("input-tensor-shape,s", po::value(&inputTensorShapeStr),
484 "The shape of the input tensor in the network as a flat array of integers separated by whitespace. "
485 "This parameter is optional, depending on the network.")
486 ("input-tensor-data,d", po::value(&inputTensorDataFilePath),
487 "Path to a file containing the input data as a flat array separated by whitespace.")
488 ("output-name,o", po::value(&outputName), "Identifier of the output tensor in the network.")
489 ("event-based-profiling,e", po::bool_switch()->default_value(false),
490 "Enables built in profiler. If unset, defaults to off.");
491 }
492 catch (const std::exception& e)
493 {
494 // Coverity points out that default_value(...) can throw a bad_lexical_cast,
495 // and that desc.add_options() can throw boost::io::too_few_args.
496 // They really won't in any of these cases.
497 BOOST_ASSERT_MSG(false, "Caught unexpected exception");
498 BOOST_LOG_TRIVIAL(fatal) << "Fatal internal error: " << e.what();
499 return EXIT_FAILURE;
500 }
501
502 // Parses the command-line.
503 po::variables_map vm;
504 try
505 {
506 po::store(po::parse_command_line(argc, argv, desc), vm);
507
508 if (CheckOption(vm, "help") || argc <= 1)
509 {
510 std::cout << "Executes a neural network model using the provided input tensor. " << std::endl;
511 std::cout << "Prints the resulting output tensor." << std::endl;
512 std::cout << std::endl;
513 std::cout << desc << std::endl;
514 return EXIT_SUCCESS;
515 }
516
517 po::notify(vm);
518 }
519 catch (const po::error& e)
520 {
521 std::cerr << e.what() << std::endl << std::endl;
522 std::cerr << desc << std::endl;
523 return EXIT_FAILURE;
524 }
525
526 // Get the value of the switch arguments.
527 bool concurrent = vm["concurrent"].as<bool>();
528 bool enableProfiling = vm["event-based-profiling"].as<bool>();
529
530 // Check whether we have to load test cases from a file.
531 if (CheckOption(vm, "test-cases"))
532 {
533 // Check that the file exists.
534 if (!boost::filesystem::exists(testCasesFile))
535 {
536 BOOST_LOG_TRIVIAL(fatal) << "Given file \"" << testCasesFile << "\" does not exist";
537 return EXIT_FAILURE;
538 }
539
540 // Parse CSV file and extract test cases
541 armnnUtils::CsvReader reader;
542 std::vector<armnnUtils::CsvRow> testCases = reader.ParseFile(testCasesFile);
543
544 // Check that there is at least one test case to run
545 if (testCases.empty())
546 {
547 BOOST_LOG_TRIVIAL(fatal) << "Given file \"" << testCasesFile << "\" has no test cases";
548 return EXIT_FAILURE;
549 }
550
551 // Create runtime
552 armnn::IRuntime::CreationOptions options;
Nina Drozd549ae372018-09-10 14:26:44 +0100553 options.m_EnableGpuProfiling = enableProfiling;
554
telsoa01c577f2c2018-08-31 09:22:23 +0100555 std::shared_ptr<armnn::IRuntime> runtime(armnn::IRuntime::Create(options));
556
557 const std::string executableName("ExecuteNetwork");
558
559 // Check whether we need to run the test cases concurrently
560 if (concurrent)
561 {
562 std::vector<std::future<int>> results;
563 results.reserve(testCases.size());
564
565 // Run each test case in its own thread
566 for (auto& testCase : testCases)
567 {
568 testCase.values.insert(testCase.values.begin(), executableName);
Nina Drozd549ae372018-09-10 14:26:44 +0100569 results.push_back(std::async(std::launch::async, RunCsvTest, std::cref(testCase), std::cref(runtime),
570 enableProfiling));
telsoa01c577f2c2018-08-31 09:22:23 +0100571 }
572
573 // Check results
574 for (auto& result : results)
575 {
576 if (result.get() != EXIT_SUCCESS)
577 {
578 return EXIT_FAILURE;
579 }
580 }
581 }
582 else
583 {
584 // Run tests sequentially
585 for (auto& testCase : testCases)
586 {
587 testCase.values.insert(testCase.values.begin(), executableName);
Nina Drozd549ae372018-09-10 14:26:44 +0100588 if (RunCsvTest(testCase, runtime, enableProfiling) != EXIT_SUCCESS)
telsoa01c577f2c2018-08-31 09:22:23 +0100589 {
590 return EXIT_FAILURE;
591 }
592 }
593 }
594
595 return EXIT_SUCCESS;
596 }
597 else // Run single test
598 {
Aron Virginas-Tar382e21c2019-01-22 14:10:39 +0000599 // Get the preferred order of compute devices. If none are specified, default to using CpuRef
600 const std::string computeOption("compute");
601 std::vector<std::string> computeDevicesAsStrings = CheckOption(vm, computeOption.c_str()) ?
602 vm[computeOption].as<std::vector<std::string>>() :
603 std::vector<std::string>({ "CpuRef" });
Matteo Martincigh067112f2018-10-29 11:01:09 +0000604 std::vector<armnn::BackendId> computeDevices(computeDevicesAsStrings.begin(), computeDevicesAsStrings.end());
telsoa01c577f2c2018-08-31 09:22:23 +0100605
606 // Remove duplicates from the list of compute devices.
607 RemoveDuplicateDevices(computeDevices);
608
609 // Check that the specified compute devices are valid.
Aron Virginas-Tar5cc8e562018-10-23 15:14:46 +0100610 std::string invalidBackends;
611 if (!CheckRequestedBackendsAreValid(computeDevices, armnn::Optional<std::string&>(invalidBackends)))
telsoa01c577f2c2018-08-31 09:22:23 +0100612 {
Aron Virginas-Tar5cc8e562018-10-23 15:14:46 +0100613 BOOST_LOG_TRIVIAL(fatal) << "The list of preferred devices contains invalid backend IDs: "
614 << invalidBackends;
telsoa01c577f2c2018-08-31 09:22:23 +0100615 return EXIT_FAILURE;
616 }
617
618 try
619 {
620 CheckOptionDependencies(vm);
621 }
622 catch (const po::error& e)
623 {
624 std::cerr << e.what() << std::endl << std::endl;
625 std::cerr << desc << std::endl;
626 return EXIT_FAILURE;
627 }
628
629 return RunTest(modelFormat, inputTensorShapeStr, computeDevices,
630 modelPath, inputName, inputTensorDataFilePath, outputName, enableProfiling, subgraphId);
telsoa014fcda012018-03-09 14:13:49 +0000631 }
632}