blob: c6f3f1feb7b0e418756489e95982c48786db9754 [file] [log] [blame]
surmeh0149b9e102018-05-17 14:11:25 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beck93e48982018-09-05 13:05:09 +01003// SPDX-License-Identifier: MIT
surmeh0149b9e102018-05-17 14:11:25 +01004//
5#pragma once
6
7#ifndef LOG_TAG
8#define LOG_TAG "ArmnnDriverTests"
9#endif // LOG_TAG
10
11#include "../ArmnnDriver.hpp"
12#include <iosfwd>
Nikhil Raj77605822018-09-03 11:25:56 +010013#include <boost/test/unit_test.hpp>
surmeh0149b9e102018-05-17 14:11:25 +010014
15namespace android
16{
17namespace hardware
18{
19namespace neuralnetworks
20{
21namespace V1_0
22{
23
24std::ostream& operator<<(std::ostream& os, ErrorStatus stat);
25
26} // namespace android::hardware::neuralnetworks::V1_0
27} // namespace android::hardware::neuralnetworks
28} // namespace android::hardware
29} // namespace android
30
31namespace driverTestHelpers
32{
33
Matteo Martincigh8b287c22018-09-07 09:25:10 +010034std::ostream& operator<<(std::ostream& os, V1_0::ErrorStatus stat);
surmeh0149b9e102018-05-17 14:11:25 +010035
Sadik Armagane6e54a82019-05-08 10:18:05 +010036struct ExecutionCallback : public V1_0::IExecutionCallback
surmeh0149b9e102018-05-17 14:11:25 +010037{
38 ExecutionCallback() : mNotified(false) {}
39 Return<void> notify(ErrorStatus status) override;
40 /// wait until the callback has notified us that it is done
41 Return<void> wait();
42
43private:
44 // use a mutex and a condition variable to wait for asynchronous callbacks
45 std::mutex mMutex;
46 std::condition_variable mCondition;
47 // and a flag, in case we are notified before the wait call
48 bool mNotified;
49};
50
Sadik Armagane6e54a82019-05-08 10:18:05 +010051class PreparedModelCallback : public V1_0::IPreparedModelCallback
surmeh0149b9e102018-05-17 14:11:25 +010052{
53public:
54 PreparedModelCallback()
55 : m_ErrorStatus(ErrorStatus::NONE)
56 , m_PreparedModel()
57 { }
58 ~PreparedModelCallback() override { }
59
60 Return<void> notify(ErrorStatus status,
Sadik Armagane6e54a82019-05-08 10:18:05 +010061 const android::sp<V1_0::IPreparedModel>& preparedModel) override;
surmeh0149b9e102018-05-17 14:11:25 +010062 ErrorStatus GetErrorStatus() { return m_ErrorStatus; }
Sadik Armagane6e54a82019-05-08 10:18:05 +010063 android::sp<V1_0::IPreparedModel> GetPreparedModel() { return m_PreparedModel; }
surmeh0149b9e102018-05-17 14:11:25 +010064
65private:
66 ErrorStatus m_ErrorStatus;
Sadik Armagane6e54a82019-05-08 10:18:05 +010067 android::sp<V1_0::IPreparedModel> m_PreparedModel;
surmeh0149b9e102018-05-17 14:11:25 +010068};
69
Ferran Balaguerb2397fd2019-07-25 12:12:39 +010070#ifdef ARMNN_ANDROID_NN_V1_2
71
72class PreparedModelCallback_1_2 : public V1_2::IPreparedModelCallback
73{
74public:
75 PreparedModelCallback_1_2()
76 : m_ErrorStatus(ErrorStatus::NONE)
77 , m_PreparedModel()
78 , m_PreparedModel_1_2()
79 { }
80 ~PreparedModelCallback_1_2() override { }
81
82 Return<void> notify(ErrorStatus status, const android::sp<V1_0::IPreparedModel>& preparedModel) override;
83
84 Return<void> notify_1_2(ErrorStatus status, const android::sp<V1_2::IPreparedModel>& preparedModel) override;
85
86 ErrorStatus GetErrorStatus() { return m_ErrorStatus; }
87
88 android::sp<V1_0::IPreparedModel> GetPreparedModel() { return m_PreparedModel; }
89
90 android::sp<V1_2::IPreparedModel> GetPreparedModel_1_2() { return m_PreparedModel_1_2; }
91
92private:
93 ErrorStatus m_ErrorStatus;
94 android::sp<V1_0::IPreparedModel> m_PreparedModel;
95 android::sp<V1_2::IPreparedModel> m_PreparedModel_1_2;
96};
97
98#endif
99
surmeh0149b9e102018-05-17 14:11:25 +0100100hidl_memory allocateSharedMemory(int64_t size);
101
102android::sp<IMemory> AddPoolAndGetData(uint32_t size, Request& request);
103
104void AddPoolAndSetData(uint32_t size, Request& request, const float* data);
105
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100106template<typename HalPolicy,
107 typename HalModel = typename HalPolicy::Model,
108 typename HalOperand = typename HalPolicy::Operand>
109void AddOperand(HalModel& model, const HalOperand& op)
Nikhil Raj77605822018-09-03 11:25:56 +0100110{
111 model.operands.resize(model.operands.size() + 1);
112 model.operands[model.operands.size() - 1] = op;
113}
surmeh0149b9e102018-05-17 14:11:25 +0100114
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100115template<typename HalPolicy, typename HalModel = typename HalPolicy::Model>
Nikhil Raj77605822018-09-03 11:25:56 +0100116void AddIntOperand(HalModel& model, int32_t value)
117{
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100118 using HalOperand = typename HalPolicy::Operand;
119 using HalOperandType = typename HalPolicy::OperandType;
120 using HalOperandLifeTime = typename HalPolicy::OperandLifeTime;
121
Nikhil Raj77605822018-09-03 11:25:56 +0100122 DataLocation location = {};
123 location.offset = model.operandValues.size();
124 location.length = sizeof(int32_t);
125
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100126 HalOperand op = {};
127 op.type = HalOperandType::INT32;
128 op.dimensions = hidl_vec<uint32_t>{};
129 op.lifetime = HalOperandLifeTime::CONSTANT_COPY;
130 op.location = location;
Nikhil Raj77605822018-09-03 11:25:56 +0100131
132 model.operandValues.resize(model.operandValues.size() + location.length);
133 *reinterpret_cast<int32_t*>(&model.operandValues[location.offset]) = value;
134
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100135 AddOperand<HalPolicy>(model, op);
136}
137
138template<typename HalPolicy, typename HalModel = typename HalPolicy::Model>
139void AddBoolOperand(HalModel& model, bool value)
140{
141 using HalOperand = typename HalPolicy::Operand;
142 using HalOperandType = typename HalPolicy::OperandType;
143 using HalOperandLifeTime = typename HalPolicy::OperandLifeTime;
144
145 DataLocation location = {};
146 location.offset = model.operandValues.size();
147 location.length = sizeof(uint8_t);
148
149 HalOperand op = {};
150 op.type = HalOperandType::BOOL;
151 op.dimensions = hidl_vec<uint32_t>{};
152 op.lifetime = HalOperandLifeTime::CONSTANT_COPY;
153 op.location = location;
154
155 model.operandValues.resize(model.operandValues.size() + location.length);
156 *reinterpret_cast<uint8_t*>(&model.operandValues[location.offset]) = static_cast<uint8_t>(value);
157
Nikhil Raj77605822018-09-03 11:25:56 +0100158 AddOperand<HalModel>(model, op);
159}
surmeh0149b9e102018-05-17 14:11:25 +0100160
161template<typename T>
162OperandType TypeToOperandType();
163
164template<>
165OperandType TypeToOperandType<float>();
166
167template<>
168OperandType TypeToOperandType<int32_t>();
169
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100170template<typename HalPolicy,
171 typename T,
172 typename HalModel = typename HalPolicy::Model,
173 typename HalOperandType = typename HalPolicy::OperandType,
174 typename HalOperandLifeTime = typename HalPolicy::OperandLifeTime>
Nikhil Raj77605822018-09-03 11:25:56 +0100175void AddTensorOperand(HalModel& model,
Matteo Martincighc7434122018-11-14 12:27:04 +0000176 const hidl_vec<uint32_t>& dimensions,
177 const T* values,
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100178 HalOperandType operandType = HalOperandType::TENSOR_FLOAT32,
179 HalOperandLifeTime operandLifeTime = HalOperandLifeTime::CONSTANT_COPY)
surmeh0149b9e102018-05-17 14:11:25 +0100180{
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100181 using HalOperand = typename HalPolicy::Operand;
182
surmeh0149b9e102018-05-17 14:11:25 +0100183 uint32_t totalElements = 1;
184 for (uint32_t dim : dimensions)
185 {
186 totalElements *= dim;
187 }
188
189 DataLocation location = {};
surmeh0149b9e102018-05-17 14:11:25 +0100190 location.length = totalElements * sizeof(T);
191
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100192 if(operandLifeTime == HalOperandLifeTime::CONSTANT_COPY)
Kevin Mayf29a2c52019-03-14 11:56:32 +0000193 {
194 location.offset = model.operandValues.size();
195 }
196
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100197 HalOperand op = {};
198 op.type = operandType;
199 op.dimensions = dimensions;
200 op.lifetime = HalOperandLifeTime::CONSTANT_COPY;
201 op.location = location;
surmeh0149b9e102018-05-17 14:11:25 +0100202
203 model.operandValues.resize(model.operandValues.size() + location.length);
204 for (uint32_t i = 0; i < totalElements; i++)
205 {
206 *(reinterpret_cast<T*>(&model.operandValues[location.offset]) + i) = values[i];
207 }
208
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100209 AddOperand<HalPolicy>(model, op);
surmeh0149b9e102018-05-17 14:11:25 +0100210}
211
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100212template<typename HalPolicy,
213 typename T,
214 typename HalModel = typename HalPolicy::Model,
215 typename HalOperandType = typename HalPolicy::OperandType,
216 typename HalOperandLifeTime = typename HalPolicy::OperandLifeTime>
Matteo Martincighc7434122018-11-14 12:27:04 +0000217void AddTensorOperand(HalModel& model,
218 const hidl_vec<uint32_t>& dimensions,
219 const std::vector<T>& values,
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100220 HalOperandType operandType = HalPolicy::OperandType::TENSOR_FLOAT32,
221 HalOperandLifeTime operandLifeTime = HalOperandLifeTime::CONSTANT_COPY)
Matteo Martincighc7434122018-11-14 12:27:04 +0000222{
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100223 AddTensorOperand<HalPolicy, T>(model, dimensions, values.data(), operandType, operandLifeTime);
Matteo Martincighc7434122018-11-14 12:27:04 +0000224}
225
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100226template<typename HalPolicy,
227 typename HalModel = typename HalPolicy::Model,
228 typename HalOperandType = typename HalPolicy::OperandType>
Nikhil Raj77605822018-09-03 11:25:56 +0100229void AddInputOperand(HalModel& model,
Matteo Martincighc7434122018-11-14 12:27:04 +0000230 const hidl_vec<uint32_t>& dimensions,
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100231 HalOperandType operandType = HalOperandType::TENSOR_FLOAT32)
Nikhil Raj77605822018-09-03 11:25:56 +0100232{
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100233 using HalOperand = typename HalPolicy::Operand;
234 using HalOperandLifeTime = typename HalPolicy::OperandLifeTime;
surmeh0149b9e102018-05-17 14:11:25 +0100235
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100236 HalOperand op = {};
237 op.type = operandType;
238 op.scale = operandType == HalOperandType::TENSOR_QUANT8_ASYMM ? 1.f / 255.f : 0.f;
239 op.dimensions = dimensions;
240 op.lifetime = HalOperandLifeTime::MODEL_INPUT;
241
242 AddOperand<HalPolicy>(model, op);
Nikhil Raj77605822018-09-03 11:25:56 +0100243
244 model.inputIndexes.resize(model.inputIndexes.size() + 1);
245 model.inputIndexes[model.inputIndexes.size() - 1] = model.operands.size() - 1;
246}
247
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100248template<typename HalPolicy,
249 typename HalModel = typename HalPolicy::Model,
250 typename HalOperandType = typename HalPolicy::OperandType>
Nikhil Raj77605822018-09-03 11:25:56 +0100251void AddOutputOperand(HalModel& model,
Matteo Martincighc7434122018-11-14 12:27:04 +0000252 const hidl_vec<uint32_t>& dimensions,
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100253 HalOperandType operandType = HalOperandType::TENSOR_FLOAT32)
Nikhil Raj77605822018-09-03 11:25:56 +0100254{
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100255 using HalOperand = typename HalPolicy::Operand;
256 using HalOperandLifeTime = typename HalPolicy::OperandLifeTime;
surmeh0149b9e102018-05-17 14:11:25 +0100257
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100258 HalOperand op = {};
259 op.type = operandType;
260 op.scale = operandType == HalOperandType::TENSOR_QUANT8_ASYMM ? 1.f / 255.f : 0.f;
261 op.dimensions = dimensions;
262 op.lifetime = HalOperandLifeTime::MODEL_OUTPUT;
263
264 AddOperand<HalPolicy>(model, op);
Nikhil Raj77605822018-09-03 11:25:56 +0100265
266 model.outputIndexes.resize(model.outputIndexes.size() + 1);
267 model.outputIndexes[model.outputIndexes.size() - 1] = model.operands.size() - 1;
268}
surmeh0149b9e102018-05-17 14:11:25 +0100269
Sadik Armagane6e54a82019-05-08 10:18:05 +0100270android::sp<V1_0::IPreparedModel> PrepareModelWithStatus(const V1_0::Model& model,
271 armnn_driver::ArmnnDriver& driver,
272 ErrorStatus& prepareStatus,
273 ErrorStatus expectedStatus = ErrorStatus::NONE);
Nikhil Raj77605822018-09-03 11:25:56 +0100274
Matteo Martincigha5f9e762019-06-17 13:26:34 +0100275#if defined(ARMNN_ANDROID_NN_V1_1) || defined(ARMNN_ANDROID_NN_V1_2)
Nikhil Raj77605822018-09-03 11:25:56 +0100276
Sadik Armagane6e54a82019-05-08 10:18:05 +0100277android::sp<V1_0::IPreparedModel> PrepareModelWithStatus(const V1_1::Model& model,
Matteo Martincigha5f9e762019-06-17 13:26:34 +0100278 armnn_driver::ArmnnDriver& driver,
279 ErrorStatus& prepareStatus,
280 ErrorStatus expectedStatus = ErrorStatus::NONE);
Nikhil Raj77605822018-09-03 11:25:56 +0100281
282#endif
283
284template<typename HalModel>
Sadik Armagane6e54a82019-05-08 10:18:05 +0100285android::sp<V1_0::IPreparedModel> PrepareModel(const HalModel& model,
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100286 armnn_driver::ArmnnDriver& driver)
Nikhil Raj77605822018-09-03 11:25:56 +0100287{
288 ErrorStatus prepareStatus = ErrorStatus::NONE;
289 return PrepareModelWithStatus(model, driver, prepareStatus);
290}
surmeh0149b9e102018-05-17 14:11:25 +0100291
Ferran Balaguerb2397fd2019-07-25 12:12:39 +0100292#ifdef ARMNN_ANDROID_NN_V1_2
293
294android::sp<V1_2::IPreparedModel> PrepareModelWithStatus_1_2(const armnn_driver::hal_1_2::HalPolicy::Model& model,
295 armnn_driver::ArmnnDriver& driver,
296 ErrorStatus& prepareStatus,
297 ErrorStatus expectedStatus = ErrorStatus::NONE);
298
299template<typename HalModel>
300android::sp<V1_2::IPreparedModel> PrepareModel_1_2(const HalModel& model,
301 armnn_driver::ArmnnDriver& driver)
302{
303 ErrorStatus prepareStatus = ErrorStatus::NONE;
304 return PrepareModelWithStatus_1_2(model, driver, prepareStatus);
305}
306
307#endif
308
309
Sadik Armagane6e54a82019-05-08 10:18:05 +0100310ErrorStatus Execute(android::sp<V1_0::IPreparedModel> preparedModel,
surmeh0149b9e102018-05-17 14:11:25 +0100311 const Request& request,
Nikhil Raj77605822018-09-03 11:25:56 +0100312 ErrorStatus expectedStatus = ErrorStatus::NONE);
surmeh0149b9e102018-05-17 14:11:25 +0100313
Sadik Armagane6e54a82019-05-08 10:18:05 +0100314android::sp<ExecutionCallback> ExecuteNoWait(android::sp<V1_0::IPreparedModel> preparedModel,
surmeh0149b9e102018-05-17 14:11:25 +0100315 const Request& request);
316
317} // namespace driverTestHelpers