blob: 0112037bc375b97f0a971dbe69f7fcf5dd7eec89 [file] [log] [blame]
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +00001//
telsoa014fcda012018-03-09 14:13:49 +00002// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#include "InferenceTest.hpp"
6
telsoa014fcda012018-03-09 14:13:49 +00007#include <boost/algorithm/string.hpp>
8#include <boost/numeric/conversion/cast.hpp>
9#include <boost/log/trivial.hpp>
10#include <boost/filesystem/path.hpp>
11#include <boost/assert.hpp>
12#include <boost/format.hpp>
13#include <boost/program_options.hpp>
14#include <boost/filesystem/operations.hpp>
15
16#include <fstream>
17#include <iostream>
18#include <iomanip>
19#include <array>
20#include <chrono>
21
22using namespace std;
23using namespace std::chrono;
24using namespace armnn::test;
25
26namespace armnn
27{
28namespace test
29{
30
Ferran Balaguerc602f292019-02-08 17:09:55 +000031using TContainer = boost::variant<std::vector<float>, std::vector<int>, std::vector<unsigned char>>;
telsoa01c577f2c2018-08-31 09:22:23 +010032
telsoa014fcda012018-03-09 14:13:49 +000033template <typename TTestCaseDatabase, typename TModel>
34ClassifierTestCase<TTestCaseDatabase, TModel>::ClassifierTestCase(
35 int& numInferencesRef,
36 int& numCorrectInferencesRef,
37 const std::vector<unsigned int>& validationPredictions,
38 std::vector<unsigned int>* validationPredictionsOut,
39 TModel& model,
40 unsigned int testCaseId,
41 unsigned int label,
42 std::vector<typename TModel::DataType> modelInput)
Ferran Balaguerc602f292019-02-08 17:09:55 +000043 : InferenceModelTestCase<TModel>(
44 model, testCaseId, std::vector<TContainer>{ modelInput }, { model.GetOutputSize() })
telsoa014fcda012018-03-09 14:13:49 +000045 , m_Label(label)
telsoa01c577f2c2018-08-31 09:22:23 +010046 , m_QuantizationParams(model.GetQuantizationParams())
telsoa014fcda012018-03-09 14:13:49 +000047 , m_NumInferencesRef(numInferencesRef)
48 , m_NumCorrectInferencesRef(numCorrectInferencesRef)
49 , m_ValidationPredictions(validationPredictions)
50 , m_ValidationPredictionsOut(validationPredictionsOut)
51{
52}
53
54template <typename TTestCaseDatabase, typename TModel>
55TestCaseResult ClassifierTestCase<TTestCaseDatabase, TModel>::ProcessResult(const InferenceTestOptions& params)
56{
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +000057 auto& output = this->GetOutputs()[0];
telsoa014fcda012018-03-09 14:13:49 +000058 const auto testCaseId = this->GetTestCaseId();
59
surmeh01bceff2f2018-03-29 16:29:27 +010060 std::map<float,int> resultMap;
61 {
62 int index = 0;
Matthew Bentham4322d362018-10-29 17:39:49 +000063
Ferran Balaguerc602f292019-02-08 17:09:55 +000064 boost::apply_visitor([&](auto&& value)
65 {
66 for (const auto & o : value)
67 {
68 float prob = ToFloat<typename TModel::DataType>::Convert(o, m_QuantizationParams);
69 int classification = index++;
70
71 // Take the first class with each probability
72 // This avoids strange results when looping over batched results produced
73 // with identical test data.
74 std::map<float, int>::iterator lb = resultMap.lower_bound(prob);
75 if (lb == resultMap.end() ||
76 !resultMap.key_comp()(prob, lb->first)) {
77 // If the key is not already in the map, insert it.
78 resultMap.insert(lb, std::map<float, int>::value_type(prob, classification));
79 }
80 }
81 },
82 output);
surmeh01bceff2f2018-03-29 16:29:27 +010083 }
84
85 {
86 BOOST_LOG_TRIVIAL(info) << "= Prediction values for test #" << testCaseId;
87 auto it = resultMap.rbegin();
88 for (int i=0; i<5 && it != resultMap.rend(); ++i)
89 {
90 BOOST_LOG_TRIVIAL(info) << "Top(" << (i+1) << ") prediction is " << it->second <<
91 " with confidence: " << 100.0*(it->first) << "%";
92 ++it;
93 }
94 }
95
Ferran Balaguerc602f292019-02-08 17:09:55 +000096 unsigned int prediction = 0;
97 boost::apply_visitor([&](auto&& value)
98 {
99 prediction = boost::numeric_cast<unsigned int>(
100 std::distance(value.begin(), std::max_element(value.begin(), value.end())));
101 },
102 output);
telsoa014fcda012018-03-09 14:13:49 +0000103
telsoa01c577f2c2018-08-31 09:22:23 +0100104 // If we're just running the defaultTestCaseIds, each one must be classified correctly.
telsoa014fcda012018-03-09 14:13:49 +0000105 if (params.m_IterationCount == 0 && prediction != m_Label)
106 {
107 BOOST_LOG_TRIVIAL(error) << "Prediction for test case " << testCaseId << " (" << prediction << ")" <<
108 " is incorrect (should be " << m_Label << ")";
109 return TestCaseResult::Failed;
110 }
111
telsoa01c577f2c2018-08-31 09:22:23 +0100112 // If a validation file was provided as input, it checks that the prediction matches.
telsoa014fcda012018-03-09 14:13:49 +0000113 if (!m_ValidationPredictions.empty() && prediction != m_ValidationPredictions[testCaseId])
114 {
115 BOOST_LOG_TRIVIAL(error) << "Prediction for test case " << testCaseId << " (" << prediction << ")" <<
116 " doesn't match the prediction in the validation file (" << m_ValidationPredictions[testCaseId] << ")";
117 return TestCaseResult::Failed;
118 }
119
telsoa01c577f2c2018-08-31 09:22:23 +0100120 // If a validation file was requested as output, it stores the predictions.
telsoa014fcda012018-03-09 14:13:49 +0000121 if (m_ValidationPredictionsOut)
122 {
123 m_ValidationPredictionsOut->push_back(prediction);
124 }
125
telsoa01c577f2c2018-08-31 09:22:23 +0100126 // Updates accuracy stats.
telsoa014fcda012018-03-09 14:13:49 +0000127 m_NumInferencesRef++;
128 if (prediction == m_Label)
129 {
130 m_NumCorrectInferencesRef++;
131 }
132
133 return TestCaseResult::Ok;
134}
135
136template <typename TDatabase, typename InferenceModel>
137template <typename TConstructDatabaseCallable, typename TConstructModelCallable>
138ClassifierTestCaseProvider<TDatabase, InferenceModel>::ClassifierTestCaseProvider(
139 TConstructDatabaseCallable constructDatabase, TConstructModelCallable constructModel)
140 : m_ConstructModel(constructModel)
141 , m_ConstructDatabase(constructDatabase)
142 , m_NumInferences(0)
143 , m_NumCorrectInferences(0)
144{
145}
146
147template <typename TDatabase, typename InferenceModel>
148void ClassifierTestCaseProvider<TDatabase, InferenceModel>::AddCommandLineOptions(
149 boost::program_options::options_description& options)
150{
151 namespace po = boost::program_options;
152
153 options.add_options()
154 ("validation-file-in", po::value<std::string>(&m_ValidationFileIn)->default_value(""),
155 "Reads expected predictions from the given file and confirms they match the actual predictions.")
156 ("validation-file-out", po::value<std::string>(&m_ValidationFileOut)->default_value(""),
157 "Predictions are saved to the given file for later use via --validation-file-in.")
158 ("data-dir,d", po::value<std::string>(&m_DataDir)->required(),
159 "Path to directory containing test data");
160
161 InferenceModel::AddCommandLineOptions(options, m_ModelCommandLineOptions);
162}
163
164template <typename TDatabase, typename InferenceModel>
Matthew Bentham3e68b972019-04-09 13:10:46 +0100165bool ClassifierTestCaseProvider<TDatabase, InferenceModel>::ProcessCommandLineOptions(
166 const InferenceTestOptions& commonOptions)
telsoa014fcda012018-03-09 14:13:49 +0000167{
168 if (!ValidateDirectory(m_DataDir))
169 {
170 return false;
171 }
172
173 ReadPredictions();
174
Matthew Bentham3e68b972019-04-09 13:10:46 +0100175 m_Model = m_ConstructModel(commonOptions, m_ModelCommandLineOptions);
telsoa014fcda012018-03-09 14:13:49 +0000176 if (!m_Model)
177 {
178 return false;
179 }
180
telsoa01c577f2c2018-08-31 09:22:23 +0100181 m_Database = std::make_unique<TDatabase>(m_ConstructDatabase(m_DataDir.c_str(), *m_Model));
telsoa014fcda012018-03-09 14:13:49 +0000182 if (!m_Database)
183 {
184 return false;
185 }
186
187 return true;
188}
189
190template <typename TDatabase, typename InferenceModel>
191std::unique_ptr<IInferenceTestCase>
192ClassifierTestCaseProvider<TDatabase, InferenceModel>::GetTestCase(unsigned int testCaseId)
193{
194 std::unique_ptr<typename TDatabase::TTestCaseData> testCaseData = m_Database->GetTestCaseData(testCaseId);
195 if (testCaseData == nullptr)
196 {
197 return nullptr;
198 }
199
200 return std::make_unique<ClassifierTestCase<TDatabase, InferenceModel>>(
201 m_NumInferences,
202 m_NumCorrectInferences,
203 m_ValidationPredictions,
204 m_ValidationFileOut.empty() ? nullptr : &m_ValidationPredictionsOut,
205 *m_Model,
206 testCaseId,
207 testCaseData->m_Label,
208 std::move(testCaseData->m_InputImage));
209}
210
211template <typename TDatabase, typename InferenceModel>
212bool ClassifierTestCaseProvider<TDatabase, InferenceModel>::OnInferenceTestFinished()
213{
214 const double accuracy = boost::numeric_cast<double>(m_NumCorrectInferences) /
215 boost::numeric_cast<double>(m_NumInferences);
216 BOOST_LOG_TRIVIAL(info) << std::fixed << std::setprecision(3) << "Overall accuracy: " << accuracy;
217
telsoa01c577f2c2018-08-31 09:22:23 +0100218 // If a validation file was requested as output, the predictions are saved to it.
telsoa014fcda012018-03-09 14:13:49 +0000219 if (!m_ValidationFileOut.empty())
220 {
221 std::ofstream validationFileOut(m_ValidationFileOut.c_str(), std::ios_base::trunc | std::ios_base::out);
222 if (validationFileOut.good())
223 {
224 for (const unsigned int prediction : m_ValidationPredictionsOut)
225 {
226 validationFileOut << prediction << std::endl;
227 }
228 }
229 else
230 {
231 BOOST_LOG_TRIVIAL(error) << "Failed to open output validation file: " << m_ValidationFileOut;
232 return false;
233 }
234 }
235
236 return true;
237}
238
239template <typename TDatabase, typename InferenceModel>
240void ClassifierTestCaseProvider<TDatabase, InferenceModel>::ReadPredictions()
241{
telsoa01c577f2c2018-08-31 09:22:23 +0100242 // Reads the expected predictions from the input validation file (if provided).
telsoa014fcda012018-03-09 14:13:49 +0000243 if (!m_ValidationFileIn.empty())
244 {
245 std::ifstream validationFileIn(m_ValidationFileIn.c_str(), std::ios_base::in);
246 if (validationFileIn.good())
247 {
248 while (!validationFileIn.eof())
249 {
250 unsigned int i;
251 validationFileIn >> i;
252 m_ValidationPredictions.emplace_back(i);
253 }
254 }
255 else
256 {
257 throw armnn::Exception(boost::str(boost::format("Failed to open input validation file: %1%")
258 % m_ValidationFileIn));
259 }
260 }
261}
262
263template<typename TConstructTestCaseProvider>
264int InferenceTestMain(int argc,
265 char* argv[],
266 const std::vector<unsigned int>& defaultTestCaseIds,
267 TConstructTestCaseProvider constructTestCaseProvider)
268{
telsoa01c577f2c2018-08-31 09:22:23 +0100269 // Configures logging for both the ARMNN library and this test program.
telsoa014fcda012018-03-09 14:13:49 +0000270#ifdef NDEBUG
271 armnn::LogSeverity level = armnn::LogSeverity::Info;
272#else
273 armnn::LogSeverity level = armnn::LogSeverity::Debug;
274#endif
275 armnn::ConfigureLogging(true, true, level);
276 armnnUtils::ConfigureLogging(boost::log::core::get().get(), true, true, level);
277
278 try
279 {
280 std::unique_ptr<IInferenceTestCaseProvider> testCaseProvider = constructTestCaseProvider();
281 if (!testCaseProvider)
282 {
283 return 1;
284 }
285
286 InferenceTestOptions inferenceTestOptions;
287 if (!ParseCommandLine(argc, argv, *testCaseProvider, inferenceTestOptions))
288 {
289 return 1;
290 }
291
292 const bool success = InferenceTest(inferenceTestOptions, defaultTestCaseIds, *testCaseProvider);
293 return success ? 0 : 1;
294 }
295 catch (armnn::Exception const& e)
296 {
297 BOOST_LOG_TRIVIAL(fatal) << "Armnn Error: " << e.what();
298 return 1;
299 }
300}
301
telsoa01c577f2c2018-08-31 09:22:23 +0100302//
303// This function allows us to create a classifier inference test based on:
304// - a model file name
305// - which can be a binary or a text file for protobuf formats
306// - an input tensor name
307// - an output tensor name
308// - a set of test case ids
309// - a callback method which creates an object that can return images
310// called 'Database' in these tests
311// - and an input tensor shape
312//
telsoa014fcda012018-03-09 14:13:49 +0000313template<typename TDatabase,
telsoa01c577f2c2018-08-31 09:22:23 +0100314 typename TParser,
315 typename TConstructDatabaseCallable>
316int ClassifierInferenceTestMain(int argc,
317 char* argv[],
318 const char* modelFilename,
319 bool isModelBinary,
320 const char* inputBindingName,
321 const char* outputBindingName,
322 const std::vector<unsigned int>& defaultTestCaseIds,
323 TConstructDatabaseCallable constructDatabase,
324 const armnn::TensorShape* inputTensorShape)
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000325
telsoa014fcda012018-03-09 14:13:49 +0000326{
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000327 BOOST_ASSERT(modelFilename);
328 BOOST_ASSERT(inputBindingName);
329 BOOST_ASSERT(outputBindingName);
330
telsoa014fcda012018-03-09 14:13:49 +0000331 return InferenceTestMain(argc, argv, defaultTestCaseIds,
332 [=]
333 ()
334 {
telsoa01c577f2c2018-08-31 09:22:23 +0100335 using InferenceModel = InferenceModel<TParser, typename TDatabase::DataType>;
telsoa014fcda012018-03-09 14:13:49 +0000336 using TestCaseProvider = ClassifierTestCaseProvider<TDatabase, InferenceModel>;
337
338 return make_unique<TestCaseProvider>(constructDatabase,
339 [&]
Matthew Bentham3e68b972019-04-09 13:10:46 +0100340 (const InferenceTestOptions &commonOptions,
341 typename InferenceModel::CommandLineOptions modelOptions)
telsoa014fcda012018-03-09 14:13:49 +0000342 {
343 if (!ValidateDirectory(modelOptions.m_ModelDir))
344 {
345 return std::unique_ptr<InferenceModel>();
346 }
347
348 typename InferenceModel::Params modelParams;
349 modelParams.m_ModelPath = modelOptions.m_ModelDir + modelFilename;
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000350 modelParams.m_InputBindings = { inputBindingName };
351 modelParams.m_OutputBindings = { outputBindingName };
352
353 if (inputTensorShape)
354 {
355 modelParams.m_InputShapes.push_back(*inputTensorShape);
356 }
357
telsoa014fcda012018-03-09 14:13:49 +0000358 modelParams.m_IsModelBinary = isModelBinary;
Aron Virginas-Tar339bcae2019-01-31 16:44:26 +0000359 modelParams.m_ComputeDevices = modelOptions.GetComputeDevicesAsBackendIds();
surmeh013537c2c2018-05-18 16:31:43 +0100360 modelParams.m_VisualizePostOptimizationModel = modelOptions.m_VisualizePostOptimizationModel;
telsoa01c577f2c2018-08-31 09:22:23 +0100361 modelParams.m_EnableFp16TurboMode = modelOptions.m_EnableFp16TurboMode;
telsoa014fcda012018-03-09 14:13:49 +0000362
Matthew Bentham3e68b972019-04-09 13:10:46 +0100363 return std::make_unique<InferenceModel>(modelParams, commonOptions.m_EnableProfiling);
telsoa014fcda012018-03-09 14:13:49 +0000364 });
365 });
366}
367
368} // namespace test
369} // namespace armnn