blob: 6423d1c7ff218d608e5056f31204872475064f72 [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
David Beckf0b48452018-10-19 15:20:56 +01007#include <armnn/ArmNN.hpp>
Derek Lamberti08446972019-11-26 16:38:31 +00008#include <armnn/Logging.hpp>
David Beckf0b48452018-10-19 15:20:56 +01009#include <armnn/TypesUtils.hpp>
telsoa01c577f2c2018-08-31 09:22:23 +010010#include "InferenceModel.hpp"
11
Derek Lambertieb1fce02019-12-10 21:20:10 +000012#include <boost/core/ignore_unused.hpp>
telsoa014fcda012018-03-09 14:13:49 +000013#include <boost/program_options.hpp>
14
telsoa01c577f2c2018-08-31 09:22:23 +010015
telsoa014fcda012018-03-09 14:13:49 +000016namespace armnn
17{
18
19inline std::istream& operator>>(std::istream& in, armnn::Compute& compute)
20{
21 std::string token;
22 in >> token;
23 compute = armnn::ParseComputeDevice(token.c_str());
24 if (compute == armnn::Compute::Undefined)
25 {
26 in.setstate(std::ios_base::failbit);
27 throw boost::program_options::validation_error(boost::program_options::validation_error::invalid_option_value);
28 }
29 return in;
30}
31
David Beckf0b48452018-10-19 15:20:56 +010032inline std::istream& operator>>(std::istream& in, armnn::BackendId& backend)
33{
34 std::string token;
35 in >> token;
36 armnn::Compute compute = armnn::ParseComputeDevice(token.c_str());
37 if (compute == armnn::Compute::Undefined)
38 {
39 in.setstate(std::ios_base::failbit);
40 throw boost::program_options::validation_error(boost::program_options::validation_error::invalid_option_value);
41 }
42 backend = compute;
43 return in;
44}
45
telsoa014fcda012018-03-09 14:13:49 +000046namespace test
47{
48
49class TestFrameworkException : public Exception
50{
51public:
52 using Exception::Exception;
53};
54
55struct InferenceTestOptions
56{
57 unsigned int m_IterationCount;
58 std::string m_InferenceTimesFile;
telsoa01c577f2c2018-08-31 09:22:23 +010059 bool m_EnableProfiling;
Matteo Martincigh00dda4a2019-08-14 11:42:30 +010060 std::string m_DynamicBackendsPath;
telsoa014fcda012018-03-09 14:13:49 +000061
62 InferenceTestOptions()
Matteo Martincigh00dda4a2019-08-14 11:42:30 +010063 : m_IterationCount(0)
64 , m_EnableProfiling(0)
65 , m_DynamicBackendsPath()
telsoa014fcda012018-03-09 14:13:49 +000066 {}
67};
68
69enum class TestCaseResult
70{
71 /// The test completed without any errors.
72 Ok,
73 /// The test failed (e.g. the prediction didn't match the validation file).
74 /// This will eventually fail the whole program but the remaining test cases will still be run.
75 Failed,
76 /// The test failed with a fatal error. The remaining tests will not be run.
77 Abort
78};
79
80class IInferenceTestCase
81{
82public:
83 virtual ~IInferenceTestCase() {}
84
85 virtual void Run() = 0;
86 virtual TestCaseResult ProcessResult(const InferenceTestOptions& options) = 0;
87};
88
89class IInferenceTestCaseProvider
90{
91public:
92 virtual ~IInferenceTestCaseProvider() {}
93
Derek Lambertieb1fce02019-12-10 21:20:10 +000094 virtual void AddCommandLineOptions(boost::program_options::options_description& options)
95 {
96 boost::ignore_unused(options);
97 };
98 virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions)
99 {
100 boost::ignore_unused(commonOptions);
101 return true;
102 };
telsoa014fcda012018-03-09 14:13:49 +0000103 virtual std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) = 0;
104 virtual bool OnInferenceTestFinished() { return true; };
105};
106
107template <typename TModel>
108class InferenceModelTestCase : public IInferenceTestCase
109{
110public:
Ferran Balaguerc602f292019-02-08 17:09:55 +0000111 using TContainer = boost::variant<std::vector<float>, std::vector<int>, std::vector<unsigned char>>;
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000112
telsoa014fcda012018-03-09 14:13:49 +0000113 InferenceModelTestCase(TModel& model,
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000114 unsigned int testCaseId,
115 const std::vector<TContainer>& inputs,
116 const std::vector<unsigned int>& outputSizes)
telsoa014fcda012018-03-09 14:13:49 +0000117 : m_Model(model)
118 , m_TestCaseId(testCaseId)
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000119 , m_Inputs(std::move(inputs))
telsoa014fcda012018-03-09 14:13:49 +0000120 {
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000121 // Initialize output vector
122 const size_t numOutputs = outputSizes.size();
Ferran Balaguerc602f292019-02-08 17:09:55 +0000123 m_Outputs.reserve(numOutputs);
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000124
125 for (size_t i = 0; i < numOutputs; i++)
126 {
Ferran Balaguerc602f292019-02-08 17:09:55 +0000127 m_Outputs.push_back(std::vector<typename TModel::DataType>(outputSizes[i]));
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000128 }
telsoa014fcda012018-03-09 14:13:49 +0000129 }
130
131 virtual void Run() override
132 {
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000133 m_Model.Run(m_Inputs, m_Outputs);
telsoa014fcda012018-03-09 14:13:49 +0000134 }
135
136protected:
137 unsigned int GetTestCaseId() const { return m_TestCaseId; }
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000138 const std::vector<TContainer>& GetOutputs() const { return m_Outputs; }
telsoa014fcda012018-03-09 14:13:49 +0000139
140private:
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000141 TModel& m_Model;
142 unsigned int m_TestCaseId;
143 std::vector<TContainer> m_Inputs;
144 std::vector<TContainer> m_Outputs;
telsoa014fcda012018-03-09 14:13:49 +0000145};
146
147template <typename TTestCaseDatabase, typename TModel>
148class ClassifierTestCase : public InferenceModelTestCase<TModel>
149{
150public:
151 ClassifierTestCase(int& numInferencesRef,
152 int& numCorrectInferencesRef,
153 const std::vector<unsigned int>& validationPredictions,
154 std::vector<unsigned int>* validationPredictionsOut,
155 TModel& model,
156 unsigned int testCaseId,
157 unsigned int label,
158 std::vector<typename TModel::DataType> modelInput);
159
160 virtual TestCaseResult ProcessResult(const InferenceTestOptions& params) override;
161
162private:
163 unsigned int m_Label;
telsoa01c577f2c2018-08-31 09:22:23 +0100164 InferenceModelInternal::QuantizationParams m_QuantizationParams;
165
telsoa014fcda012018-03-09 14:13:49 +0000166 /// These fields reference the corresponding member in the ClassifierTestCaseProvider.
167 /// @{
168 int& m_NumInferencesRef;
169 int& m_NumCorrectInferencesRef;
170 const std::vector<unsigned int>& m_ValidationPredictions;
171 std::vector<unsigned int>* m_ValidationPredictionsOut;
172 /// @}
173};
174
175template <typename TDatabase, typename InferenceModel>
176class ClassifierTestCaseProvider : public IInferenceTestCaseProvider
177{
178public:
179 template <typename TConstructDatabaseCallable, typename TConstructModelCallable>
180 ClassifierTestCaseProvider(TConstructDatabaseCallable constructDatabase, TConstructModelCallable constructModel);
181
182 virtual void AddCommandLineOptions(boost::program_options::options_description& options) override;
Matthew Bentham3e68b972019-04-09 13:10:46 +0100183 virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions) override;
telsoa014fcda012018-03-09 14:13:49 +0000184 virtual std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) override;
185 virtual bool OnInferenceTestFinished() override;
186
187private:
188 void ReadPredictions();
189
190 typename InferenceModel::CommandLineOptions m_ModelCommandLineOptions;
Matthew Bentham3e68b972019-04-09 13:10:46 +0100191 std::function<std::unique_ptr<InferenceModel>(const InferenceTestOptions& commonOptions,
192 typename InferenceModel::CommandLineOptions)> m_ConstructModel;
telsoa014fcda012018-03-09 14:13:49 +0000193 std::unique_ptr<InferenceModel> m_Model;
194
195 std::string m_DataDir;
telsoa01c577f2c2018-08-31 09:22:23 +0100196 std::function<TDatabase(const char*, const InferenceModel&)> m_ConstructDatabase;
telsoa014fcda012018-03-09 14:13:49 +0000197 std::unique_ptr<TDatabase> m_Database;
198
telsoa01c577f2c2018-08-31 09:22:23 +0100199 int m_NumInferences; // Referenced by test cases.
200 int m_NumCorrectInferences; // Referenced by test cases.
telsoa014fcda012018-03-09 14:13:49 +0000201
202 std::string m_ValidationFileIn;
telsoa01c577f2c2018-08-31 09:22:23 +0100203 std::vector<unsigned int> m_ValidationPredictions; // Referenced by test cases.
telsoa014fcda012018-03-09 14:13:49 +0000204
205 std::string m_ValidationFileOut;
telsoa01c577f2c2018-08-31 09:22:23 +0100206 std::vector<unsigned int> m_ValidationPredictionsOut; // Referenced by test cases.
telsoa014fcda012018-03-09 14:13:49 +0000207};
208
209bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCaseProvider,
210 InferenceTestOptions& outParams);
211
212bool ValidateDirectory(std::string& dir);
213
214bool InferenceTest(const InferenceTestOptions& params,
215 const std::vector<unsigned int>& defaultTestCaseIds,
216 IInferenceTestCaseProvider& testCaseProvider);
217
218template<typename TConstructTestCaseProvider>
219int InferenceTestMain(int argc,
220 char* argv[],
221 const std::vector<unsigned int>& defaultTestCaseIds,
222 TConstructTestCaseProvider constructTestCaseProvider);
223
224template<typename TDatabase,
225 typename TParser,
226 typename TConstructDatabaseCallable>
227int ClassifierInferenceTestMain(int argc, char* argv[], const char* modelFilename, bool isModelBinary,
228 const char* inputBindingName, const char* outputBindingName,
229 const std::vector<unsigned int>& defaultTestCaseIds,
230 TConstructDatabaseCallable constructDatabase,
231 const armnn::TensorShape* inputTensorShape = nullptr);
232
233} // namespace test
234} // namespace armnn
235
236#include "InferenceTest.inl"