blob: 5f53c06a8881d66cb2b429a4f07a2f1bd128bcab [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5#pragma once
6
7#include "armnn/ArmNN.hpp"
8#include "armnn/TypesUtils.hpp"
9#include <Logging.hpp>
10
11#include <boost/log/core/core.hpp>
12#include <boost/program_options.hpp>
13
14namespace armnn
15{
16
17inline std::istream& operator>>(std::istream& in, armnn::Compute& compute)
18{
19 std::string token;
20 in >> token;
21 compute = armnn::ParseComputeDevice(token.c_str());
22 if (compute == armnn::Compute::Undefined)
23 {
24 in.setstate(std::ios_base::failbit);
25 throw boost::program_options::validation_error(boost::program_options::validation_error::invalid_option_value);
26 }
27 return in;
28}
29
30namespace test
31{
32
33class TestFrameworkException : public Exception
34{
35public:
36 using Exception::Exception;
37};
38
39struct InferenceTestOptions
40{
41 unsigned int m_IterationCount;
42 std::string m_InferenceTimesFile;
43
44 InferenceTestOptions()
45 : m_IterationCount(0)
46 {}
47};
48
49enum class TestCaseResult
50{
51 /// The test completed without any errors.
52 Ok,
53 /// The test failed (e.g. the prediction didn't match the validation file).
54 /// This will eventually fail the whole program but the remaining test cases will still be run.
55 Failed,
56 /// The test failed with a fatal error. The remaining tests will not be run.
57 Abort
58};
59
60class IInferenceTestCase
61{
62public:
63 virtual ~IInferenceTestCase() {}
64
65 virtual void Run() = 0;
66 virtual TestCaseResult ProcessResult(const InferenceTestOptions& options) = 0;
67};
68
69class IInferenceTestCaseProvider
70{
71public:
72 virtual ~IInferenceTestCaseProvider() {}
73
74 virtual void AddCommandLineOptions(boost::program_options::options_description& options) {};
75 virtual bool ProcessCommandLineOptions() { return true; };
76 virtual std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) = 0;
77 virtual bool OnInferenceTestFinished() { return true; };
78};
79
80template <typename TModel>
81class InferenceModelTestCase : public IInferenceTestCase
82{
83public:
84 InferenceModelTestCase(TModel& model,
85 unsigned int testCaseId,
86 std::vector<typename TModel::DataType> modelInput,
87 unsigned int outputSize)
88 : m_Model(model)
89 , m_TestCaseId(testCaseId)
90 , m_Input(std::move(modelInput))
91 {
92 m_Output.resize(outputSize);
93 }
94
95 virtual void Run() override
96 {
97 m_Model.Run(m_Input, m_Output);
98 }
99
100protected:
101 unsigned int GetTestCaseId() const { return m_TestCaseId; }
102 const std::vector<typename TModel::DataType>& GetOutput() const { return m_Output; }
103
104private:
105 TModel& m_Model;
106 unsigned int m_TestCaseId;
107 std::vector<typename TModel::DataType> m_Input;
108 std::vector<typename TModel::DataType> m_Output;
109};
110
111template <typename TTestCaseDatabase, typename TModel>
112class ClassifierTestCase : public InferenceModelTestCase<TModel>
113{
114public:
115 ClassifierTestCase(int& numInferencesRef,
116 int& numCorrectInferencesRef,
117 const std::vector<unsigned int>& validationPredictions,
118 std::vector<unsigned int>* validationPredictionsOut,
119 TModel& model,
120 unsigned int testCaseId,
121 unsigned int label,
122 std::vector<typename TModel::DataType> modelInput);
123
124 virtual TestCaseResult ProcessResult(const InferenceTestOptions& params) override;
125
126private:
127 unsigned int m_Label;
128 /// These fields reference the corresponding member in the ClassifierTestCaseProvider.
129 /// @{
130 int& m_NumInferencesRef;
131 int& m_NumCorrectInferencesRef;
132 const std::vector<unsigned int>& m_ValidationPredictions;
133 std::vector<unsigned int>* m_ValidationPredictionsOut;
134 /// @}
135};
136
137template <typename TDatabase, typename InferenceModel>
138class ClassifierTestCaseProvider : public IInferenceTestCaseProvider
139{
140public:
141 template <typename TConstructDatabaseCallable, typename TConstructModelCallable>
142 ClassifierTestCaseProvider(TConstructDatabaseCallable constructDatabase, TConstructModelCallable constructModel);
143
144 virtual void AddCommandLineOptions(boost::program_options::options_description& options) override;
145 virtual bool ProcessCommandLineOptions() override;
146 virtual std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) override;
147 virtual bool OnInferenceTestFinished() override;
148
149private:
150 void ReadPredictions();
151
152 typename InferenceModel::CommandLineOptions m_ModelCommandLineOptions;
153 std::function<std::unique_ptr<InferenceModel>(typename InferenceModel::CommandLineOptions)> m_ConstructModel;
154 std::unique_ptr<InferenceModel> m_Model;
155
156 std::string m_DataDir;
157 std::function<TDatabase(const char*)> m_ConstructDatabase;
158 std::unique_ptr<TDatabase> m_Database;
159
160 int m_NumInferences; // Referenced by test cases
161 int m_NumCorrectInferences; // Referenced by test cases
162
163 std::string m_ValidationFileIn;
164 std::vector<unsigned int> m_ValidationPredictions; // Referenced by test cases
165
166 std::string m_ValidationFileOut;
167 std::vector<unsigned int> m_ValidationPredictionsOut; // Referenced by test cases
168};
169
170bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCaseProvider,
171 InferenceTestOptions& outParams);
172
173bool ValidateDirectory(std::string& dir);
174
175bool InferenceTest(const InferenceTestOptions& params,
176 const std::vector<unsigned int>& defaultTestCaseIds,
177 IInferenceTestCaseProvider& testCaseProvider);
178
179template<typename TConstructTestCaseProvider>
180int InferenceTestMain(int argc,
181 char* argv[],
182 const std::vector<unsigned int>& defaultTestCaseIds,
183 TConstructTestCaseProvider constructTestCaseProvider);
184
185template<typename TDatabase,
186 typename TParser,
187 typename TConstructDatabaseCallable>
188int ClassifierInferenceTestMain(int argc, char* argv[], const char* modelFilename, bool isModelBinary,
189 const char* inputBindingName, const char* outputBindingName,
190 const std::vector<unsigned int>& defaultTestCaseIds,
191 TConstructDatabaseCallable constructDatabase,
192 const armnn::TensorShape* inputTensorShape = nullptr);
193
194} // namespace test
195} // namespace armnn
196
197#include "InferenceTest.inl"