blob: 68c168f1263ac989c918dc368330112ec46428fe [file] [log] [blame]
telsoa014fcda012018-03-09 14:13:49 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
Jan Eilers8eb25602020-03-09 12:13:48 +00007#include "InferenceModel.hpp"
8
David Beckf0b48452018-10-19 15:20:56 +01009#include <armnn/ArmNN.hpp>
Derek Lamberti08446972019-11-26 16:38:31 +000010#include <armnn/Logging.hpp>
David Beckf0b48452018-10-19 15:20:56 +010011#include <armnn/TypesUtils.hpp>
Jan Eilers8eb25602020-03-09 12:13:48 +000012#include <armnn/utility/IgnoreUnused.hpp>
telsoa01c577f2c2018-08-31 09:22:23 +010013
telsoa014fcda012018-03-09 14:13:49 +000014#include <boost/program_options.hpp>
15
telsoa01c577f2c2018-08-31 09:22:23 +010016
telsoa014fcda012018-03-09 14:13:49 +000017namespace armnn
18{
19
20inline std::istream& operator>>(std::istream& in, armnn::Compute& compute)
21{
22 std::string token;
23 in >> token;
24 compute = armnn::ParseComputeDevice(token.c_str());
25 if (compute == armnn::Compute::Undefined)
26 {
27 in.setstate(std::ios_base::failbit);
28 throw boost::program_options::validation_error(boost::program_options::validation_error::invalid_option_value);
29 }
30 return in;
31}
32
David Beckf0b48452018-10-19 15:20:56 +010033inline std::istream& operator>>(std::istream& in, armnn::BackendId& backend)
34{
35 std::string token;
36 in >> token;
37 armnn::Compute compute = armnn::ParseComputeDevice(token.c_str());
38 if (compute == armnn::Compute::Undefined)
39 {
40 in.setstate(std::ios_base::failbit);
41 throw boost::program_options::validation_error(boost::program_options::validation_error::invalid_option_value);
42 }
43 backend = compute;
44 return in;
45}
46
telsoa014fcda012018-03-09 14:13:49 +000047namespace test
48{
49
50class TestFrameworkException : public Exception
51{
52public:
53 using Exception::Exception;
54};
55
56struct InferenceTestOptions
57{
58 unsigned int m_IterationCount;
59 std::string m_InferenceTimesFile;
telsoa01c577f2c2018-08-31 09:22:23 +010060 bool m_EnableProfiling;
Matteo Martincigh00dda4a2019-08-14 11:42:30 +010061 std::string m_DynamicBackendsPath;
telsoa014fcda012018-03-09 14:13:49 +000062
63 InferenceTestOptions()
Matteo Martincigh00dda4a2019-08-14 11:42:30 +010064 : m_IterationCount(0)
65 , m_EnableProfiling(0)
66 , m_DynamicBackendsPath()
telsoa014fcda012018-03-09 14:13:49 +000067 {}
68};
69
70enum class TestCaseResult
71{
72 /// The test completed without any errors.
73 Ok,
74 /// The test failed (e.g. the prediction didn't match the validation file).
75 /// This will eventually fail the whole program but the remaining test cases will still be run.
76 Failed,
77 /// The test failed with a fatal error. The remaining tests will not be run.
78 Abort
79};
80
81class IInferenceTestCase
82{
83public:
84 virtual ~IInferenceTestCase() {}
85
86 virtual void Run() = 0;
87 virtual TestCaseResult ProcessResult(const InferenceTestOptions& options) = 0;
88};
89
90class IInferenceTestCaseProvider
91{
92public:
93 virtual ~IInferenceTestCaseProvider() {}
94
Derek Lambertieb1fce02019-12-10 21:20:10 +000095 virtual void AddCommandLineOptions(boost::program_options::options_description& options)
96 {
Jan Eilers8eb25602020-03-09 12:13:48 +000097 IgnoreUnused(options);
Derek Lambertieb1fce02019-12-10 21:20:10 +000098 };
99 virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions)
100 {
Jan Eilers8eb25602020-03-09 12:13:48 +0000101 IgnoreUnused(commonOptions);
Derek Lambertieb1fce02019-12-10 21:20:10 +0000102 return true;
103 };
telsoa014fcda012018-03-09 14:13:49 +0000104 virtual std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) = 0;
105 virtual bool OnInferenceTestFinished() { return true; };
106};
107
108template <typename TModel>
109class InferenceModelTestCase : public IInferenceTestCase
110{
111public:
Ferran Balaguerc602f292019-02-08 17:09:55 +0000112 using TContainer = boost::variant<std::vector<float>, std::vector<int>, std::vector<unsigned char>>;
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000113
telsoa014fcda012018-03-09 14:13:49 +0000114 InferenceModelTestCase(TModel& model,
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000115 unsigned int testCaseId,
116 const std::vector<TContainer>& inputs,
117 const std::vector<unsigned int>& outputSizes)
telsoa014fcda012018-03-09 14:13:49 +0000118 : m_Model(model)
119 , m_TestCaseId(testCaseId)
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000120 , m_Inputs(std::move(inputs))
telsoa014fcda012018-03-09 14:13:49 +0000121 {
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000122 // Initialize output vector
123 const size_t numOutputs = outputSizes.size();
Ferran Balaguerc602f292019-02-08 17:09:55 +0000124 m_Outputs.reserve(numOutputs);
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000125
126 for (size_t i = 0; i < numOutputs; i++)
127 {
Ferran Balaguerc602f292019-02-08 17:09:55 +0000128 m_Outputs.push_back(std::vector<typename TModel::DataType>(outputSizes[i]));
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000129 }
telsoa014fcda012018-03-09 14:13:49 +0000130 }
131
132 virtual void Run() override
133 {
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000134 m_Model.Run(m_Inputs, m_Outputs);
telsoa014fcda012018-03-09 14:13:49 +0000135 }
136
137protected:
138 unsigned int GetTestCaseId() const { return m_TestCaseId; }
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000139 const std::vector<TContainer>& GetOutputs() const { return m_Outputs; }
telsoa014fcda012018-03-09 14:13:49 +0000140
141private:
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000142 TModel& m_Model;
143 unsigned int m_TestCaseId;
144 std::vector<TContainer> m_Inputs;
145 std::vector<TContainer> m_Outputs;
telsoa014fcda012018-03-09 14:13:49 +0000146};
147
148template <typename TTestCaseDatabase, typename TModel>
149class ClassifierTestCase : public InferenceModelTestCase<TModel>
150{
151public:
152 ClassifierTestCase(int& numInferencesRef,
153 int& numCorrectInferencesRef,
154 const std::vector<unsigned int>& validationPredictions,
155 std::vector<unsigned int>* validationPredictionsOut,
156 TModel& model,
157 unsigned int testCaseId,
158 unsigned int label,
159 std::vector<typename TModel::DataType> modelInput);
160
161 virtual TestCaseResult ProcessResult(const InferenceTestOptions& params) override;
162
163private:
164 unsigned int m_Label;
telsoa01c577f2c2018-08-31 09:22:23 +0100165 InferenceModelInternal::QuantizationParams m_QuantizationParams;
166
telsoa014fcda012018-03-09 14:13:49 +0000167 /// These fields reference the corresponding member in the ClassifierTestCaseProvider.
168 /// @{
169 int& m_NumInferencesRef;
170 int& m_NumCorrectInferencesRef;
171 const std::vector<unsigned int>& m_ValidationPredictions;
172 std::vector<unsigned int>* m_ValidationPredictionsOut;
173 /// @}
174};
175
176template <typename TDatabase, typename InferenceModel>
177class ClassifierTestCaseProvider : public IInferenceTestCaseProvider
178{
179public:
180 template <typename TConstructDatabaseCallable, typename TConstructModelCallable>
181 ClassifierTestCaseProvider(TConstructDatabaseCallable constructDatabase, TConstructModelCallable constructModel);
182
183 virtual void AddCommandLineOptions(boost::program_options::options_description& options) override;
Matthew Bentham3e68b972019-04-09 13:10:46 +0100184 virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions) override;
telsoa014fcda012018-03-09 14:13:49 +0000185 virtual std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) override;
186 virtual bool OnInferenceTestFinished() override;
187
188private:
189 void ReadPredictions();
190
191 typename InferenceModel::CommandLineOptions m_ModelCommandLineOptions;
Matthew Bentham3e68b972019-04-09 13:10:46 +0100192 std::function<std::unique_ptr<InferenceModel>(const InferenceTestOptions& commonOptions,
193 typename InferenceModel::CommandLineOptions)> m_ConstructModel;
telsoa014fcda012018-03-09 14:13:49 +0000194 std::unique_ptr<InferenceModel> m_Model;
195
196 std::string m_DataDir;
telsoa01c577f2c2018-08-31 09:22:23 +0100197 std::function<TDatabase(const char*, const InferenceModel&)> m_ConstructDatabase;
telsoa014fcda012018-03-09 14:13:49 +0000198 std::unique_ptr<TDatabase> m_Database;
199
telsoa01c577f2c2018-08-31 09:22:23 +0100200 int m_NumInferences; // Referenced by test cases.
201 int m_NumCorrectInferences; // Referenced by test cases.
telsoa014fcda012018-03-09 14:13:49 +0000202
203 std::string m_ValidationFileIn;
telsoa01c577f2c2018-08-31 09:22:23 +0100204 std::vector<unsigned int> m_ValidationPredictions; // Referenced by test cases.
telsoa014fcda012018-03-09 14:13:49 +0000205
206 std::string m_ValidationFileOut;
telsoa01c577f2c2018-08-31 09:22:23 +0100207 std::vector<unsigned int> m_ValidationPredictionsOut; // Referenced by test cases.
telsoa014fcda012018-03-09 14:13:49 +0000208};
209
210bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCaseProvider,
211 InferenceTestOptions& outParams);
212
213bool ValidateDirectory(std::string& dir);
214
215bool InferenceTest(const InferenceTestOptions& params,
216 const std::vector<unsigned int>& defaultTestCaseIds,
217 IInferenceTestCaseProvider& testCaseProvider);
218
219template<typename TConstructTestCaseProvider>
220int InferenceTestMain(int argc,
221 char* argv[],
222 const std::vector<unsigned int>& defaultTestCaseIds,
223 TConstructTestCaseProvider constructTestCaseProvider);
224
225template<typename TDatabase,
226 typename TParser,
227 typename TConstructDatabaseCallable>
228int ClassifierInferenceTestMain(int argc, char* argv[], const char* modelFilename, bool isModelBinary,
229 const char* inputBindingName, const char* outputBindingName,
230 const std::vector<unsigned int>& defaultTestCaseIds,
231 TConstructDatabaseCallable constructDatabase,
232 const armnn::TensorShape* inputTensorShape = nullptr);
233
234} // namespace test
235} // namespace armnn
236
237#include "InferenceTest.inl"