blob: 5e858f06d3e138c4d65fd9afb55fe1fb4bcdfb4d [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#include "InferenceTest.hpp"
6
telsoa014fcda012018-03-09 14:13:49 +00007#include <boost/algorithm/string.hpp>
8#include <boost/numeric/conversion/cast.hpp>
9#include <boost/log/trivial.hpp>
10#include <boost/filesystem/path.hpp>
11#include <boost/assert.hpp>
12#include <boost/format.hpp>
13#include <boost/program_options.hpp>
14#include <boost/filesystem/operations.hpp>
15
16#include <fstream>
17#include <iostream>
18#include <iomanip>
19#include <array>
20#include <chrono>
21
22using namespace std;
23using namespace std::chrono;
24using namespace armnn::test;
25
26namespace armnn
27{
28namespace test
29{
30
telsoa01c577f2c2018-08-31 09:22:23 +010031
telsoa014fcda012018-03-09 14:13:49 +000032template <typename TTestCaseDatabase, typename TModel>
33ClassifierTestCase<TTestCaseDatabase, TModel>::ClassifierTestCase(
34 int& numInferencesRef,
35 int& numCorrectInferencesRef,
36 const std::vector<unsigned int>& validationPredictions,
37 std::vector<unsigned int>* validationPredictionsOut,
38 TModel& model,
39 unsigned int testCaseId,
40 unsigned int label,
41 std::vector<typename TModel::DataType> modelInput)
42 : InferenceModelTestCase<TModel>(model, testCaseId, std::move(modelInput), model.GetOutputSize())
43 , m_Label(label)
telsoa01c577f2c2018-08-31 09:22:23 +010044 , m_QuantizationParams(model.GetQuantizationParams())
telsoa014fcda012018-03-09 14:13:49 +000045 , m_NumInferencesRef(numInferencesRef)
46 , m_NumCorrectInferencesRef(numCorrectInferencesRef)
47 , m_ValidationPredictions(validationPredictions)
48 , m_ValidationPredictionsOut(validationPredictionsOut)
49{
50}
51
52template <typename TTestCaseDatabase, typename TModel>
53TestCaseResult ClassifierTestCase<TTestCaseDatabase, TModel>::ProcessResult(const InferenceTestOptions& params)
54{
55 auto& output = this->GetOutput();
56 const auto testCaseId = this->GetTestCaseId();
57
surmeh01bceff2f2018-03-29 16:29:27 +010058 std::map<float,int> resultMap;
59 {
60 int index = 0;
61 for (const auto & o : output)
62 {
telsoa01c577f2c2018-08-31 09:22:23 +010063 resultMap[ToFloat<typename TModel::DataType>::Convert(o, m_QuantizationParams)] = index++;
surmeh01bceff2f2018-03-29 16:29:27 +010064 }
65 }
66
67 {
68 BOOST_LOG_TRIVIAL(info) << "= Prediction values for test #" << testCaseId;
69 auto it = resultMap.rbegin();
70 for (int i=0; i<5 && it != resultMap.rend(); ++i)
71 {
72 BOOST_LOG_TRIVIAL(info) << "Top(" << (i+1) << ") prediction is " << it->second <<
73 " with confidence: " << 100.0*(it->first) << "%";
74 ++it;
75 }
76 }
77
telsoa014fcda012018-03-09 14:13:49 +000078 const unsigned int prediction = boost::numeric_cast<unsigned int>(
79 std::distance(output.begin(), std::max_element(output.begin(), output.end())));
80
telsoa01c577f2c2018-08-31 09:22:23 +010081 // If we're just running the defaultTestCaseIds, each one must be classified correctly.
telsoa014fcda012018-03-09 14:13:49 +000082 if (params.m_IterationCount == 0 && prediction != m_Label)
83 {
84 BOOST_LOG_TRIVIAL(error) << "Prediction for test case " << testCaseId << " (" << prediction << ")" <<
85 " is incorrect (should be " << m_Label << ")";
86 return TestCaseResult::Failed;
87 }
88
telsoa01c577f2c2018-08-31 09:22:23 +010089 // If a validation file was provided as input, it checks that the prediction matches.
telsoa014fcda012018-03-09 14:13:49 +000090 if (!m_ValidationPredictions.empty() && prediction != m_ValidationPredictions[testCaseId])
91 {
92 BOOST_LOG_TRIVIAL(error) << "Prediction for test case " << testCaseId << " (" << prediction << ")" <<
93 " doesn't match the prediction in the validation file (" << m_ValidationPredictions[testCaseId] << ")";
94 return TestCaseResult::Failed;
95 }
96
telsoa01c577f2c2018-08-31 09:22:23 +010097 // If a validation file was requested as output, it stores the predictions.
telsoa014fcda012018-03-09 14:13:49 +000098 if (m_ValidationPredictionsOut)
99 {
100 m_ValidationPredictionsOut->push_back(prediction);
101 }
102
telsoa01c577f2c2018-08-31 09:22:23 +0100103 // Updates accuracy stats.
telsoa014fcda012018-03-09 14:13:49 +0000104 m_NumInferencesRef++;
105 if (prediction == m_Label)
106 {
107 m_NumCorrectInferencesRef++;
108 }
109
110 return TestCaseResult::Ok;
111}
112
113template <typename TDatabase, typename InferenceModel>
114template <typename TConstructDatabaseCallable, typename TConstructModelCallable>
115ClassifierTestCaseProvider<TDatabase, InferenceModel>::ClassifierTestCaseProvider(
116 TConstructDatabaseCallable constructDatabase, TConstructModelCallable constructModel)
117 : m_ConstructModel(constructModel)
118 , m_ConstructDatabase(constructDatabase)
119 , m_NumInferences(0)
120 , m_NumCorrectInferences(0)
121{
122}
123
124template <typename TDatabase, typename InferenceModel>
125void ClassifierTestCaseProvider<TDatabase, InferenceModel>::AddCommandLineOptions(
126 boost::program_options::options_description& options)
127{
128 namespace po = boost::program_options;
129
130 options.add_options()
131 ("validation-file-in", po::value<std::string>(&m_ValidationFileIn)->default_value(""),
132 "Reads expected predictions from the given file and confirms they match the actual predictions.")
133 ("validation-file-out", po::value<std::string>(&m_ValidationFileOut)->default_value(""),
134 "Predictions are saved to the given file for later use via --validation-file-in.")
135 ("data-dir,d", po::value<std::string>(&m_DataDir)->required(),
136 "Path to directory containing test data");
137
138 InferenceModel::AddCommandLineOptions(options, m_ModelCommandLineOptions);
139}
140
141template <typename TDatabase, typename InferenceModel>
142bool ClassifierTestCaseProvider<TDatabase, InferenceModel>::ProcessCommandLineOptions()
143{
144 if (!ValidateDirectory(m_DataDir))
145 {
146 return false;
147 }
148
149 ReadPredictions();
150
151 m_Model = m_ConstructModel(m_ModelCommandLineOptions);
152 if (!m_Model)
153 {
154 return false;
155 }
156
telsoa01c577f2c2018-08-31 09:22:23 +0100157 m_Database = std::make_unique<TDatabase>(m_ConstructDatabase(m_DataDir.c_str(), *m_Model));
telsoa014fcda012018-03-09 14:13:49 +0000158 if (!m_Database)
159 {
160 return false;
161 }
162
163 return true;
164}
165
166template <typename TDatabase, typename InferenceModel>
167std::unique_ptr<IInferenceTestCase>
168ClassifierTestCaseProvider<TDatabase, InferenceModel>::GetTestCase(unsigned int testCaseId)
169{
170 std::unique_ptr<typename TDatabase::TTestCaseData> testCaseData = m_Database->GetTestCaseData(testCaseId);
171 if (testCaseData == nullptr)
172 {
173 return nullptr;
174 }
175
176 return std::make_unique<ClassifierTestCase<TDatabase, InferenceModel>>(
177 m_NumInferences,
178 m_NumCorrectInferences,
179 m_ValidationPredictions,
180 m_ValidationFileOut.empty() ? nullptr : &m_ValidationPredictionsOut,
181 *m_Model,
182 testCaseId,
183 testCaseData->m_Label,
184 std::move(testCaseData->m_InputImage));
185}
186
187template <typename TDatabase, typename InferenceModel>
188bool ClassifierTestCaseProvider<TDatabase, InferenceModel>::OnInferenceTestFinished()
189{
190 const double accuracy = boost::numeric_cast<double>(m_NumCorrectInferences) /
191 boost::numeric_cast<double>(m_NumInferences);
192 BOOST_LOG_TRIVIAL(info) << std::fixed << std::setprecision(3) << "Overall accuracy: " << accuracy;
193
telsoa01c577f2c2018-08-31 09:22:23 +0100194 // If a validation file was requested as output, the predictions are saved to it.
telsoa014fcda012018-03-09 14:13:49 +0000195 if (!m_ValidationFileOut.empty())
196 {
197 std::ofstream validationFileOut(m_ValidationFileOut.c_str(), std::ios_base::trunc | std::ios_base::out);
198 if (validationFileOut.good())
199 {
200 for (const unsigned int prediction : m_ValidationPredictionsOut)
201 {
202 validationFileOut << prediction << std::endl;
203 }
204 }
205 else
206 {
207 BOOST_LOG_TRIVIAL(error) << "Failed to open output validation file: " << m_ValidationFileOut;
208 return false;
209 }
210 }
211
212 return true;
213}
214
215template <typename TDatabase, typename InferenceModel>
216void ClassifierTestCaseProvider<TDatabase, InferenceModel>::ReadPredictions()
217{
telsoa01c577f2c2018-08-31 09:22:23 +0100218 // Reads the expected predictions from the input validation file (if provided).
telsoa014fcda012018-03-09 14:13:49 +0000219 if (!m_ValidationFileIn.empty())
220 {
221 std::ifstream validationFileIn(m_ValidationFileIn.c_str(), std::ios_base::in);
222 if (validationFileIn.good())
223 {
224 while (!validationFileIn.eof())
225 {
226 unsigned int i;
227 validationFileIn >> i;
228 m_ValidationPredictions.emplace_back(i);
229 }
230 }
231 else
232 {
233 throw armnn::Exception(boost::str(boost::format("Failed to open input validation file: %1%")
234 % m_ValidationFileIn));
235 }
236 }
237}
238
239template<typename TConstructTestCaseProvider>
240int InferenceTestMain(int argc,
241 char* argv[],
242 const std::vector<unsigned int>& defaultTestCaseIds,
243 TConstructTestCaseProvider constructTestCaseProvider)
244{
telsoa01c577f2c2018-08-31 09:22:23 +0100245 // Configures logging for both the ARMNN library and this test program.
telsoa014fcda012018-03-09 14:13:49 +0000246#ifdef NDEBUG
247 armnn::LogSeverity level = armnn::LogSeverity::Info;
248#else
249 armnn::LogSeverity level = armnn::LogSeverity::Debug;
250#endif
251 armnn::ConfigureLogging(true, true, level);
252 armnnUtils::ConfigureLogging(boost::log::core::get().get(), true, true, level);
253
254 try
255 {
256 std::unique_ptr<IInferenceTestCaseProvider> testCaseProvider = constructTestCaseProvider();
257 if (!testCaseProvider)
258 {
259 return 1;
260 }
261
262 InferenceTestOptions inferenceTestOptions;
263 if (!ParseCommandLine(argc, argv, *testCaseProvider, inferenceTestOptions))
264 {
265 return 1;
266 }
267
268 const bool success = InferenceTest(inferenceTestOptions, defaultTestCaseIds, *testCaseProvider);
269 return success ? 0 : 1;
270 }
271 catch (armnn::Exception const& e)
272 {
273 BOOST_LOG_TRIVIAL(fatal) << "Armnn Error: " << e.what();
274 return 1;
275 }
276}
277
telsoa01c577f2c2018-08-31 09:22:23 +0100278//
279// This function allows us to create a classifier inference test based on:
280// - a model file name
281// - which can be a binary or a text file for protobuf formats
282// - an input tensor name
283// - an output tensor name
284// - a set of test case ids
285// - a callback method which creates an object that can return images
286// called 'Database' in these tests
287// - and an input tensor shape
288//
telsoa014fcda012018-03-09 14:13:49 +0000289template<typename TDatabase,
telsoa01c577f2c2018-08-31 09:22:23 +0100290 typename TParser,
291 typename TConstructDatabaseCallable>
292int ClassifierInferenceTestMain(int argc,
293 char* argv[],
294 const char* modelFilename,
295 bool isModelBinary,
296 const char* inputBindingName,
297 const char* outputBindingName,
298 const std::vector<unsigned int>& defaultTestCaseIds,
299 TConstructDatabaseCallable constructDatabase,
300 const armnn::TensorShape* inputTensorShape)
telsoa014fcda012018-03-09 14:13:49 +0000301{
302 return InferenceTestMain(argc, argv, defaultTestCaseIds,
303 [=]
304 ()
305 {
telsoa01c577f2c2018-08-31 09:22:23 +0100306 using InferenceModel = InferenceModel<TParser, typename TDatabase::DataType>;
telsoa014fcda012018-03-09 14:13:49 +0000307 using TestCaseProvider = ClassifierTestCaseProvider<TDatabase, InferenceModel>;
308
309 return make_unique<TestCaseProvider>(constructDatabase,
310 [&]
311 (typename InferenceModel::CommandLineOptions modelOptions)
312 {
313 if (!ValidateDirectory(modelOptions.m_ModelDir))
314 {
315 return std::unique_ptr<InferenceModel>();
316 }
317
318 typename InferenceModel::Params modelParams;
319 modelParams.m_ModelPath = modelOptions.m_ModelDir + modelFilename;
320 modelParams.m_InputBinding = inputBindingName;
321 modelParams.m_OutputBinding = outputBindingName;
322 modelParams.m_InputTensorShape = inputTensorShape;
323 modelParams.m_IsModelBinary = isModelBinary;
324 modelParams.m_ComputeDevice = modelOptions.m_ComputeDevice;
surmeh013537c2c2018-05-18 16:31:43 +0100325 modelParams.m_VisualizePostOptimizationModel = modelOptions.m_VisualizePostOptimizationModel;
telsoa01c577f2c2018-08-31 09:22:23 +0100326 modelParams.m_EnableFp16TurboMode = modelOptions.m_EnableFp16TurboMode;
telsoa014fcda012018-03-09 14:13:49 +0000327
328 return std::make_unique<InferenceModel>(modelParams);
329 });
330 });
331}
332
333} // namespace test
334} // namespace armnn