blob: 7ce017c6cdd33ff6b0199baaacf4124f7ab3edbf [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#include "InferenceTest.hpp"
6
telsoa014fcda012018-03-09 14:13:49 +00007#include <boost/algorithm/string.hpp>
8#include <boost/numeric/conversion/cast.hpp>
9#include <boost/log/trivial.hpp>
10#include <boost/filesystem/path.hpp>
11#include <boost/assert.hpp>
12#include <boost/format.hpp>
13#include <boost/program_options.hpp>
14#include <boost/filesystem/operations.hpp>
15
16#include <fstream>
17#include <iostream>
18#include <iomanip>
19#include <array>
20#include <chrono>
21
22using namespace std;
23using namespace std::chrono;
24using namespace armnn::test;
25
26namespace armnn
27{
28namespace test
29{
30
telsoa01c577f2c2018-08-31 09:22:23 +010031
telsoa014fcda012018-03-09 14:13:49 +000032template <typename TTestCaseDatabase, typename TModel>
33ClassifierTestCase<TTestCaseDatabase, TModel>::ClassifierTestCase(
34 int& numInferencesRef,
35 int& numCorrectInferencesRef,
36 const std::vector<unsigned int>& validationPredictions,
37 std::vector<unsigned int>* validationPredictionsOut,
38 TModel& model,
39 unsigned int testCaseId,
40 unsigned int label,
41 std::vector<typename TModel::DataType> modelInput)
42 : InferenceModelTestCase<TModel>(model, testCaseId, std::move(modelInput), model.GetOutputSize())
43 , m_Label(label)
telsoa01c577f2c2018-08-31 09:22:23 +010044 , m_QuantizationParams(model.GetQuantizationParams())
telsoa014fcda012018-03-09 14:13:49 +000045 , m_NumInferencesRef(numInferencesRef)
46 , m_NumCorrectInferencesRef(numCorrectInferencesRef)
47 , m_ValidationPredictions(validationPredictions)
48 , m_ValidationPredictionsOut(validationPredictionsOut)
49{
50}
51
52template <typename TTestCaseDatabase, typename TModel>
53TestCaseResult ClassifierTestCase<TTestCaseDatabase, TModel>::ProcessResult(const InferenceTestOptions& params)
54{
55 auto& output = this->GetOutput();
56 const auto testCaseId = this->GetTestCaseId();
57
surmeh01bceff2f2018-03-29 16:29:27 +010058 std::map<float,int> resultMap;
59 {
60 int index = 0;
61 for (const auto & o : output)
62 {
Matthew Bentham4322d362018-10-29 17:39:49 +000063 float prob = ToFloat<typename TModel::DataType>::Convert(o, m_QuantizationParams);
64 int classification = index++;
65
66 // Take the first class with each probability
67 // This avoids strange results when looping over batched results produced
68 // with identical test data.
69 std::map<float, int>::iterator lb = resultMap.lower_bound(prob);
70 if (lb == resultMap.end() ||
71 !resultMap.key_comp()(prob, lb->first)) {
72 // If the key is not already in the map, insert it.
73 resultMap.insert(lb, std::map<float, int>::value_type(prob, classification));
74 }
surmeh01bceff2f2018-03-29 16:29:27 +010075 }
76 }
77
78 {
79 BOOST_LOG_TRIVIAL(info) << "= Prediction values for test #" << testCaseId;
80 auto it = resultMap.rbegin();
81 for (int i=0; i<5 && it != resultMap.rend(); ++i)
82 {
83 BOOST_LOG_TRIVIAL(info) << "Top(" << (i+1) << ") prediction is " << it->second <<
84 " with confidence: " << 100.0*(it->first) << "%";
85 ++it;
86 }
87 }
88
telsoa014fcda012018-03-09 14:13:49 +000089 const unsigned int prediction = boost::numeric_cast<unsigned int>(
90 std::distance(output.begin(), std::max_element(output.begin(), output.end())));
91
telsoa01c577f2c2018-08-31 09:22:23 +010092 // If we're just running the defaultTestCaseIds, each one must be classified correctly.
telsoa014fcda012018-03-09 14:13:49 +000093 if (params.m_IterationCount == 0 && prediction != m_Label)
94 {
95 BOOST_LOG_TRIVIAL(error) << "Prediction for test case " << testCaseId << " (" << prediction << ")" <<
96 " is incorrect (should be " << m_Label << ")";
97 return TestCaseResult::Failed;
98 }
99
telsoa01c577f2c2018-08-31 09:22:23 +0100100 // If a validation file was provided as input, it checks that the prediction matches.
telsoa014fcda012018-03-09 14:13:49 +0000101 if (!m_ValidationPredictions.empty() && prediction != m_ValidationPredictions[testCaseId])
102 {
103 BOOST_LOG_TRIVIAL(error) << "Prediction for test case " << testCaseId << " (" << prediction << ")" <<
104 " doesn't match the prediction in the validation file (" << m_ValidationPredictions[testCaseId] << ")";
105 return TestCaseResult::Failed;
106 }
107
telsoa01c577f2c2018-08-31 09:22:23 +0100108 // If a validation file was requested as output, it stores the predictions.
telsoa014fcda012018-03-09 14:13:49 +0000109 if (m_ValidationPredictionsOut)
110 {
111 m_ValidationPredictionsOut->push_back(prediction);
112 }
113
telsoa01c577f2c2018-08-31 09:22:23 +0100114 // Updates accuracy stats.
telsoa014fcda012018-03-09 14:13:49 +0000115 m_NumInferencesRef++;
116 if (prediction == m_Label)
117 {
118 m_NumCorrectInferencesRef++;
119 }
120
121 return TestCaseResult::Ok;
122}
123
124template <typename TDatabase, typename InferenceModel>
125template <typename TConstructDatabaseCallable, typename TConstructModelCallable>
126ClassifierTestCaseProvider<TDatabase, InferenceModel>::ClassifierTestCaseProvider(
127 TConstructDatabaseCallable constructDatabase, TConstructModelCallable constructModel)
128 : m_ConstructModel(constructModel)
129 , m_ConstructDatabase(constructDatabase)
130 , m_NumInferences(0)
131 , m_NumCorrectInferences(0)
132{
133}
134
135template <typename TDatabase, typename InferenceModel>
136void ClassifierTestCaseProvider<TDatabase, InferenceModel>::AddCommandLineOptions(
137 boost::program_options::options_description& options)
138{
139 namespace po = boost::program_options;
140
141 options.add_options()
142 ("validation-file-in", po::value<std::string>(&m_ValidationFileIn)->default_value(""),
143 "Reads expected predictions from the given file and confirms they match the actual predictions.")
144 ("validation-file-out", po::value<std::string>(&m_ValidationFileOut)->default_value(""),
145 "Predictions are saved to the given file for later use via --validation-file-in.")
146 ("data-dir,d", po::value<std::string>(&m_DataDir)->required(),
147 "Path to directory containing test data");
148
149 InferenceModel::AddCommandLineOptions(options, m_ModelCommandLineOptions);
150}
151
152template <typename TDatabase, typename InferenceModel>
153bool ClassifierTestCaseProvider<TDatabase, InferenceModel>::ProcessCommandLineOptions()
154{
155 if (!ValidateDirectory(m_DataDir))
156 {
157 return false;
158 }
159
160 ReadPredictions();
161
162 m_Model = m_ConstructModel(m_ModelCommandLineOptions);
163 if (!m_Model)
164 {
165 return false;
166 }
167
telsoa01c577f2c2018-08-31 09:22:23 +0100168 m_Database = std::make_unique<TDatabase>(m_ConstructDatabase(m_DataDir.c_str(), *m_Model));
telsoa014fcda012018-03-09 14:13:49 +0000169 if (!m_Database)
170 {
171 return false;
172 }
173
174 return true;
175}
176
177template <typename TDatabase, typename InferenceModel>
178std::unique_ptr<IInferenceTestCase>
179ClassifierTestCaseProvider<TDatabase, InferenceModel>::GetTestCase(unsigned int testCaseId)
180{
181 std::unique_ptr<typename TDatabase::TTestCaseData> testCaseData = m_Database->GetTestCaseData(testCaseId);
182 if (testCaseData == nullptr)
183 {
184 return nullptr;
185 }
186
187 return std::make_unique<ClassifierTestCase<TDatabase, InferenceModel>>(
188 m_NumInferences,
189 m_NumCorrectInferences,
190 m_ValidationPredictions,
191 m_ValidationFileOut.empty() ? nullptr : &m_ValidationPredictionsOut,
192 *m_Model,
193 testCaseId,
194 testCaseData->m_Label,
195 std::move(testCaseData->m_InputImage));
196}
197
198template <typename TDatabase, typename InferenceModel>
199bool ClassifierTestCaseProvider<TDatabase, InferenceModel>::OnInferenceTestFinished()
200{
201 const double accuracy = boost::numeric_cast<double>(m_NumCorrectInferences) /
202 boost::numeric_cast<double>(m_NumInferences);
203 BOOST_LOG_TRIVIAL(info) << std::fixed << std::setprecision(3) << "Overall accuracy: " << accuracy;
204
telsoa01c577f2c2018-08-31 09:22:23 +0100205 // If a validation file was requested as output, the predictions are saved to it.
telsoa014fcda012018-03-09 14:13:49 +0000206 if (!m_ValidationFileOut.empty())
207 {
208 std::ofstream validationFileOut(m_ValidationFileOut.c_str(), std::ios_base::trunc | std::ios_base::out);
209 if (validationFileOut.good())
210 {
211 for (const unsigned int prediction : m_ValidationPredictionsOut)
212 {
213 validationFileOut << prediction << std::endl;
214 }
215 }
216 else
217 {
218 BOOST_LOG_TRIVIAL(error) << "Failed to open output validation file: " << m_ValidationFileOut;
219 return false;
220 }
221 }
222
223 return true;
224}
225
226template <typename TDatabase, typename InferenceModel>
227void ClassifierTestCaseProvider<TDatabase, InferenceModel>::ReadPredictions()
228{
telsoa01c577f2c2018-08-31 09:22:23 +0100229 // Reads the expected predictions from the input validation file (if provided).
telsoa014fcda012018-03-09 14:13:49 +0000230 if (!m_ValidationFileIn.empty())
231 {
232 std::ifstream validationFileIn(m_ValidationFileIn.c_str(), std::ios_base::in);
233 if (validationFileIn.good())
234 {
235 while (!validationFileIn.eof())
236 {
237 unsigned int i;
238 validationFileIn >> i;
239 m_ValidationPredictions.emplace_back(i);
240 }
241 }
242 else
243 {
244 throw armnn::Exception(boost::str(boost::format("Failed to open input validation file: %1%")
245 % m_ValidationFileIn));
246 }
247 }
248}
249
250template<typename TConstructTestCaseProvider>
251int InferenceTestMain(int argc,
252 char* argv[],
253 const std::vector<unsigned int>& defaultTestCaseIds,
254 TConstructTestCaseProvider constructTestCaseProvider)
255{
telsoa01c577f2c2018-08-31 09:22:23 +0100256 // Configures logging for both the ARMNN library and this test program.
telsoa014fcda012018-03-09 14:13:49 +0000257#ifdef NDEBUG
258 armnn::LogSeverity level = armnn::LogSeverity::Info;
259#else
260 armnn::LogSeverity level = armnn::LogSeverity::Debug;
261#endif
262 armnn::ConfigureLogging(true, true, level);
263 armnnUtils::ConfigureLogging(boost::log::core::get().get(), true, true, level);
264
265 try
266 {
267 std::unique_ptr<IInferenceTestCaseProvider> testCaseProvider = constructTestCaseProvider();
268 if (!testCaseProvider)
269 {
270 return 1;
271 }
272
273 InferenceTestOptions inferenceTestOptions;
274 if (!ParseCommandLine(argc, argv, *testCaseProvider, inferenceTestOptions))
275 {
276 return 1;
277 }
278
279 const bool success = InferenceTest(inferenceTestOptions, defaultTestCaseIds, *testCaseProvider);
280 return success ? 0 : 1;
281 }
282 catch (armnn::Exception const& e)
283 {
284 BOOST_LOG_TRIVIAL(fatal) << "Armnn Error: " << e.what();
285 return 1;
286 }
287}
288
telsoa01c577f2c2018-08-31 09:22:23 +0100289//
290// This function allows us to create a classifier inference test based on:
291// - a model file name
292// - which can be a binary or a text file for protobuf formats
293// - an input tensor name
294// - an output tensor name
295// - a set of test case ids
296// - a callback method which creates an object that can return images
297// called 'Database' in these tests
298// - and an input tensor shape
299//
telsoa014fcda012018-03-09 14:13:49 +0000300template<typename TDatabase,
telsoa01c577f2c2018-08-31 09:22:23 +0100301 typename TParser,
302 typename TConstructDatabaseCallable>
303int ClassifierInferenceTestMain(int argc,
304 char* argv[],
305 const char* modelFilename,
306 bool isModelBinary,
307 const char* inputBindingName,
308 const char* outputBindingName,
309 const std::vector<unsigned int>& defaultTestCaseIds,
310 TConstructDatabaseCallable constructDatabase,
311 const armnn::TensorShape* inputTensorShape)
telsoa014fcda012018-03-09 14:13:49 +0000312{
313 return InferenceTestMain(argc, argv, defaultTestCaseIds,
314 [=]
315 ()
316 {
telsoa01c577f2c2018-08-31 09:22:23 +0100317 using InferenceModel = InferenceModel<TParser, typename TDatabase::DataType>;
telsoa014fcda012018-03-09 14:13:49 +0000318 using TestCaseProvider = ClassifierTestCaseProvider<TDatabase, InferenceModel>;
319
320 return make_unique<TestCaseProvider>(constructDatabase,
321 [&]
322 (typename InferenceModel::CommandLineOptions modelOptions)
323 {
324 if (!ValidateDirectory(modelOptions.m_ModelDir))
325 {
326 return std::unique_ptr<InferenceModel>();
327 }
328
329 typename InferenceModel::Params modelParams;
330 modelParams.m_ModelPath = modelOptions.m_ModelDir + modelFilename;
331 modelParams.m_InputBinding = inputBindingName;
332 modelParams.m_OutputBinding = outputBindingName;
333 modelParams.m_InputTensorShape = inputTensorShape;
334 modelParams.m_IsModelBinary = isModelBinary;
335 modelParams.m_ComputeDevice = modelOptions.m_ComputeDevice;
surmeh013537c2c2018-05-18 16:31:43 +0100336 modelParams.m_VisualizePostOptimizationModel = modelOptions.m_VisualizePostOptimizationModel;
telsoa01c577f2c2018-08-31 09:22:23 +0100337 modelParams.m_EnableFp16TurboMode = modelOptions.m_EnableFp16TurboMode;
telsoa014fcda012018-03-09 14:13:49 +0000338
339 return std::make_unique<InferenceModel>(modelParams);
340 });
341 });
342}
343
344} // namespace test
345} // namespace armnn