blob: 9da02603bfd579097c519fdebec19f16b1f9d791 [file] [log] [blame]
surmeh0149b9e102018-05-17 14:11:25 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beck93e48982018-09-05 13:05:09 +01003// SPDX-License-Identifier: MIT
surmeh0149b9e102018-05-17 14:11:25 +01004//
5#pragma once
6
7#ifndef LOG_TAG
8#define LOG_TAG "ArmnnDriverTests"
9#endif // LOG_TAG
10
11#include "../ArmnnDriver.hpp"
12#include <iosfwd>
Nikhil Raj77605822018-09-03 11:25:56 +010013#include <boost/test/unit_test.hpp>
surmeh0149b9e102018-05-17 14:11:25 +010014
15namespace android
16{
17namespace hardware
18{
19namespace neuralnetworks
20{
21namespace V1_0
22{
23
24std::ostream& operator<<(std::ostream& os, ErrorStatus stat);
25
26} // namespace android::hardware::neuralnetworks::V1_0
27} // namespace android::hardware::neuralnetworks
28} // namespace android::hardware
29} // namespace android
30
31namespace driverTestHelpers
32{
33
Matteo Martincigh8b287c22018-09-07 09:25:10 +010034std::ostream& operator<<(std::ostream& os, V1_0::ErrorStatus stat);
surmeh0149b9e102018-05-17 14:11:25 +010035
Sadik Armagane6e54a82019-05-08 10:18:05 +010036struct ExecutionCallback : public V1_0::IExecutionCallback
surmeh0149b9e102018-05-17 14:11:25 +010037{
38 ExecutionCallback() : mNotified(false) {}
39 Return<void> notify(ErrorStatus status) override;
40 /// wait until the callback has notified us that it is done
41 Return<void> wait();
42
43private:
44 // use a mutex and a condition variable to wait for asynchronous callbacks
45 std::mutex mMutex;
46 std::condition_variable mCondition;
47 // and a flag, in case we are notified before the wait call
48 bool mNotified;
49};
50
Sadik Armagane6e54a82019-05-08 10:18:05 +010051class PreparedModelCallback : public V1_0::IPreparedModelCallback
surmeh0149b9e102018-05-17 14:11:25 +010052{
53public:
54 PreparedModelCallback()
55 : m_ErrorStatus(ErrorStatus::NONE)
56 , m_PreparedModel()
57 { }
58 ~PreparedModelCallback() override { }
59
60 Return<void> notify(ErrorStatus status,
Sadik Armagane6e54a82019-05-08 10:18:05 +010061 const android::sp<V1_0::IPreparedModel>& preparedModel) override;
surmeh0149b9e102018-05-17 14:11:25 +010062 ErrorStatus GetErrorStatus() { return m_ErrorStatus; }
Sadik Armagane6e54a82019-05-08 10:18:05 +010063 android::sp<V1_0::IPreparedModel> GetPreparedModel() { return m_PreparedModel; }
surmeh0149b9e102018-05-17 14:11:25 +010064
65private:
66 ErrorStatus m_ErrorStatus;
Sadik Armagane6e54a82019-05-08 10:18:05 +010067 android::sp<V1_0::IPreparedModel> m_PreparedModel;
surmeh0149b9e102018-05-17 14:11:25 +010068};
69
Ferran Balaguerb2397fd2019-07-25 12:12:39 +010070#ifdef ARMNN_ANDROID_NN_V1_2
71
72class PreparedModelCallback_1_2 : public V1_2::IPreparedModelCallback
73{
74public:
75 PreparedModelCallback_1_2()
76 : m_ErrorStatus(ErrorStatus::NONE)
77 , m_PreparedModel()
78 , m_PreparedModel_1_2()
79 { }
80 ~PreparedModelCallback_1_2() override { }
81
82 Return<void> notify(ErrorStatus status, const android::sp<V1_0::IPreparedModel>& preparedModel) override;
83
84 Return<void> notify_1_2(ErrorStatus status, const android::sp<V1_2::IPreparedModel>& preparedModel) override;
85
86 ErrorStatus GetErrorStatus() { return m_ErrorStatus; }
87
88 android::sp<V1_0::IPreparedModel> GetPreparedModel() { return m_PreparedModel; }
89
90 android::sp<V1_2::IPreparedModel> GetPreparedModel_1_2() { return m_PreparedModel_1_2; }
91
92private:
93 ErrorStatus m_ErrorStatus;
94 android::sp<V1_0::IPreparedModel> m_PreparedModel;
95 android::sp<V1_2::IPreparedModel> m_PreparedModel_1_2;
96};
97
98#endif
99
surmeh0149b9e102018-05-17 14:11:25 +0100100hidl_memory allocateSharedMemory(int64_t size);
101
Ellen Norris-Thompson976ad3e2019-08-21 15:21:14 +0100102template<typename T>
103android::sp<IMemory> AddPoolAndGetData(uint32_t size, Request& request)
104{
105 hidl_memory pool;
surmeh0149b9e102018-05-17 14:11:25 +0100106
Ellen Norris-Thompson976ad3e2019-08-21 15:21:14 +0100107 android::sp<IAllocator> allocator = IAllocator::getService("ashmem");
108 allocator->allocate(sizeof(T) * size, [&](bool success, const hidl_memory& mem) {
109 BOOST_TEST(success);
110 pool = mem;
111 });
112
113 request.pools.resize(request.pools.size() + 1);
114 request.pools[request.pools.size() - 1] = pool;
115
116 android::sp<IMemory> mapped = mapMemory(pool);
117 mapped->update();
118 return mapped;
119}
120
121template<typename T>
122void AddPoolAndSetData(uint32_t size, Request& request, const T* data)
123{
124 android::sp<IMemory> memory = AddPoolAndGetData<T>(size, request);
125
126 T* dst = static_cast<T*>(static_cast<void*>(memory->getPointer()));
127
128 memcpy(dst, data, size * sizeof(T));
129}
surmeh0149b9e102018-05-17 14:11:25 +0100130
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100131template<typename HalPolicy,
132 typename HalModel = typename HalPolicy::Model,
133 typename HalOperand = typename HalPolicy::Operand>
134void AddOperand(HalModel& model, const HalOperand& op)
Nikhil Raj77605822018-09-03 11:25:56 +0100135{
136 model.operands.resize(model.operands.size() + 1);
137 model.operands[model.operands.size() - 1] = op;
138}
surmeh0149b9e102018-05-17 14:11:25 +0100139
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100140template<typename HalPolicy, typename HalModel = typename HalPolicy::Model>
Nikhil Raj77605822018-09-03 11:25:56 +0100141void AddIntOperand(HalModel& model, int32_t value)
142{
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100143 using HalOperand = typename HalPolicy::Operand;
144 using HalOperandType = typename HalPolicy::OperandType;
145 using HalOperandLifeTime = typename HalPolicy::OperandLifeTime;
146
Nikhil Raj77605822018-09-03 11:25:56 +0100147 DataLocation location = {};
148 location.offset = model.operandValues.size();
149 location.length = sizeof(int32_t);
150
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100151 HalOperand op = {};
152 op.type = HalOperandType::INT32;
153 op.dimensions = hidl_vec<uint32_t>{};
154 op.lifetime = HalOperandLifeTime::CONSTANT_COPY;
155 op.location = location;
Nikhil Raj77605822018-09-03 11:25:56 +0100156
157 model.operandValues.resize(model.operandValues.size() + location.length);
158 *reinterpret_cast<int32_t*>(&model.operandValues[location.offset]) = value;
159
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100160 AddOperand<HalPolicy>(model, op);
161}
162
163template<typename HalPolicy, typename HalModel = typename HalPolicy::Model>
164void AddBoolOperand(HalModel& model, bool value)
165{
166 using HalOperand = typename HalPolicy::Operand;
167 using HalOperandType = typename HalPolicy::OperandType;
168 using HalOperandLifeTime = typename HalPolicy::OperandLifeTime;
169
170 DataLocation location = {};
171 location.offset = model.operandValues.size();
172 location.length = sizeof(uint8_t);
173
174 HalOperand op = {};
175 op.type = HalOperandType::BOOL;
176 op.dimensions = hidl_vec<uint32_t>{};
177 op.lifetime = HalOperandLifeTime::CONSTANT_COPY;
178 op.location = location;
179
180 model.operandValues.resize(model.operandValues.size() + location.length);
181 *reinterpret_cast<uint8_t*>(&model.operandValues[location.offset]) = static_cast<uint8_t>(value);
182
Nikhil Raj77605822018-09-03 11:25:56 +0100183 AddOperand<HalModel>(model, op);
184}
surmeh0149b9e102018-05-17 14:11:25 +0100185
186template<typename T>
187OperandType TypeToOperandType();
188
189template<>
190OperandType TypeToOperandType<float>();
191
192template<>
193OperandType TypeToOperandType<int32_t>();
194
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100195template<typename HalPolicy,
196 typename T,
197 typename HalModel = typename HalPolicy::Model,
198 typename HalOperandType = typename HalPolicy::OperandType,
199 typename HalOperandLifeTime = typename HalPolicy::OperandLifeTime>
Nikhil Raj77605822018-09-03 11:25:56 +0100200void AddTensorOperand(HalModel& model,
Matteo Martincighc7434122018-11-14 12:27:04 +0000201 const hidl_vec<uint32_t>& dimensions,
202 const T* values,
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100203 HalOperandType operandType = HalOperandType::TENSOR_FLOAT32,
Ellen Norris-Thompson976ad3e2019-08-21 15:21:14 +0100204 HalOperandLifeTime operandLifeTime = HalOperandLifeTime::CONSTANT_COPY,
205 double scale = 0.f,
206 int offset = 0)
surmeh0149b9e102018-05-17 14:11:25 +0100207{
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100208 using HalOperand = typename HalPolicy::Operand;
209
surmeh0149b9e102018-05-17 14:11:25 +0100210 uint32_t totalElements = 1;
211 for (uint32_t dim : dimensions)
212 {
213 totalElements *= dim;
214 }
215
216 DataLocation location = {};
surmeh0149b9e102018-05-17 14:11:25 +0100217 location.length = totalElements * sizeof(T);
218
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100219 if(operandLifeTime == HalOperandLifeTime::CONSTANT_COPY)
Kevin Mayf29a2c52019-03-14 11:56:32 +0000220 {
221 location.offset = model.operandValues.size();
222 }
223
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100224 HalOperand op = {};
225 op.type = operandType;
226 op.dimensions = dimensions;
Ellen Norris-Thompson976ad3e2019-08-21 15:21:14 +0100227 op.scale = scale;
228 op.zeroPoint = offset;
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100229 op.lifetime = HalOperandLifeTime::CONSTANT_COPY;
230 op.location = location;
surmeh0149b9e102018-05-17 14:11:25 +0100231
232 model.operandValues.resize(model.operandValues.size() + location.length);
233 for (uint32_t i = 0; i < totalElements; i++)
234 {
235 *(reinterpret_cast<T*>(&model.operandValues[location.offset]) + i) = values[i];
236 }
237
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100238 AddOperand<HalPolicy>(model, op);
surmeh0149b9e102018-05-17 14:11:25 +0100239}
240
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100241template<typename HalPolicy,
242 typename T,
243 typename HalModel = typename HalPolicy::Model,
244 typename HalOperandType = typename HalPolicy::OperandType,
245 typename HalOperandLifeTime = typename HalPolicy::OperandLifeTime>
Matteo Martincighc7434122018-11-14 12:27:04 +0000246void AddTensorOperand(HalModel& model,
247 const hidl_vec<uint32_t>& dimensions,
248 const std::vector<T>& values,
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100249 HalOperandType operandType = HalPolicy::OperandType::TENSOR_FLOAT32,
Ellen Norris-Thompson976ad3e2019-08-21 15:21:14 +0100250 HalOperandLifeTime operandLifeTime = HalOperandLifeTime::CONSTANT_COPY,
251 double scale = 0.f,
252 int offset = 0)
Matteo Martincighc7434122018-11-14 12:27:04 +0000253{
Ellen Norris-Thompson976ad3e2019-08-21 15:21:14 +0100254 AddTensorOperand<HalPolicy, T>(model, dimensions, values.data(), operandType, operandLifeTime, scale, offset);
Matteo Martincighc7434122018-11-14 12:27:04 +0000255}
256
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100257template<typename HalPolicy,
258 typename HalModel = typename HalPolicy::Model,
259 typename HalOperandType = typename HalPolicy::OperandType>
Nikhil Raj77605822018-09-03 11:25:56 +0100260void AddInputOperand(HalModel& model,
Matteo Martincighc7434122018-11-14 12:27:04 +0000261 const hidl_vec<uint32_t>& dimensions,
Ellen Norris-Thompson976ad3e2019-08-21 15:21:14 +0100262 HalOperandType operandType = HalOperandType::TENSOR_FLOAT32,
263 double scale = 0.f,
264 int offset = 0)
Nikhil Raj77605822018-09-03 11:25:56 +0100265{
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100266 using HalOperand = typename HalPolicy::Operand;
267 using HalOperandLifeTime = typename HalPolicy::OperandLifeTime;
surmeh0149b9e102018-05-17 14:11:25 +0100268
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100269 HalOperand op = {};
270 op.type = operandType;
Ellen Norris-Thompson976ad3e2019-08-21 15:21:14 +0100271 op.scale = scale;
272 op.zeroPoint = offset;
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100273 op.dimensions = dimensions;
274 op.lifetime = HalOperandLifeTime::MODEL_INPUT;
275
276 AddOperand<HalPolicy>(model, op);
Nikhil Raj77605822018-09-03 11:25:56 +0100277
278 model.inputIndexes.resize(model.inputIndexes.size() + 1);
279 model.inputIndexes[model.inputIndexes.size() - 1] = model.operands.size() - 1;
280}
281
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100282template<typename HalPolicy,
283 typename HalModel = typename HalPolicy::Model,
284 typename HalOperandType = typename HalPolicy::OperandType>
Nikhil Raj77605822018-09-03 11:25:56 +0100285void AddOutputOperand(HalModel& model,
Matteo Martincighc7434122018-11-14 12:27:04 +0000286 const hidl_vec<uint32_t>& dimensions,
Ellen Norris-Thompson976ad3e2019-08-21 15:21:14 +0100287 HalOperandType operandType = HalOperandType::TENSOR_FLOAT32,
288 double scale = 0.f,
289 int offset = 0)
Nikhil Raj77605822018-09-03 11:25:56 +0100290{
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100291 using HalOperand = typename HalPolicy::Operand;
292 using HalOperandLifeTime = typename HalPolicy::OperandLifeTime;
surmeh0149b9e102018-05-17 14:11:25 +0100293
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100294 HalOperand op = {};
295 op.type = operandType;
Ellen Norris-Thompson976ad3e2019-08-21 15:21:14 +0100296 op.scale = scale;
297 op.zeroPoint = offset;
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100298 op.dimensions = dimensions;
299 op.lifetime = HalOperandLifeTime::MODEL_OUTPUT;
300
301 AddOperand<HalPolicy>(model, op);
Nikhil Raj77605822018-09-03 11:25:56 +0100302
303 model.outputIndexes.resize(model.outputIndexes.size() + 1);
304 model.outputIndexes[model.outputIndexes.size() - 1] = model.operands.size() - 1;
305}
surmeh0149b9e102018-05-17 14:11:25 +0100306
Sadik Armagane6e54a82019-05-08 10:18:05 +0100307android::sp<V1_0::IPreparedModel> PrepareModelWithStatus(const V1_0::Model& model,
308 armnn_driver::ArmnnDriver& driver,
309 ErrorStatus& prepareStatus,
310 ErrorStatus expectedStatus = ErrorStatus::NONE);
Nikhil Raj77605822018-09-03 11:25:56 +0100311
Matteo Martincigha5f9e762019-06-17 13:26:34 +0100312#if defined(ARMNN_ANDROID_NN_V1_1) || defined(ARMNN_ANDROID_NN_V1_2)
Nikhil Raj77605822018-09-03 11:25:56 +0100313
Sadik Armagane6e54a82019-05-08 10:18:05 +0100314android::sp<V1_0::IPreparedModel> PrepareModelWithStatus(const V1_1::Model& model,
Matteo Martincigha5f9e762019-06-17 13:26:34 +0100315 armnn_driver::ArmnnDriver& driver,
316 ErrorStatus& prepareStatus,
317 ErrorStatus expectedStatus = ErrorStatus::NONE);
Nikhil Raj77605822018-09-03 11:25:56 +0100318
319#endif
320
321template<typename HalModel>
Sadik Armagane6e54a82019-05-08 10:18:05 +0100322android::sp<V1_0::IPreparedModel> PrepareModel(const HalModel& model,
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100323 armnn_driver::ArmnnDriver& driver)
Nikhil Raj77605822018-09-03 11:25:56 +0100324{
325 ErrorStatus prepareStatus = ErrorStatus::NONE;
326 return PrepareModelWithStatus(model, driver, prepareStatus);
327}
surmeh0149b9e102018-05-17 14:11:25 +0100328
Ferran Balaguerb2397fd2019-07-25 12:12:39 +0100329#ifdef ARMNN_ANDROID_NN_V1_2
330
331android::sp<V1_2::IPreparedModel> PrepareModelWithStatus_1_2(const armnn_driver::hal_1_2::HalPolicy::Model& model,
332 armnn_driver::ArmnnDriver& driver,
333 ErrorStatus& prepareStatus,
334 ErrorStatus expectedStatus = ErrorStatus::NONE);
335
336template<typename HalModel>
337android::sp<V1_2::IPreparedModel> PrepareModel_1_2(const HalModel& model,
338 armnn_driver::ArmnnDriver& driver)
339{
340 ErrorStatus prepareStatus = ErrorStatus::NONE;
341 return PrepareModelWithStatus_1_2(model, driver, prepareStatus);
342}
343
344#endif
345
346
Sadik Armagane6e54a82019-05-08 10:18:05 +0100347ErrorStatus Execute(android::sp<V1_0::IPreparedModel> preparedModel,
surmeh0149b9e102018-05-17 14:11:25 +0100348 const Request& request,
Nikhil Raj77605822018-09-03 11:25:56 +0100349 ErrorStatus expectedStatus = ErrorStatus::NONE);
surmeh0149b9e102018-05-17 14:11:25 +0100350
Sadik Armagane6e54a82019-05-08 10:18:05 +0100351android::sp<ExecutionCallback> ExecuteNoWait(android::sp<V1_0::IPreparedModel> preparedModel,
surmeh0149b9e102018-05-17 14:11:25 +0100352 const Request& request);
353
354} // namespace driverTestHelpers