blob: 181afe4d8fd25945a30c93b8706fb3da72579e18 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5#pragma once
6
7#include "armnn/ArmNN.hpp"
8#include "armnn/TypesUtils.hpp"
telsoa01c577f2c2018-08-31 09:22:23 +01009#include "InferenceModel.hpp"
10
telsoa014fcda012018-03-09 14:13:49 +000011#include <Logging.hpp>
12
13#include <boost/log/core/core.hpp>
14#include <boost/program_options.hpp>
15
telsoa01c577f2c2018-08-31 09:22:23 +010016
telsoa014fcda012018-03-09 14:13:49 +000017namespace armnn
18{
19
20inline std::istream& operator>>(std::istream& in, armnn::Compute& compute)
21{
22 std::string token;
23 in >> token;
24 compute = armnn::ParseComputeDevice(token.c_str());
25 if (compute == armnn::Compute::Undefined)
26 {
27 in.setstate(std::ios_base::failbit);
28 throw boost::program_options::validation_error(boost::program_options::validation_error::invalid_option_value);
29 }
30 return in;
31}
32
33namespace test
34{
35
36class TestFrameworkException : public Exception
37{
38public:
39 using Exception::Exception;
40};
41
42struct InferenceTestOptions
43{
44 unsigned int m_IterationCount;
45 std::string m_InferenceTimesFile;
telsoa01c577f2c2018-08-31 09:22:23 +010046 bool m_EnableProfiling;
telsoa014fcda012018-03-09 14:13:49 +000047
48 InferenceTestOptions()
telsoa01c577f2c2018-08-31 09:22:23 +010049 : m_IterationCount(0),
50 m_EnableProfiling(0)
telsoa014fcda012018-03-09 14:13:49 +000051 {}
52};
53
54enum class TestCaseResult
55{
56 /// The test completed without any errors.
57 Ok,
58 /// The test failed (e.g. the prediction didn't match the validation file).
59 /// This will eventually fail the whole program but the remaining test cases will still be run.
60 Failed,
61 /// The test failed with a fatal error. The remaining tests will not be run.
62 Abort
63};
64
65class IInferenceTestCase
66{
67public:
68 virtual ~IInferenceTestCase() {}
69
70 virtual void Run() = 0;
71 virtual TestCaseResult ProcessResult(const InferenceTestOptions& options) = 0;
72};
73
74class IInferenceTestCaseProvider
75{
76public:
77 virtual ~IInferenceTestCaseProvider() {}
78
79 virtual void AddCommandLineOptions(boost::program_options::options_description& options) {};
80 virtual bool ProcessCommandLineOptions() { return true; };
81 virtual std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) = 0;
82 virtual bool OnInferenceTestFinished() { return true; };
83};
84
85template <typename TModel>
86class InferenceModelTestCase : public IInferenceTestCase
87{
88public:
89 InferenceModelTestCase(TModel& model,
90 unsigned int testCaseId,
91 std::vector<typename TModel::DataType> modelInput,
92 unsigned int outputSize)
93 : m_Model(model)
94 , m_TestCaseId(testCaseId)
95 , m_Input(std::move(modelInput))
96 {
97 m_Output.resize(outputSize);
98 }
99
100 virtual void Run() override
101 {
102 m_Model.Run(m_Input, m_Output);
103 }
104
105protected:
106 unsigned int GetTestCaseId() const { return m_TestCaseId; }
107 const std::vector<typename TModel::DataType>& GetOutput() const { return m_Output; }
108
109private:
110 TModel& m_Model;
111 unsigned int m_TestCaseId;
112 std::vector<typename TModel::DataType> m_Input;
113 std::vector<typename TModel::DataType> m_Output;
114};
115
telsoa01c577f2c2018-08-31 09:22:23 +0100116template <typename TDataType>
117struct ToFloat { }; // nothing defined for the generic case
118
119template <>
120struct ToFloat<float>
121{
122 static inline float Convert(float value, const InferenceModelInternal::QuantizationParams &)
123 {
124 // assuming that float models are not quantized
125 return value;
126 }
127};
128
129template <>
130struct ToFloat<uint8_t>
131{
132 static inline float Convert(uint8_t value,
133 const InferenceModelInternal::QuantizationParams & quantizationParams)
134 {
135 return armnn::Dequantize<uint8_t>(value,
136 quantizationParams.first,
137 quantizationParams.second);
138 }
139};
140
telsoa014fcda012018-03-09 14:13:49 +0000141template <typename TTestCaseDatabase, typename TModel>
142class ClassifierTestCase : public InferenceModelTestCase<TModel>
143{
144public:
145 ClassifierTestCase(int& numInferencesRef,
146 int& numCorrectInferencesRef,
147 const std::vector<unsigned int>& validationPredictions,
148 std::vector<unsigned int>* validationPredictionsOut,
149 TModel& model,
150 unsigned int testCaseId,
151 unsigned int label,
152 std::vector<typename TModel::DataType> modelInput);
153
154 virtual TestCaseResult ProcessResult(const InferenceTestOptions& params) override;
155
156private:
157 unsigned int m_Label;
telsoa01c577f2c2018-08-31 09:22:23 +0100158 InferenceModelInternal::QuantizationParams m_QuantizationParams;
159
telsoa014fcda012018-03-09 14:13:49 +0000160 /// These fields reference the corresponding member in the ClassifierTestCaseProvider.
161 /// @{
162 int& m_NumInferencesRef;
163 int& m_NumCorrectInferencesRef;
164 const std::vector<unsigned int>& m_ValidationPredictions;
165 std::vector<unsigned int>* m_ValidationPredictionsOut;
166 /// @}
167};
168
169template <typename TDatabase, typename InferenceModel>
170class ClassifierTestCaseProvider : public IInferenceTestCaseProvider
171{
172public:
173 template <typename TConstructDatabaseCallable, typename TConstructModelCallable>
174 ClassifierTestCaseProvider(TConstructDatabaseCallable constructDatabase, TConstructModelCallable constructModel);
175
176 virtual void AddCommandLineOptions(boost::program_options::options_description& options) override;
177 virtual bool ProcessCommandLineOptions() override;
178 virtual std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) override;
179 virtual bool OnInferenceTestFinished() override;
180
181private:
182 void ReadPredictions();
183
184 typename InferenceModel::CommandLineOptions m_ModelCommandLineOptions;
185 std::function<std::unique_ptr<InferenceModel>(typename InferenceModel::CommandLineOptions)> m_ConstructModel;
186 std::unique_ptr<InferenceModel> m_Model;
187
188 std::string m_DataDir;
telsoa01c577f2c2018-08-31 09:22:23 +0100189 std::function<TDatabase(const char*, const InferenceModel&)> m_ConstructDatabase;
telsoa014fcda012018-03-09 14:13:49 +0000190 std::unique_ptr<TDatabase> m_Database;
191
telsoa01c577f2c2018-08-31 09:22:23 +0100192 int m_NumInferences; // Referenced by test cases.
193 int m_NumCorrectInferences; // Referenced by test cases.
telsoa014fcda012018-03-09 14:13:49 +0000194
195 std::string m_ValidationFileIn;
telsoa01c577f2c2018-08-31 09:22:23 +0100196 std::vector<unsigned int> m_ValidationPredictions; // Referenced by test cases.
telsoa014fcda012018-03-09 14:13:49 +0000197
198 std::string m_ValidationFileOut;
telsoa01c577f2c2018-08-31 09:22:23 +0100199 std::vector<unsigned int> m_ValidationPredictionsOut; // Referenced by test cases.
telsoa014fcda012018-03-09 14:13:49 +0000200};
201
202bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCaseProvider,
203 InferenceTestOptions& outParams);
204
205bool ValidateDirectory(std::string& dir);
206
207bool InferenceTest(const InferenceTestOptions& params,
208 const std::vector<unsigned int>& defaultTestCaseIds,
209 IInferenceTestCaseProvider& testCaseProvider);
210
211template<typename TConstructTestCaseProvider>
212int InferenceTestMain(int argc,
213 char* argv[],
214 const std::vector<unsigned int>& defaultTestCaseIds,
215 TConstructTestCaseProvider constructTestCaseProvider);
216
217template<typename TDatabase,
218 typename TParser,
219 typename TConstructDatabaseCallable>
220int ClassifierInferenceTestMain(int argc, char* argv[], const char* modelFilename, bool isModelBinary,
221 const char* inputBindingName, const char* outputBindingName,
222 const std::vector<unsigned int>& defaultTestCaseIds,
223 TConstructDatabaseCallable constructDatabase,
224 const armnn::TensorShape* inputTensorShape = nullptr);
225
226} // namespace test
227} // namespace armnn
228
229#include "InferenceTest.inl"