blob: fb9c04848867981d619c3b7da7b96fa020b28ae0 [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
telsoa014fcda012018-03-09 14:13:49 +00002// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
Jan Eilers8eb25602020-03-09 12:13:48 +00007#include "InferenceModel.hpp"
8
David Beckf0b48452018-10-19 15:20:56 +01009#include <armnn/ArmNN.hpp>
Derek Lamberti08446972019-11-26 16:38:31 +000010#include <armnn/Logging.hpp>
David Beckf0b48452018-10-19 15:20:56 +010011#include <armnn/TypesUtils.hpp>
Jan Eilers8eb25602020-03-09 12:13:48 +000012#include <armnn/utility/IgnoreUnused.hpp>
telsoa01c577f2c2018-08-31 09:22:23 +010013
Francis Murtagh40d27412021-10-28 11:11:35 +010014#include <armnnUtils/TContainer.hpp>
15
James Wardc89829f2020-10-12 14:17:36 +010016#include <cxxopts/cxxopts.hpp>
17#include <fmt/format.h>
telsoa014fcda012018-03-09 14:13:49 +000018
telsoa01c577f2c2018-08-31 09:22:23 +010019
telsoa014fcda012018-03-09 14:13:49 +000020namespace armnn
21{
22
23inline std::istream& operator>>(std::istream& in, armnn::Compute& compute)
24{
25 std::string token;
26 in >> token;
27 compute = armnn::ParseComputeDevice(token.c_str());
28 if (compute == armnn::Compute::Undefined)
29 {
30 in.setstate(std::ios_base::failbit);
James Wardc89829f2020-10-12 14:17:36 +010031 throw cxxopts::OptionException(fmt::format("Unrecognised compute device: {}", token));
telsoa014fcda012018-03-09 14:13:49 +000032 }
33 return in;
34}
35
David Beckf0b48452018-10-19 15:20:56 +010036inline std::istream& operator>>(std::istream& in, armnn::BackendId& backend)
37{
38 std::string token;
39 in >> token;
40 armnn::Compute compute = armnn::ParseComputeDevice(token.c_str());
41 if (compute == armnn::Compute::Undefined)
42 {
43 in.setstate(std::ios_base::failbit);
James Wardc89829f2020-10-12 14:17:36 +010044 throw cxxopts::OptionException(fmt::format("Unrecognised compute device: {}", token));
David Beckf0b48452018-10-19 15:20:56 +010045 }
46 backend = compute;
47 return in;
48}
49
telsoa014fcda012018-03-09 14:13:49 +000050namespace test
51{
52
53class TestFrameworkException : public Exception
54{
55public:
56 using Exception::Exception;
57};
58
59struct InferenceTestOptions
60{
61 unsigned int m_IterationCount;
62 std::string m_InferenceTimesFile;
telsoa01c577f2c2018-08-31 09:22:23 +010063 bool m_EnableProfiling;
Matteo Martincigh00dda4a2019-08-14 11:42:30 +010064 std::string m_DynamicBackendsPath;
telsoa014fcda012018-03-09 14:13:49 +000065
66 InferenceTestOptions()
Matteo Martincigh00dda4a2019-08-14 11:42:30 +010067 : m_IterationCount(0)
68 , m_EnableProfiling(0)
69 , m_DynamicBackendsPath()
telsoa014fcda012018-03-09 14:13:49 +000070 {}
71};
72
73enum class TestCaseResult
74{
75 /// The test completed without any errors.
76 Ok,
77 /// The test failed (e.g. the prediction didn't match the validation file).
78 /// This will eventually fail the whole program but the remaining test cases will still be run.
79 Failed,
80 /// The test failed with a fatal error. The remaining tests will not be run.
81 Abort
82};
83
84class IInferenceTestCase
85{
86public:
87 virtual ~IInferenceTestCase() {}
88
89 virtual void Run() = 0;
90 virtual TestCaseResult ProcessResult(const InferenceTestOptions& options) = 0;
91};
92
93class IInferenceTestCaseProvider
94{
95public:
96 virtual ~IInferenceTestCaseProvider() {}
97
James Wardc89829f2020-10-12 14:17:36 +010098 virtual void AddCommandLineOptions(cxxopts::Options& options, std::vector<std::string>& required)
Derek Lambertieb1fce02019-12-10 21:20:10 +000099 {
James Wardc89829f2020-10-12 14:17:36 +0100100 IgnoreUnused(options, required);
Derek Lambertieb1fce02019-12-10 21:20:10 +0000101 };
102 virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions)
103 {
Jan Eilers8eb25602020-03-09 12:13:48 +0000104 IgnoreUnused(commonOptions);
Derek Lambertieb1fce02019-12-10 21:20:10 +0000105 return true;
106 };
telsoa014fcda012018-03-09 14:13:49 +0000107 virtual std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) = 0;
108 virtual bool OnInferenceTestFinished() { return true; };
109};
110
111template <typename TModel>
112class InferenceModelTestCase : public IInferenceTestCase
113{
114public:
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000115
telsoa014fcda012018-03-09 14:13:49 +0000116 InferenceModelTestCase(TModel& model,
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000117 unsigned int testCaseId,
Francis Murtagh40d27412021-10-28 11:11:35 +0100118 const std::vector<armnnUtils::TContainer>& inputs,
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000119 const std::vector<unsigned int>& outputSizes)
telsoa014fcda012018-03-09 14:13:49 +0000120 : m_Model(model)
121 , m_TestCaseId(testCaseId)
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000122 , m_Inputs(std::move(inputs))
telsoa014fcda012018-03-09 14:13:49 +0000123 {
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000124 // Initialize output vector
125 const size_t numOutputs = outputSizes.size();
Ferran Balaguerc602f292019-02-08 17:09:55 +0000126 m_Outputs.reserve(numOutputs);
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000127
128 for (size_t i = 0; i < numOutputs; i++)
129 {
Ferran Balaguerc602f292019-02-08 17:09:55 +0000130 m_Outputs.push_back(std::vector<typename TModel::DataType>(outputSizes[i]));
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000131 }
telsoa014fcda012018-03-09 14:13:49 +0000132 }
133
134 virtual void Run() override
135 {
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000136 m_Model.Run(m_Inputs, m_Outputs);
telsoa014fcda012018-03-09 14:13:49 +0000137 }
138
139protected:
140 unsigned int GetTestCaseId() const { return m_TestCaseId; }
Francis Murtagh40d27412021-10-28 11:11:35 +0100141 const std::vector<armnnUtils::TContainer>& GetOutputs() const { return m_Outputs; }
telsoa014fcda012018-03-09 14:13:49 +0000142
143private:
David Monahan6bb47a72021-10-22 12:57:28 +0100144 TModel& m_Model;
145 unsigned int m_TestCaseId;
Francis Murtagh40d27412021-10-28 11:11:35 +0100146 std::vector<armnnUtils::TContainer> m_Inputs;
147 std::vector<armnnUtils::TContainer> m_Outputs;
telsoa014fcda012018-03-09 14:13:49 +0000148};
149
150template <typename TTestCaseDatabase, typename TModel>
151class ClassifierTestCase : public InferenceModelTestCase<TModel>
152{
153public:
154 ClassifierTestCase(int& numInferencesRef,
155 int& numCorrectInferencesRef,
156 const std::vector<unsigned int>& validationPredictions,
157 std::vector<unsigned int>* validationPredictionsOut,
158 TModel& model,
159 unsigned int testCaseId,
160 unsigned int label,
161 std::vector<typename TModel::DataType> modelInput);
162
163 virtual TestCaseResult ProcessResult(const InferenceTestOptions& params) override;
164
165private:
166 unsigned int m_Label;
telsoa01c577f2c2018-08-31 09:22:23 +0100167 InferenceModelInternal::QuantizationParams m_QuantizationParams;
168
telsoa014fcda012018-03-09 14:13:49 +0000169 /// These fields reference the corresponding member in the ClassifierTestCaseProvider.
170 /// @{
171 int& m_NumInferencesRef;
172 int& m_NumCorrectInferencesRef;
173 const std::vector<unsigned int>& m_ValidationPredictions;
174 std::vector<unsigned int>* m_ValidationPredictionsOut;
175 /// @}
176};
177
178template <typename TDatabase, typename InferenceModel>
179class ClassifierTestCaseProvider : public IInferenceTestCaseProvider
180{
181public:
182 template <typename TConstructDatabaseCallable, typename TConstructModelCallable>
183 ClassifierTestCaseProvider(TConstructDatabaseCallable constructDatabase, TConstructModelCallable constructModel);
184
James Wardc89829f2020-10-12 14:17:36 +0100185 virtual void AddCommandLineOptions(cxxopts::Options& options, std::vector<std::string>& required) override;
Matthew Bentham3e68b972019-04-09 13:10:46 +0100186 virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions) override;
telsoa014fcda012018-03-09 14:13:49 +0000187 virtual std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) override;
188 virtual bool OnInferenceTestFinished() override;
189
190private:
191 void ReadPredictions();
192
193 typename InferenceModel::CommandLineOptions m_ModelCommandLineOptions;
Matthew Bentham3e68b972019-04-09 13:10:46 +0100194 std::function<std::unique_ptr<InferenceModel>(const InferenceTestOptions& commonOptions,
195 typename InferenceModel::CommandLineOptions)> m_ConstructModel;
telsoa014fcda012018-03-09 14:13:49 +0000196 std::unique_ptr<InferenceModel> m_Model;
197
198 std::string m_DataDir;
telsoa01c577f2c2018-08-31 09:22:23 +0100199 std::function<TDatabase(const char*, const InferenceModel&)> m_ConstructDatabase;
telsoa014fcda012018-03-09 14:13:49 +0000200 std::unique_ptr<TDatabase> m_Database;
201
telsoa01c577f2c2018-08-31 09:22:23 +0100202 int m_NumInferences; // Referenced by test cases.
203 int m_NumCorrectInferences; // Referenced by test cases.
telsoa014fcda012018-03-09 14:13:49 +0000204
205 std::string m_ValidationFileIn;
telsoa01c577f2c2018-08-31 09:22:23 +0100206 std::vector<unsigned int> m_ValidationPredictions; // Referenced by test cases.
telsoa014fcda012018-03-09 14:13:49 +0000207
208 std::string m_ValidationFileOut;
telsoa01c577f2c2018-08-31 09:22:23 +0100209 std::vector<unsigned int> m_ValidationPredictionsOut; // Referenced by test cases.
telsoa014fcda012018-03-09 14:13:49 +0000210};
211
212bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCaseProvider,
213 InferenceTestOptions& outParams);
214
215bool ValidateDirectory(std::string& dir);
216
217bool InferenceTest(const InferenceTestOptions& params,
218 const std::vector<unsigned int>& defaultTestCaseIds,
219 IInferenceTestCaseProvider& testCaseProvider);
220
221template<typename TConstructTestCaseProvider>
222int InferenceTestMain(int argc,
223 char* argv[],
224 const std::vector<unsigned int>& defaultTestCaseIds,
225 TConstructTestCaseProvider constructTestCaseProvider);
226
227template<typename TDatabase,
228 typename TParser,
229 typename TConstructDatabaseCallable>
230int ClassifierInferenceTestMain(int argc, char* argv[], const char* modelFilename, bool isModelBinary,
231 const char* inputBindingName, const char* outputBindingName,
232 const std::vector<unsigned int>& defaultTestCaseIds,
233 TConstructDatabaseCallable constructDatabase,
234 const armnn::TensorShape* inputTensorShape = nullptr);
235
236} // namespace test
237} // namespace armnn
238
239#include "InferenceTest.inl"