blob: d0bb0c00f3c2b0ffa6f9b2ea7e0c75d79bb920be [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
telsoa014fcda012018-03-09 14:13:49 +00002// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
Jan Eilers8eb25602020-03-09 12:13:48 +00007#include "InferenceModel.hpp"
8
David Beckf0b48452018-10-19 15:20:56 +01009#include <armnn/ArmNN.hpp>
David Monahan6bb47a72021-10-22 12:57:28 +010010#include <armnn/Utils.hpp>
Derek Lamberti08446972019-11-26 16:38:31 +000011#include <armnn/Logging.hpp>
David Beckf0b48452018-10-19 15:20:56 +010012#include <armnn/TypesUtils.hpp>
Jan Eilers8eb25602020-03-09 12:13:48 +000013#include <armnn/utility/IgnoreUnused.hpp>
telsoa01c577f2c2018-08-31 09:22:23 +010014
James Wardc89829f2020-10-12 14:17:36 +010015#include <cxxopts/cxxopts.hpp>
16#include <fmt/format.h>
telsoa014fcda012018-03-09 14:13:49 +000017
telsoa01c577f2c2018-08-31 09:22:23 +010018
telsoa014fcda012018-03-09 14:13:49 +000019namespace armnn
20{
21
22inline std::istream& operator>>(std::istream& in, armnn::Compute& compute)
23{
24 std::string token;
25 in >> token;
26 compute = armnn::ParseComputeDevice(token.c_str());
27 if (compute == armnn::Compute::Undefined)
28 {
29 in.setstate(std::ios_base::failbit);
James Wardc89829f2020-10-12 14:17:36 +010030 throw cxxopts::OptionException(fmt::format("Unrecognised compute device: {}", token));
telsoa014fcda012018-03-09 14:13:49 +000031 }
32 return in;
33}
34
David Beckf0b48452018-10-19 15:20:56 +010035inline std::istream& operator>>(std::istream& in, armnn::BackendId& backend)
36{
37 std::string token;
38 in >> token;
39 armnn::Compute compute = armnn::ParseComputeDevice(token.c_str());
40 if (compute == armnn::Compute::Undefined)
41 {
42 in.setstate(std::ios_base::failbit);
James Wardc89829f2020-10-12 14:17:36 +010043 throw cxxopts::OptionException(fmt::format("Unrecognised compute device: {}", token));
David Beckf0b48452018-10-19 15:20:56 +010044 }
45 backend = compute;
46 return in;
47}
48
telsoa014fcda012018-03-09 14:13:49 +000049namespace test
50{
51
52class TestFrameworkException : public Exception
53{
54public:
55 using Exception::Exception;
56};
57
58struct InferenceTestOptions
59{
60 unsigned int m_IterationCount;
61 std::string m_InferenceTimesFile;
telsoa01c577f2c2018-08-31 09:22:23 +010062 bool m_EnableProfiling;
Matteo Martincigh00dda4a2019-08-14 11:42:30 +010063 std::string m_DynamicBackendsPath;
telsoa014fcda012018-03-09 14:13:49 +000064
65 InferenceTestOptions()
Matteo Martincigh00dda4a2019-08-14 11:42:30 +010066 : m_IterationCount(0)
67 , m_EnableProfiling(0)
68 , m_DynamicBackendsPath()
telsoa014fcda012018-03-09 14:13:49 +000069 {}
70};
71
72enum class TestCaseResult
73{
74 /// The test completed without any errors.
75 Ok,
76 /// The test failed (e.g. the prediction didn't match the validation file).
77 /// This will eventually fail the whole program but the remaining test cases will still be run.
78 Failed,
79 /// The test failed with a fatal error. The remaining tests will not be run.
80 Abort
81};
82
83class IInferenceTestCase
84{
85public:
86 virtual ~IInferenceTestCase() {}
87
88 virtual void Run() = 0;
89 virtual TestCaseResult ProcessResult(const InferenceTestOptions& options) = 0;
90};
91
92class IInferenceTestCaseProvider
93{
94public:
95 virtual ~IInferenceTestCaseProvider() {}
96
James Wardc89829f2020-10-12 14:17:36 +010097 virtual void AddCommandLineOptions(cxxopts::Options& options, std::vector<std::string>& required)
Derek Lambertieb1fce02019-12-10 21:20:10 +000098 {
James Wardc89829f2020-10-12 14:17:36 +010099 IgnoreUnused(options, required);
Derek Lambertieb1fce02019-12-10 21:20:10 +0000100 };
101 virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions)
102 {
Jan Eilers8eb25602020-03-09 12:13:48 +0000103 IgnoreUnused(commonOptions);
Derek Lambertieb1fce02019-12-10 21:20:10 +0000104 return true;
105 };
telsoa014fcda012018-03-09 14:13:49 +0000106 virtual std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) = 0;
107 virtual bool OnInferenceTestFinished() { return true; };
108};
109
110template <typename TModel>
111class InferenceModelTestCase : public IInferenceTestCase
112{
113public:
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000114
telsoa014fcda012018-03-09 14:13:49 +0000115 InferenceModelTestCase(TModel& model,
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000116 unsigned int testCaseId,
David Monahan6bb47a72021-10-22 12:57:28 +0100117 const std::vector<armnn::TContainer>& inputs,
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000118 const std::vector<unsigned int>& outputSizes)
telsoa014fcda012018-03-09 14:13:49 +0000119 : m_Model(model)
120 , m_TestCaseId(testCaseId)
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000121 , m_Inputs(std::move(inputs))
telsoa014fcda012018-03-09 14:13:49 +0000122 {
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000123 // Initialize output vector
124 const size_t numOutputs = outputSizes.size();
Ferran Balaguerc602f292019-02-08 17:09:55 +0000125 m_Outputs.reserve(numOutputs);
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000126
127 for (size_t i = 0; i < numOutputs; i++)
128 {
Ferran Balaguerc602f292019-02-08 17:09:55 +0000129 m_Outputs.push_back(std::vector<typename TModel::DataType>(outputSizes[i]));
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000130 }
telsoa014fcda012018-03-09 14:13:49 +0000131 }
132
133 virtual void Run() override
134 {
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000135 m_Model.Run(m_Inputs, m_Outputs);
telsoa014fcda012018-03-09 14:13:49 +0000136 }
137
138protected:
139 unsigned int GetTestCaseId() const { return m_TestCaseId; }
David Monahan6bb47a72021-10-22 12:57:28 +0100140 const std::vector<armnn::TContainer>& GetOutputs() const { return m_Outputs; }
telsoa014fcda012018-03-09 14:13:49 +0000141
142private:
David Monahan6bb47a72021-10-22 12:57:28 +0100143 TModel& m_Model;
144 unsigned int m_TestCaseId;
145 std::vector<armnn::TContainer> m_Inputs;
146 std::vector<armnn::TContainer> m_Outputs;
telsoa014fcda012018-03-09 14:13:49 +0000147};
148
149template <typename TTestCaseDatabase, typename TModel>
150class ClassifierTestCase : public InferenceModelTestCase<TModel>
151{
152public:
153 ClassifierTestCase(int& numInferencesRef,
154 int& numCorrectInferencesRef,
155 const std::vector<unsigned int>& validationPredictions,
156 std::vector<unsigned int>* validationPredictionsOut,
157 TModel& model,
158 unsigned int testCaseId,
159 unsigned int label,
160 std::vector<typename TModel::DataType> modelInput);
161
162 virtual TestCaseResult ProcessResult(const InferenceTestOptions& params) override;
163
164private:
165 unsigned int m_Label;
telsoa01c577f2c2018-08-31 09:22:23 +0100166 InferenceModelInternal::QuantizationParams m_QuantizationParams;
167
telsoa014fcda012018-03-09 14:13:49 +0000168 /// These fields reference the corresponding member in the ClassifierTestCaseProvider.
169 /// @{
170 int& m_NumInferencesRef;
171 int& m_NumCorrectInferencesRef;
172 const std::vector<unsigned int>& m_ValidationPredictions;
173 std::vector<unsigned int>* m_ValidationPredictionsOut;
174 /// @}
175};
176
177template <typename TDatabase, typename InferenceModel>
178class ClassifierTestCaseProvider : public IInferenceTestCaseProvider
179{
180public:
181 template <typename TConstructDatabaseCallable, typename TConstructModelCallable>
182 ClassifierTestCaseProvider(TConstructDatabaseCallable constructDatabase, TConstructModelCallable constructModel);
183
James Wardc89829f2020-10-12 14:17:36 +0100184 virtual void AddCommandLineOptions(cxxopts::Options& options, std::vector<std::string>& required) override;
Matthew Bentham3e68b972019-04-09 13:10:46 +0100185 virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions) override;
telsoa014fcda012018-03-09 14:13:49 +0000186 virtual std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) override;
187 virtual bool OnInferenceTestFinished() override;
188
189private:
190 void ReadPredictions();
191
192 typename InferenceModel::CommandLineOptions m_ModelCommandLineOptions;
Matthew Bentham3e68b972019-04-09 13:10:46 +0100193 std::function<std::unique_ptr<InferenceModel>(const InferenceTestOptions& commonOptions,
194 typename InferenceModel::CommandLineOptions)> m_ConstructModel;
telsoa014fcda012018-03-09 14:13:49 +0000195 std::unique_ptr<InferenceModel> m_Model;
196
197 std::string m_DataDir;
telsoa01c577f2c2018-08-31 09:22:23 +0100198 std::function<TDatabase(const char*, const InferenceModel&)> m_ConstructDatabase;
telsoa014fcda012018-03-09 14:13:49 +0000199 std::unique_ptr<TDatabase> m_Database;
200
telsoa01c577f2c2018-08-31 09:22:23 +0100201 int m_NumInferences; // Referenced by test cases.
202 int m_NumCorrectInferences; // Referenced by test cases.
telsoa014fcda012018-03-09 14:13:49 +0000203
204 std::string m_ValidationFileIn;
telsoa01c577f2c2018-08-31 09:22:23 +0100205 std::vector<unsigned int> m_ValidationPredictions; // Referenced by test cases.
telsoa014fcda012018-03-09 14:13:49 +0000206
207 std::string m_ValidationFileOut;
telsoa01c577f2c2018-08-31 09:22:23 +0100208 std::vector<unsigned int> m_ValidationPredictionsOut; // Referenced by test cases.
telsoa014fcda012018-03-09 14:13:49 +0000209};
210
211bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCaseProvider,
212 InferenceTestOptions& outParams);
213
214bool ValidateDirectory(std::string& dir);
215
216bool InferenceTest(const InferenceTestOptions& params,
217 const std::vector<unsigned int>& defaultTestCaseIds,
218 IInferenceTestCaseProvider& testCaseProvider);
219
220template<typename TConstructTestCaseProvider>
221int InferenceTestMain(int argc,
222 char* argv[],
223 const std::vector<unsigned int>& defaultTestCaseIds,
224 TConstructTestCaseProvider constructTestCaseProvider);
225
226template<typename TDatabase,
227 typename TParser,
228 typename TConstructDatabaseCallable>
229int ClassifierInferenceTestMain(int argc, char* argv[], const char* modelFilename, bool isModelBinary,
230 const char* inputBindingName, const char* outputBindingName,
231 const std::vector<unsigned int>& defaultTestCaseIds,
232 TConstructDatabaseCallable constructDatabase,
233 const armnn::TensorShape* inputTensorShape = nullptr);
234
235} // namespace test
236} // namespace armnn
237
238#include "InferenceTest.inl"