blob: e8401f6bc3ff39368d1bca2980dae0d3c716bd59 [file] [log] [blame]
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +00001//
telsoa014fcda012018-03-09 14:13:49 +00002// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#include "InferenceTest.hpp"
6
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01007#include <armnn/utility/Assert.hpp>
Matthew Sloyan80c6b142020-09-08 12:00:32 +01008#include <armnn/utility/NumericCast.hpp>
9
telsoa014fcda012018-03-09 14:13:49 +000010#include <boost/program_options.hpp>
James Ward08f40162020-09-07 16:45:07 +010011#include <fmt/format.h>
telsoa014fcda012018-03-09 14:13:49 +000012
13#include <fstream>
14#include <iostream>
15#include <iomanip>
16#include <array>
17#include <chrono>
18
19using namespace std;
20using namespace std::chrono;
21using namespace armnn::test;
22
23namespace armnn
24{
25namespace test
26{
27
James Ward6d9f5c52020-09-28 11:56:35 +010028using TContainer = mapbox::util::variant<std::vector<float>, std::vector<int>, std::vector<unsigned char>>;
telsoa01c577f2c2018-08-31 09:22:23 +010029
telsoa014fcda012018-03-09 14:13:49 +000030template <typename TTestCaseDatabase, typename TModel>
31ClassifierTestCase<TTestCaseDatabase, TModel>::ClassifierTestCase(
32 int& numInferencesRef,
33 int& numCorrectInferencesRef,
34 const std::vector<unsigned int>& validationPredictions,
35 std::vector<unsigned int>* validationPredictionsOut,
36 TModel& model,
37 unsigned int testCaseId,
38 unsigned int label,
39 std::vector<typename TModel::DataType> modelInput)
Ferran Balaguerc602f292019-02-08 17:09:55 +000040 : InferenceModelTestCase<TModel>(
41 model, testCaseId, std::vector<TContainer>{ modelInput }, { model.GetOutputSize() })
telsoa014fcda012018-03-09 14:13:49 +000042 , m_Label(label)
telsoa01c577f2c2018-08-31 09:22:23 +010043 , m_QuantizationParams(model.GetQuantizationParams())
telsoa014fcda012018-03-09 14:13:49 +000044 , m_NumInferencesRef(numInferencesRef)
45 , m_NumCorrectInferencesRef(numCorrectInferencesRef)
46 , m_ValidationPredictions(validationPredictions)
47 , m_ValidationPredictionsOut(validationPredictionsOut)
48{
49}
50
James Ward6d9f5c52020-09-28 11:56:35 +010051struct ClassifierResultProcessor
Derek Lambertiac737602019-05-16 16:33:00 +010052{
53 using ResultMap = std::map<float,int>;
54
55 ClassifierResultProcessor(float scale, int offset)
56 : m_Scale(scale)
57 , m_Offset(offset)
58 {}
59
60 void operator()(const std::vector<float>& values)
61 {
62 SortPredictions(values, [](float value)
63 {
64 return value;
65 });
66 }
67
68 void operator()(const std::vector<uint8_t>& values)
69 {
70 auto& scale = m_Scale;
71 auto& offset = m_Offset;
72 SortPredictions(values, [&scale, &offset](uint8_t value)
73 {
74 return armnn::Dequantize(value, scale, offset);
75 });
76 }
77
78 void operator()(const std::vector<int>& values)
79 {
Jan Eilers8eb25602020-03-09 12:13:48 +000080 IgnoreUnused(values);
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010081 ARMNN_ASSERT_MSG(false, "Non-float predictions output not supported.");
Derek Lambertiac737602019-05-16 16:33:00 +010082 }
83
84 ResultMap& GetResultMap() { return m_ResultMap; }
85
86private:
87 template<typename Container, typename Delegate>
88 void SortPredictions(const Container& c, Delegate delegate)
89 {
90 int index = 0;
91 for (const auto& value : c)
92 {
93 int classification = index++;
94 // Take the first class with each probability
95 // This avoids strange results when looping over batched results produced
96 // with identical test data.
97 ResultMap::iterator lb = m_ResultMap.lower_bound(value);
98
99 if (lb == m_ResultMap.end() || !m_ResultMap.key_comp()(value, lb->first))
100 {
101 // If the key is not already in the map, insert it.
102 m_ResultMap.insert(lb, ResultMap::value_type(delegate(value), classification));
103 }
104 }
105 }
106
107 ResultMap m_ResultMap;
108
109 float m_Scale=0.0f;
110 int m_Offset=0;
111};
112
telsoa014fcda012018-03-09 14:13:49 +0000113template <typename TTestCaseDatabase, typename TModel>
114TestCaseResult ClassifierTestCase<TTestCaseDatabase, TModel>::ProcessResult(const InferenceTestOptions& params)
115{
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000116 auto& output = this->GetOutputs()[0];
telsoa014fcda012018-03-09 14:13:49 +0000117 const auto testCaseId = this->GetTestCaseId();
118
Derek Lambertiac737602019-05-16 16:33:00 +0100119 ClassifierResultProcessor resultProcessor(m_QuantizationParams.first, m_QuantizationParams.second);
James Ward6d9f5c52020-09-28 11:56:35 +0100120 mapbox::util::apply_visitor(resultProcessor, output);
Derek Lambertiac737602019-05-16 16:33:00 +0100121
Derek Lamberti08446972019-11-26 16:38:31 +0000122 ARMNN_LOG(info) << "= Prediction values for test #" << testCaseId;
Derek Lambertiac737602019-05-16 16:33:00 +0100123 auto it = resultProcessor.GetResultMap().rbegin();
124 for (int i=0; i<5 && it != resultProcessor.GetResultMap().rend(); ++i)
surmeh01bceff2f2018-03-29 16:29:27 +0100125 {
Derek Lamberti08446972019-11-26 16:38:31 +0000126 ARMNN_LOG(info) << "Top(" << (i+1) << ") prediction is " << it->second <<
Derek Lambertiac737602019-05-16 16:33:00 +0100127 " with value: " << (it->first);
128 ++it;
surmeh01bceff2f2018-03-29 16:29:27 +0100129 }
130
Ferran Balaguerc602f292019-02-08 17:09:55 +0000131 unsigned int prediction = 0;
James Ward6d9f5c52020-09-28 11:56:35 +0100132 mapbox::util::apply_visitor([&](auto&& value)
Ferran Balaguerc602f292019-02-08 17:09:55 +0000133 {
Matthew Sloyan80c6b142020-09-08 12:00:32 +0100134 prediction = armnn::numeric_cast<unsigned int>(
Ferran Balaguerc602f292019-02-08 17:09:55 +0000135 std::distance(value.begin(), std::max_element(value.begin(), value.end())));
136 },
137 output);
telsoa014fcda012018-03-09 14:13:49 +0000138
telsoa01c577f2c2018-08-31 09:22:23 +0100139 // If we're just running the defaultTestCaseIds, each one must be classified correctly.
telsoa014fcda012018-03-09 14:13:49 +0000140 if (params.m_IterationCount == 0 && prediction != m_Label)
141 {
Derek Lamberti08446972019-11-26 16:38:31 +0000142 ARMNN_LOG(error) << "Prediction for test case " << testCaseId << " (" << prediction << ")" <<
telsoa014fcda012018-03-09 14:13:49 +0000143 " is incorrect (should be " << m_Label << ")";
144 return TestCaseResult::Failed;
145 }
146
telsoa01c577f2c2018-08-31 09:22:23 +0100147 // If a validation file was provided as input, it checks that the prediction matches.
telsoa014fcda012018-03-09 14:13:49 +0000148 if (!m_ValidationPredictions.empty() && prediction != m_ValidationPredictions[testCaseId])
149 {
Derek Lamberti08446972019-11-26 16:38:31 +0000150 ARMNN_LOG(error) << "Prediction for test case " << testCaseId << " (" << prediction << ")" <<
telsoa014fcda012018-03-09 14:13:49 +0000151 " doesn't match the prediction in the validation file (" << m_ValidationPredictions[testCaseId] << ")";
152 return TestCaseResult::Failed;
153 }
154
telsoa01c577f2c2018-08-31 09:22:23 +0100155 // If a validation file was requested as output, it stores the predictions.
telsoa014fcda012018-03-09 14:13:49 +0000156 if (m_ValidationPredictionsOut)
157 {
158 m_ValidationPredictionsOut->push_back(prediction);
159 }
160
telsoa01c577f2c2018-08-31 09:22:23 +0100161 // Updates accuracy stats.
telsoa014fcda012018-03-09 14:13:49 +0000162 m_NumInferencesRef++;
163 if (prediction == m_Label)
164 {
165 m_NumCorrectInferencesRef++;
166 }
167
168 return TestCaseResult::Ok;
169}
170
171template <typename TDatabase, typename InferenceModel>
172template <typename TConstructDatabaseCallable, typename TConstructModelCallable>
173ClassifierTestCaseProvider<TDatabase, InferenceModel>::ClassifierTestCaseProvider(
174 TConstructDatabaseCallable constructDatabase, TConstructModelCallable constructModel)
175 : m_ConstructModel(constructModel)
176 , m_ConstructDatabase(constructDatabase)
177 , m_NumInferences(0)
178 , m_NumCorrectInferences(0)
179{
180}
181
182template <typename TDatabase, typename InferenceModel>
183void ClassifierTestCaseProvider<TDatabase, InferenceModel>::AddCommandLineOptions(
184 boost::program_options::options_description& options)
185{
186 namespace po = boost::program_options;
187
188 options.add_options()
189 ("validation-file-in", po::value<std::string>(&m_ValidationFileIn)->default_value(""),
190 "Reads expected predictions from the given file and confirms they match the actual predictions.")
191 ("validation-file-out", po::value<std::string>(&m_ValidationFileOut)->default_value(""),
192 "Predictions are saved to the given file for later use via --validation-file-in.")
193 ("data-dir,d", po::value<std::string>(&m_DataDir)->required(),
194 "Path to directory containing test data");
195
196 InferenceModel::AddCommandLineOptions(options, m_ModelCommandLineOptions);
197}
198
199template <typename TDatabase, typename InferenceModel>
Matthew Bentham3e68b972019-04-09 13:10:46 +0100200bool ClassifierTestCaseProvider<TDatabase, InferenceModel>::ProcessCommandLineOptions(
201 const InferenceTestOptions& commonOptions)
telsoa014fcda012018-03-09 14:13:49 +0000202{
203 if (!ValidateDirectory(m_DataDir))
204 {
205 return false;
206 }
207
208 ReadPredictions();
209
Matthew Bentham3e68b972019-04-09 13:10:46 +0100210 m_Model = m_ConstructModel(commonOptions, m_ModelCommandLineOptions);
telsoa014fcda012018-03-09 14:13:49 +0000211 if (!m_Model)
212 {
213 return false;
214 }
215
telsoa01c577f2c2018-08-31 09:22:23 +0100216 m_Database = std::make_unique<TDatabase>(m_ConstructDatabase(m_DataDir.c_str(), *m_Model));
telsoa014fcda012018-03-09 14:13:49 +0000217 if (!m_Database)
218 {
219 return false;
220 }
221
222 return true;
223}
224
225template <typename TDatabase, typename InferenceModel>
226std::unique_ptr<IInferenceTestCase>
227ClassifierTestCaseProvider<TDatabase, InferenceModel>::GetTestCase(unsigned int testCaseId)
228{
229 std::unique_ptr<typename TDatabase::TTestCaseData> testCaseData = m_Database->GetTestCaseData(testCaseId);
230 if (testCaseData == nullptr)
231 {
232 return nullptr;
233 }
234
235 return std::make_unique<ClassifierTestCase<TDatabase, InferenceModel>>(
236 m_NumInferences,
237 m_NumCorrectInferences,
238 m_ValidationPredictions,
239 m_ValidationFileOut.empty() ? nullptr : &m_ValidationPredictionsOut,
240 *m_Model,
241 testCaseId,
242 testCaseData->m_Label,
243 std::move(testCaseData->m_InputImage));
244}
245
246template <typename TDatabase, typename InferenceModel>
247bool ClassifierTestCaseProvider<TDatabase, InferenceModel>::OnInferenceTestFinished()
248{
Matthew Sloyan24ac8592020-09-23 16:57:23 +0100249 const double accuracy = armnn::numeric_cast<double>(m_NumCorrectInferences) /
250 armnn::numeric_cast<double>(m_NumInferences);
Derek Lamberti08446972019-11-26 16:38:31 +0000251 ARMNN_LOG(info) << std::fixed << std::setprecision(3) << "Overall accuracy: " << accuracy;
telsoa014fcda012018-03-09 14:13:49 +0000252
telsoa01c577f2c2018-08-31 09:22:23 +0100253 // If a validation file was requested as output, the predictions are saved to it.
telsoa014fcda012018-03-09 14:13:49 +0000254 if (!m_ValidationFileOut.empty())
255 {
256 std::ofstream validationFileOut(m_ValidationFileOut.c_str(), std::ios_base::trunc | std::ios_base::out);
257 if (validationFileOut.good())
258 {
259 for (const unsigned int prediction : m_ValidationPredictionsOut)
260 {
261 validationFileOut << prediction << std::endl;
262 }
263 }
264 else
265 {
Derek Lamberti08446972019-11-26 16:38:31 +0000266 ARMNN_LOG(error) << "Failed to open output validation file: " << m_ValidationFileOut;
telsoa014fcda012018-03-09 14:13:49 +0000267 return false;
268 }
269 }
270
271 return true;
272}
273
274template <typename TDatabase, typename InferenceModel>
275void ClassifierTestCaseProvider<TDatabase, InferenceModel>::ReadPredictions()
276{
telsoa01c577f2c2018-08-31 09:22:23 +0100277 // Reads the expected predictions from the input validation file (if provided).
telsoa014fcda012018-03-09 14:13:49 +0000278 if (!m_ValidationFileIn.empty())
279 {
280 std::ifstream validationFileIn(m_ValidationFileIn.c_str(), std::ios_base::in);
281 if (validationFileIn.good())
282 {
283 while (!validationFileIn.eof())
284 {
285 unsigned int i;
286 validationFileIn >> i;
287 m_ValidationPredictions.emplace_back(i);
288 }
289 }
290 else
291 {
James Ward08f40162020-09-07 16:45:07 +0100292 throw armnn::Exception(fmt::format("Failed to open input validation file: {}"
293 , m_ValidationFileIn));
telsoa014fcda012018-03-09 14:13:49 +0000294 }
295 }
296}
297
298template<typename TConstructTestCaseProvider>
299int InferenceTestMain(int argc,
300 char* argv[],
301 const std::vector<unsigned int>& defaultTestCaseIds,
302 TConstructTestCaseProvider constructTestCaseProvider)
303{
telsoa01c577f2c2018-08-31 09:22:23 +0100304 // Configures logging for both the ARMNN library and this test program.
telsoa014fcda012018-03-09 14:13:49 +0000305#ifdef NDEBUG
306 armnn::LogSeverity level = armnn::LogSeverity::Info;
307#else
308 armnn::LogSeverity level = armnn::LogSeverity::Debug;
309#endif
310 armnn::ConfigureLogging(true, true, level);
telsoa014fcda012018-03-09 14:13:49 +0000311
312 try
313 {
314 std::unique_ptr<IInferenceTestCaseProvider> testCaseProvider = constructTestCaseProvider();
315 if (!testCaseProvider)
316 {
317 return 1;
318 }
319
320 InferenceTestOptions inferenceTestOptions;
321 if (!ParseCommandLine(argc, argv, *testCaseProvider, inferenceTestOptions))
322 {
323 return 1;
324 }
325
326 const bool success = InferenceTest(inferenceTestOptions, defaultTestCaseIds, *testCaseProvider);
327 return success ? 0 : 1;
328 }
329 catch (armnn::Exception const& e)
330 {
Derek Lamberti08446972019-11-26 16:38:31 +0000331 ARMNN_LOG(fatal) << "Armnn Error: " << e.what();
telsoa014fcda012018-03-09 14:13:49 +0000332 return 1;
333 }
334}
335
telsoa01c577f2c2018-08-31 09:22:23 +0100336//
337// This function allows us to create a classifier inference test based on:
338// - a model file name
339// - which can be a binary or a text file for protobuf formats
340// - an input tensor name
341// - an output tensor name
342// - a set of test case ids
343// - a callback method which creates an object that can return images
344// called 'Database' in these tests
345// - and an input tensor shape
346//
telsoa014fcda012018-03-09 14:13:49 +0000347template<typename TDatabase,
telsoa01c577f2c2018-08-31 09:22:23 +0100348 typename TParser,
349 typename TConstructDatabaseCallable>
350int ClassifierInferenceTestMain(int argc,
351 char* argv[],
352 const char* modelFilename,
353 bool isModelBinary,
354 const char* inputBindingName,
355 const char* outputBindingName,
356 const std::vector<unsigned int>& defaultTestCaseIds,
357 TConstructDatabaseCallable constructDatabase,
358 const armnn::TensorShape* inputTensorShape)
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000359
telsoa014fcda012018-03-09 14:13:49 +0000360{
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100361 ARMNN_ASSERT(modelFilename);
362 ARMNN_ASSERT(inputBindingName);
363 ARMNN_ASSERT(outputBindingName);
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000364
telsoa014fcda012018-03-09 14:13:49 +0000365 return InferenceTestMain(argc, argv, defaultTestCaseIds,
366 [=]
367 ()
368 {
telsoa01c577f2c2018-08-31 09:22:23 +0100369 using InferenceModel = InferenceModel<TParser, typename TDatabase::DataType>;
telsoa014fcda012018-03-09 14:13:49 +0000370 using TestCaseProvider = ClassifierTestCaseProvider<TDatabase, InferenceModel>;
371
372 return make_unique<TestCaseProvider>(constructDatabase,
373 [&]
Matthew Bentham3e68b972019-04-09 13:10:46 +0100374 (const InferenceTestOptions &commonOptions,
375 typename InferenceModel::CommandLineOptions modelOptions)
telsoa014fcda012018-03-09 14:13:49 +0000376 {
377 if (!ValidateDirectory(modelOptions.m_ModelDir))
378 {
379 return std::unique_ptr<InferenceModel>();
380 }
381
382 typename InferenceModel::Params modelParams;
383 modelParams.m_ModelPath = modelOptions.m_ModelDir + modelFilename;
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000384 modelParams.m_InputBindings = { inputBindingName };
385 modelParams.m_OutputBindings = { outputBindingName };
386
387 if (inputTensorShape)
388 {
389 modelParams.m_InputShapes.push_back(*inputTensorShape);
390 }
391
telsoa014fcda012018-03-09 14:13:49 +0000392 modelParams.m_IsModelBinary = isModelBinary;
Aron Virginas-Tar339bcae2019-01-31 16:44:26 +0000393 modelParams.m_ComputeDevices = modelOptions.GetComputeDevicesAsBackendIds();
surmeh013537c2c2018-05-18 16:31:43 +0100394 modelParams.m_VisualizePostOptimizationModel = modelOptions.m_VisualizePostOptimizationModel;
telsoa01c577f2c2018-08-31 09:22:23 +0100395 modelParams.m_EnableFp16TurboMode = modelOptions.m_EnableFp16TurboMode;
telsoa014fcda012018-03-09 14:13:49 +0000396
Matteo Martincigh00dda4a2019-08-14 11:42:30 +0100397 return std::make_unique<InferenceModel>(modelParams,
398 commonOptions.m_EnableProfiling,
399 commonOptions.m_DynamicBackendsPath);
telsoa014fcda012018-03-09 14:13:49 +0000400 });
401 });
402}
403
404} // namespace test
405} // namespace armnn