blob: 64f97c1f871ddba3de3cd225d22a21a94b0e154f [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5#include "InferenceTest.hpp"
6
7#include "InferenceModel.hpp"
8
9#include <boost/algorithm/string.hpp>
10#include <boost/numeric/conversion/cast.hpp>
11#include <boost/log/trivial.hpp>
12#include <boost/filesystem/path.hpp>
13#include <boost/assert.hpp>
14#include <boost/format.hpp>
15#include <boost/program_options.hpp>
16#include <boost/filesystem/operations.hpp>
17
18#include <fstream>
19#include <iostream>
20#include <iomanip>
21#include <array>
22#include <chrono>
23
24using namespace std;
25using namespace std::chrono;
26using namespace armnn::test;
27
28namespace armnn
29{
30namespace test
31{
32
33template <typename TTestCaseDatabase, typename TModel>
34ClassifierTestCase<TTestCaseDatabase, TModel>::ClassifierTestCase(
35 int& numInferencesRef,
36 int& numCorrectInferencesRef,
37 const std::vector<unsigned int>& validationPredictions,
38 std::vector<unsigned int>* validationPredictionsOut,
39 TModel& model,
40 unsigned int testCaseId,
41 unsigned int label,
42 std::vector<typename TModel::DataType> modelInput)
43 : InferenceModelTestCase<TModel>(model, testCaseId, std::move(modelInput), model.GetOutputSize())
44 , m_Label(label)
45 , m_NumInferencesRef(numInferencesRef)
46 , m_NumCorrectInferencesRef(numCorrectInferencesRef)
47 , m_ValidationPredictions(validationPredictions)
48 , m_ValidationPredictionsOut(validationPredictionsOut)
49{
50}
51
52template <typename TTestCaseDatabase, typename TModel>
53TestCaseResult ClassifierTestCase<TTestCaseDatabase, TModel>::ProcessResult(const InferenceTestOptions& params)
54{
55 auto& output = this->GetOutput();
56 const auto testCaseId = this->GetTestCaseId();
57
58 const unsigned int prediction = boost::numeric_cast<unsigned int>(
59 std::distance(output.begin(), std::max_element(output.begin(), output.end())));
60
61 // If we're just running the defaultTestCaseIds, each one must be classified correctly
62 if (params.m_IterationCount == 0 && prediction != m_Label)
63 {
64 BOOST_LOG_TRIVIAL(error) << "Prediction for test case " << testCaseId << " (" << prediction << ")" <<
65 " is incorrect (should be " << m_Label << ")";
66 return TestCaseResult::Failed;
67 }
68
69 // If a validation file was provided as input, check that the prediction matches
70 if (!m_ValidationPredictions.empty() && prediction != m_ValidationPredictions[testCaseId])
71 {
72 BOOST_LOG_TRIVIAL(error) << "Prediction for test case " << testCaseId << " (" << prediction << ")" <<
73 " doesn't match the prediction in the validation file (" << m_ValidationPredictions[testCaseId] << ")";
74 return TestCaseResult::Failed;
75 }
76
77 // If a validation file was requested as output, store the predictions
78 if (m_ValidationPredictionsOut)
79 {
80 m_ValidationPredictionsOut->push_back(prediction);
81 }
82
83 // Update accuracy stats
84 m_NumInferencesRef++;
85 if (prediction == m_Label)
86 {
87 m_NumCorrectInferencesRef++;
88 }
89
90 return TestCaseResult::Ok;
91}
92
93template <typename TDatabase, typename InferenceModel>
94template <typename TConstructDatabaseCallable, typename TConstructModelCallable>
95ClassifierTestCaseProvider<TDatabase, InferenceModel>::ClassifierTestCaseProvider(
96 TConstructDatabaseCallable constructDatabase, TConstructModelCallable constructModel)
97 : m_ConstructModel(constructModel)
98 , m_ConstructDatabase(constructDatabase)
99 , m_NumInferences(0)
100 , m_NumCorrectInferences(0)
101{
102}
103
104template <typename TDatabase, typename InferenceModel>
105void ClassifierTestCaseProvider<TDatabase, InferenceModel>::AddCommandLineOptions(
106 boost::program_options::options_description& options)
107{
108 namespace po = boost::program_options;
109
110 options.add_options()
111 ("validation-file-in", po::value<std::string>(&m_ValidationFileIn)->default_value(""),
112 "Reads expected predictions from the given file and confirms they match the actual predictions.")
113 ("validation-file-out", po::value<std::string>(&m_ValidationFileOut)->default_value(""),
114 "Predictions are saved to the given file for later use via --validation-file-in.")
115 ("data-dir,d", po::value<std::string>(&m_DataDir)->required(),
116 "Path to directory containing test data");
117
118 InferenceModel::AddCommandLineOptions(options, m_ModelCommandLineOptions);
119}
120
121template <typename TDatabase, typename InferenceModel>
122bool ClassifierTestCaseProvider<TDatabase, InferenceModel>::ProcessCommandLineOptions()
123{
124 if (!ValidateDirectory(m_DataDir))
125 {
126 return false;
127 }
128
129 ReadPredictions();
130
131 m_Model = m_ConstructModel(m_ModelCommandLineOptions);
132 if (!m_Model)
133 {
134 return false;
135 }
136
137 m_Database = std::make_unique<TDatabase>(m_ConstructDatabase(m_DataDir.c_str()));
138 if (!m_Database)
139 {
140 return false;
141 }
142
143 return true;
144}
145
146template <typename TDatabase, typename InferenceModel>
147std::unique_ptr<IInferenceTestCase>
148ClassifierTestCaseProvider<TDatabase, InferenceModel>::GetTestCase(unsigned int testCaseId)
149{
150 std::unique_ptr<typename TDatabase::TTestCaseData> testCaseData = m_Database->GetTestCaseData(testCaseId);
151 if (testCaseData == nullptr)
152 {
153 return nullptr;
154 }
155
156 return std::make_unique<ClassifierTestCase<TDatabase, InferenceModel>>(
157 m_NumInferences,
158 m_NumCorrectInferences,
159 m_ValidationPredictions,
160 m_ValidationFileOut.empty() ? nullptr : &m_ValidationPredictionsOut,
161 *m_Model,
162 testCaseId,
163 testCaseData->m_Label,
164 std::move(testCaseData->m_InputImage));
165}
166
167template <typename TDatabase, typename InferenceModel>
168bool ClassifierTestCaseProvider<TDatabase, InferenceModel>::OnInferenceTestFinished()
169{
170 const double accuracy = boost::numeric_cast<double>(m_NumCorrectInferences) /
171 boost::numeric_cast<double>(m_NumInferences);
172 BOOST_LOG_TRIVIAL(info) << std::fixed << std::setprecision(3) << "Overall accuracy: " << accuracy;
173
174 // If a validation file was requested as output, save the predictions to it
175 if (!m_ValidationFileOut.empty())
176 {
177 std::ofstream validationFileOut(m_ValidationFileOut.c_str(), std::ios_base::trunc | std::ios_base::out);
178 if (validationFileOut.good())
179 {
180 for (const unsigned int prediction : m_ValidationPredictionsOut)
181 {
182 validationFileOut << prediction << std::endl;
183 }
184 }
185 else
186 {
187 BOOST_LOG_TRIVIAL(error) << "Failed to open output validation file: " << m_ValidationFileOut;
188 return false;
189 }
190 }
191
192 return true;
193}
194
195template <typename TDatabase, typename InferenceModel>
196void ClassifierTestCaseProvider<TDatabase, InferenceModel>::ReadPredictions()
197{
198 // Read expected predictions from the input validation file (if provided)
199 if (!m_ValidationFileIn.empty())
200 {
201 std::ifstream validationFileIn(m_ValidationFileIn.c_str(), std::ios_base::in);
202 if (validationFileIn.good())
203 {
204 while (!validationFileIn.eof())
205 {
206 unsigned int i;
207 validationFileIn >> i;
208 m_ValidationPredictions.emplace_back(i);
209 }
210 }
211 else
212 {
213 throw armnn::Exception(boost::str(boost::format("Failed to open input validation file: %1%")
214 % m_ValidationFileIn));
215 }
216 }
217}
218
219template<typename TConstructTestCaseProvider>
220int InferenceTestMain(int argc,
221 char* argv[],
222 const std::vector<unsigned int>& defaultTestCaseIds,
223 TConstructTestCaseProvider constructTestCaseProvider)
224{
225 // Configure logging for both the ARMNN library and this test program
226#ifdef NDEBUG
227 armnn::LogSeverity level = armnn::LogSeverity::Info;
228#else
229 armnn::LogSeverity level = armnn::LogSeverity::Debug;
230#endif
231 armnn::ConfigureLogging(true, true, level);
232 armnnUtils::ConfigureLogging(boost::log::core::get().get(), true, true, level);
233
234 try
235 {
236 std::unique_ptr<IInferenceTestCaseProvider> testCaseProvider = constructTestCaseProvider();
237 if (!testCaseProvider)
238 {
239 return 1;
240 }
241
242 InferenceTestOptions inferenceTestOptions;
243 if (!ParseCommandLine(argc, argv, *testCaseProvider, inferenceTestOptions))
244 {
245 return 1;
246 }
247
248 const bool success = InferenceTest(inferenceTestOptions, defaultTestCaseIds, *testCaseProvider);
249 return success ? 0 : 1;
250 }
251 catch (armnn::Exception const& e)
252 {
253 BOOST_LOG_TRIVIAL(fatal) << "Armnn Error: " << e.what();
254 return 1;
255 }
256}
257
258template<typename TDatabase,
259 typename TParser,
260 typename TConstructDatabaseCallable>
261int ClassifierInferenceTestMain(int argc, char* argv[], const char* modelFilename, bool isModelBinary,
262 const char* inputBindingName, const char* outputBindingName,
263 const std::vector<unsigned int>& defaultTestCaseIds,
264 TConstructDatabaseCallable constructDatabase,
265 const armnn::TensorShape* inputTensorShape)
266{
267 return InferenceTestMain(argc, argv, defaultTestCaseIds,
268 [=]
269 ()
270 {
271 using InferenceModel = InferenceModel<TParser, float>;
272 using TestCaseProvider = ClassifierTestCaseProvider<TDatabase, InferenceModel>;
273
274 return make_unique<TestCaseProvider>(constructDatabase,
275 [&]
276 (typename InferenceModel::CommandLineOptions modelOptions)
277 {
278 if (!ValidateDirectory(modelOptions.m_ModelDir))
279 {
280 return std::unique_ptr<InferenceModel>();
281 }
282
283 typename InferenceModel::Params modelParams;
284 modelParams.m_ModelPath = modelOptions.m_ModelDir + modelFilename;
285 modelParams.m_InputBinding = inputBindingName;
286 modelParams.m_OutputBinding = outputBindingName;
287 modelParams.m_InputTensorShape = inputTensorShape;
288 modelParams.m_IsModelBinary = isModelBinary;
289 modelParams.m_ComputeDevice = modelOptions.m_ComputeDevice;
290
291 return std::make_unique<InferenceModel>(modelParams);
292 });
293 });
294}
295
296} // namespace test
297} // namespace armnn