blob: 5ec744ca7e932534f7bae69ec1f49ddc92d360c6 [file] [log] [blame]
Laurent Carlier749294b2020-06-01 09:03:17 +01001//
telsoa014fcda012018-03-09 14:13:49 +00002// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa014fcda012018-03-09 14:13:49 +00004//
5#pragma once
6
Jan Eilers8eb25602020-03-09 12:13:48 +00007#include "InferenceModel.hpp"
8
David Beckf0b48452018-10-19 15:20:56 +01009#include <armnn/ArmNN.hpp>
Derek Lamberti08446972019-11-26 16:38:31 +000010#include <armnn/Logging.hpp>
David Beckf0b48452018-10-19 15:20:56 +010011#include <armnn/TypesUtils.hpp>
Jan Eilers8eb25602020-03-09 12:13:48 +000012#include <armnn/utility/IgnoreUnused.hpp>
telsoa01c577f2c2018-08-31 09:22:23 +010013
James Wardc89829f2020-10-12 14:17:36 +010014#include <cxxopts/cxxopts.hpp>
15#include <fmt/format.h>
telsoa014fcda012018-03-09 14:13:49 +000016
telsoa01c577f2c2018-08-31 09:22:23 +010017
telsoa014fcda012018-03-09 14:13:49 +000018namespace armnn
19{
20
21inline std::istream& operator>>(std::istream& in, armnn::Compute& compute)
22{
23 std::string token;
24 in >> token;
25 compute = armnn::ParseComputeDevice(token.c_str());
26 if (compute == armnn::Compute::Undefined)
27 {
28 in.setstate(std::ios_base::failbit);
James Wardc89829f2020-10-12 14:17:36 +010029 throw cxxopts::OptionException(fmt::format("Unrecognised compute device: {}", token));
telsoa014fcda012018-03-09 14:13:49 +000030 }
31 return in;
32}
33
David Beckf0b48452018-10-19 15:20:56 +010034inline std::istream& operator>>(std::istream& in, armnn::BackendId& backend)
35{
36 std::string token;
37 in >> token;
38 armnn::Compute compute = armnn::ParseComputeDevice(token.c_str());
39 if (compute == armnn::Compute::Undefined)
40 {
41 in.setstate(std::ios_base::failbit);
James Wardc89829f2020-10-12 14:17:36 +010042 throw cxxopts::OptionException(fmt::format("Unrecognised compute device: {}", token));
David Beckf0b48452018-10-19 15:20:56 +010043 }
44 backend = compute;
45 return in;
46}
47
telsoa014fcda012018-03-09 14:13:49 +000048namespace test
49{
50
51class TestFrameworkException : public Exception
52{
53public:
54 using Exception::Exception;
55};
56
57struct InferenceTestOptions
58{
59 unsigned int m_IterationCount;
60 std::string m_InferenceTimesFile;
telsoa01c577f2c2018-08-31 09:22:23 +010061 bool m_EnableProfiling;
Matteo Martincigh00dda4a2019-08-14 11:42:30 +010062 std::string m_DynamicBackendsPath;
telsoa014fcda012018-03-09 14:13:49 +000063
64 InferenceTestOptions()
Matteo Martincigh00dda4a2019-08-14 11:42:30 +010065 : m_IterationCount(0)
66 , m_EnableProfiling(0)
67 , m_DynamicBackendsPath()
telsoa014fcda012018-03-09 14:13:49 +000068 {}
69};
70
71enum class TestCaseResult
72{
73 /// The test completed without any errors.
74 Ok,
75 /// The test failed (e.g. the prediction didn't match the validation file).
76 /// This will eventually fail the whole program but the remaining test cases will still be run.
77 Failed,
78 /// The test failed with a fatal error. The remaining tests will not be run.
79 Abort
80};
81
82class IInferenceTestCase
83{
84public:
85 virtual ~IInferenceTestCase() {}
86
87 virtual void Run() = 0;
88 virtual TestCaseResult ProcessResult(const InferenceTestOptions& options) = 0;
89};
90
91class IInferenceTestCaseProvider
92{
93public:
94 virtual ~IInferenceTestCaseProvider() {}
95
James Wardc89829f2020-10-12 14:17:36 +010096 virtual void AddCommandLineOptions(cxxopts::Options& options, std::vector<std::string>& required)
Derek Lambertieb1fce02019-12-10 21:20:10 +000097 {
James Wardc89829f2020-10-12 14:17:36 +010098 IgnoreUnused(options, required);
Derek Lambertieb1fce02019-12-10 21:20:10 +000099 };
100 virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions)
101 {
Jan Eilers8eb25602020-03-09 12:13:48 +0000102 IgnoreUnused(commonOptions);
Derek Lambertieb1fce02019-12-10 21:20:10 +0000103 return true;
104 };
telsoa014fcda012018-03-09 14:13:49 +0000105 virtual std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) = 0;
106 virtual bool OnInferenceTestFinished() { return true; };
107};
108
109template <typename TModel>
110class InferenceModelTestCase : public IInferenceTestCase
111{
112public:
James Ward6d9f5c52020-09-28 11:56:35 +0100113 using TContainer = mapbox::util::variant<std::vector<float>, std::vector<int>, std::vector<unsigned char>>;
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000114
telsoa014fcda012018-03-09 14:13:49 +0000115 InferenceModelTestCase(TModel& model,
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000116 unsigned int testCaseId,
117 const std::vector<TContainer>& inputs,
118 const std::vector<unsigned int>& outputSizes)
telsoa014fcda012018-03-09 14:13:49 +0000119 : m_Model(model)
120 , m_TestCaseId(testCaseId)
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000121 , m_Inputs(std::move(inputs))
telsoa014fcda012018-03-09 14:13:49 +0000122 {
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000123 // Initialize output vector
124 const size_t numOutputs = outputSizes.size();
Ferran Balaguerc602f292019-02-08 17:09:55 +0000125 m_Outputs.reserve(numOutputs);
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000126
127 for (size_t i = 0; i < numOutputs; i++)
128 {
Ferran Balaguerc602f292019-02-08 17:09:55 +0000129 m_Outputs.push_back(std::vector<typename TModel::DataType>(outputSizes[i]));
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000130 }
telsoa014fcda012018-03-09 14:13:49 +0000131 }
132
133 virtual void Run() override
134 {
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000135 m_Model.Run(m_Inputs, m_Outputs);
telsoa014fcda012018-03-09 14:13:49 +0000136 }
137
138protected:
139 unsigned int GetTestCaseId() const { return m_TestCaseId; }
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000140 const std::vector<TContainer>& GetOutputs() const { return m_Outputs; }
telsoa014fcda012018-03-09 14:13:49 +0000141
142private:
Aron Virginas-Tar7cf0eaa2019-01-24 17:05:36 +0000143 TModel& m_Model;
144 unsigned int m_TestCaseId;
145 std::vector<TContainer> m_Inputs;
146 std::vector<TContainer> m_Outputs;
telsoa014fcda012018-03-09 14:13:49 +0000147};
148
149template <typename TTestCaseDatabase, typename TModel>
150class ClassifierTestCase : public InferenceModelTestCase<TModel>
151{
152public:
153 ClassifierTestCase(int& numInferencesRef,
154 int& numCorrectInferencesRef,
155 const std::vector<unsigned int>& validationPredictions,
156 std::vector<unsigned int>* validationPredictionsOut,
157 TModel& model,
158 unsigned int testCaseId,
159 unsigned int label,
160 std::vector<typename TModel::DataType> modelInput);
161
162 virtual TestCaseResult ProcessResult(const InferenceTestOptions& params) override;
163
164private:
165 unsigned int m_Label;
telsoa01c577f2c2018-08-31 09:22:23 +0100166 InferenceModelInternal::QuantizationParams m_QuantizationParams;
167
telsoa014fcda012018-03-09 14:13:49 +0000168 /// These fields reference the corresponding member in the ClassifierTestCaseProvider.
169 /// @{
170 int& m_NumInferencesRef;
171 int& m_NumCorrectInferencesRef;
172 const std::vector<unsigned int>& m_ValidationPredictions;
173 std::vector<unsigned int>* m_ValidationPredictionsOut;
174 /// @}
175};
176
177template <typename TDatabase, typename InferenceModel>
178class ClassifierTestCaseProvider : public IInferenceTestCaseProvider
179{
180public:
181 template <typename TConstructDatabaseCallable, typename TConstructModelCallable>
182 ClassifierTestCaseProvider(TConstructDatabaseCallable constructDatabase, TConstructModelCallable constructModel);
183
James Wardc89829f2020-10-12 14:17:36 +0100184 virtual void AddCommandLineOptions(cxxopts::Options& options, std::vector<std::string>& required) override;
Matthew Bentham3e68b972019-04-09 13:10:46 +0100185 virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions) override;
telsoa014fcda012018-03-09 14:13:49 +0000186 virtual std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) override;
187 virtual bool OnInferenceTestFinished() override;
188
189private:
190 void ReadPredictions();
191
192 typename InferenceModel::CommandLineOptions m_ModelCommandLineOptions;
Matthew Bentham3e68b972019-04-09 13:10:46 +0100193 std::function<std::unique_ptr<InferenceModel>(const InferenceTestOptions& commonOptions,
194 typename InferenceModel::CommandLineOptions)> m_ConstructModel;
telsoa014fcda012018-03-09 14:13:49 +0000195 std::unique_ptr<InferenceModel> m_Model;
196
197 std::string m_DataDir;
telsoa01c577f2c2018-08-31 09:22:23 +0100198 std::function<TDatabase(const char*, const InferenceModel&)> m_ConstructDatabase;
telsoa014fcda012018-03-09 14:13:49 +0000199 std::unique_ptr<TDatabase> m_Database;
200
telsoa01c577f2c2018-08-31 09:22:23 +0100201 int m_NumInferences; // Referenced by test cases.
202 int m_NumCorrectInferences; // Referenced by test cases.
telsoa014fcda012018-03-09 14:13:49 +0000203
204 std::string m_ValidationFileIn;
telsoa01c577f2c2018-08-31 09:22:23 +0100205 std::vector<unsigned int> m_ValidationPredictions; // Referenced by test cases.
telsoa014fcda012018-03-09 14:13:49 +0000206
207 std::string m_ValidationFileOut;
telsoa01c577f2c2018-08-31 09:22:23 +0100208 std::vector<unsigned int> m_ValidationPredictionsOut; // Referenced by test cases.
telsoa014fcda012018-03-09 14:13:49 +0000209};
210
211bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCaseProvider,
212 InferenceTestOptions& outParams);
213
214bool ValidateDirectory(std::string& dir);
215
216bool InferenceTest(const InferenceTestOptions& params,
217 const std::vector<unsigned int>& defaultTestCaseIds,
218 IInferenceTestCaseProvider& testCaseProvider);
219
220template<typename TConstructTestCaseProvider>
221int InferenceTestMain(int argc,
222 char* argv[],
223 const std::vector<unsigned int>& defaultTestCaseIds,
224 TConstructTestCaseProvider constructTestCaseProvider);
225
226template<typename TDatabase,
227 typename TParser,
228 typename TConstructDatabaseCallable>
229int ClassifierInferenceTestMain(int argc, char* argv[], const char* modelFilename, bool isModelBinary,
230 const char* inputBindingName, const char* outputBindingName,
231 const std::vector<unsigned int>& defaultTestCaseIds,
232 TConstructDatabaseCallable constructDatabase,
233 const armnn::TensorShape* inputTensorShape = nullptr);
234
235} // namespace test
236} // namespace armnn
237
238#include "InferenceTest.inl"