blob: 538720bd83a9c2eae85e01b0949315f4e7016bc2 [file] [log] [blame]
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +00001//
telsoa014fcda012018-03-09 14:13:49 +00002// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#include "InferenceTest.hpp"
6
telsoa014fcda012018-03-09 14:13:49 +00007#include <boost/algorithm/string.hpp>
8#include <boost/numeric/conversion/cast.hpp>
9#include <boost/log/trivial.hpp>
10#include <boost/filesystem/path.hpp>
11#include <boost/assert.hpp>
12#include <boost/format.hpp>
13#include <boost/program_options.hpp>
14#include <boost/filesystem/operations.hpp>
15
16#include <fstream>
17#include <iostream>
18#include <iomanip>
19#include <array>
20#include <chrono>
21
22using namespace std;
23using namespace std::chrono;
24using namespace armnn::test;
25
26namespace armnn
27{
28namespace test
29{
30
Ferran Balaguerc602f292019-02-08 17:09:55 +000031using TContainer = boost::variant<std::vector<float>, std::vector<int>, std::vector<unsigned char>>;
telsoa01c577f2c2018-08-31 09:22:23 +010032
telsoa014fcda012018-03-09 14:13:49 +000033template <typename TTestCaseDatabase, typename TModel>
34ClassifierTestCase<TTestCaseDatabase, TModel>::ClassifierTestCase(
35 int& numInferencesRef,
36 int& numCorrectInferencesRef,
37 const std::vector<unsigned int>& validationPredictions,
38 std::vector<unsigned int>* validationPredictionsOut,
39 TModel& model,
40 unsigned int testCaseId,
41 unsigned int label,
42 std::vector<typename TModel::DataType> modelInput)
Ferran Balaguerc602f292019-02-08 17:09:55 +000043 : InferenceModelTestCase<TModel>(
44 model, testCaseId, std::vector<TContainer>{ modelInput }, { model.GetOutputSize() })
telsoa014fcda012018-03-09 14:13:49 +000045 , m_Label(label)
telsoa01c577f2c2018-08-31 09:22:23 +010046 , m_QuantizationParams(model.GetQuantizationParams())
telsoa014fcda012018-03-09 14:13:49 +000047 , m_NumInferencesRef(numInferencesRef)
48 , m_NumCorrectInferencesRef(numCorrectInferencesRef)
49 , m_ValidationPredictions(validationPredictions)
50 , m_ValidationPredictionsOut(validationPredictionsOut)
51{
52}
53
54template <typename TTestCaseDatabase, typename TModel>
55TestCaseResult ClassifierTestCase<TTestCaseDatabase, TModel>::ProcessResult(const InferenceTestOptions& params)
56{
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +000057 auto& output = this->GetOutputs()[0];
telsoa014fcda012018-03-09 14:13:49 +000058 const auto testCaseId = this->GetTestCaseId();
59
surmeh01bceff2f2018-03-29 16:29:27 +010060 std::map<float,int> resultMap;
61 {
62 int index = 0;
Matthew Bentham4322d362018-10-29 17:39:49 +000063
Ferran Balaguerc602f292019-02-08 17:09:55 +000064 boost::apply_visitor([&](auto&& value)
65 {
66 for (const auto & o : value)
67 {
68 float prob = ToFloat<typename TModel::DataType>::Convert(o, m_QuantizationParams);
69 int classification = index++;
70
71 // Take the first class with each probability
72 // This avoids strange results when looping over batched results produced
73 // with identical test data.
74 std::map<float, int>::iterator lb = resultMap.lower_bound(prob);
75 if (lb == resultMap.end() ||
76 !resultMap.key_comp()(prob, lb->first)) {
77 // If the key is not already in the map, insert it.
78 resultMap.insert(lb, std::map<float, int>::value_type(prob, classification));
79 }
80 }
81 },
82 output);
surmeh01bceff2f2018-03-29 16:29:27 +010083 }
84
85 {
86 BOOST_LOG_TRIVIAL(info) << "= Prediction values for test #" << testCaseId;
87 auto it = resultMap.rbegin();
88 for (int i=0; i<5 && it != resultMap.rend(); ++i)
89 {
90 BOOST_LOG_TRIVIAL(info) << "Top(" << (i+1) << ") prediction is " << it->second <<
91 " with confidence: " << 100.0*(it->first) << "%";
92 ++it;
93 }
94 }
95
Ferran Balaguerc602f292019-02-08 17:09:55 +000096 unsigned int prediction = 0;
97 boost::apply_visitor([&](auto&& value)
98 {
99 prediction = boost::numeric_cast<unsigned int>(
100 std::distance(value.begin(), std::max_element(value.begin(), value.end())));
101 },
102 output);
telsoa014fcda012018-03-09 14:13:49 +0000103
telsoa01c577f2c2018-08-31 09:22:23 +0100104 // If we're just running the defaultTestCaseIds, each one must be classified correctly.
telsoa014fcda012018-03-09 14:13:49 +0000105 if (params.m_IterationCount == 0 && prediction != m_Label)
106 {
107 BOOST_LOG_TRIVIAL(error) << "Prediction for test case " << testCaseId << " (" << prediction << ")" <<
108 " is incorrect (should be " << m_Label << ")";
109 return TestCaseResult::Failed;
110 }
111
telsoa01c577f2c2018-08-31 09:22:23 +0100112 // If a validation file was provided as input, it checks that the prediction matches.
telsoa014fcda012018-03-09 14:13:49 +0000113 if (!m_ValidationPredictions.empty() && prediction != m_ValidationPredictions[testCaseId])
114 {
115 BOOST_LOG_TRIVIAL(error) << "Prediction for test case " << testCaseId << " (" << prediction << ")" <<
116 " doesn't match the prediction in the validation file (" << m_ValidationPredictions[testCaseId] << ")";
117 return TestCaseResult::Failed;
118 }
119
telsoa01c577f2c2018-08-31 09:22:23 +0100120 // If a validation file was requested as output, it stores the predictions.
telsoa014fcda012018-03-09 14:13:49 +0000121 if (m_ValidationPredictionsOut)
122 {
123 m_ValidationPredictionsOut->push_back(prediction);
124 }
125
telsoa01c577f2c2018-08-31 09:22:23 +0100126 // Updates accuracy stats.
telsoa014fcda012018-03-09 14:13:49 +0000127 m_NumInferencesRef++;
128 if (prediction == m_Label)
129 {
130 m_NumCorrectInferencesRef++;
131 }
132
133 return TestCaseResult::Ok;
134}
135
136template <typename TDatabase, typename InferenceModel>
137template <typename TConstructDatabaseCallable, typename TConstructModelCallable>
138ClassifierTestCaseProvider<TDatabase, InferenceModel>::ClassifierTestCaseProvider(
139 TConstructDatabaseCallable constructDatabase, TConstructModelCallable constructModel)
140 : m_ConstructModel(constructModel)
141 , m_ConstructDatabase(constructDatabase)
142 , m_NumInferences(0)
143 , m_NumCorrectInferences(0)
144{
145}
146
147template <typename TDatabase, typename InferenceModel>
148void ClassifierTestCaseProvider<TDatabase, InferenceModel>::AddCommandLineOptions(
149 boost::program_options::options_description& options)
150{
151 namespace po = boost::program_options;
152
153 options.add_options()
154 ("validation-file-in", po::value<std::string>(&m_ValidationFileIn)->default_value(""),
155 "Reads expected predictions from the given file and confirms they match the actual predictions.")
156 ("validation-file-out", po::value<std::string>(&m_ValidationFileOut)->default_value(""),
157 "Predictions are saved to the given file for later use via --validation-file-in.")
158 ("data-dir,d", po::value<std::string>(&m_DataDir)->required(),
159 "Path to directory containing test data");
160
161 InferenceModel::AddCommandLineOptions(options, m_ModelCommandLineOptions);
162}
163
164template <typename TDatabase, typename InferenceModel>
165bool ClassifierTestCaseProvider<TDatabase, InferenceModel>::ProcessCommandLineOptions()
166{
167 if (!ValidateDirectory(m_DataDir))
168 {
169 return false;
170 }
171
172 ReadPredictions();
173
174 m_Model = m_ConstructModel(m_ModelCommandLineOptions);
175 if (!m_Model)
176 {
177 return false;
178 }
179
telsoa01c577f2c2018-08-31 09:22:23 +0100180 m_Database = std::make_unique<TDatabase>(m_ConstructDatabase(m_DataDir.c_str(), *m_Model));
telsoa014fcda012018-03-09 14:13:49 +0000181 if (!m_Database)
182 {
183 return false;
184 }
185
186 return true;
187}
188
189template <typename TDatabase, typename InferenceModel>
190std::unique_ptr<IInferenceTestCase>
191ClassifierTestCaseProvider<TDatabase, InferenceModel>::GetTestCase(unsigned int testCaseId)
192{
193 std::unique_ptr<typename TDatabase::TTestCaseData> testCaseData = m_Database->GetTestCaseData(testCaseId);
194 if (testCaseData == nullptr)
195 {
196 return nullptr;
197 }
198
199 return std::make_unique<ClassifierTestCase<TDatabase, InferenceModel>>(
200 m_NumInferences,
201 m_NumCorrectInferences,
202 m_ValidationPredictions,
203 m_ValidationFileOut.empty() ? nullptr : &m_ValidationPredictionsOut,
204 *m_Model,
205 testCaseId,
206 testCaseData->m_Label,
207 std::move(testCaseData->m_InputImage));
208}
209
210template <typename TDatabase, typename InferenceModel>
211bool ClassifierTestCaseProvider<TDatabase, InferenceModel>::OnInferenceTestFinished()
212{
213 const double accuracy = boost::numeric_cast<double>(m_NumCorrectInferences) /
214 boost::numeric_cast<double>(m_NumInferences);
215 BOOST_LOG_TRIVIAL(info) << std::fixed << std::setprecision(3) << "Overall accuracy: " << accuracy;
216
telsoa01c577f2c2018-08-31 09:22:23 +0100217 // If a validation file was requested as output, the predictions are saved to it.
telsoa014fcda012018-03-09 14:13:49 +0000218 if (!m_ValidationFileOut.empty())
219 {
220 std::ofstream validationFileOut(m_ValidationFileOut.c_str(), std::ios_base::trunc | std::ios_base::out);
221 if (validationFileOut.good())
222 {
223 for (const unsigned int prediction : m_ValidationPredictionsOut)
224 {
225 validationFileOut << prediction << std::endl;
226 }
227 }
228 else
229 {
230 BOOST_LOG_TRIVIAL(error) << "Failed to open output validation file: " << m_ValidationFileOut;
231 return false;
232 }
233 }
234
235 return true;
236}
237
238template <typename TDatabase, typename InferenceModel>
239void ClassifierTestCaseProvider<TDatabase, InferenceModel>::ReadPredictions()
240{
telsoa01c577f2c2018-08-31 09:22:23 +0100241 // Reads the expected predictions from the input validation file (if provided).
telsoa014fcda012018-03-09 14:13:49 +0000242 if (!m_ValidationFileIn.empty())
243 {
244 std::ifstream validationFileIn(m_ValidationFileIn.c_str(), std::ios_base::in);
245 if (validationFileIn.good())
246 {
247 while (!validationFileIn.eof())
248 {
249 unsigned int i;
250 validationFileIn >> i;
251 m_ValidationPredictions.emplace_back(i);
252 }
253 }
254 else
255 {
256 throw armnn::Exception(boost::str(boost::format("Failed to open input validation file: %1%")
257 % m_ValidationFileIn));
258 }
259 }
260}
261
262template<typename TConstructTestCaseProvider>
263int InferenceTestMain(int argc,
264 char* argv[],
265 const std::vector<unsigned int>& defaultTestCaseIds,
266 TConstructTestCaseProvider constructTestCaseProvider)
267{
telsoa01c577f2c2018-08-31 09:22:23 +0100268 // Configures logging for both the ARMNN library and this test program.
telsoa014fcda012018-03-09 14:13:49 +0000269#ifdef NDEBUG
270 armnn::LogSeverity level = armnn::LogSeverity::Info;
271#else
272 armnn::LogSeverity level = armnn::LogSeverity::Debug;
273#endif
274 armnn::ConfigureLogging(true, true, level);
275 armnnUtils::ConfigureLogging(boost::log::core::get().get(), true, true, level);
276
277 try
278 {
279 std::unique_ptr<IInferenceTestCaseProvider> testCaseProvider = constructTestCaseProvider();
280 if (!testCaseProvider)
281 {
282 return 1;
283 }
284
285 InferenceTestOptions inferenceTestOptions;
286 if (!ParseCommandLine(argc, argv, *testCaseProvider, inferenceTestOptions))
287 {
288 return 1;
289 }
290
291 const bool success = InferenceTest(inferenceTestOptions, defaultTestCaseIds, *testCaseProvider);
292 return success ? 0 : 1;
293 }
294 catch (armnn::Exception const& e)
295 {
296 BOOST_LOG_TRIVIAL(fatal) << "Armnn Error: " << e.what();
297 return 1;
298 }
299}
300
telsoa01c577f2c2018-08-31 09:22:23 +0100301//
302// This function allows us to create a classifier inference test based on:
303// - a model file name
304// - which can be a binary or a text file for protobuf formats
305// - an input tensor name
306// - an output tensor name
307// - a set of test case ids
308// - a callback method which creates an object that can return images
309// called 'Database' in these tests
310// - and an input tensor shape
311//
telsoa014fcda012018-03-09 14:13:49 +0000312template<typename TDatabase,
telsoa01c577f2c2018-08-31 09:22:23 +0100313 typename TParser,
314 typename TConstructDatabaseCallable>
315int ClassifierInferenceTestMain(int argc,
316 char* argv[],
317 const char* modelFilename,
318 bool isModelBinary,
319 const char* inputBindingName,
320 const char* outputBindingName,
321 const std::vector<unsigned int>& defaultTestCaseIds,
322 TConstructDatabaseCallable constructDatabase,
323 const armnn::TensorShape* inputTensorShape)
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000324
telsoa014fcda012018-03-09 14:13:49 +0000325{
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000326 BOOST_ASSERT(modelFilename);
327 BOOST_ASSERT(inputBindingName);
328 BOOST_ASSERT(outputBindingName);
329
telsoa014fcda012018-03-09 14:13:49 +0000330 return InferenceTestMain(argc, argv, defaultTestCaseIds,
331 [=]
332 ()
333 {
telsoa01c577f2c2018-08-31 09:22:23 +0100334 using InferenceModel = InferenceModel<TParser, typename TDatabase::DataType>;
telsoa014fcda012018-03-09 14:13:49 +0000335 using TestCaseProvider = ClassifierTestCaseProvider<TDatabase, InferenceModel>;
336
337 return make_unique<TestCaseProvider>(constructDatabase,
338 [&]
339 (typename InferenceModel::CommandLineOptions modelOptions)
340 {
341 if (!ValidateDirectory(modelOptions.m_ModelDir))
342 {
343 return std::unique_ptr<InferenceModel>();
344 }
345
346 typename InferenceModel::Params modelParams;
347 modelParams.m_ModelPath = modelOptions.m_ModelDir + modelFilename;
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000348 modelParams.m_InputBindings = { inputBindingName };
349 modelParams.m_OutputBindings = { outputBindingName };
350
351 if (inputTensorShape)
352 {
353 modelParams.m_InputShapes.push_back(*inputTensorShape);
354 }
355
telsoa014fcda012018-03-09 14:13:49 +0000356 modelParams.m_IsModelBinary = isModelBinary;
Aron Virginas-Tar339bcae2019-01-31 16:44:26 +0000357 modelParams.m_ComputeDevices = modelOptions.GetComputeDevicesAsBackendIds();
surmeh013537c2c2018-05-18 16:31:43 +0100358 modelParams.m_VisualizePostOptimizationModel = modelOptions.m_VisualizePostOptimizationModel;
telsoa01c577f2c2018-08-31 09:22:23 +0100359 modelParams.m_EnableFp16TurboMode = modelOptions.m_EnableFp16TurboMode;
telsoa014fcda012018-03-09 14:13:49 +0000360
361 return std::make_unique<InferenceModel>(modelParams);
362 });
363 });
364}
365
366} // namespace test
367} // namespace armnn