blob: c17eabd837c2805e409e6f91a46bfe579485b324 [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
Sadik Armagana9c2ce12020-07-14 10:02:22 +01002// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
telsoa01c577f2c2018-08-31 09:22:23 +01005
Jan Eilers45274902020-10-15 18:34:43 +01006#include "NetworkExecutionUtils/NetworkExecutionUtils.hpp"
7#include "ExecuteNetworkProgramOptions.hpp"
8
9#include <armnn/Logging.hpp>
10#include <Filesystem.hpp>
11#include <InferenceTest.hpp>
12
13#if defined(ARMNN_SERIALIZER)
14#include "armnnDeserializer/IDeserializer.hpp"
15#endif
16#if defined(ARMNN_CAFFE_PARSER)
17#include "armnnCaffeParser/ICaffeParser.hpp"
18#endif
19#if defined(ARMNN_TF_PARSER)
20#include "armnnTfParser/ITfParser.hpp"
21#endif
22#if defined(ARMNN_TF_LITE_PARSER)
23#include "armnnTfLiteParser/ITfLiteParser.hpp"
24#endif
25#if defined(ARMNN_ONNX_PARSER)
26#include "armnnOnnxParser/IOnnxParser.hpp"
27#endif
28
29#include <future>
30
31template<typename TParser, typename TDataType>
32int MainImpl(const ExecuteNetworkParams& params,
33 const std::shared_ptr<armnn::IRuntime>& runtime = nullptr)
34{
35 using TContainer = mapbox::util::variant<std::vector<float>, std::vector<int>, std::vector<unsigned char>>;
36
37 std::vector<TContainer> inputDataContainers;
38
39 try
40 {
41 // Creates an InferenceModel, which will parse the model and load it into an IRuntime.
42 typename InferenceModel<TParser, TDataType>::Params inferenceModelParams;
43 inferenceModelParams.m_ModelPath = params.m_ModelPath;
44 inferenceModelParams.m_IsModelBinary = params.m_IsModelBinary;
45 inferenceModelParams.m_ComputeDevices = params.m_ComputeDevices;
46 inferenceModelParams.m_DynamicBackendsPath = params.m_DynamicBackendsPath;
47 inferenceModelParams.m_PrintIntermediateLayers = params.m_PrintIntermediate;
48 inferenceModelParams.m_VisualizePostOptimizationModel = params.m_EnableLayerDetails;
49 inferenceModelParams.m_ParseUnsupported = params.m_ParseUnsupported;
50 inferenceModelParams.m_InferOutputShape = params.m_InferOutputShape;
51 inferenceModelParams.m_EnableFastMath = params.m_EnableFastMath;
52
53 for(const std::string& inputName: params.m_InputNames)
54 {
55 inferenceModelParams.m_InputBindings.push_back(inputName);
56 }
57
58 for(unsigned int i = 0; i < params.m_InputTensorShapes.size(); ++i)
59 {
60 inferenceModelParams.m_InputShapes.push_back(*params.m_InputTensorShapes[i]);
61 }
62
63 for(const std::string& outputName: params.m_OutputNames)
64 {
65 inferenceModelParams.m_OutputBindings.push_back(outputName);
66 }
67
68 inferenceModelParams.m_SubgraphId = params.m_SubgraphId;
69 inferenceModelParams.m_EnableFp16TurboMode = params.m_EnableFp16TurboMode;
70 inferenceModelParams.m_EnableBf16TurboMode = params.m_EnableBf16TurboMode;
71
72 InferenceModel<TParser, TDataType> model(inferenceModelParams,
73 params.m_EnableProfiling,
74 params.m_DynamicBackendsPath,
75 runtime);
76
77 const size_t numInputs = inferenceModelParams.m_InputBindings.size();
78 for(unsigned int i = 0; i < numInputs; ++i)
79 {
80 armnn::Optional<QuantizationParams> qParams = params.m_QuantizeInput ?
81 armnn::MakeOptional<QuantizationParams>(
82 model.GetInputQuantizationParams()) :
83 armnn::EmptyOptional();
84
85 armnn::Optional<std::string> dataFile = params.m_GenerateTensorData ?
86 armnn::EmptyOptional() :
87 armnn::MakeOptional<std::string>(
88 params.m_InputTensorDataFilePaths[i]);
89
90 unsigned int numElements = model.GetInputSize(i);
91 if (params.m_InputTensorShapes.size() > i && params.m_InputTensorShapes[i])
92 {
93 // If the user has provided a tensor shape for the current input,
94 // override numElements
95 numElements = params.m_InputTensorShapes[i]->GetNumElements();
96 }
97
98 TContainer tensorData;
99 PopulateTensorWithData(tensorData,
100 numElements,
101 params.m_InputTypes[i],
102 qParams,
103 dataFile);
104
105 inputDataContainers.push_back(tensorData);
106 }
107
108 const size_t numOutputs = inferenceModelParams.m_OutputBindings.size();
109 std::vector<TContainer> outputDataContainers;
110
111 for (unsigned int i = 0; i < numOutputs; ++i)
112 {
113 if (params.m_OutputTypes[i].compare("float") == 0)
114 {
115 outputDataContainers.push_back(std::vector<float>(model.GetOutputSize(i)));
116 }
117 else if (params.m_OutputTypes[i].compare("int") == 0)
118 {
119 outputDataContainers.push_back(std::vector<int>(model.GetOutputSize(i)));
120 }
121 else if (params.m_OutputTypes[i].compare("qasymm8") == 0)
122 {
123 outputDataContainers.push_back(std::vector<uint8_t>(model.GetOutputSize(i)));
124 }
125 else
126 {
127 ARMNN_LOG(fatal) << "Unsupported tensor data type \"" << params.m_OutputTypes[i] << "\". ";
128 return EXIT_FAILURE;
129 }
130 }
131
132 for (size_t x = 0; x < params.m_Iterations; x++)
133 {
134 // model.Run returns the inference time elapsed in EnqueueWorkload (in milliseconds)
135 auto inference_duration = model.Run(inputDataContainers, outputDataContainers);
136
137 if (params.m_GenerateTensorData)
138 {
139 ARMNN_LOG(warning) << "The input data was generated, note that the output will not be useful";
140 }
141
142 // Print output tensors
143 const auto& infosOut = model.GetOutputBindingInfos();
144 for (size_t i = 0; i < numOutputs; i++)
145 {
146 const armnn::TensorInfo& infoOut = infosOut[i].second;
147 auto outputTensorFile = params.m_OutputTensorFiles.empty() ? "" : params.m_OutputTensorFiles[i];
148
149 TensorPrinter printer(inferenceModelParams.m_OutputBindings[i],
150 infoOut,
151 outputTensorFile,
152 params.m_DequantizeOutput);
153 mapbox::util::apply_visitor(printer, outputDataContainers[i]);
154 }
155
156 ARMNN_LOG(info) << "\nInference time: " << std::setprecision(2)
157 << std::fixed << inference_duration.count() << " ms\n";
158
159 // If thresholdTime == 0.0 (default), then it hasn't been supplied at command line
160 if (params.m_ThresholdTime != 0.0)
161 {
162 ARMNN_LOG(info) << "Threshold time: " << std::setprecision(2)
163 << std::fixed << params.m_ThresholdTime << " ms";
164 auto thresholdMinusInference = params.m_ThresholdTime - inference_duration.count();
165 ARMNN_LOG(info) << "Threshold time - Inference time: " << std::setprecision(2)
166 << std::fixed << thresholdMinusInference << " ms" << "\n";
167
168 if (thresholdMinusInference < 0)
169 {
170 std::string errorMessage = "Elapsed inference time is greater than provided threshold time.";
171 ARMNN_LOG(fatal) << errorMessage;
172 }
173 }
174 }
175 }
176 catch (const armnn::Exception& e)
177 {
178 ARMNN_LOG(fatal) << "Armnn Error: " << e.what();
179 return EXIT_FAILURE;
180 }
181
182 return EXIT_SUCCESS;
183}
184
telsoa01c577f2c2018-08-31 09:22:23 +0100185
James Conroy7b4886f2019-04-11 10:23:58 +0100186// MAIN
telsoa01c577f2c2018-08-31 09:22:23 +0100187int main(int argc, const char* argv[])
188{
189 // Configures logging for both the ARMNN library and this test program.
Jan Eilers45274902020-10-15 18:34:43 +0100190 #ifdef NDEBUG
telsoa01c577f2c2018-08-31 09:22:23 +0100191 armnn::LogSeverity level = armnn::LogSeverity::Info;
Jan Eilers45274902020-10-15 18:34:43 +0100192 #else
telsoa01c577f2c2018-08-31 09:22:23 +0100193 armnn::LogSeverity level = armnn::LogSeverity::Debug;
Jan Eilers45274902020-10-15 18:34:43 +0100194 #endif
telsoa01c577f2c2018-08-31 09:22:23 +0100195 armnn::ConfigureLogging(true, true, level);
telsoa01c577f2c2018-08-31 09:22:23 +0100196
telsoa01c577f2c2018-08-31 09:22:23 +0100197
Jan Eilers45274902020-10-15 18:34:43 +0100198 // Get ExecuteNetwork parameters and runtime options from command line
199 ProgramOptions ProgramOptions(argc, argv);
Narumol Prangnawaratd8cc8112020-03-24 13:54:05 +0000200
Finn Williamsd7fcafa2020-04-23 17:55:18 +0100201 // Create runtime
Jan Eilers45274902020-10-15 18:34:43 +0100202 std::shared_ptr<armnn::IRuntime> runtime(armnn::IRuntime::Create(ProgramOptions.m_RuntimeOptions));
Finn Williamsd7fcafa2020-04-23 17:55:18 +0100203
Jan Eilers45274902020-10-15 18:34:43 +0100204 std::string modelFormat = ProgramOptions.m_ExNetParams.m_ModelFormat;
205
206 // Forward to implementation based on the parser type
207 if (modelFormat.find("armnn") != std::string::npos)
Finn Williamsd7fcafa2020-04-23 17:55:18 +0100208 {
Jan Eilers45274902020-10-15 18:34:43 +0100209 #if defined(ARMNN_SERIALIZER)
210 return MainImpl<armnnDeserializer::IDeserializer, float>(ProgramOptions.m_ExNetParams, runtime);
211 #else
212 ARMNN_LOG(fatal) << "Not built with serialization support.";
Finn Williamsd7fcafa2020-04-23 17:55:18 +0100213 return EXIT_FAILURE;
Jan Eilers45274902020-10-15 18:34:43 +0100214 #endif
Finn Williamsd7fcafa2020-04-23 17:55:18 +0100215 }
Jan Eilers45274902020-10-15 18:34:43 +0100216 else if (modelFormat.find("caffe") != std::string::npos)
telsoa01c577f2c2018-08-31 09:22:23 +0100217 {
Jan Eilers45274902020-10-15 18:34:43 +0100218 #if defined(ARMNN_CAFFE_PARSER)
219 return MainImpl<armnnCaffeParser::ICaffeParser, float>(ProgramOptions.m_ExNetParams, runtime);
220 #else
221 ARMNN_LOG(fatal) << "Not built with Caffe parser support.";
222 return EXIT_FAILURE;
223 #endif
telsoa01c577f2c2018-08-31 09:22:23 +0100224 }
Jan Eilers45274902020-10-15 18:34:43 +0100225 else if (modelFormat.find("onnx") != std::string::npos)
telsoa01c577f2c2018-08-31 09:22:23 +0100226 {
Jan Eilers45274902020-10-15 18:34:43 +0100227 #if defined(ARMNN_ONNX_PARSER)
228 return MainImpl<armnnOnnxParser::IOnnxParser, float>(ProgramOptions.m_ExNetParams, runtime);
229 #else
230 ARMNN_LOG(fatal) << "Not built with Onnx parser support.";
231 return EXIT_FAILURE;
232 #endif
233 }
234 else if (modelFormat.find("tensorflow") != std::string::npos)
235 {
236 #if defined(ARMNN_TF_PARSER)
237 return MainImpl<armnnTfParser::ITfParser, float>(ProgramOptions.m_ExNetParams, runtime);
238 #else
239 ARMNN_LOG(fatal) << "Not built with Tensorflow parser support.";
240 return EXIT_FAILURE;
241 #endif
242 }
243 else if(modelFormat.find("tflite") != std::string::npos)
244 {
245 #if defined(ARMNN_TF_LITE_PARSER)
246 return MainImpl<armnnTfLiteParser::ITfLiteParser, float>(ProgramOptions.m_ExNetParams, runtime);
247 #else
248 ARMNN_LOG(fatal) << "Not built with Tensorflow-Lite parser support.";
249 return EXIT_FAILURE;
250 #endif
251 }
252 else
253 {
254 ARMNN_LOG(fatal) << "Unknown model format: '" << modelFormat
255 << "'. Please include 'caffe', 'tensorflow', 'tflite' or 'onnx'";
256 return EXIT_FAILURE;
telsoa014fcda012018-03-09 14:13:49 +0000257 }
258}