blob: 91a65ea494625326a2d056a880e8170046994953 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
David Beckf0b48452018-10-19 15:20:56 +01007#include <armnn/ArmNN.hpp>
8#include <armnn/TypesUtils.hpp>
telsoa01c577f2c2018-08-31 09:22:23 +01009#include "InferenceModel.hpp"
10
telsoa014fcda012018-03-09 14:13:49 +000011#include <Logging.hpp>
12
13#include <boost/log/core/core.hpp>
14#include <boost/program_options.hpp>
15
telsoa01c577f2c2018-08-31 09:22:23 +010016
telsoa014fcda012018-03-09 14:13:49 +000017namespace armnn
18{
19
20inline std::istream& operator>>(std::istream& in, armnn::Compute& compute)
21{
22 std::string token;
23 in >> token;
24 compute = armnn::ParseComputeDevice(token.c_str());
25 if (compute == armnn::Compute::Undefined)
26 {
27 in.setstate(std::ios_base::failbit);
28 throw boost::program_options::validation_error(boost::program_options::validation_error::invalid_option_value);
29 }
30 return in;
31}
32
David Beckf0b48452018-10-19 15:20:56 +010033inline std::istream& operator>>(std::istream& in, armnn::BackendId& backend)
34{
35 std::string token;
36 in >> token;
37 armnn::Compute compute = armnn::ParseComputeDevice(token.c_str());
38 if (compute == armnn::Compute::Undefined)
39 {
40 in.setstate(std::ios_base::failbit);
41 throw boost::program_options::validation_error(boost::program_options::validation_error::invalid_option_value);
42 }
43 backend = compute;
44 return in;
45}
46
telsoa014fcda012018-03-09 14:13:49 +000047namespace test
48{
49
50class TestFrameworkException : public Exception
51{
52public:
53 using Exception::Exception;
54};
55
56struct InferenceTestOptions
57{
58 unsigned int m_IterationCount;
59 std::string m_InferenceTimesFile;
telsoa01c577f2c2018-08-31 09:22:23 +010060 bool m_EnableProfiling;
telsoa014fcda012018-03-09 14:13:49 +000061
62 InferenceTestOptions()
telsoa01c577f2c2018-08-31 09:22:23 +010063 : m_IterationCount(0),
64 m_EnableProfiling(0)
telsoa014fcda012018-03-09 14:13:49 +000065 {}
66};
67
68enum class TestCaseResult
69{
70 /// The test completed without any errors.
71 Ok,
72 /// The test failed (e.g. the prediction didn't match the validation file).
73 /// This will eventually fail the whole program but the remaining test cases will still be run.
74 Failed,
75 /// The test failed with a fatal error. The remaining tests will not be run.
76 Abort
77};
78
79class IInferenceTestCase
80{
81public:
82 virtual ~IInferenceTestCase() {}
83
84 virtual void Run() = 0;
85 virtual TestCaseResult ProcessResult(const InferenceTestOptions& options) = 0;
86};
87
88class IInferenceTestCaseProvider
89{
90public:
91 virtual ~IInferenceTestCaseProvider() {}
92
93 virtual void AddCommandLineOptions(boost::program_options::options_description& options) {};
94 virtual bool ProcessCommandLineOptions() { return true; };
95 virtual std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) = 0;
96 virtual bool OnInferenceTestFinished() { return true; };
97};
98
99template <typename TModel>
100class InferenceModelTestCase : public IInferenceTestCase
101{
102public:
Ferran Balaguerc602f292019-02-08 17:09:55 +0000103 using TContainer = boost::variant<std::vector<float>, std::vector<int>, std::vector<unsigned char>>;
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000104
telsoa014fcda012018-03-09 14:13:49 +0000105 InferenceModelTestCase(TModel& model,
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000106 unsigned int testCaseId,
107 const std::vector<TContainer>& inputs,
108 const std::vector<unsigned int>& outputSizes)
telsoa014fcda012018-03-09 14:13:49 +0000109 : m_Model(model)
110 , m_TestCaseId(testCaseId)
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000111 , m_Inputs(std::move(inputs))
telsoa014fcda012018-03-09 14:13:49 +0000112 {
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000113 // Initialize output vector
114 const size_t numOutputs = outputSizes.size();
Ferran Balaguerc602f292019-02-08 17:09:55 +0000115 m_Outputs.reserve(numOutputs);
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000116
117 for (size_t i = 0; i < numOutputs; i++)
118 {
Ferran Balaguerc602f292019-02-08 17:09:55 +0000119 m_Outputs.push_back(std::vector<typename TModel::DataType>(outputSizes[i]));
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000120 }
telsoa014fcda012018-03-09 14:13:49 +0000121 }
122
123 virtual void Run() override
124 {
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000125 m_Model.Run(m_Inputs, m_Outputs);
telsoa014fcda012018-03-09 14:13:49 +0000126 }
127
128protected:
129 unsigned int GetTestCaseId() const { return m_TestCaseId; }
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000130 const std::vector<TContainer>& GetOutputs() const { return m_Outputs; }
telsoa014fcda012018-03-09 14:13:49 +0000131
132private:
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000133 TModel& m_Model;
134 unsigned int m_TestCaseId;
135 std::vector<TContainer> m_Inputs;
136 std::vector<TContainer> m_Outputs;
telsoa014fcda012018-03-09 14:13:49 +0000137};
138
telsoa01c577f2c2018-08-31 09:22:23 +0100139template <typename TDataType>
140struct ToFloat { }; // nothing defined for the generic case
141
142template <>
143struct ToFloat<float>
144{
145 static inline float Convert(float value, const InferenceModelInternal::QuantizationParams &)
146 {
147 // assuming that float models are not quantized
148 return value;
149 }
Ferran Balaguerc602f292019-02-08 17:09:55 +0000150
151 static inline float Convert(int value, const InferenceModelInternal::QuantizationParams &)
152 {
153 // assuming that float models are not quantized
154 return static_cast<float>(value);
155 }
telsoa01c577f2c2018-08-31 09:22:23 +0100156};
157
158template <>
159struct ToFloat<uint8_t>
160{
161 static inline float Convert(uint8_t value,
162 const InferenceModelInternal::QuantizationParams & quantizationParams)
163 {
164 return armnn::Dequantize<uint8_t>(value,
165 quantizationParams.first,
166 quantizationParams.second);
167 }
Ferran Balaguerc602f292019-02-08 17:09:55 +0000168
169 static inline float Convert(int value,
170 const InferenceModelInternal::QuantizationParams & quantizationParams)
171 {
172 return armnn::Dequantize<uint8_t>(static_cast<uint8_t>(value),
173 quantizationParams.first,
174 quantizationParams.second);
175 }
176
177 static inline float Convert(float value,
178 const InferenceModelInternal::QuantizationParams & quantizationParams)
179 {
180 return armnn::Dequantize<uint8_t>(static_cast<uint8_t>(value),
181 quantizationParams.first,
182 quantizationParams.second);
183 }
telsoa01c577f2c2018-08-31 09:22:23 +0100184};
185
telsoa014fcda012018-03-09 14:13:49 +0000186template <typename TTestCaseDatabase, typename TModel>
187class ClassifierTestCase : public InferenceModelTestCase<TModel>
188{
189public:
190 ClassifierTestCase(int& numInferencesRef,
191 int& numCorrectInferencesRef,
192 const std::vector<unsigned int>& validationPredictions,
193 std::vector<unsigned int>* validationPredictionsOut,
194 TModel& model,
195 unsigned int testCaseId,
196 unsigned int label,
197 std::vector<typename TModel::DataType> modelInput);
198
199 virtual TestCaseResult ProcessResult(const InferenceTestOptions& params) override;
200
201private:
202 unsigned int m_Label;
telsoa01c577f2c2018-08-31 09:22:23 +0100203 InferenceModelInternal::QuantizationParams m_QuantizationParams;
204
telsoa014fcda012018-03-09 14:13:49 +0000205 /// These fields reference the corresponding member in the ClassifierTestCaseProvider.
206 /// @{
207 int& m_NumInferencesRef;
208 int& m_NumCorrectInferencesRef;
209 const std::vector<unsigned int>& m_ValidationPredictions;
210 std::vector<unsigned int>* m_ValidationPredictionsOut;
211 /// @}
212};
213
214template <typename TDatabase, typename InferenceModel>
215class ClassifierTestCaseProvider : public IInferenceTestCaseProvider
216{
217public:
218 template <typename TConstructDatabaseCallable, typename TConstructModelCallable>
219 ClassifierTestCaseProvider(TConstructDatabaseCallable constructDatabase, TConstructModelCallable constructModel);
220
221 virtual void AddCommandLineOptions(boost::program_options::options_description& options) override;
222 virtual bool ProcessCommandLineOptions() override;
223 virtual std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) override;
224 virtual bool OnInferenceTestFinished() override;
225
226private:
227 void ReadPredictions();
228
229 typename InferenceModel::CommandLineOptions m_ModelCommandLineOptions;
230 std::function<std::unique_ptr<InferenceModel>(typename InferenceModel::CommandLineOptions)> m_ConstructModel;
231 std::unique_ptr<InferenceModel> m_Model;
232
233 std::string m_DataDir;
telsoa01c577f2c2018-08-31 09:22:23 +0100234 std::function<TDatabase(const char*, const InferenceModel&)> m_ConstructDatabase;
telsoa014fcda012018-03-09 14:13:49 +0000235 std::unique_ptr<TDatabase> m_Database;
236
telsoa01c577f2c2018-08-31 09:22:23 +0100237 int m_NumInferences; // Referenced by test cases.
238 int m_NumCorrectInferences; // Referenced by test cases.
telsoa014fcda012018-03-09 14:13:49 +0000239
240 std::string m_ValidationFileIn;
telsoa01c577f2c2018-08-31 09:22:23 +0100241 std::vector<unsigned int> m_ValidationPredictions; // Referenced by test cases.
telsoa014fcda012018-03-09 14:13:49 +0000242
243 std::string m_ValidationFileOut;
telsoa01c577f2c2018-08-31 09:22:23 +0100244 std::vector<unsigned int> m_ValidationPredictionsOut; // Referenced by test cases.
telsoa014fcda012018-03-09 14:13:49 +0000245};
246
247bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCaseProvider,
248 InferenceTestOptions& outParams);
249
250bool ValidateDirectory(std::string& dir);
251
252bool InferenceTest(const InferenceTestOptions& params,
253 const std::vector<unsigned int>& defaultTestCaseIds,
254 IInferenceTestCaseProvider& testCaseProvider);
255
256template<typename TConstructTestCaseProvider>
257int InferenceTestMain(int argc,
258 char* argv[],
259 const std::vector<unsigned int>& defaultTestCaseIds,
260 TConstructTestCaseProvider constructTestCaseProvider);
261
262template<typename TDatabase,
263 typename TParser,
264 typename TConstructDatabaseCallable>
265int ClassifierInferenceTestMain(int argc, char* argv[], const char* modelFilename, bool isModelBinary,
266 const char* inputBindingName, const char* outputBindingName,
267 const std::vector<unsigned int>& defaultTestCaseIds,
268 TConstructDatabaseCallable constructDatabase,
269 const armnn::TensorShape* inputTensorShape = nullptr);
270
271} // namespace test
272} // namespace armnn
273
274#include "InferenceTest.inl"