blob: 980b3a724c863d55b8217ede2e07436741039010 [file] [log] [blame]
surmeh0149b9e102018-05-17 14:11:25 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beck93e48982018-09-05 13:05:09 +01003// SPDX-License-Identifier: MIT
surmeh0149b9e102018-05-17 14:11:25 +01004//
5#pragma once
6
7#ifndef LOG_TAG
8#define LOG_TAG "ArmnnDriverTests"
9#endif // LOG_TAG
10
11#include "../ArmnnDriver.hpp"
12#include <iosfwd>
Nikhil Raj77605822018-09-03 11:25:56 +010013#include <boost/test/unit_test.hpp>
surmeh0149b9e102018-05-17 14:11:25 +010014
15namespace android
16{
17namespace hardware
18{
19namespace neuralnetworks
20{
21namespace V1_0
22{
23
24std::ostream& operator<<(std::ostream& os, ErrorStatus stat);
25
26} // namespace android::hardware::neuralnetworks::V1_0
27} // namespace android::hardware::neuralnetworks
28} // namespace android::hardware
29} // namespace android
30
31namespace driverTestHelpers
32{
33
Matteo Martincigh8b287c22018-09-07 09:25:10 +010034std::ostream& operator<<(std::ostream& os, V1_0::ErrorStatus stat);
surmeh0149b9e102018-05-17 14:11:25 +010035
Sadik Armagane6e54a82019-05-08 10:18:05 +010036struct ExecutionCallback : public V1_0::IExecutionCallback
surmeh0149b9e102018-05-17 14:11:25 +010037{
38 ExecutionCallback() : mNotified(false) {}
39 Return<void> notify(ErrorStatus status) override;
40 /// wait until the callback has notified us that it is done
41 Return<void> wait();
42
43private:
44 // use a mutex and a condition variable to wait for asynchronous callbacks
45 std::mutex mMutex;
46 std::condition_variable mCondition;
47 // and a flag, in case we are notified before the wait call
48 bool mNotified;
49};
50
Sadik Armagane6e54a82019-05-08 10:18:05 +010051class PreparedModelCallback : public V1_0::IPreparedModelCallback
surmeh0149b9e102018-05-17 14:11:25 +010052{
53public:
54 PreparedModelCallback()
55 : m_ErrorStatus(ErrorStatus::NONE)
56 , m_PreparedModel()
57 { }
58 ~PreparedModelCallback() override { }
59
60 Return<void> notify(ErrorStatus status,
Sadik Armagane6e54a82019-05-08 10:18:05 +010061 const android::sp<V1_0::IPreparedModel>& preparedModel) override;
surmeh0149b9e102018-05-17 14:11:25 +010062 ErrorStatus GetErrorStatus() { return m_ErrorStatus; }
Sadik Armagane6e54a82019-05-08 10:18:05 +010063 android::sp<V1_0::IPreparedModel> GetPreparedModel() { return m_PreparedModel; }
surmeh0149b9e102018-05-17 14:11:25 +010064
65private:
66 ErrorStatus m_ErrorStatus;
Sadik Armagane6e54a82019-05-08 10:18:05 +010067 android::sp<V1_0::IPreparedModel> m_PreparedModel;
surmeh0149b9e102018-05-17 14:11:25 +010068};
69
70hidl_memory allocateSharedMemory(int64_t size);
71
72android::sp<IMemory> AddPoolAndGetData(uint32_t size, Request& request);
73
74void AddPoolAndSetData(uint32_t size, Request& request, const float* data);
75
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +010076template<typename HalPolicy,
77 typename HalModel = typename HalPolicy::Model,
78 typename HalOperand = typename HalPolicy::Operand>
79void AddOperand(HalModel& model, const HalOperand& op)
Nikhil Raj77605822018-09-03 11:25:56 +010080{
81 model.operands.resize(model.operands.size() + 1);
82 model.operands[model.operands.size() - 1] = op;
83}
surmeh0149b9e102018-05-17 14:11:25 +010084
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +010085template<typename HalPolicy, typename HalModel = typename HalPolicy::Model>
Nikhil Raj77605822018-09-03 11:25:56 +010086void AddIntOperand(HalModel& model, int32_t value)
87{
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +010088 using HalOperand = typename HalPolicy::Operand;
89 using HalOperandType = typename HalPolicy::OperandType;
90 using HalOperandLifeTime = typename HalPolicy::OperandLifeTime;
91
Nikhil Raj77605822018-09-03 11:25:56 +010092 DataLocation location = {};
93 location.offset = model.operandValues.size();
94 location.length = sizeof(int32_t);
95
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +010096 HalOperand op = {};
97 op.type = HalOperandType::INT32;
98 op.dimensions = hidl_vec<uint32_t>{};
99 op.lifetime = HalOperandLifeTime::CONSTANT_COPY;
100 op.location = location;
Nikhil Raj77605822018-09-03 11:25:56 +0100101
102 model.operandValues.resize(model.operandValues.size() + location.length);
103 *reinterpret_cast<int32_t*>(&model.operandValues[location.offset]) = value;
104
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100105 AddOperand<HalPolicy>(model, op);
106}
107
108template<typename HalPolicy, typename HalModel = typename HalPolicy::Model>
109void AddBoolOperand(HalModel& model, bool value)
110{
111 using HalOperand = typename HalPolicy::Operand;
112 using HalOperandType = typename HalPolicy::OperandType;
113 using HalOperandLifeTime = typename HalPolicy::OperandLifeTime;
114
115 DataLocation location = {};
116 location.offset = model.operandValues.size();
117 location.length = sizeof(uint8_t);
118
119 HalOperand op = {};
120 op.type = HalOperandType::BOOL;
121 op.dimensions = hidl_vec<uint32_t>{};
122 op.lifetime = HalOperandLifeTime::CONSTANT_COPY;
123 op.location = location;
124
125 model.operandValues.resize(model.operandValues.size() + location.length);
126 *reinterpret_cast<uint8_t*>(&model.operandValues[location.offset]) = static_cast<uint8_t>(value);
127
Nikhil Raj77605822018-09-03 11:25:56 +0100128 AddOperand<HalModel>(model, op);
129}
surmeh0149b9e102018-05-17 14:11:25 +0100130
131template<typename T>
132OperandType TypeToOperandType();
133
134template<>
135OperandType TypeToOperandType<float>();
136
137template<>
138OperandType TypeToOperandType<int32_t>();
139
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100140template<typename HalPolicy,
141 typename T,
142 typename HalModel = typename HalPolicy::Model,
143 typename HalOperandType = typename HalPolicy::OperandType,
144 typename HalOperandLifeTime = typename HalPolicy::OperandLifeTime>
Nikhil Raj77605822018-09-03 11:25:56 +0100145void AddTensorOperand(HalModel& model,
Matteo Martincighc7434122018-11-14 12:27:04 +0000146 const hidl_vec<uint32_t>& dimensions,
147 const T* values,
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100148 HalOperandType operandType = HalOperandType::TENSOR_FLOAT32,
149 HalOperandLifeTime operandLifeTime = HalOperandLifeTime::CONSTANT_COPY)
surmeh0149b9e102018-05-17 14:11:25 +0100150{
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100151 using HalOperand = typename HalPolicy::Operand;
152
surmeh0149b9e102018-05-17 14:11:25 +0100153 uint32_t totalElements = 1;
154 for (uint32_t dim : dimensions)
155 {
156 totalElements *= dim;
157 }
158
159 DataLocation location = {};
surmeh0149b9e102018-05-17 14:11:25 +0100160 location.length = totalElements * sizeof(T);
161
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100162 if(operandLifeTime == HalOperandLifeTime::CONSTANT_COPY)
Kevin Mayf29a2c52019-03-14 11:56:32 +0000163 {
164 location.offset = model.operandValues.size();
165 }
166
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100167 HalOperand op = {};
168 op.type = operandType;
169 op.dimensions = dimensions;
170 op.lifetime = HalOperandLifeTime::CONSTANT_COPY;
171 op.location = location;
surmeh0149b9e102018-05-17 14:11:25 +0100172
173 model.operandValues.resize(model.operandValues.size() + location.length);
174 for (uint32_t i = 0; i < totalElements; i++)
175 {
176 *(reinterpret_cast<T*>(&model.operandValues[location.offset]) + i) = values[i];
177 }
178
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100179 AddOperand<HalPolicy>(model, op);
surmeh0149b9e102018-05-17 14:11:25 +0100180}
181
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100182template<typename HalPolicy,
183 typename T,
184 typename HalModel = typename HalPolicy::Model,
185 typename HalOperandType = typename HalPolicy::OperandType,
186 typename HalOperandLifeTime = typename HalPolicy::OperandLifeTime>
Matteo Martincighc7434122018-11-14 12:27:04 +0000187void AddTensorOperand(HalModel& model,
188 const hidl_vec<uint32_t>& dimensions,
189 const std::vector<T>& values,
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100190 HalOperandType operandType = HalPolicy::OperandType::TENSOR_FLOAT32,
191 HalOperandLifeTime operandLifeTime = HalOperandLifeTime::CONSTANT_COPY)
Matteo Martincighc7434122018-11-14 12:27:04 +0000192{
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100193 AddTensorOperand<HalPolicy, T>(model, dimensions, values.data(), operandType, operandLifeTime);
Matteo Martincighc7434122018-11-14 12:27:04 +0000194}
195
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100196template<typename HalPolicy,
197 typename HalModel = typename HalPolicy::Model,
198 typename HalOperandType = typename HalPolicy::OperandType>
Nikhil Raj77605822018-09-03 11:25:56 +0100199void AddInputOperand(HalModel& model,
Matteo Martincighc7434122018-11-14 12:27:04 +0000200 const hidl_vec<uint32_t>& dimensions,
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100201 HalOperandType operandType = HalOperandType::TENSOR_FLOAT32)
Nikhil Raj77605822018-09-03 11:25:56 +0100202{
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100203 using HalOperand = typename HalPolicy::Operand;
204 using HalOperandLifeTime = typename HalPolicy::OperandLifeTime;
surmeh0149b9e102018-05-17 14:11:25 +0100205
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100206 HalOperand op = {};
207 op.type = operandType;
208 op.scale = operandType == HalOperandType::TENSOR_QUANT8_ASYMM ? 1.f / 255.f : 0.f;
209 op.dimensions = dimensions;
210 op.lifetime = HalOperandLifeTime::MODEL_INPUT;
211
212 AddOperand<HalPolicy>(model, op);
Nikhil Raj77605822018-09-03 11:25:56 +0100213
214 model.inputIndexes.resize(model.inputIndexes.size() + 1);
215 model.inputIndexes[model.inputIndexes.size() - 1] = model.operands.size() - 1;
216}
217
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100218template<typename HalPolicy,
219 typename HalModel = typename HalPolicy::Model,
220 typename HalOperandType = typename HalPolicy::OperandType>
Nikhil Raj77605822018-09-03 11:25:56 +0100221void AddOutputOperand(HalModel& model,
Matteo Martincighc7434122018-11-14 12:27:04 +0000222 const hidl_vec<uint32_t>& dimensions,
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100223 HalOperandType operandType = HalOperandType::TENSOR_FLOAT32)
Nikhil Raj77605822018-09-03 11:25:56 +0100224{
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100225 using HalOperand = typename HalPolicy::Operand;
226 using HalOperandLifeTime = typename HalPolicy::OperandLifeTime;
surmeh0149b9e102018-05-17 14:11:25 +0100227
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100228 HalOperand op = {};
229 op.type = operandType;
230 op.scale = operandType == HalOperandType::TENSOR_QUANT8_ASYMM ? 1.f / 255.f : 0.f;
231 op.dimensions = dimensions;
232 op.lifetime = HalOperandLifeTime::MODEL_OUTPUT;
233
234 AddOperand<HalPolicy>(model, op);
Nikhil Raj77605822018-09-03 11:25:56 +0100235
236 model.outputIndexes.resize(model.outputIndexes.size() + 1);
237 model.outputIndexes[model.outputIndexes.size() - 1] = model.operands.size() - 1;
238}
surmeh0149b9e102018-05-17 14:11:25 +0100239
Sadik Armagane6e54a82019-05-08 10:18:05 +0100240android::sp<V1_0::IPreparedModel> PrepareModelWithStatus(const V1_0::Model& model,
241 armnn_driver::ArmnnDriver& driver,
242 ErrorStatus& prepareStatus,
243 ErrorStatus expectedStatus = ErrorStatus::NONE);
Nikhil Raj77605822018-09-03 11:25:56 +0100244
Matteo Martincigha5f9e762019-06-17 13:26:34 +0100245#if defined(ARMNN_ANDROID_NN_V1_1) || defined(ARMNN_ANDROID_NN_V1_2)
Nikhil Raj77605822018-09-03 11:25:56 +0100246
Sadik Armagane6e54a82019-05-08 10:18:05 +0100247android::sp<V1_0::IPreparedModel> PrepareModelWithStatus(const V1_1::Model& model,
Matteo Martincigha5f9e762019-06-17 13:26:34 +0100248 armnn_driver::ArmnnDriver& driver,
249 ErrorStatus& prepareStatus,
250 ErrorStatus expectedStatus = ErrorStatus::NONE);
Nikhil Raj77605822018-09-03 11:25:56 +0100251
252#endif
253
254template<typename HalModel>
Sadik Armagane6e54a82019-05-08 10:18:05 +0100255android::sp<V1_0::IPreparedModel> PrepareModel(const HalModel& model,
Aron Virginas-Tar44cfd842019-06-14 15:45:03 +0100256 armnn_driver::ArmnnDriver& driver)
Nikhil Raj77605822018-09-03 11:25:56 +0100257{
258 ErrorStatus prepareStatus = ErrorStatus::NONE;
259 return PrepareModelWithStatus(model, driver, prepareStatus);
260}
surmeh0149b9e102018-05-17 14:11:25 +0100261
Sadik Armagane6e54a82019-05-08 10:18:05 +0100262ErrorStatus Execute(android::sp<V1_0::IPreparedModel> preparedModel,
surmeh0149b9e102018-05-17 14:11:25 +0100263 const Request& request,
Nikhil Raj77605822018-09-03 11:25:56 +0100264 ErrorStatus expectedStatus = ErrorStatus::NONE);
surmeh0149b9e102018-05-17 14:11:25 +0100265
Sadik Armagane6e54a82019-05-08 10:18:05 +0100266android::sp<ExecutionCallback> ExecuteNoWait(android::sp<V1_0::IPreparedModel> preparedModel,
surmeh0149b9e102018-05-17 14:11:25 +0100267 const Request& request);
268
269} // namespace driverTestHelpers