blob: ed16464787c3575b933355e529e132b51e52d7c8 [file] [log] [blame]
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +00001//
telsoa014fcda012018-03-09 14:13:49 +00002// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#include "InferenceTest.hpp"
6
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01007#include <armnn/utility/Assert.hpp>
telsoa014fcda012018-03-09 14:13:49 +00008#include <boost/algorithm/string.hpp>
9#include <boost/numeric/conversion/cast.hpp>
telsoa014fcda012018-03-09 14:13:49 +000010#include <boost/filesystem/path.hpp>
telsoa014fcda012018-03-09 14:13:49 +000011#include <boost/format.hpp>
12#include <boost/program_options.hpp>
13#include <boost/filesystem/operations.hpp>
14
15#include <fstream>
16#include <iostream>
17#include <iomanip>
18#include <array>
19#include <chrono>
20
21using namespace std;
22using namespace std::chrono;
23using namespace armnn::test;
24
25namespace armnn
26{
27namespace test
28{
29
Ferran Balaguerc602f292019-02-08 17:09:55 +000030using TContainer = boost::variant<std::vector<float>, std::vector<int>, std::vector<unsigned char>>;
telsoa01c577f2c2018-08-31 09:22:23 +010031
telsoa014fcda012018-03-09 14:13:49 +000032template <typename TTestCaseDatabase, typename TModel>
33ClassifierTestCase<TTestCaseDatabase, TModel>::ClassifierTestCase(
34 int& numInferencesRef,
35 int& numCorrectInferencesRef,
36 const std::vector<unsigned int>& validationPredictions,
37 std::vector<unsigned int>* validationPredictionsOut,
38 TModel& model,
39 unsigned int testCaseId,
40 unsigned int label,
41 std::vector<typename TModel::DataType> modelInput)
Ferran Balaguerc602f292019-02-08 17:09:55 +000042 : InferenceModelTestCase<TModel>(
43 model, testCaseId, std::vector<TContainer>{ modelInput }, { model.GetOutputSize() })
telsoa014fcda012018-03-09 14:13:49 +000044 , m_Label(label)
telsoa01c577f2c2018-08-31 09:22:23 +010045 , m_QuantizationParams(model.GetQuantizationParams())
telsoa014fcda012018-03-09 14:13:49 +000046 , m_NumInferencesRef(numInferencesRef)
47 , m_NumCorrectInferencesRef(numCorrectInferencesRef)
48 , m_ValidationPredictions(validationPredictions)
49 , m_ValidationPredictionsOut(validationPredictionsOut)
50{
51}
52
Derek Lambertiac737602019-05-16 16:33:00 +010053struct ClassifierResultProcessor : public boost::static_visitor<>
54{
55 using ResultMap = std::map<float,int>;
56
57 ClassifierResultProcessor(float scale, int offset)
58 : m_Scale(scale)
59 , m_Offset(offset)
60 {}
61
62 void operator()(const std::vector<float>& values)
63 {
64 SortPredictions(values, [](float value)
65 {
66 return value;
67 });
68 }
69
70 void operator()(const std::vector<uint8_t>& values)
71 {
72 auto& scale = m_Scale;
73 auto& offset = m_Offset;
74 SortPredictions(values, [&scale, &offset](uint8_t value)
75 {
76 return armnn::Dequantize(value, scale, offset);
77 });
78 }
79
80 void operator()(const std::vector<int>& values)
81 {
Jan Eilers8eb25602020-03-09 12:13:48 +000082 IgnoreUnused(values);
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010083 ARMNN_ASSERT_MSG(false, "Non-float predictions output not supported.");
Derek Lambertiac737602019-05-16 16:33:00 +010084 }
85
86 ResultMap& GetResultMap() { return m_ResultMap; }
87
88private:
89 template<typename Container, typename Delegate>
90 void SortPredictions(const Container& c, Delegate delegate)
91 {
92 int index = 0;
93 for (const auto& value : c)
94 {
95 int classification = index++;
96 // Take the first class with each probability
97 // This avoids strange results when looping over batched results produced
98 // with identical test data.
99 ResultMap::iterator lb = m_ResultMap.lower_bound(value);
100
101 if (lb == m_ResultMap.end() || !m_ResultMap.key_comp()(value, lb->first))
102 {
103 // If the key is not already in the map, insert it.
104 m_ResultMap.insert(lb, ResultMap::value_type(delegate(value), classification));
105 }
106 }
107 }
108
109 ResultMap m_ResultMap;
110
111 float m_Scale=0.0f;
112 int m_Offset=0;
113};
114
telsoa014fcda012018-03-09 14:13:49 +0000115template <typename TTestCaseDatabase, typename TModel>
116TestCaseResult ClassifierTestCase<TTestCaseDatabase, TModel>::ProcessResult(const InferenceTestOptions& params)
117{
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000118 auto& output = this->GetOutputs()[0];
telsoa014fcda012018-03-09 14:13:49 +0000119 const auto testCaseId = this->GetTestCaseId();
120
Derek Lambertiac737602019-05-16 16:33:00 +0100121 ClassifierResultProcessor resultProcessor(m_QuantizationParams.first, m_QuantizationParams.second);
122 boost::apply_visitor(resultProcessor, output);
123
Derek Lamberti08446972019-11-26 16:38:31 +0000124 ARMNN_LOG(info) << "= Prediction values for test #" << testCaseId;
Derek Lambertiac737602019-05-16 16:33:00 +0100125 auto it = resultProcessor.GetResultMap().rbegin();
126 for (int i=0; i<5 && it != resultProcessor.GetResultMap().rend(); ++i)
surmeh01bceff2f2018-03-29 16:29:27 +0100127 {
Derek Lamberti08446972019-11-26 16:38:31 +0000128 ARMNN_LOG(info) << "Top(" << (i+1) << ") prediction is " << it->second <<
Derek Lambertiac737602019-05-16 16:33:00 +0100129 " with value: " << (it->first);
130 ++it;
surmeh01bceff2f2018-03-29 16:29:27 +0100131 }
132
Ferran Balaguerc602f292019-02-08 17:09:55 +0000133 unsigned int prediction = 0;
134 boost::apply_visitor([&](auto&& value)
135 {
136 prediction = boost::numeric_cast<unsigned int>(
137 std::distance(value.begin(), std::max_element(value.begin(), value.end())));
138 },
139 output);
telsoa014fcda012018-03-09 14:13:49 +0000140
telsoa01c577f2c2018-08-31 09:22:23 +0100141 // If we're just running the defaultTestCaseIds, each one must be classified correctly.
telsoa014fcda012018-03-09 14:13:49 +0000142 if (params.m_IterationCount == 0 && prediction != m_Label)
143 {
Derek Lamberti08446972019-11-26 16:38:31 +0000144 ARMNN_LOG(error) << "Prediction for test case " << testCaseId << " (" << prediction << ")" <<
telsoa014fcda012018-03-09 14:13:49 +0000145 " is incorrect (should be " << m_Label << ")";
146 return TestCaseResult::Failed;
147 }
148
telsoa01c577f2c2018-08-31 09:22:23 +0100149 // If a validation file was provided as input, it checks that the prediction matches.
telsoa014fcda012018-03-09 14:13:49 +0000150 if (!m_ValidationPredictions.empty() && prediction != m_ValidationPredictions[testCaseId])
151 {
Derek Lamberti08446972019-11-26 16:38:31 +0000152 ARMNN_LOG(error) << "Prediction for test case " << testCaseId << " (" << prediction << ")" <<
telsoa014fcda012018-03-09 14:13:49 +0000153 " doesn't match the prediction in the validation file (" << m_ValidationPredictions[testCaseId] << ")";
154 return TestCaseResult::Failed;
155 }
156
telsoa01c577f2c2018-08-31 09:22:23 +0100157 // If a validation file was requested as output, it stores the predictions.
telsoa014fcda012018-03-09 14:13:49 +0000158 if (m_ValidationPredictionsOut)
159 {
160 m_ValidationPredictionsOut->push_back(prediction);
161 }
162
telsoa01c577f2c2018-08-31 09:22:23 +0100163 // Updates accuracy stats.
telsoa014fcda012018-03-09 14:13:49 +0000164 m_NumInferencesRef++;
165 if (prediction == m_Label)
166 {
167 m_NumCorrectInferencesRef++;
168 }
169
170 return TestCaseResult::Ok;
171}
172
173template <typename TDatabase, typename InferenceModel>
174template <typename TConstructDatabaseCallable, typename TConstructModelCallable>
175ClassifierTestCaseProvider<TDatabase, InferenceModel>::ClassifierTestCaseProvider(
176 TConstructDatabaseCallable constructDatabase, TConstructModelCallable constructModel)
177 : m_ConstructModel(constructModel)
178 , m_ConstructDatabase(constructDatabase)
179 , m_NumInferences(0)
180 , m_NumCorrectInferences(0)
181{
182}
183
184template <typename TDatabase, typename InferenceModel>
185void ClassifierTestCaseProvider<TDatabase, InferenceModel>::AddCommandLineOptions(
186 boost::program_options::options_description& options)
187{
188 namespace po = boost::program_options;
189
190 options.add_options()
191 ("validation-file-in", po::value<std::string>(&m_ValidationFileIn)->default_value(""),
192 "Reads expected predictions from the given file and confirms they match the actual predictions.")
193 ("validation-file-out", po::value<std::string>(&m_ValidationFileOut)->default_value(""),
194 "Predictions are saved to the given file for later use via --validation-file-in.")
195 ("data-dir,d", po::value<std::string>(&m_DataDir)->required(),
196 "Path to directory containing test data");
197
198 InferenceModel::AddCommandLineOptions(options, m_ModelCommandLineOptions);
199}
200
201template <typename TDatabase, typename InferenceModel>
Matthew Bentham3e68b972019-04-09 13:10:46 +0100202bool ClassifierTestCaseProvider<TDatabase, InferenceModel>::ProcessCommandLineOptions(
203 const InferenceTestOptions& commonOptions)
telsoa014fcda012018-03-09 14:13:49 +0000204{
205 if (!ValidateDirectory(m_DataDir))
206 {
207 return false;
208 }
209
210 ReadPredictions();
211
Matthew Bentham3e68b972019-04-09 13:10:46 +0100212 m_Model = m_ConstructModel(commonOptions, m_ModelCommandLineOptions);
telsoa014fcda012018-03-09 14:13:49 +0000213 if (!m_Model)
214 {
215 return false;
216 }
217
telsoa01c577f2c2018-08-31 09:22:23 +0100218 m_Database = std::make_unique<TDatabase>(m_ConstructDatabase(m_DataDir.c_str(), *m_Model));
telsoa014fcda012018-03-09 14:13:49 +0000219 if (!m_Database)
220 {
221 return false;
222 }
223
224 return true;
225}
226
227template <typename TDatabase, typename InferenceModel>
228std::unique_ptr<IInferenceTestCase>
229ClassifierTestCaseProvider<TDatabase, InferenceModel>::GetTestCase(unsigned int testCaseId)
230{
231 std::unique_ptr<typename TDatabase::TTestCaseData> testCaseData = m_Database->GetTestCaseData(testCaseId);
232 if (testCaseData == nullptr)
233 {
234 return nullptr;
235 }
236
237 return std::make_unique<ClassifierTestCase<TDatabase, InferenceModel>>(
238 m_NumInferences,
239 m_NumCorrectInferences,
240 m_ValidationPredictions,
241 m_ValidationFileOut.empty() ? nullptr : &m_ValidationPredictionsOut,
242 *m_Model,
243 testCaseId,
244 testCaseData->m_Label,
245 std::move(testCaseData->m_InputImage));
246}
247
248template <typename TDatabase, typename InferenceModel>
249bool ClassifierTestCaseProvider<TDatabase, InferenceModel>::OnInferenceTestFinished()
250{
251 const double accuracy = boost::numeric_cast<double>(m_NumCorrectInferences) /
252 boost::numeric_cast<double>(m_NumInferences);
Derek Lamberti08446972019-11-26 16:38:31 +0000253 ARMNN_LOG(info) << std::fixed << std::setprecision(3) << "Overall accuracy: " << accuracy;
telsoa014fcda012018-03-09 14:13:49 +0000254
telsoa01c577f2c2018-08-31 09:22:23 +0100255 // If a validation file was requested as output, the predictions are saved to it.
telsoa014fcda012018-03-09 14:13:49 +0000256 if (!m_ValidationFileOut.empty())
257 {
258 std::ofstream validationFileOut(m_ValidationFileOut.c_str(), std::ios_base::trunc | std::ios_base::out);
259 if (validationFileOut.good())
260 {
261 for (const unsigned int prediction : m_ValidationPredictionsOut)
262 {
263 validationFileOut << prediction << std::endl;
264 }
265 }
266 else
267 {
Derek Lamberti08446972019-11-26 16:38:31 +0000268 ARMNN_LOG(error) << "Failed to open output validation file: " << m_ValidationFileOut;
telsoa014fcda012018-03-09 14:13:49 +0000269 return false;
270 }
271 }
272
273 return true;
274}
275
276template <typename TDatabase, typename InferenceModel>
277void ClassifierTestCaseProvider<TDatabase, InferenceModel>::ReadPredictions()
278{
telsoa01c577f2c2018-08-31 09:22:23 +0100279 // Reads the expected predictions from the input validation file (if provided).
telsoa014fcda012018-03-09 14:13:49 +0000280 if (!m_ValidationFileIn.empty())
281 {
282 std::ifstream validationFileIn(m_ValidationFileIn.c_str(), std::ios_base::in);
283 if (validationFileIn.good())
284 {
285 while (!validationFileIn.eof())
286 {
287 unsigned int i;
288 validationFileIn >> i;
289 m_ValidationPredictions.emplace_back(i);
290 }
291 }
292 else
293 {
294 throw armnn::Exception(boost::str(boost::format("Failed to open input validation file: %1%")
295 % m_ValidationFileIn));
296 }
297 }
298}
299
300template<typename TConstructTestCaseProvider>
301int InferenceTestMain(int argc,
302 char* argv[],
303 const std::vector<unsigned int>& defaultTestCaseIds,
304 TConstructTestCaseProvider constructTestCaseProvider)
305{
telsoa01c577f2c2018-08-31 09:22:23 +0100306 // Configures logging for both the ARMNN library and this test program.
telsoa014fcda012018-03-09 14:13:49 +0000307#ifdef NDEBUG
308 armnn::LogSeverity level = armnn::LogSeverity::Info;
309#else
310 armnn::LogSeverity level = armnn::LogSeverity::Debug;
311#endif
312 armnn::ConfigureLogging(true, true, level);
telsoa014fcda012018-03-09 14:13:49 +0000313
314 try
315 {
316 std::unique_ptr<IInferenceTestCaseProvider> testCaseProvider = constructTestCaseProvider();
317 if (!testCaseProvider)
318 {
319 return 1;
320 }
321
322 InferenceTestOptions inferenceTestOptions;
323 if (!ParseCommandLine(argc, argv, *testCaseProvider, inferenceTestOptions))
324 {
325 return 1;
326 }
327
328 const bool success = InferenceTest(inferenceTestOptions, defaultTestCaseIds, *testCaseProvider);
329 return success ? 0 : 1;
330 }
331 catch (armnn::Exception const& e)
332 {
Derek Lamberti08446972019-11-26 16:38:31 +0000333 ARMNN_LOG(fatal) << "Armnn Error: " << e.what();
telsoa014fcda012018-03-09 14:13:49 +0000334 return 1;
335 }
336}
337
telsoa01c577f2c2018-08-31 09:22:23 +0100338//
339// This function allows us to create a classifier inference test based on:
340// - a model file name
341// - which can be a binary or a text file for protobuf formats
342// - an input tensor name
343// - an output tensor name
344// - a set of test case ids
345// - a callback method which creates an object that can return images
346// called 'Database' in these tests
347// - and an input tensor shape
348//
telsoa014fcda012018-03-09 14:13:49 +0000349template<typename TDatabase,
telsoa01c577f2c2018-08-31 09:22:23 +0100350 typename TParser,
351 typename TConstructDatabaseCallable>
352int ClassifierInferenceTestMain(int argc,
353 char* argv[],
354 const char* modelFilename,
355 bool isModelBinary,
356 const char* inputBindingName,
357 const char* outputBindingName,
358 const std::vector<unsigned int>& defaultTestCaseIds,
359 TConstructDatabaseCallable constructDatabase,
360 const armnn::TensorShape* inputTensorShape)
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000361
telsoa014fcda012018-03-09 14:13:49 +0000362{
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100363 ARMNN_ASSERT(modelFilename);
364 ARMNN_ASSERT(inputBindingName);
365 ARMNN_ASSERT(outputBindingName);
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000366
telsoa014fcda012018-03-09 14:13:49 +0000367 return InferenceTestMain(argc, argv, defaultTestCaseIds,
368 [=]
369 ()
370 {
telsoa01c577f2c2018-08-31 09:22:23 +0100371 using InferenceModel = InferenceModel<TParser, typename TDatabase::DataType>;
telsoa014fcda012018-03-09 14:13:49 +0000372 using TestCaseProvider = ClassifierTestCaseProvider<TDatabase, InferenceModel>;
373
374 return make_unique<TestCaseProvider>(constructDatabase,
375 [&]
Matthew Bentham3e68b972019-04-09 13:10:46 +0100376 (const InferenceTestOptions &commonOptions,
377 typename InferenceModel::CommandLineOptions modelOptions)
telsoa014fcda012018-03-09 14:13:49 +0000378 {
379 if (!ValidateDirectory(modelOptions.m_ModelDir))
380 {
381 return std::unique_ptr<InferenceModel>();
382 }
383
384 typename InferenceModel::Params modelParams;
385 modelParams.m_ModelPath = modelOptions.m_ModelDir + modelFilename;
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000386 modelParams.m_InputBindings = { inputBindingName };
387 modelParams.m_OutputBindings = { outputBindingName };
388
389 if (inputTensorShape)
390 {
391 modelParams.m_InputShapes.push_back(*inputTensorShape);
392 }
393
telsoa014fcda012018-03-09 14:13:49 +0000394 modelParams.m_IsModelBinary = isModelBinary;
Aron Virginas-Tar339bcae2019-01-31 16:44:26 +0000395 modelParams.m_ComputeDevices = modelOptions.GetComputeDevicesAsBackendIds();
surmeh013537c2c2018-05-18 16:31:43 +0100396 modelParams.m_VisualizePostOptimizationModel = modelOptions.m_VisualizePostOptimizationModel;
telsoa01c577f2c2018-08-31 09:22:23 +0100397 modelParams.m_EnableFp16TurboMode = modelOptions.m_EnableFp16TurboMode;
telsoa014fcda012018-03-09 14:13:49 +0000398
Matteo Martincigh00dda4a2019-08-14 11:42:30 +0100399 return std::make_unique<InferenceModel>(modelParams,
400 commonOptions.m_EnableProfiling,
401 commonOptions.m_DynamicBackendsPath);
telsoa014fcda012018-03-09 14:13:49 +0000402 });
403 });
404}
405
406} // namespace test
407} // namespace armnn