blob: 4a97d17018aeeeb4153116d7f67c77366f5678fb [file] [log] [blame]
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +00001//
telsoa014fcda012018-03-09 14:13:49 +00002// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#include "InferenceTest.hpp"
6
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01007#include <armnn/utility/Assert.hpp>
telsoa014fcda012018-03-09 14:13:49 +00008#include <boost/numeric/conversion/cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +00009#include <boost/program_options.hpp>
James Ward08f40162020-09-07 16:45:07 +010010#include <fmt/format.h>
telsoa014fcda012018-03-09 14:13:49 +000011
12#include <fstream>
13#include <iostream>
14#include <iomanip>
15#include <array>
16#include <chrono>
17
18using namespace std;
19using namespace std::chrono;
20using namespace armnn::test;
21
22namespace armnn
23{
24namespace test
25{
26
Ferran Balaguerc602f292019-02-08 17:09:55 +000027using TContainer = boost::variant<std::vector<float>, std::vector<int>, std::vector<unsigned char>>;
telsoa01c577f2c2018-08-31 09:22:23 +010028
telsoa014fcda012018-03-09 14:13:49 +000029template <typename TTestCaseDatabase, typename TModel>
30ClassifierTestCase<TTestCaseDatabase, TModel>::ClassifierTestCase(
31 int& numInferencesRef,
32 int& numCorrectInferencesRef,
33 const std::vector<unsigned int>& validationPredictions,
34 std::vector<unsigned int>* validationPredictionsOut,
35 TModel& model,
36 unsigned int testCaseId,
37 unsigned int label,
38 std::vector<typename TModel::DataType> modelInput)
Ferran Balaguerc602f292019-02-08 17:09:55 +000039 : InferenceModelTestCase<TModel>(
40 model, testCaseId, std::vector<TContainer>{ modelInput }, { model.GetOutputSize() })
telsoa014fcda012018-03-09 14:13:49 +000041 , m_Label(label)
telsoa01c577f2c2018-08-31 09:22:23 +010042 , m_QuantizationParams(model.GetQuantizationParams())
telsoa014fcda012018-03-09 14:13:49 +000043 , m_NumInferencesRef(numInferencesRef)
44 , m_NumCorrectInferencesRef(numCorrectInferencesRef)
45 , m_ValidationPredictions(validationPredictions)
46 , m_ValidationPredictionsOut(validationPredictionsOut)
47{
48}
49
Derek Lambertiac737602019-05-16 16:33:00 +010050struct ClassifierResultProcessor : public boost::static_visitor<>
51{
52 using ResultMap = std::map<float,int>;
53
54 ClassifierResultProcessor(float scale, int offset)
55 : m_Scale(scale)
56 , m_Offset(offset)
57 {}
58
59 void operator()(const std::vector<float>& values)
60 {
61 SortPredictions(values, [](float value)
62 {
63 return value;
64 });
65 }
66
67 void operator()(const std::vector<uint8_t>& values)
68 {
69 auto& scale = m_Scale;
70 auto& offset = m_Offset;
71 SortPredictions(values, [&scale, &offset](uint8_t value)
72 {
73 return armnn::Dequantize(value, scale, offset);
74 });
75 }
76
77 void operator()(const std::vector<int>& values)
78 {
Jan Eilers8eb25602020-03-09 12:13:48 +000079 IgnoreUnused(values);
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010080 ARMNN_ASSERT_MSG(false, "Non-float predictions output not supported.");
Derek Lambertiac737602019-05-16 16:33:00 +010081 }
82
83 ResultMap& GetResultMap() { return m_ResultMap; }
84
85private:
86 template<typename Container, typename Delegate>
87 void SortPredictions(const Container& c, Delegate delegate)
88 {
89 int index = 0;
90 for (const auto& value : c)
91 {
92 int classification = index++;
93 // Take the first class with each probability
94 // This avoids strange results when looping over batched results produced
95 // with identical test data.
96 ResultMap::iterator lb = m_ResultMap.lower_bound(value);
97
98 if (lb == m_ResultMap.end() || !m_ResultMap.key_comp()(value, lb->first))
99 {
100 // If the key is not already in the map, insert it.
101 m_ResultMap.insert(lb, ResultMap::value_type(delegate(value), classification));
102 }
103 }
104 }
105
106 ResultMap m_ResultMap;
107
108 float m_Scale=0.0f;
109 int m_Offset=0;
110};
111
telsoa014fcda012018-03-09 14:13:49 +0000112template <typename TTestCaseDatabase, typename TModel>
113TestCaseResult ClassifierTestCase<TTestCaseDatabase, TModel>::ProcessResult(const InferenceTestOptions& params)
114{
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000115 auto& output = this->GetOutputs()[0];
telsoa014fcda012018-03-09 14:13:49 +0000116 const auto testCaseId = this->GetTestCaseId();
117
Derek Lambertiac737602019-05-16 16:33:00 +0100118 ClassifierResultProcessor resultProcessor(m_QuantizationParams.first, m_QuantizationParams.second);
119 boost::apply_visitor(resultProcessor, output);
120
Derek Lamberti08446972019-11-26 16:38:31 +0000121 ARMNN_LOG(info) << "= Prediction values for test #" << testCaseId;
Derek Lambertiac737602019-05-16 16:33:00 +0100122 auto it = resultProcessor.GetResultMap().rbegin();
123 for (int i=0; i<5 && it != resultProcessor.GetResultMap().rend(); ++i)
surmeh01bceff2f2018-03-29 16:29:27 +0100124 {
Derek Lamberti08446972019-11-26 16:38:31 +0000125 ARMNN_LOG(info) << "Top(" << (i+1) << ") prediction is " << it->second <<
Derek Lambertiac737602019-05-16 16:33:00 +0100126 " with value: " << (it->first);
127 ++it;
surmeh01bceff2f2018-03-29 16:29:27 +0100128 }
129
Ferran Balaguerc602f292019-02-08 17:09:55 +0000130 unsigned int prediction = 0;
131 boost::apply_visitor([&](auto&& value)
132 {
133 prediction = boost::numeric_cast<unsigned int>(
134 std::distance(value.begin(), std::max_element(value.begin(), value.end())));
135 },
136 output);
telsoa014fcda012018-03-09 14:13:49 +0000137
telsoa01c577f2c2018-08-31 09:22:23 +0100138 // If we're just running the defaultTestCaseIds, each one must be classified correctly.
telsoa014fcda012018-03-09 14:13:49 +0000139 if (params.m_IterationCount == 0 && prediction != m_Label)
140 {
Derek Lamberti08446972019-11-26 16:38:31 +0000141 ARMNN_LOG(error) << "Prediction for test case " << testCaseId << " (" << prediction << ")" <<
telsoa014fcda012018-03-09 14:13:49 +0000142 " is incorrect (should be " << m_Label << ")";
143 return TestCaseResult::Failed;
144 }
145
telsoa01c577f2c2018-08-31 09:22:23 +0100146 // If a validation file was provided as input, it checks that the prediction matches.
telsoa014fcda012018-03-09 14:13:49 +0000147 if (!m_ValidationPredictions.empty() && prediction != m_ValidationPredictions[testCaseId])
148 {
Derek Lamberti08446972019-11-26 16:38:31 +0000149 ARMNN_LOG(error) << "Prediction for test case " << testCaseId << " (" << prediction << ")" <<
telsoa014fcda012018-03-09 14:13:49 +0000150 " doesn't match the prediction in the validation file (" << m_ValidationPredictions[testCaseId] << ")";
151 return TestCaseResult::Failed;
152 }
153
telsoa01c577f2c2018-08-31 09:22:23 +0100154 // If a validation file was requested as output, it stores the predictions.
telsoa014fcda012018-03-09 14:13:49 +0000155 if (m_ValidationPredictionsOut)
156 {
157 m_ValidationPredictionsOut->push_back(prediction);
158 }
159
telsoa01c577f2c2018-08-31 09:22:23 +0100160 // Updates accuracy stats.
telsoa014fcda012018-03-09 14:13:49 +0000161 m_NumInferencesRef++;
162 if (prediction == m_Label)
163 {
164 m_NumCorrectInferencesRef++;
165 }
166
167 return TestCaseResult::Ok;
168}
169
170template <typename TDatabase, typename InferenceModel>
171template <typename TConstructDatabaseCallable, typename TConstructModelCallable>
172ClassifierTestCaseProvider<TDatabase, InferenceModel>::ClassifierTestCaseProvider(
173 TConstructDatabaseCallable constructDatabase, TConstructModelCallable constructModel)
174 : m_ConstructModel(constructModel)
175 , m_ConstructDatabase(constructDatabase)
176 , m_NumInferences(0)
177 , m_NumCorrectInferences(0)
178{
179}
180
181template <typename TDatabase, typename InferenceModel>
182void ClassifierTestCaseProvider<TDatabase, InferenceModel>::AddCommandLineOptions(
183 boost::program_options::options_description& options)
184{
185 namespace po = boost::program_options;
186
187 options.add_options()
188 ("validation-file-in", po::value<std::string>(&m_ValidationFileIn)->default_value(""),
189 "Reads expected predictions from the given file and confirms they match the actual predictions.")
190 ("validation-file-out", po::value<std::string>(&m_ValidationFileOut)->default_value(""),
191 "Predictions are saved to the given file for later use via --validation-file-in.")
192 ("data-dir,d", po::value<std::string>(&m_DataDir)->required(),
193 "Path to directory containing test data");
194
195 InferenceModel::AddCommandLineOptions(options, m_ModelCommandLineOptions);
196}
197
198template <typename TDatabase, typename InferenceModel>
Matthew Bentham3e68b972019-04-09 13:10:46 +0100199bool ClassifierTestCaseProvider<TDatabase, InferenceModel>::ProcessCommandLineOptions(
200 const InferenceTestOptions& commonOptions)
telsoa014fcda012018-03-09 14:13:49 +0000201{
202 if (!ValidateDirectory(m_DataDir))
203 {
204 return false;
205 }
206
207 ReadPredictions();
208
Matthew Bentham3e68b972019-04-09 13:10:46 +0100209 m_Model = m_ConstructModel(commonOptions, m_ModelCommandLineOptions);
telsoa014fcda012018-03-09 14:13:49 +0000210 if (!m_Model)
211 {
212 return false;
213 }
214
telsoa01c577f2c2018-08-31 09:22:23 +0100215 m_Database = std::make_unique<TDatabase>(m_ConstructDatabase(m_DataDir.c_str(), *m_Model));
telsoa014fcda012018-03-09 14:13:49 +0000216 if (!m_Database)
217 {
218 return false;
219 }
220
221 return true;
222}
223
224template <typename TDatabase, typename InferenceModel>
225std::unique_ptr<IInferenceTestCase>
226ClassifierTestCaseProvider<TDatabase, InferenceModel>::GetTestCase(unsigned int testCaseId)
227{
228 std::unique_ptr<typename TDatabase::TTestCaseData> testCaseData = m_Database->GetTestCaseData(testCaseId);
229 if (testCaseData == nullptr)
230 {
231 return nullptr;
232 }
233
234 return std::make_unique<ClassifierTestCase<TDatabase, InferenceModel>>(
235 m_NumInferences,
236 m_NumCorrectInferences,
237 m_ValidationPredictions,
238 m_ValidationFileOut.empty() ? nullptr : &m_ValidationPredictionsOut,
239 *m_Model,
240 testCaseId,
241 testCaseData->m_Label,
242 std::move(testCaseData->m_InputImage));
243}
244
245template <typename TDatabase, typename InferenceModel>
246bool ClassifierTestCaseProvider<TDatabase, InferenceModel>::OnInferenceTestFinished()
247{
248 const double accuracy = boost::numeric_cast<double>(m_NumCorrectInferences) /
249 boost::numeric_cast<double>(m_NumInferences);
Derek Lamberti08446972019-11-26 16:38:31 +0000250 ARMNN_LOG(info) << std::fixed << std::setprecision(3) << "Overall accuracy: " << accuracy;
telsoa014fcda012018-03-09 14:13:49 +0000251
telsoa01c577f2c2018-08-31 09:22:23 +0100252 // If a validation file was requested as output, the predictions are saved to it.
telsoa014fcda012018-03-09 14:13:49 +0000253 if (!m_ValidationFileOut.empty())
254 {
255 std::ofstream validationFileOut(m_ValidationFileOut.c_str(), std::ios_base::trunc | std::ios_base::out);
256 if (validationFileOut.good())
257 {
258 for (const unsigned int prediction : m_ValidationPredictionsOut)
259 {
260 validationFileOut << prediction << std::endl;
261 }
262 }
263 else
264 {
Derek Lamberti08446972019-11-26 16:38:31 +0000265 ARMNN_LOG(error) << "Failed to open output validation file: " << m_ValidationFileOut;
telsoa014fcda012018-03-09 14:13:49 +0000266 return false;
267 }
268 }
269
270 return true;
271}
272
273template <typename TDatabase, typename InferenceModel>
274void ClassifierTestCaseProvider<TDatabase, InferenceModel>::ReadPredictions()
275{
telsoa01c577f2c2018-08-31 09:22:23 +0100276 // Reads the expected predictions from the input validation file (if provided).
telsoa014fcda012018-03-09 14:13:49 +0000277 if (!m_ValidationFileIn.empty())
278 {
279 std::ifstream validationFileIn(m_ValidationFileIn.c_str(), std::ios_base::in);
280 if (validationFileIn.good())
281 {
282 while (!validationFileIn.eof())
283 {
284 unsigned int i;
285 validationFileIn >> i;
286 m_ValidationPredictions.emplace_back(i);
287 }
288 }
289 else
290 {
James Ward08f40162020-09-07 16:45:07 +0100291 throw armnn::Exception(fmt::format("Failed to open input validation file: {}"
292 , m_ValidationFileIn));
telsoa014fcda012018-03-09 14:13:49 +0000293 }
294 }
295}
296
297template<typename TConstructTestCaseProvider>
298int InferenceTestMain(int argc,
299 char* argv[],
300 const std::vector<unsigned int>& defaultTestCaseIds,
301 TConstructTestCaseProvider constructTestCaseProvider)
302{
telsoa01c577f2c2018-08-31 09:22:23 +0100303 // Configures logging for both the ARMNN library and this test program.
telsoa014fcda012018-03-09 14:13:49 +0000304#ifdef NDEBUG
305 armnn::LogSeverity level = armnn::LogSeverity::Info;
306#else
307 armnn::LogSeverity level = armnn::LogSeverity::Debug;
308#endif
309 armnn::ConfigureLogging(true, true, level);
telsoa014fcda012018-03-09 14:13:49 +0000310
311 try
312 {
313 std::unique_ptr<IInferenceTestCaseProvider> testCaseProvider = constructTestCaseProvider();
314 if (!testCaseProvider)
315 {
316 return 1;
317 }
318
319 InferenceTestOptions inferenceTestOptions;
320 if (!ParseCommandLine(argc, argv, *testCaseProvider, inferenceTestOptions))
321 {
322 return 1;
323 }
324
325 const bool success = InferenceTest(inferenceTestOptions, defaultTestCaseIds, *testCaseProvider);
326 return success ? 0 : 1;
327 }
328 catch (armnn::Exception const& e)
329 {
Derek Lamberti08446972019-11-26 16:38:31 +0000330 ARMNN_LOG(fatal) << "Armnn Error: " << e.what();
telsoa014fcda012018-03-09 14:13:49 +0000331 return 1;
332 }
333}
334
telsoa01c577f2c2018-08-31 09:22:23 +0100335//
336// This function allows us to create a classifier inference test based on:
337// - a model file name
338// - which can be a binary or a text file for protobuf formats
339// - an input tensor name
340// - an output tensor name
341// - a set of test case ids
342// - a callback method which creates an object that can return images
343// called 'Database' in these tests
344// - and an input tensor shape
345//
telsoa014fcda012018-03-09 14:13:49 +0000346template<typename TDatabase,
telsoa01c577f2c2018-08-31 09:22:23 +0100347 typename TParser,
348 typename TConstructDatabaseCallable>
349int ClassifierInferenceTestMain(int argc,
350 char* argv[],
351 const char* modelFilename,
352 bool isModelBinary,
353 const char* inputBindingName,
354 const char* outputBindingName,
355 const std::vector<unsigned int>& defaultTestCaseIds,
356 TConstructDatabaseCallable constructDatabase,
357 const armnn::TensorShape* inputTensorShape)
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000358
telsoa014fcda012018-03-09 14:13:49 +0000359{
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100360 ARMNN_ASSERT(modelFilename);
361 ARMNN_ASSERT(inputBindingName);
362 ARMNN_ASSERT(outputBindingName);
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000363
telsoa014fcda012018-03-09 14:13:49 +0000364 return InferenceTestMain(argc, argv, defaultTestCaseIds,
365 [=]
366 ()
367 {
telsoa01c577f2c2018-08-31 09:22:23 +0100368 using InferenceModel = InferenceModel<TParser, typename TDatabase::DataType>;
telsoa014fcda012018-03-09 14:13:49 +0000369 using TestCaseProvider = ClassifierTestCaseProvider<TDatabase, InferenceModel>;
370
371 return make_unique<TestCaseProvider>(constructDatabase,
372 [&]
Matthew Bentham3e68b972019-04-09 13:10:46 +0100373 (const InferenceTestOptions &commonOptions,
374 typename InferenceModel::CommandLineOptions modelOptions)
telsoa014fcda012018-03-09 14:13:49 +0000375 {
376 if (!ValidateDirectory(modelOptions.m_ModelDir))
377 {
378 return std::unique_ptr<InferenceModel>();
379 }
380
381 typename InferenceModel::Params modelParams;
382 modelParams.m_ModelPath = modelOptions.m_ModelDir + modelFilename;
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000383 modelParams.m_InputBindings = { inputBindingName };
384 modelParams.m_OutputBindings = { outputBindingName };
385
386 if (inputTensorShape)
387 {
388 modelParams.m_InputShapes.push_back(*inputTensorShape);
389 }
390
telsoa014fcda012018-03-09 14:13:49 +0000391 modelParams.m_IsModelBinary = isModelBinary;
Aron Virginas-Tar339bcae2019-01-31 16:44:26 +0000392 modelParams.m_ComputeDevices = modelOptions.GetComputeDevicesAsBackendIds();
surmeh013537c2c2018-05-18 16:31:43 +0100393 modelParams.m_VisualizePostOptimizationModel = modelOptions.m_VisualizePostOptimizationModel;
telsoa01c577f2c2018-08-31 09:22:23 +0100394 modelParams.m_EnableFp16TurboMode = modelOptions.m_EnableFp16TurboMode;
telsoa014fcda012018-03-09 14:13:49 +0000395
Matteo Martincigh00dda4a2019-08-14 11:42:30 +0100396 return std::make_unique<InferenceModel>(modelParams,
397 commonOptions.m_EnableProfiling,
398 commonOptions.m_DynamicBackendsPath);
telsoa014fcda012018-03-09 14:13:49 +0000399 });
400 });
401}
402
403} // namespace test
404} // namespace armnn