blob: b6087c5e5a32baad8c6bd44f91e12ec213d7cfb3 [file] [log] [blame]
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +00001//
telsoa014fcda012018-03-09 14:13:49 +00002// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#include "InferenceTest.hpp"
6
David Monahan6bb47a72021-10-22 12:57:28 +01007#include <armnn/Utils.hpp>
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01008#include <armnn/utility/Assert.hpp>
Matthew Sloyan80c6b142020-09-08 12:00:32 +01009#include <armnn/utility/NumericCast.hpp>
Francis Murtagh40d27412021-10-28 11:11:35 +010010#include <armnnUtils/TContainer.hpp>
11
James Wardc89829f2020-10-12 14:17:36 +010012#include "CxxoptsUtils.hpp"
Matthew Sloyan80c6b142020-09-08 12:00:32 +010013
James Wardc89829f2020-10-12 14:17:36 +010014#include <cxxopts/cxxopts.hpp>
James Ward08f40162020-09-07 16:45:07 +010015#include <fmt/format.h>
telsoa014fcda012018-03-09 14:13:49 +000016
17#include <fstream>
18#include <iostream>
19#include <iomanip>
20#include <array>
21#include <chrono>
22
23using namespace std;
24using namespace std::chrono;
25using namespace armnn::test;
26
27namespace armnn
28{
29namespace test
30{
31
32template <typename TTestCaseDatabase, typename TModel>
33ClassifierTestCase<TTestCaseDatabase, TModel>::ClassifierTestCase(
34 int& numInferencesRef,
35 int& numCorrectInferencesRef,
36 const std::vector<unsigned int>& validationPredictions,
37 std::vector<unsigned int>* validationPredictionsOut,
38 TModel& model,
39 unsigned int testCaseId,
40 unsigned int label,
41 std::vector<typename TModel::DataType> modelInput)
Ferran Balaguerc602f292019-02-08 17:09:55 +000042 : InferenceModelTestCase<TModel>(
Francis Murtagh40d27412021-10-28 11:11:35 +010043 model, testCaseId, std::vector<armnnUtils::TContainer>{ modelInput }, { model.GetOutputSize() })
telsoa014fcda012018-03-09 14:13:49 +000044 , m_Label(label)
telsoa01c577f2c2018-08-31 09:22:23 +010045 , m_QuantizationParams(model.GetQuantizationParams())
telsoa014fcda012018-03-09 14:13:49 +000046 , m_NumInferencesRef(numInferencesRef)
47 , m_NumCorrectInferencesRef(numCorrectInferencesRef)
48 , m_ValidationPredictions(validationPredictions)
49 , m_ValidationPredictionsOut(validationPredictionsOut)
50{
51}
52
James Ward6d9f5c52020-09-28 11:56:35 +010053struct ClassifierResultProcessor
Derek Lambertiac737602019-05-16 16:33:00 +010054{
55 using ResultMap = std::map<float,int>;
56
57 ClassifierResultProcessor(float scale, int offset)
58 : m_Scale(scale)
59 , m_Offset(offset)
60 {}
61
62 void operator()(const std::vector<float>& values)
63 {
64 SortPredictions(values, [](float value)
65 {
66 return value;
67 });
68 }
69
Finn Williamsf806c4d2021-02-22 15:13:12 +000070 void operator()(const std::vector<int8_t>& values)
71 {
72 SortPredictions(values, [](int8_t value)
73 {
74 return value;
75 });
76 }
77
Derek Lambertiac737602019-05-16 16:33:00 +010078 void operator()(const std::vector<uint8_t>& values)
79 {
80 auto& scale = m_Scale;
81 auto& offset = m_Offset;
82 SortPredictions(values, [&scale, &offset](uint8_t value)
83 {
84 return armnn::Dequantize(value, scale, offset);
85 });
86 }
87
88 void operator()(const std::vector<int>& values)
89 {
Jan Eilers8eb25602020-03-09 12:13:48 +000090 IgnoreUnused(values);
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010091 ARMNN_ASSERT_MSG(false, "Non-float predictions output not supported.");
Derek Lambertiac737602019-05-16 16:33:00 +010092 }
93
94 ResultMap& GetResultMap() { return m_ResultMap; }
95
96private:
97 template<typename Container, typename Delegate>
98 void SortPredictions(const Container& c, Delegate delegate)
99 {
100 int index = 0;
101 for (const auto& value : c)
102 {
103 int classification = index++;
104 // Take the first class with each probability
105 // This avoids strange results when looping over batched results produced
106 // with identical test data.
107 ResultMap::iterator lb = m_ResultMap.lower_bound(value);
108
109 if (lb == m_ResultMap.end() || !m_ResultMap.key_comp()(value, lb->first))
110 {
111 // If the key is not already in the map, insert it.
112 m_ResultMap.insert(lb, ResultMap::value_type(delegate(value), classification));
113 }
114 }
115 }
116
117 ResultMap m_ResultMap;
118
119 float m_Scale=0.0f;
120 int m_Offset=0;
121};
122
telsoa014fcda012018-03-09 14:13:49 +0000123template <typename TTestCaseDatabase, typename TModel>
124TestCaseResult ClassifierTestCase<TTestCaseDatabase, TModel>::ProcessResult(const InferenceTestOptions& params)
125{
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000126 auto& output = this->GetOutputs()[0];
telsoa014fcda012018-03-09 14:13:49 +0000127 const auto testCaseId = this->GetTestCaseId();
128
Derek Lambertiac737602019-05-16 16:33:00 +0100129 ClassifierResultProcessor resultProcessor(m_QuantizationParams.first, m_QuantizationParams.second);
James Ward6d9f5c52020-09-28 11:56:35 +0100130 mapbox::util::apply_visitor(resultProcessor, output);
Derek Lambertiac737602019-05-16 16:33:00 +0100131
Derek Lamberti08446972019-11-26 16:38:31 +0000132 ARMNN_LOG(info) << "= Prediction values for test #" << testCaseId;
Derek Lambertiac737602019-05-16 16:33:00 +0100133 auto it = resultProcessor.GetResultMap().rbegin();
134 for (int i=0; i<5 && it != resultProcessor.GetResultMap().rend(); ++i)
surmeh01bceff2f2018-03-29 16:29:27 +0100135 {
Derek Lamberti08446972019-11-26 16:38:31 +0000136 ARMNN_LOG(info) << "Top(" << (i+1) << ") prediction is " << it->second <<
Derek Lambertiac737602019-05-16 16:33:00 +0100137 " with value: " << (it->first);
138 ++it;
surmeh01bceff2f2018-03-29 16:29:27 +0100139 }
140
Ferran Balaguerc602f292019-02-08 17:09:55 +0000141 unsigned int prediction = 0;
James Ward6d9f5c52020-09-28 11:56:35 +0100142 mapbox::util::apply_visitor([&](auto&& value)
Ferran Balaguerc602f292019-02-08 17:09:55 +0000143 {
Matthew Sloyan80c6b142020-09-08 12:00:32 +0100144 prediction = armnn::numeric_cast<unsigned int>(
Ferran Balaguerc602f292019-02-08 17:09:55 +0000145 std::distance(value.begin(), std::max_element(value.begin(), value.end())));
146 },
147 output);
telsoa014fcda012018-03-09 14:13:49 +0000148
telsoa01c577f2c2018-08-31 09:22:23 +0100149 // If we're just running the defaultTestCaseIds, each one must be classified correctly.
telsoa014fcda012018-03-09 14:13:49 +0000150 if (params.m_IterationCount == 0 && prediction != m_Label)
151 {
Derek Lamberti08446972019-11-26 16:38:31 +0000152 ARMNN_LOG(error) << "Prediction for test case " << testCaseId << " (" << prediction << ")" <<
telsoa014fcda012018-03-09 14:13:49 +0000153 " is incorrect (should be " << m_Label << ")";
154 return TestCaseResult::Failed;
155 }
156
telsoa01c577f2c2018-08-31 09:22:23 +0100157 // If a validation file was provided as input, it checks that the prediction matches.
telsoa014fcda012018-03-09 14:13:49 +0000158 if (!m_ValidationPredictions.empty() && prediction != m_ValidationPredictions[testCaseId])
159 {
Derek Lamberti08446972019-11-26 16:38:31 +0000160 ARMNN_LOG(error) << "Prediction for test case " << testCaseId << " (" << prediction << ")" <<
telsoa014fcda012018-03-09 14:13:49 +0000161 " doesn't match the prediction in the validation file (" << m_ValidationPredictions[testCaseId] << ")";
162 return TestCaseResult::Failed;
163 }
164
telsoa01c577f2c2018-08-31 09:22:23 +0100165 // If a validation file was requested as output, it stores the predictions.
telsoa014fcda012018-03-09 14:13:49 +0000166 if (m_ValidationPredictionsOut)
167 {
168 m_ValidationPredictionsOut->push_back(prediction);
169 }
170
telsoa01c577f2c2018-08-31 09:22:23 +0100171 // Updates accuracy stats.
telsoa014fcda012018-03-09 14:13:49 +0000172 m_NumInferencesRef++;
173 if (prediction == m_Label)
174 {
175 m_NumCorrectInferencesRef++;
176 }
177
178 return TestCaseResult::Ok;
179}
180
181template <typename TDatabase, typename InferenceModel>
182template <typename TConstructDatabaseCallable, typename TConstructModelCallable>
183ClassifierTestCaseProvider<TDatabase, InferenceModel>::ClassifierTestCaseProvider(
184 TConstructDatabaseCallable constructDatabase, TConstructModelCallable constructModel)
185 : m_ConstructModel(constructModel)
186 , m_ConstructDatabase(constructDatabase)
187 , m_NumInferences(0)
188 , m_NumCorrectInferences(0)
189{
190}
191
192template <typename TDatabase, typename InferenceModel>
193void ClassifierTestCaseProvider<TDatabase, InferenceModel>::AddCommandLineOptions(
James Wardc89829f2020-10-12 14:17:36 +0100194 cxxopts::Options& options, std::vector<std::string>& required)
telsoa014fcda012018-03-09 14:13:49 +0000195{
James Wardc89829f2020-10-12 14:17:36 +0100196 options
197 .allow_unrecognised_options()
198 .add_options()
199 ("validation-file-in",
200 "Reads expected predictions from the given file and confirms they match the actual predictions.",
201 cxxopts::value<std::string>(m_ValidationFileIn)->default_value(""))
202 ("validation-file-out", "Predictions are saved to the given file for later use via --validation-file-in.",
203 cxxopts::value<std::string>(m_ValidationFileOut)->default_value(""))
204 ("d,data-dir", "Path to directory containing test data", cxxopts::value<std::string>(m_DataDir));
telsoa014fcda012018-03-09 14:13:49 +0000205
James Wardc89829f2020-10-12 14:17:36 +0100206 required.emplace_back("data-dir"); //add to required arguments to check
telsoa014fcda012018-03-09 14:13:49 +0000207
James Wardc89829f2020-10-12 14:17:36 +0100208 InferenceModel::AddCommandLineOptions(options, m_ModelCommandLineOptions, required);
telsoa014fcda012018-03-09 14:13:49 +0000209}
210
211template <typename TDatabase, typename InferenceModel>
Matthew Bentham3e68b972019-04-09 13:10:46 +0100212bool ClassifierTestCaseProvider<TDatabase, InferenceModel>::ProcessCommandLineOptions(
213 const InferenceTestOptions& commonOptions)
telsoa014fcda012018-03-09 14:13:49 +0000214{
215 if (!ValidateDirectory(m_DataDir))
216 {
217 return false;
218 }
219
220 ReadPredictions();
221
Matthew Bentham3e68b972019-04-09 13:10:46 +0100222 m_Model = m_ConstructModel(commonOptions, m_ModelCommandLineOptions);
telsoa014fcda012018-03-09 14:13:49 +0000223 if (!m_Model)
224 {
225 return false;
226 }
227
telsoa01c577f2c2018-08-31 09:22:23 +0100228 m_Database = std::make_unique<TDatabase>(m_ConstructDatabase(m_DataDir.c_str(), *m_Model));
telsoa014fcda012018-03-09 14:13:49 +0000229 if (!m_Database)
230 {
231 return false;
232 }
233
234 return true;
235}
236
237template <typename TDatabase, typename InferenceModel>
238std::unique_ptr<IInferenceTestCase>
239ClassifierTestCaseProvider<TDatabase, InferenceModel>::GetTestCase(unsigned int testCaseId)
240{
241 std::unique_ptr<typename TDatabase::TTestCaseData> testCaseData = m_Database->GetTestCaseData(testCaseId);
242 if (testCaseData == nullptr)
243 {
244 return nullptr;
245 }
246
247 return std::make_unique<ClassifierTestCase<TDatabase, InferenceModel>>(
248 m_NumInferences,
249 m_NumCorrectInferences,
250 m_ValidationPredictions,
251 m_ValidationFileOut.empty() ? nullptr : &m_ValidationPredictionsOut,
252 *m_Model,
253 testCaseId,
254 testCaseData->m_Label,
255 std::move(testCaseData->m_InputImage));
256}
257
258template <typename TDatabase, typename InferenceModel>
259bool ClassifierTestCaseProvider<TDatabase, InferenceModel>::OnInferenceTestFinished()
260{
Matthew Sloyan24ac8592020-09-23 16:57:23 +0100261 const double accuracy = armnn::numeric_cast<double>(m_NumCorrectInferences) /
262 armnn::numeric_cast<double>(m_NumInferences);
Derek Lamberti08446972019-11-26 16:38:31 +0000263 ARMNN_LOG(info) << std::fixed << std::setprecision(3) << "Overall accuracy: " << accuracy;
telsoa014fcda012018-03-09 14:13:49 +0000264
telsoa01c577f2c2018-08-31 09:22:23 +0100265 // If a validation file was requested as output, the predictions are saved to it.
telsoa014fcda012018-03-09 14:13:49 +0000266 if (!m_ValidationFileOut.empty())
267 {
268 std::ofstream validationFileOut(m_ValidationFileOut.c_str(), std::ios_base::trunc | std::ios_base::out);
269 if (validationFileOut.good())
270 {
271 for (const unsigned int prediction : m_ValidationPredictionsOut)
272 {
273 validationFileOut << prediction << std::endl;
274 }
275 }
276 else
277 {
Derek Lamberti08446972019-11-26 16:38:31 +0000278 ARMNN_LOG(error) << "Failed to open output validation file: " << m_ValidationFileOut;
telsoa014fcda012018-03-09 14:13:49 +0000279 return false;
280 }
281 }
282
283 return true;
284}
285
286template <typename TDatabase, typename InferenceModel>
287void ClassifierTestCaseProvider<TDatabase, InferenceModel>::ReadPredictions()
288{
telsoa01c577f2c2018-08-31 09:22:23 +0100289 // Reads the expected predictions from the input validation file (if provided).
telsoa014fcda012018-03-09 14:13:49 +0000290 if (!m_ValidationFileIn.empty())
291 {
292 std::ifstream validationFileIn(m_ValidationFileIn.c_str(), std::ios_base::in);
293 if (validationFileIn.good())
294 {
295 while (!validationFileIn.eof())
296 {
297 unsigned int i;
298 validationFileIn >> i;
299 m_ValidationPredictions.emplace_back(i);
300 }
301 }
302 else
303 {
James Ward08f40162020-09-07 16:45:07 +0100304 throw armnn::Exception(fmt::format("Failed to open input validation file: {}"
305 , m_ValidationFileIn));
telsoa014fcda012018-03-09 14:13:49 +0000306 }
307 }
308}
309
310template<typename TConstructTestCaseProvider>
311int InferenceTestMain(int argc,
312 char* argv[],
313 const std::vector<unsigned int>& defaultTestCaseIds,
314 TConstructTestCaseProvider constructTestCaseProvider)
315{
telsoa01c577f2c2018-08-31 09:22:23 +0100316 // Configures logging for both the ARMNN library and this test program.
telsoa014fcda012018-03-09 14:13:49 +0000317#ifdef NDEBUG
318 armnn::LogSeverity level = armnn::LogSeverity::Info;
319#else
320 armnn::LogSeverity level = armnn::LogSeverity::Debug;
321#endif
322 armnn::ConfigureLogging(true, true, level);
telsoa014fcda012018-03-09 14:13:49 +0000323
324 try
325 {
326 std::unique_ptr<IInferenceTestCaseProvider> testCaseProvider = constructTestCaseProvider();
327 if (!testCaseProvider)
328 {
329 return 1;
330 }
331
332 InferenceTestOptions inferenceTestOptions;
333 if (!ParseCommandLine(argc, argv, *testCaseProvider, inferenceTestOptions))
334 {
335 return 1;
336 }
337
338 const bool success = InferenceTest(inferenceTestOptions, defaultTestCaseIds, *testCaseProvider);
339 return success ? 0 : 1;
340 }
341 catch (armnn::Exception const& e)
342 {
Derek Lamberti08446972019-11-26 16:38:31 +0000343 ARMNN_LOG(fatal) << "Armnn Error: " << e.what();
telsoa014fcda012018-03-09 14:13:49 +0000344 return 1;
345 }
346}
347
telsoa01c577f2c2018-08-31 09:22:23 +0100348//
349// This function allows us to create a classifier inference test based on:
350// - a model file name
351// - which can be a binary or a text file for protobuf formats
352// - an input tensor name
353// - an output tensor name
354// - a set of test case ids
355// - a callback method which creates an object that can return images
356// called 'Database' in these tests
357// - and an input tensor shape
358//
telsoa014fcda012018-03-09 14:13:49 +0000359template<typename TDatabase,
telsoa01c577f2c2018-08-31 09:22:23 +0100360 typename TParser,
361 typename TConstructDatabaseCallable>
362int ClassifierInferenceTestMain(int argc,
363 char* argv[],
364 const char* modelFilename,
365 bool isModelBinary,
366 const char* inputBindingName,
367 const char* outputBindingName,
368 const std::vector<unsigned int>& defaultTestCaseIds,
369 TConstructDatabaseCallable constructDatabase,
370 const armnn::TensorShape* inputTensorShape)
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000371
telsoa014fcda012018-03-09 14:13:49 +0000372{
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100373 ARMNN_ASSERT(modelFilename);
374 ARMNN_ASSERT(inputBindingName);
375 ARMNN_ASSERT(outputBindingName);
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000376
telsoa014fcda012018-03-09 14:13:49 +0000377 return InferenceTestMain(argc, argv, defaultTestCaseIds,
378 [=]
379 ()
380 {
telsoa01c577f2c2018-08-31 09:22:23 +0100381 using InferenceModel = InferenceModel<TParser, typename TDatabase::DataType>;
telsoa014fcda012018-03-09 14:13:49 +0000382 using TestCaseProvider = ClassifierTestCaseProvider<TDatabase, InferenceModel>;
383
384 return make_unique<TestCaseProvider>(constructDatabase,
385 [&]
Matthew Bentham3e68b972019-04-09 13:10:46 +0100386 (const InferenceTestOptions &commonOptions,
387 typename InferenceModel::CommandLineOptions modelOptions)
telsoa014fcda012018-03-09 14:13:49 +0000388 {
389 if (!ValidateDirectory(modelOptions.m_ModelDir))
390 {
391 return std::unique_ptr<InferenceModel>();
392 }
393
394 typename InferenceModel::Params modelParams;
395 modelParams.m_ModelPath = modelOptions.m_ModelDir + modelFilename;
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000396 modelParams.m_InputBindings = { inputBindingName };
397 modelParams.m_OutputBindings = { outputBindingName };
398
399 if (inputTensorShape)
400 {
401 modelParams.m_InputShapes.push_back(*inputTensorShape);
402 }
403
telsoa014fcda012018-03-09 14:13:49 +0000404 modelParams.m_IsModelBinary = isModelBinary;
Aron Virginas-Tar339bcae2019-01-31 16:44:26 +0000405 modelParams.m_ComputeDevices = modelOptions.GetComputeDevicesAsBackendIds();
surmeh013537c2c2018-05-18 16:31:43 +0100406 modelParams.m_VisualizePostOptimizationModel = modelOptions.m_VisualizePostOptimizationModel;
telsoa01c577f2c2018-08-31 09:22:23 +0100407 modelParams.m_EnableFp16TurboMode = modelOptions.m_EnableFp16TurboMode;
telsoa014fcda012018-03-09 14:13:49 +0000408
Matteo Martincigh00dda4a2019-08-14 11:42:30 +0100409 return std::make_unique<InferenceModel>(modelParams,
410 commonOptions.m_EnableProfiling,
411 commonOptions.m_DynamicBackendsPath);
telsoa014fcda012018-03-09 14:13:49 +0000412 });
413 });
414}
415
416} // namespace test
417} // namespace armnn