blob: 91a90f3820942c6e814bc5957e24e98ff4bcc568 [file] [log] [blame]
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +00001//
telsoa014fcda012018-03-09 14:13:49 +00002// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#include "InferenceTest.hpp"
6
David Monahan6bb47a72021-10-22 12:57:28 +01007#include <armnn/Utils.hpp>
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01008#include <armnn/utility/Assert.hpp>
Matthew Sloyan80c6b142020-09-08 12:00:32 +01009#include <armnn/utility/NumericCast.hpp>
James Wardc89829f2020-10-12 14:17:36 +010010#include "CxxoptsUtils.hpp"
Matthew Sloyan80c6b142020-09-08 12:00:32 +010011
James Wardc89829f2020-10-12 14:17:36 +010012#include <cxxopts/cxxopts.hpp>
James Ward08f40162020-09-07 16:45:07 +010013#include <fmt/format.h>
telsoa014fcda012018-03-09 14:13:49 +000014
15#include <fstream>
16#include <iostream>
17#include <iomanip>
18#include <array>
19#include <chrono>
20
21using namespace std;
22using namespace std::chrono;
23using namespace armnn::test;
24
25namespace armnn
26{
27namespace test
28{
29
30template <typename TTestCaseDatabase, typename TModel>
31ClassifierTestCase<TTestCaseDatabase, TModel>::ClassifierTestCase(
32 int& numInferencesRef,
33 int& numCorrectInferencesRef,
34 const std::vector<unsigned int>& validationPredictions,
35 std::vector<unsigned int>* validationPredictionsOut,
36 TModel& model,
37 unsigned int testCaseId,
38 unsigned int label,
39 std::vector<typename TModel::DataType> modelInput)
Ferran Balaguerc602f292019-02-08 17:09:55 +000040 : InferenceModelTestCase<TModel>(
David Monahan6bb47a72021-10-22 12:57:28 +010041 model, testCaseId, std::vector<armnn::TContainer>{ modelInput }, { model.GetOutputSize() })
telsoa014fcda012018-03-09 14:13:49 +000042 , m_Label(label)
telsoa01c577f2c2018-08-31 09:22:23 +010043 , m_QuantizationParams(model.GetQuantizationParams())
telsoa014fcda012018-03-09 14:13:49 +000044 , m_NumInferencesRef(numInferencesRef)
45 , m_NumCorrectInferencesRef(numCorrectInferencesRef)
46 , m_ValidationPredictions(validationPredictions)
47 , m_ValidationPredictionsOut(validationPredictionsOut)
48{
49}
50
James Ward6d9f5c52020-09-28 11:56:35 +010051struct ClassifierResultProcessor
Derek Lambertiac737602019-05-16 16:33:00 +010052{
53 using ResultMap = std::map<float,int>;
54
55 ClassifierResultProcessor(float scale, int offset)
56 : m_Scale(scale)
57 , m_Offset(offset)
58 {}
59
60 void operator()(const std::vector<float>& values)
61 {
62 SortPredictions(values, [](float value)
63 {
64 return value;
65 });
66 }
67
Finn Williamsf806c4d2021-02-22 15:13:12 +000068 void operator()(const std::vector<int8_t>& values)
69 {
70 SortPredictions(values, [](int8_t value)
71 {
72 return value;
73 });
74 }
75
Derek Lambertiac737602019-05-16 16:33:00 +010076 void operator()(const std::vector<uint8_t>& values)
77 {
78 auto& scale = m_Scale;
79 auto& offset = m_Offset;
80 SortPredictions(values, [&scale, &offset](uint8_t value)
81 {
82 return armnn::Dequantize(value, scale, offset);
83 });
84 }
85
86 void operator()(const std::vector<int>& values)
87 {
Jan Eilers8eb25602020-03-09 12:13:48 +000088 IgnoreUnused(values);
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010089 ARMNN_ASSERT_MSG(false, "Non-float predictions output not supported.");
Derek Lambertiac737602019-05-16 16:33:00 +010090 }
91
92 ResultMap& GetResultMap() { return m_ResultMap; }
93
94private:
95 template<typename Container, typename Delegate>
96 void SortPredictions(const Container& c, Delegate delegate)
97 {
98 int index = 0;
99 for (const auto& value : c)
100 {
101 int classification = index++;
102 // Take the first class with each probability
103 // This avoids strange results when looping over batched results produced
104 // with identical test data.
105 ResultMap::iterator lb = m_ResultMap.lower_bound(value);
106
107 if (lb == m_ResultMap.end() || !m_ResultMap.key_comp()(value, lb->first))
108 {
109 // If the key is not already in the map, insert it.
110 m_ResultMap.insert(lb, ResultMap::value_type(delegate(value), classification));
111 }
112 }
113 }
114
115 ResultMap m_ResultMap;
116
117 float m_Scale=0.0f;
118 int m_Offset=0;
119};
120
telsoa014fcda012018-03-09 14:13:49 +0000121template <typename TTestCaseDatabase, typename TModel>
122TestCaseResult ClassifierTestCase<TTestCaseDatabase, TModel>::ProcessResult(const InferenceTestOptions& params)
123{
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000124 auto& output = this->GetOutputs()[0];
telsoa014fcda012018-03-09 14:13:49 +0000125 const auto testCaseId = this->GetTestCaseId();
126
Derek Lambertiac737602019-05-16 16:33:00 +0100127 ClassifierResultProcessor resultProcessor(m_QuantizationParams.first, m_QuantizationParams.second);
James Ward6d9f5c52020-09-28 11:56:35 +0100128 mapbox::util::apply_visitor(resultProcessor, output);
Derek Lambertiac737602019-05-16 16:33:00 +0100129
Derek Lamberti08446972019-11-26 16:38:31 +0000130 ARMNN_LOG(info) << "= Prediction values for test #" << testCaseId;
Derek Lambertiac737602019-05-16 16:33:00 +0100131 auto it = resultProcessor.GetResultMap().rbegin();
132 for (int i=0; i<5 && it != resultProcessor.GetResultMap().rend(); ++i)
surmeh01bceff2f2018-03-29 16:29:27 +0100133 {
Derek Lamberti08446972019-11-26 16:38:31 +0000134 ARMNN_LOG(info) << "Top(" << (i+1) << ") prediction is " << it->second <<
Derek Lambertiac737602019-05-16 16:33:00 +0100135 " with value: " << (it->first);
136 ++it;
surmeh01bceff2f2018-03-29 16:29:27 +0100137 }
138
Ferran Balaguerc602f292019-02-08 17:09:55 +0000139 unsigned int prediction = 0;
James Ward6d9f5c52020-09-28 11:56:35 +0100140 mapbox::util::apply_visitor([&](auto&& value)
Ferran Balaguerc602f292019-02-08 17:09:55 +0000141 {
Matthew Sloyan80c6b142020-09-08 12:00:32 +0100142 prediction = armnn::numeric_cast<unsigned int>(
Ferran Balaguerc602f292019-02-08 17:09:55 +0000143 std::distance(value.begin(), std::max_element(value.begin(), value.end())));
144 },
145 output);
telsoa014fcda012018-03-09 14:13:49 +0000146
telsoa01c577f2c2018-08-31 09:22:23 +0100147 // If we're just running the defaultTestCaseIds, each one must be classified correctly.
telsoa014fcda012018-03-09 14:13:49 +0000148 if (params.m_IterationCount == 0 && prediction != m_Label)
149 {
Derek Lamberti08446972019-11-26 16:38:31 +0000150 ARMNN_LOG(error) << "Prediction for test case " << testCaseId << " (" << prediction << ")" <<
telsoa014fcda012018-03-09 14:13:49 +0000151 " is incorrect (should be " << m_Label << ")";
152 return TestCaseResult::Failed;
153 }
154
telsoa01c577f2c2018-08-31 09:22:23 +0100155 // If a validation file was provided as input, it checks that the prediction matches.
telsoa014fcda012018-03-09 14:13:49 +0000156 if (!m_ValidationPredictions.empty() && prediction != m_ValidationPredictions[testCaseId])
157 {
Derek Lamberti08446972019-11-26 16:38:31 +0000158 ARMNN_LOG(error) << "Prediction for test case " << testCaseId << " (" << prediction << ")" <<
telsoa014fcda012018-03-09 14:13:49 +0000159 " doesn't match the prediction in the validation file (" << m_ValidationPredictions[testCaseId] << ")";
160 return TestCaseResult::Failed;
161 }
162
telsoa01c577f2c2018-08-31 09:22:23 +0100163 // If a validation file was requested as output, it stores the predictions.
telsoa014fcda012018-03-09 14:13:49 +0000164 if (m_ValidationPredictionsOut)
165 {
166 m_ValidationPredictionsOut->push_back(prediction);
167 }
168
telsoa01c577f2c2018-08-31 09:22:23 +0100169 // Updates accuracy stats.
telsoa014fcda012018-03-09 14:13:49 +0000170 m_NumInferencesRef++;
171 if (prediction == m_Label)
172 {
173 m_NumCorrectInferencesRef++;
174 }
175
176 return TestCaseResult::Ok;
177}
178
179template <typename TDatabase, typename InferenceModel>
180template <typename TConstructDatabaseCallable, typename TConstructModelCallable>
181ClassifierTestCaseProvider<TDatabase, InferenceModel>::ClassifierTestCaseProvider(
182 TConstructDatabaseCallable constructDatabase, TConstructModelCallable constructModel)
183 : m_ConstructModel(constructModel)
184 , m_ConstructDatabase(constructDatabase)
185 , m_NumInferences(0)
186 , m_NumCorrectInferences(0)
187{
188}
189
190template <typename TDatabase, typename InferenceModel>
191void ClassifierTestCaseProvider<TDatabase, InferenceModel>::AddCommandLineOptions(
James Wardc89829f2020-10-12 14:17:36 +0100192 cxxopts::Options& options, std::vector<std::string>& required)
telsoa014fcda012018-03-09 14:13:49 +0000193{
James Wardc89829f2020-10-12 14:17:36 +0100194 options
195 .allow_unrecognised_options()
196 .add_options()
197 ("validation-file-in",
198 "Reads expected predictions from the given file and confirms they match the actual predictions.",
199 cxxopts::value<std::string>(m_ValidationFileIn)->default_value(""))
200 ("validation-file-out", "Predictions are saved to the given file for later use via --validation-file-in.",
201 cxxopts::value<std::string>(m_ValidationFileOut)->default_value(""))
202 ("d,data-dir", "Path to directory containing test data", cxxopts::value<std::string>(m_DataDir));
telsoa014fcda012018-03-09 14:13:49 +0000203
James Wardc89829f2020-10-12 14:17:36 +0100204 required.emplace_back("data-dir"); //add to required arguments to check
telsoa014fcda012018-03-09 14:13:49 +0000205
James Wardc89829f2020-10-12 14:17:36 +0100206 InferenceModel::AddCommandLineOptions(options, m_ModelCommandLineOptions, required);
telsoa014fcda012018-03-09 14:13:49 +0000207}
208
209template <typename TDatabase, typename InferenceModel>
Matthew Bentham3e68b972019-04-09 13:10:46 +0100210bool ClassifierTestCaseProvider<TDatabase, InferenceModel>::ProcessCommandLineOptions(
211 const InferenceTestOptions& commonOptions)
telsoa014fcda012018-03-09 14:13:49 +0000212{
213 if (!ValidateDirectory(m_DataDir))
214 {
215 return false;
216 }
217
218 ReadPredictions();
219
Matthew Bentham3e68b972019-04-09 13:10:46 +0100220 m_Model = m_ConstructModel(commonOptions, m_ModelCommandLineOptions);
telsoa014fcda012018-03-09 14:13:49 +0000221 if (!m_Model)
222 {
223 return false;
224 }
225
telsoa01c577f2c2018-08-31 09:22:23 +0100226 m_Database = std::make_unique<TDatabase>(m_ConstructDatabase(m_DataDir.c_str(), *m_Model));
telsoa014fcda012018-03-09 14:13:49 +0000227 if (!m_Database)
228 {
229 return false;
230 }
231
232 return true;
233}
234
235template <typename TDatabase, typename InferenceModel>
236std::unique_ptr<IInferenceTestCase>
237ClassifierTestCaseProvider<TDatabase, InferenceModel>::GetTestCase(unsigned int testCaseId)
238{
239 std::unique_ptr<typename TDatabase::TTestCaseData> testCaseData = m_Database->GetTestCaseData(testCaseId);
240 if (testCaseData == nullptr)
241 {
242 return nullptr;
243 }
244
245 return std::make_unique<ClassifierTestCase<TDatabase, InferenceModel>>(
246 m_NumInferences,
247 m_NumCorrectInferences,
248 m_ValidationPredictions,
249 m_ValidationFileOut.empty() ? nullptr : &m_ValidationPredictionsOut,
250 *m_Model,
251 testCaseId,
252 testCaseData->m_Label,
253 std::move(testCaseData->m_InputImage));
254}
255
256template <typename TDatabase, typename InferenceModel>
257bool ClassifierTestCaseProvider<TDatabase, InferenceModel>::OnInferenceTestFinished()
258{
Matthew Sloyan24ac8592020-09-23 16:57:23 +0100259 const double accuracy = armnn::numeric_cast<double>(m_NumCorrectInferences) /
260 armnn::numeric_cast<double>(m_NumInferences);
Derek Lamberti08446972019-11-26 16:38:31 +0000261 ARMNN_LOG(info) << std::fixed << std::setprecision(3) << "Overall accuracy: " << accuracy;
telsoa014fcda012018-03-09 14:13:49 +0000262
telsoa01c577f2c2018-08-31 09:22:23 +0100263 // If a validation file was requested as output, the predictions are saved to it.
telsoa014fcda012018-03-09 14:13:49 +0000264 if (!m_ValidationFileOut.empty())
265 {
266 std::ofstream validationFileOut(m_ValidationFileOut.c_str(), std::ios_base::trunc | std::ios_base::out);
267 if (validationFileOut.good())
268 {
269 for (const unsigned int prediction : m_ValidationPredictionsOut)
270 {
271 validationFileOut << prediction << std::endl;
272 }
273 }
274 else
275 {
Derek Lamberti08446972019-11-26 16:38:31 +0000276 ARMNN_LOG(error) << "Failed to open output validation file: " << m_ValidationFileOut;
telsoa014fcda012018-03-09 14:13:49 +0000277 return false;
278 }
279 }
280
281 return true;
282}
283
284template <typename TDatabase, typename InferenceModel>
285void ClassifierTestCaseProvider<TDatabase, InferenceModel>::ReadPredictions()
286{
telsoa01c577f2c2018-08-31 09:22:23 +0100287 // Reads the expected predictions from the input validation file (if provided).
telsoa014fcda012018-03-09 14:13:49 +0000288 if (!m_ValidationFileIn.empty())
289 {
290 std::ifstream validationFileIn(m_ValidationFileIn.c_str(), std::ios_base::in);
291 if (validationFileIn.good())
292 {
293 while (!validationFileIn.eof())
294 {
295 unsigned int i;
296 validationFileIn >> i;
297 m_ValidationPredictions.emplace_back(i);
298 }
299 }
300 else
301 {
James Ward08f40162020-09-07 16:45:07 +0100302 throw armnn::Exception(fmt::format("Failed to open input validation file: {}"
303 , m_ValidationFileIn));
telsoa014fcda012018-03-09 14:13:49 +0000304 }
305 }
306}
307
308template<typename TConstructTestCaseProvider>
309int InferenceTestMain(int argc,
310 char* argv[],
311 const std::vector<unsigned int>& defaultTestCaseIds,
312 TConstructTestCaseProvider constructTestCaseProvider)
313{
telsoa01c577f2c2018-08-31 09:22:23 +0100314 // Configures logging for both the ARMNN library and this test program.
telsoa014fcda012018-03-09 14:13:49 +0000315#ifdef NDEBUG
316 armnn::LogSeverity level = armnn::LogSeverity::Info;
317#else
318 armnn::LogSeverity level = armnn::LogSeverity::Debug;
319#endif
320 armnn::ConfigureLogging(true, true, level);
telsoa014fcda012018-03-09 14:13:49 +0000321
322 try
323 {
324 std::unique_ptr<IInferenceTestCaseProvider> testCaseProvider = constructTestCaseProvider();
325 if (!testCaseProvider)
326 {
327 return 1;
328 }
329
330 InferenceTestOptions inferenceTestOptions;
331 if (!ParseCommandLine(argc, argv, *testCaseProvider, inferenceTestOptions))
332 {
333 return 1;
334 }
335
336 const bool success = InferenceTest(inferenceTestOptions, defaultTestCaseIds, *testCaseProvider);
337 return success ? 0 : 1;
338 }
339 catch (armnn::Exception const& e)
340 {
Derek Lamberti08446972019-11-26 16:38:31 +0000341 ARMNN_LOG(fatal) << "Armnn Error: " << e.what();
telsoa014fcda012018-03-09 14:13:49 +0000342 return 1;
343 }
344}
345
telsoa01c577f2c2018-08-31 09:22:23 +0100346//
347// This function allows us to create a classifier inference test based on:
348// - a model file name
349// - which can be a binary or a text file for protobuf formats
350// - an input tensor name
351// - an output tensor name
352// - a set of test case ids
353// - a callback method which creates an object that can return images
354// called 'Database' in these tests
355// - and an input tensor shape
356//
telsoa014fcda012018-03-09 14:13:49 +0000357template<typename TDatabase,
telsoa01c577f2c2018-08-31 09:22:23 +0100358 typename TParser,
359 typename TConstructDatabaseCallable>
360int ClassifierInferenceTestMain(int argc,
361 char* argv[],
362 const char* modelFilename,
363 bool isModelBinary,
364 const char* inputBindingName,
365 const char* outputBindingName,
366 const std::vector<unsigned int>& defaultTestCaseIds,
367 TConstructDatabaseCallable constructDatabase,
368 const armnn::TensorShape* inputTensorShape)
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000369
telsoa014fcda012018-03-09 14:13:49 +0000370{
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100371 ARMNN_ASSERT(modelFilename);
372 ARMNN_ASSERT(inputBindingName);
373 ARMNN_ASSERT(outputBindingName);
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000374
telsoa014fcda012018-03-09 14:13:49 +0000375 return InferenceTestMain(argc, argv, defaultTestCaseIds,
376 [=]
377 ()
378 {
telsoa01c577f2c2018-08-31 09:22:23 +0100379 using InferenceModel = InferenceModel<TParser, typename TDatabase::DataType>;
telsoa014fcda012018-03-09 14:13:49 +0000380 using TestCaseProvider = ClassifierTestCaseProvider<TDatabase, InferenceModel>;
381
382 return make_unique<TestCaseProvider>(constructDatabase,
383 [&]
Matthew Bentham3e68b972019-04-09 13:10:46 +0100384 (const InferenceTestOptions &commonOptions,
385 typename InferenceModel::CommandLineOptions modelOptions)
telsoa014fcda012018-03-09 14:13:49 +0000386 {
387 if (!ValidateDirectory(modelOptions.m_ModelDir))
388 {
389 return std::unique_ptr<InferenceModel>();
390 }
391
392 typename InferenceModel::Params modelParams;
393 modelParams.m_ModelPath = modelOptions.m_ModelDir + modelFilename;
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000394 modelParams.m_InputBindings = { inputBindingName };
395 modelParams.m_OutputBindings = { outputBindingName };
396
397 if (inputTensorShape)
398 {
399 modelParams.m_InputShapes.push_back(*inputTensorShape);
400 }
401
telsoa014fcda012018-03-09 14:13:49 +0000402 modelParams.m_IsModelBinary = isModelBinary;
Aron Virginas-Tar339bcae2019-01-31 16:44:26 +0000403 modelParams.m_ComputeDevices = modelOptions.GetComputeDevicesAsBackendIds();
surmeh013537c2c2018-05-18 16:31:43 +0100404 modelParams.m_VisualizePostOptimizationModel = modelOptions.m_VisualizePostOptimizationModel;
telsoa01c577f2c2018-08-31 09:22:23 +0100405 modelParams.m_EnableFp16TurboMode = modelOptions.m_EnableFp16TurboMode;
telsoa014fcda012018-03-09 14:13:49 +0000406
Matteo Martincigh00dda4a2019-08-14 11:42:30 +0100407 return std::make_unique<InferenceModel>(modelParams,
408 commonOptions.m_EnableProfiling,
409 commonOptions.m_DynamicBackendsPath);
telsoa014fcda012018-03-09 14:13:49 +0000410 });
411 });
412}
413
414} // namespace test
415} // namespace armnn