blob: 3d6dae335aa9606732fb398e3436ff37f703c8b7 [file] [log] [blame]
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +00001//
telsoa014fcda012018-03-09 14:13:49 +00002// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#include "InferenceTest.hpp"
6
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01007#include <armnn/utility/Assert.hpp>
Matthew Sloyan80c6b142020-09-08 12:00:32 +01008#include <armnn/utility/NumericCast.hpp>
James Wardc89829f2020-10-12 14:17:36 +01009#include "CxxoptsUtils.hpp"
Matthew Sloyan80c6b142020-09-08 12:00:32 +010010
James Wardc89829f2020-10-12 14:17:36 +010011#include <cxxopts/cxxopts.hpp>
James Ward08f40162020-09-07 16:45:07 +010012#include <fmt/format.h>
telsoa014fcda012018-03-09 14:13:49 +000013
14#include <fstream>
15#include <iostream>
16#include <iomanip>
17#include <array>
18#include <chrono>
19
20using namespace std;
21using namespace std::chrono;
22using namespace armnn::test;
23
24namespace armnn
25{
26namespace test
27{
28
James Ward6d9f5c52020-09-28 11:56:35 +010029using TContainer = mapbox::util::variant<std::vector<float>, std::vector<int>, std::vector<unsigned char>>;
telsoa01c577f2c2018-08-31 09:22:23 +010030
telsoa014fcda012018-03-09 14:13:49 +000031template <typename TTestCaseDatabase, typename TModel>
32ClassifierTestCase<TTestCaseDatabase, TModel>::ClassifierTestCase(
33 int& numInferencesRef,
34 int& numCorrectInferencesRef,
35 const std::vector<unsigned int>& validationPredictions,
36 std::vector<unsigned int>* validationPredictionsOut,
37 TModel& model,
38 unsigned int testCaseId,
39 unsigned int label,
40 std::vector<typename TModel::DataType> modelInput)
Ferran Balaguerc602f292019-02-08 17:09:55 +000041 : InferenceModelTestCase<TModel>(
42 model, testCaseId, std::vector<TContainer>{ modelInput }, { model.GetOutputSize() })
telsoa014fcda012018-03-09 14:13:49 +000043 , m_Label(label)
telsoa01c577f2c2018-08-31 09:22:23 +010044 , m_QuantizationParams(model.GetQuantizationParams())
telsoa014fcda012018-03-09 14:13:49 +000045 , m_NumInferencesRef(numInferencesRef)
46 , m_NumCorrectInferencesRef(numCorrectInferencesRef)
47 , m_ValidationPredictions(validationPredictions)
48 , m_ValidationPredictionsOut(validationPredictionsOut)
49{
50}
51
James Ward6d9f5c52020-09-28 11:56:35 +010052struct ClassifierResultProcessor
Derek Lambertiac737602019-05-16 16:33:00 +010053{
54 using ResultMap = std::map<float,int>;
55
56 ClassifierResultProcessor(float scale, int offset)
57 : m_Scale(scale)
58 , m_Offset(offset)
59 {}
60
61 void operator()(const std::vector<float>& values)
62 {
63 SortPredictions(values, [](float value)
64 {
65 return value;
66 });
67 }
68
69 void operator()(const std::vector<uint8_t>& values)
70 {
71 auto& scale = m_Scale;
72 auto& offset = m_Offset;
73 SortPredictions(values, [&scale, &offset](uint8_t value)
74 {
75 return armnn::Dequantize(value, scale, offset);
76 });
77 }
78
79 void operator()(const std::vector<int>& values)
80 {
Jan Eilers8eb25602020-03-09 12:13:48 +000081 IgnoreUnused(values);
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010082 ARMNN_ASSERT_MSG(false, "Non-float predictions output not supported.");
Derek Lambertiac737602019-05-16 16:33:00 +010083 }
84
85 ResultMap& GetResultMap() { return m_ResultMap; }
86
87private:
88 template<typename Container, typename Delegate>
89 void SortPredictions(const Container& c, Delegate delegate)
90 {
91 int index = 0;
92 for (const auto& value : c)
93 {
94 int classification = index++;
95 // Take the first class with each probability
96 // This avoids strange results when looping over batched results produced
97 // with identical test data.
98 ResultMap::iterator lb = m_ResultMap.lower_bound(value);
99
100 if (lb == m_ResultMap.end() || !m_ResultMap.key_comp()(value, lb->first))
101 {
102 // If the key is not already in the map, insert it.
103 m_ResultMap.insert(lb, ResultMap::value_type(delegate(value), classification));
104 }
105 }
106 }
107
108 ResultMap m_ResultMap;
109
110 float m_Scale=0.0f;
111 int m_Offset=0;
112};
113
telsoa014fcda012018-03-09 14:13:49 +0000114template <typename TTestCaseDatabase, typename TModel>
115TestCaseResult ClassifierTestCase<TTestCaseDatabase, TModel>::ProcessResult(const InferenceTestOptions& params)
116{
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000117 auto& output = this->GetOutputs()[0];
telsoa014fcda012018-03-09 14:13:49 +0000118 const auto testCaseId = this->GetTestCaseId();
119
Derek Lambertiac737602019-05-16 16:33:00 +0100120 ClassifierResultProcessor resultProcessor(m_QuantizationParams.first, m_QuantizationParams.second);
James Ward6d9f5c52020-09-28 11:56:35 +0100121 mapbox::util::apply_visitor(resultProcessor, output);
Derek Lambertiac737602019-05-16 16:33:00 +0100122
Derek Lamberti08446972019-11-26 16:38:31 +0000123 ARMNN_LOG(info) << "= Prediction values for test #" << testCaseId;
Derek Lambertiac737602019-05-16 16:33:00 +0100124 auto it = resultProcessor.GetResultMap().rbegin();
125 for (int i=0; i<5 && it != resultProcessor.GetResultMap().rend(); ++i)
surmeh01bceff2f2018-03-29 16:29:27 +0100126 {
Derek Lamberti08446972019-11-26 16:38:31 +0000127 ARMNN_LOG(info) << "Top(" << (i+1) << ") prediction is " << it->second <<
Derek Lambertiac737602019-05-16 16:33:00 +0100128 " with value: " << (it->first);
129 ++it;
surmeh01bceff2f2018-03-29 16:29:27 +0100130 }
131
Ferran Balaguerc602f292019-02-08 17:09:55 +0000132 unsigned int prediction = 0;
James Ward6d9f5c52020-09-28 11:56:35 +0100133 mapbox::util::apply_visitor([&](auto&& value)
Ferran Balaguerc602f292019-02-08 17:09:55 +0000134 {
Matthew Sloyan80c6b142020-09-08 12:00:32 +0100135 prediction = armnn::numeric_cast<unsigned int>(
Ferran Balaguerc602f292019-02-08 17:09:55 +0000136 std::distance(value.begin(), std::max_element(value.begin(), value.end())));
137 },
138 output);
telsoa014fcda012018-03-09 14:13:49 +0000139
telsoa01c577f2c2018-08-31 09:22:23 +0100140 // If we're just running the defaultTestCaseIds, each one must be classified correctly.
telsoa014fcda012018-03-09 14:13:49 +0000141 if (params.m_IterationCount == 0 && prediction != m_Label)
142 {
Derek Lamberti08446972019-11-26 16:38:31 +0000143 ARMNN_LOG(error) << "Prediction for test case " << testCaseId << " (" << prediction << ")" <<
telsoa014fcda012018-03-09 14:13:49 +0000144 " is incorrect (should be " << m_Label << ")";
145 return TestCaseResult::Failed;
146 }
147
telsoa01c577f2c2018-08-31 09:22:23 +0100148 // If a validation file was provided as input, it checks that the prediction matches.
telsoa014fcda012018-03-09 14:13:49 +0000149 if (!m_ValidationPredictions.empty() && prediction != m_ValidationPredictions[testCaseId])
150 {
Derek Lamberti08446972019-11-26 16:38:31 +0000151 ARMNN_LOG(error) << "Prediction for test case " << testCaseId << " (" << prediction << ")" <<
telsoa014fcda012018-03-09 14:13:49 +0000152 " doesn't match the prediction in the validation file (" << m_ValidationPredictions[testCaseId] << ")";
153 return TestCaseResult::Failed;
154 }
155
telsoa01c577f2c2018-08-31 09:22:23 +0100156 // If a validation file was requested as output, it stores the predictions.
telsoa014fcda012018-03-09 14:13:49 +0000157 if (m_ValidationPredictionsOut)
158 {
159 m_ValidationPredictionsOut->push_back(prediction);
160 }
161
telsoa01c577f2c2018-08-31 09:22:23 +0100162 // Updates accuracy stats.
telsoa014fcda012018-03-09 14:13:49 +0000163 m_NumInferencesRef++;
164 if (prediction == m_Label)
165 {
166 m_NumCorrectInferencesRef++;
167 }
168
169 return TestCaseResult::Ok;
170}
171
172template <typename TDatabase, typename InferenceModel>
173template <typename TConstructDatabaseCallable, typename TConstructModelCallable>
174ClassifierTestCaseProvider<TDatabase, InferenceModel>::ClassifierTestCaseProvider(
175 TConstructDatabaseCallable constructDatabase, TConstructModelCallable constructModel)
176 : m_ConstructModel(constructModel)
177 , m_ConstructDatabase(constructDatabase)
178 , m_NumInferences(0)
179 , m_NumCorrectInferences(0)
180{
181}
182
183template <typename TDatabase, typename InferenceModel>
184void ClassifierTestCaseProvider<TDatabase, InferenceModel>::AddCommandLineOptions(
James Wardc89829f2020-10-12 14:17:36 +0100185 cxxopts::Options& options, std::vector<std::string>& required)
telsoa014fcda012018-03-09 14:13:49 +0000186{
James Wardc89829f2020-10-12 14:17:36 +0100187 options
188 .allow_unrecognised_options()
189 .add_options()
190 ("validation-file-in",
191 "Reads expected predictions from the given file and confirms they match the actual predictions.",
192 cxxopts::value<std::string>(m_ValidationFileIn)->default_value(""))
193 ("validation-file-out", "Predictions are saved to the given file for later use via --validation-file-in.",
194 cxxopts::value<std::string>(m_ValidationFileOut)->default_value(""))
195 ("d,data-dir", "Path to directory containing test data", cxxopts::value<std::string>(m_DataDir));
telsoa014fcda012018-03-09 14:13:49 +0000196
James Wardc89829f2020-10-12 14:17:36 +0100197 required.emplace_back("data-dir"); //add to required arguments to check
telsoa014fcda012018-03-09 14:13:49 +0000198
James Wardc89829f2020-10-12 14:17:36 +0100199 InferenceModel::AddCommandLineOptions(options, m_ModelCommandLineOptions, required);
telsoa014fcda012018-03-09 14:13:49 +0000200}
201
202template <typename TDatabase, typename InferenceModel>
Matthew Bentham3e68b972019-04-09 13:10:46 +0100203bool ClassifierTestCaseProvider<TDatabase, InferenceModel>::ProcessCommandLineOptions(
204 const InferenceTestOptions& commonOptions)
telsoa014fcda012018-03-09 14:13:49 +0000205{
206 if (!ValidateDirectory(m_DataDir))
207 {
208 return false;
209 }
210
211 ReadPredictions();
212
Matthew Bentham3e68b972019-04-09 13:10:46 +0100213 m_Model = m_ConstructModel(commonOptions, m_ModelCommandLineOptions);
telsoa014fcda012018-03-09 14:13:49 +0000214 if (!m_Model)
215 {
216 return false;
217 }
218
telsoa01c577f2c2018-08-31 09:22:23 +0100219 m_Database = std::make_unique<TDatabase>(m_ConstructDatabase(m_DataDir.c_str(), *m_Model));
telsoa014fcda012018-03-09 14:13:49 +0000220 if (!m_Database)
221 {
222 return false;
223 }
224
225 return true;
226}
227
228template <typename TDatabase, typename InferenceModel>
229std::unique_ptr<IInferenceTestCase>
230ClassifierTestCaseProvider<TDatabase, InferenceModel>::GetTestCase(unsigned int testCaseId)
231{
232 std::unique_ptr<typename TDatabase::TTestCaseData> testCaseData = m_Database->GetTestCaseData(testCaseId);
233 if (testCaseData == nullptr)
234 {
235 return nullptr;
236 }
237
238 return std::make_unique<ClassifierTestCase<TDatabase, InferenceModel>>(
239 m_NumInferences,
240 m_NumCorrectInferences,
241 m_ValidationPredictions,
242 m_ValidationFileOut.empty() ? nullptr : &m_ValidationPredictionsOut,
243 *m_Model,
244 testCaseId,
245 testCaseData->m_Label,
246 std::move(testCaseData->m_InputImage));
247}
248
249template <typename TDatabase, typename InferenceModel>
250bool ClassifierTestCaseProvider<TDatabase, InferenceModel>::OnInferenceTestFinished()
251{
Matthew Sloyan24ac8592020-09-23 16:57:23 +0100252 const double accuracy = armnn::numeric_cast<double>(m_NumCorrectInferences) /
253 armnn::numeric_cast<double>(m_NumInferences);
Derek Lamberti08446972019-11-26 16:38:31 +0000254 ARMNN_LOG(info) << std::fixed << std::setprecision(3) << "Overall accuracy: " << accuracy;
telsoa014fcda012018-03-09 14:13:49 +0000255
telsoa01c577f2c2018-08-31 09:22:23 +0100256 // If a validation file was requested as output, the predictions are saved to it.
telsoa014fcda012018-03-09 14:13:49 +0000257 if (!m_ValidationFileOut.empty())
258 {
259 std::ofstream validationFileOut(m_ValidationFileOut.c_str(), std::ios_base::trunc | std::ios_base::out);
260 if (validationFileOut.good())
261 {
262 for (const unsigned int prediction : m_ValidationPredictionsOut)
263 {
264 validationFileOut << prediction << std::endl;
265 }
266 }
267 else
268 {
Derek Lamberti08446972019-11-26 16:38:31 +0000269 ARMNN_LOG(error) << "Failed to open output validation file: " << m_ValidationFileOut;
telsoa014fcda012018-03-09 14:13:49 +0000270 return false;
271 }
272 }
273
274 return true;
275}
276
277template <typename TDatabase, typename InferenceModel>
278void ClassifierTestCaseProvider<TDatabase, InferenceModel>::ReadPredictions()
279{
telsoa01c577f2c2018-08-31 09:22:23 +0100280 // Reads the expected predictions from the input validation file (if provided).
telsoa014fcda012018-03-09 14:13:49 +0000281 if (!m_ValidationFileIn.empty())
282 {
283 std::ifstream validationFileIn(m_ValidationFileIn.c_str(), std::ios_base::in);
284 if (validationFileIn.good())
285 {
286 while (!validationFileIn.eof())
287 {
288 unsigned int i;
289 validationFileIn >> i;
290 m_ValidationPredictions.emplace_back(i);
291 }
292 }
293 else
294 {
James Ward08f40162020-09-07 16:45:07 +0100295 throw armnn::Exception(fmt::format("Failed to open input validation file: {}"
296 , m_ValidationFileIn));
telsoa014fcda012018-03-09 14:13:49 +0000297 }
298 }
299}
300
301template<typename TConstructTestCaseProvider>
302int InferenceTestMain(int argc,
303 char* argv[],
304 const std::vector<unsigned int>& defaultTestCaseIds,
305 TConstructTestCaseProvider constructTestCaseProvider)
306{
telsoa01c577f2c2018-08-31 09:22:23 +0100307 // Configures logging for both the ARMNN library and this test program.
telsoa014fcda012018-03-09 14:13:49 +0000308#ifdef NDEBUG
309 armnn::LogSeverity level = armnn::LogSeverity::Info;
310#else
311 armnn::LogSeverity level = armnn::LogSeverity::Debug;
312#endif
313 armnn::ConfigureLogging(true, true, level);
telsoa014fcda012018-03-09 14:13:49 +0000314
315 try
316 {
317 std::unique_ptr<IInferenceTestCaseProvider> testCaseProvider = constructTestCaseProvider();
318 if (!testCaseProvider)
319 {
320 return 1;
321 }
322
323 InferenceTestOptions inferenceTestOptions;
324 if (!ParseCommandLine(argc, argv, *testCaseProvider, inferenceTestOptions))
325 {
326 return 1;
327 }
328
329 const bool success = InferenceTest(inferenceTestOptions, defaultTestCaseIds, *testCaseProvider);
330 return success ? 0 : 1;
331 }
332 catch (armnn::Exception const& e)
333 {
Derek Lamberti08446972019-11-26 16:38:31 +0000334 ARMNN_LOG(fatal) << "Armnn Error: " << e.what();
telsoa014fcda012018-03-09 14:13:49 +0000335 return 1;
336 }
337}
338
telsoa01c577f2c2018-08-31 09:22:23 +0100339//
340// This function allows us to create a classifier inference test based on:
341// - a model file name
342// - which can be a binary or a text file for protobuf formats
343// - an input tensor name
344// - an output tensor name
345// - a set of test case ids
346// - a callback method which creates an object that can return images
347// called 'Database' in these tests
348// - and an input tensor shape
349//
telsoa014fcda012018-03-09 14:13:49 +0000350template<typename TDatabase,
telsoa01c577f2c2018-08-31 09:22:23 +0100351 typename TParser,
352 typename TConstructDatabaseCallable>
353int ClassifierInferenceTestMain(int argc,
354 char* argv[],
355 const char* modelFilename,
356 bool isModelBinary,
357 const char* inputBindingName,
358 const char* outputBindingName,
359 const std::vector<unsigned int>& defaultTestCaseIds,
360 TConstructDatabaseCallable constructDatabase,
361 const armnn::TensorShape* inputTensorShape)
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000362
telsoa014fcda012018-03-09 14:13:49 +0000363{
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100364 ARMNN_ASSERT(modelFilename);
365 ARMNN_ASSERT(inputBindingName);
366 ARMNN_ASSERT(outputBindingName);
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000367
telsoa014fcda012018-03-09 14:13:49 +0000368 return InferenceTestMain(argc, argv, defaultTestCaseIds,
369 [=]
370 ()
371 {
telsoa01c577f2c2018-08-31 09:22:23 +0100372 using InferenceModel = InferenceModel<TParser, typename TDatabase::DataType>;
telsoa014fcda012018-03-09 14:13:49 +0000373 using TestCaseProvider = ClassifierTestCaseProvider<TDatabase, InferenceModel>;
374
375 return make_unique<TestCaseProvider>(constructDatabase,
376 [&]
Matthew Bentham3e68b972019-04-09 13:10:46 +0100377 (const InferenceTestOptions &commonOptions,
378 typename InferenceModel::CommandLineOptions modelOptions)
telsoa014fcda012018-03-09 14:13:49 +0000379 {
380 if (!ValidateDirectory(modelOptions.m_ModelDir))
381 {
382 return std::unique_ptr<InferenceModel>();
383 }
384
385 typename InferenceModel::Params modelParams;
386 modelParams.m_ModelPath = modelOptions.m_ModelDir + modelFilename;
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000387 modelParams.m_InputBindings = { inputBindingName };
388 modelParams.m_OutputBindings = { outputBindingName };
389
390 if (inputTensorShape)
391 {
392 modelParams.m_InputShapes.push_back(*inputTensorShape);
393 }
394
telsoa014fcda012018-03-09 14:13:49 +0000395 modelParams.m_IsModelBinary = isModelBinary;
Aron Virginas-Tar339bcae2019-01-31 16:44:26 +0000396 modelParams.m_ComputeDevices = modelOptions.GetComputeDevicesAsBackendIds();
surmeh013537c2c2018-05-18 16:31:43 +0100397 modelParams.m_VisualizePostOptimizationModel = modelOptions.m_VisualizePostOptimizationModel;
telsoa01c577f2c2018-08-31 09:22:23 +0100398 modelParams.m_EnableFp16TurboMode = modelOptions.m_EnableFp16TurboMode;
telsoa014fcda012018-03-09 14:13:49 +0000399
Matteo Martincigh00dda4a2019-08-14 11:42:30 +0100400 return std::make_unique<InferenceModel>(modelParams,
401 commonOptions.m_EnableProfiling,
402 commonOptions.m_DynamicBackendsPath);
telsoa014fcda012018-03-09 14:13:49 +0000403 });
404 });
405}
406
407} // namespace test
408} // namespace armnn