blob: 4a8f607e931de1bc394f3415b9c3a1bf166e65c0 [file] [log] [blame]
surmeh0149b9e102018-05-17 14:11:25 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beck93e48982018-09-05 13:05:09 +01003// SPDX-License-Identifier: MIT
surmeh0149b9e102018-05-17 14:11:25 +01004//
5#pragma once
6
7#ifndef LOG_TAG
8#define LOG_TAG "ArmnnDriverTests"
9#endif // LOG_TAG
10
11#include "../ArmnnDriver.hpp"
12#include <iosfwd>
Nikhil Raj77605822018-09-03 11:25:56 +010013#include <boost/test/unit_test.hpp>
surmeh0149b9e102018-05-17 14:11:25 +010014
15namespace android
16{
17namespace hardware
18{
19namespace neuralnetworks
20{
21namespace V1_0
22{
23
24std::ostream& operator<<(std::ostream& os, ErrorStatus stat);
25
26} // namespace android::hardware::neuralnetworks::V1_0
27} // namespace android::hardware::neuralnetworks
28} // namespace android::hardware
29} // namespace android
30
31namespace driverTestHelpers
32{
33
Matteo Martincigh8b287c22018-09-07 09:25:10 +010034std::ostream& operator<<(std::ostream& os, V1_0::ErrorStatus stat);
surmeh0149b9e102018-05-17 14:11:25 +010035
Sadik Armagane6e54a82019-05-08 10:18:05 +010036struct ExecutionCallback : public V1_0::IExecutionCallback
surmeh0149b9e102018-05-17 14:11:25 +010037{
38 ExecutionCallback() : mNotified(false) {}
39 Return<void> notify(ErrorStatus status) override;
40 /// wait until the callback has notified us that it is done
41 Return<void> wait();
42
43private:
44 // use a mutex and a condition variable to wait for asynchronous callbacks
45 std::mutex mMutex;
46 std::condition_variable mCondition;
47 // and a flag, in case we are notified before the wait call
48 bool mNotified;
49};
50
Sadik Armagane6e54a82019-05-08 10:18:05 +010051class PreparedModelCallback : public V1_0::IPreparedModelCallback
surmeh0149b9e102018-05-17 14:11:25 +010052{
53public:
54 PreparedModelCallback()
55 : m_ErrorStatus(ErrorStatus::NONE)
56 , m_PreparedModel()
57 { }
58 ~PreparedModelCallback() override { }
59
60 Return<void> notify(ErrorStatus status,
Sadik Armagane6e54a82019-05-08 10:18:05 +010061 const android::sp<V1_0::IPreparedModel>& preparedModel) override;
surmeh0149b9e102018-05-17 14:11:25 +010062 ErrorStatus GetErrorStatus() { return m_ErrorStatus; }
Sadik Armagane6e54a82019-05-08 10:18:05 +010063 android::sp<V1_0::IPreparedModel> GetPreparedModel() { return m_PreparedModel; }
surmeh0149b9e102018-05-17 14:11:25 +010064
65private:
66 ErrorStatus m_ErrorStatus;
Sadik Armagane6e54a82019-05-08 10:18:05 +010067 android::sp<V1_0::IPreparedModel> m_PreparedModel;
surmeh0149b9e102018-05-17 14:11:25 +010068};
69
70hidl_memory allocateSharedMemory(int64_t size);
71
72android::sp<IMemory> AddPoolAndGetData(uint32_t size, Request& request);
73
74void AddPoolAndSetData(uint32_t size, Request& request, const float* data);
75
Nikhil Raj77605822018-09-03 11:25:56 +010076template<typename HalModel>
Sadik Armagane6e54a82019-05-08 10:18:05 +010077void AddOperand(HalModel& model, const V1_0::Operand& op)
Nikhil Raj77605822018-09-03 11:25:56 +010078{
79 model.operands.resize(model.operands.size() + 1);
80 model.operands[model.operands.size() - 1] = op;
81}
surmeh0149b9e102018-05-17 14:11:25 +010082
Nikhil Raj77605822018-09-03 11:25:56 +010083template<typename HalModel>
84void AddIntOperand(HalModel& model, int32_t value)
85{
86 DataLocation location = {};
87 location.offset = model.operandValues.size();
88 location.length = sizeof(int32_t);
89
Sadik Armagane6e54a82019-05-08 10:18:05 +010090 V1_0::Operand op = {};
91 op.type = V1_0::OperandType::INT32;
92 op.dimensions = hidl_vec<uint32_t>{};
93 op.lifetime = V1_0::OperandLifeTime::CONSTANT_COPY;
94 op.location = location;
Nikhil Raj77605822018-09-03 11:25:56 +010095
96 model.operandValues.resize(model.operandValues.size() + location.length);
97 *reinterpret_cast<int32_t*>(&model.operandValues[location.offset]) = value;
98
99 AddOperand<HalModel>(model, op);
100}
surmeh0149b9e102018-05-17 14:11:25 +0100101
102template<typename T>
103OperandType TypeToOperandType();
104
105template<>
106OperandType TypeToOperandType<float>();
107
108template<>
109OperandType TypeToOperandType<int32_t>();
110
Nikhil Raj77605822018-09-03 11:25:56 +0100111template<typename HalModel, typename T>
112void AddTensorOperand(HalModel& model,
Matteo Martincighc7434122018-11-14 12:27:04 +0000113 const hidl_vec<uint32_t>& dimensions,
114 const T* values,
Sadik Armagane6e54a82019-05-08 10:18:05 +0100115 V1_0::OperandType operandType = V1_0::OperandType::TENSOR_FLOAT32,
116 V1_0::OperandLifeTime operandLifeTime = V1_0::OperandLifeTime::CONSTANT_COPY)
surmeh0149b9e102018-05-17 14:11:25 +0100117{
118 uint32_t totalElements = 1;
119 for (uint32_t dim : dimensions)
120 {
121 totalElements *= dim;
122 }
123
124 DataLocation location = {};
surmeh0149b9e102018-05-17 14:11:25 +0100125 location.length = totalElements * sizeof(T);
126
Sadik Armagane6e54a82019-05-08 10:18:05 +0100127 if(operandLifeTime == V1_0::OperandLifeTime::CONSTANT_COPY)
Kevin Mayf29a2c52019-03-14 11:56:32 +0000128 {
129 location.offset = model.operandValues.size();
130 }
131
Sadik Armagane6e54a82019-05-08 10:18:05 +0100132 V1_0::Operand op = {};
133 op.type = operandType;
134 op.dimensions = dimensions;
135 op.lifetime = V1_0::OperandLifeTime::CONSTANT_COPY;
136 op.location = location;
surmeh0149b9e102018-05-17 14:11:25 +0100137
138 model.operandValues.resize(model.operandValues.size() + location.length);
139 for (uint32_t i = 0; i < totalElements; i++)
140 {
141 *(reinterpret_cast<T*>(&model.operandValues[location.offset]) + i) = values[i];
142 }
143
Nikhil Raj77605822018-09-03 11:25:56 +0100144 AddOperand<HalModel>(model, op);
surmeh0149b9e102018-05-17 14:11:25 +0100145}
146
Matteo Martincighc7434122018-11-14 12:27:04 +0000147template<typename HalModel, typename T>
148void AddTensorOperand(HalModel& model,
149 const hidl_vec<uint32_t>& dimensions,
150 const std::vector<T>& values,
Sadik Armagane6e54a82019-05-08 10:18:05 +0100151 V1_0::OperandType operandType = V1_0::OperandType::TENSOR_FLOAT32,
152 V1_0::OperandLifeTime operandLifeTime = V1_0::OperandLifeTime::CONSTANT_COPY)
Matteo Martincighc7434122018-11-14 12:27:04 +0000153{
Kevin Mayf29a2c52019-03-14 11:56:32 +0000154 AddTensorOperand<HalModel, T>(model, dimensions, values.data(), operandType, operandLifeTime);
Matteo Martincighc7434122018-11-14 12:27:04 +0000155}
156
Nikhil Raj77605822018-09-03 11:25:56 +0100157template<typename HalModel>
158void AddInputOperand(HalModel& model,
Matteo Martincighc7434122018-11-14 12:27:04 +0000159 const hidl_vec<uint32_t>& dimensions,
Sadik Armagane6e54a82019-05-08 10:18:05 +0100160 V1_0::OperandType operandType = V1_0::OperandType::TENSOR_FLOAT32)
Nikhil Raj77605822018-09-03 11:25:56 +0100161{
Sadik Armagane6e54a82019-05-08 10:18:05 +0100162 V1_0::Operand op = {};
163 op.type = operandType;
164 op.scale = operandType == V1_0::OperandType::TENSOR_QUANT8_ASYMM ? 1.f / 255.f : 0.f;
165 op.dimensions = dimensions;
166 op.lifetime = V1_0::OperandLifeTime::MODEL_INPUT;
surmeh0149b9e102018-05-17 14:11:25 +0100167
Nikhil Raj77605822018-09-03 11:25:56 +0100168 AddOperand<HalModel>(model, op);
169
170 model.inputIndexes.resize(model.inputIndexes.size() + 1);
171 model.inputIndexes[model.inputIndexes.size() - 1] = model.operands.size() - 1;
172}
173
174template<typename HalModel>
175void AddOutputOperand(HalModel& model,
Matteo Martincighc7434122018-11-14 12:27:04 +0000176 const hidl_vec<uint32_t>& dimensions,
Sadik Armagane6e54a82019-05-08 10:18:05 +0100177 V1_0::OperandType operandType = V1_0::OperandType::TENSOR_FLOAT32)
Nikhil Raj77605822018-09-03 11:25:56 +0100178{
Sadik Armagane6e54a82019-05-08 10:18:05 +0100179 V1_0::Operand op = {};
180 op.type = operandType;
181 op.scale = operandType == V1_0::OperandType::TENSOR_QUANT8_ASYMM ? 1.f / 255.f : 0.f;
182 op.dimensions = dimensions;
183 op.lifetime = V1_0::OperandLifeTime::MODEL_OUTPUT;
surmeh0149b9e102018-05-17 14:11:25 +0100184
Nikhil Raj77605822018-09-03 11:25:56 +0100185 AddOperand<HalModel>(model, op);
186
187 model.outputIndexes.resize(model.outputIndexes.size() + 1);
188 model.outputIndexes[model.outputIndexes.size() - 1] = model.operands.size() - 1;
189}
surmeh0149b9e102018-05-17 14:11:25 +0100190
Sadik Armagane6e54a82019-05-08 10:18:05 +0100191android::sp<V1_0::IPreparedModel> PrepareModelWithStatus(const V1_0::Model& model,
192 armnn_driver::ArmnnDriver& driver,
193 ErrorStatus& prepareStatus,
194 ErrorStatus expectedStatus = ErrorStatus::NONE);
Nikhil Raj77605822018-09-03 11:25:56 +0100195
Matteo Martincigh8b287c22018-09-07 09:25:10 +0100196#ifdef ARMNN_ANDROID_NN_V1_1
Nikhil Raj77605822018-09-03 11:25:56 +0100197
Sadik Armagane6e54a82019-05-08 10:18:05 +0100198android::sp<V1_0::IPreparedModel> PrepareModelWithStatus(const V1_1::Model& model,
Nikhil Raj77605822018-09-03 11:25:56 +0100199 armnn_driver::ArmnnDriver& driver,
200 ErrorStatus& prepareStatus,
201 ErrorStatus expectedStatus = ErrorStatus::NONE);
202
203#endif
204
205template<typename HalModel>
Sadik Armagane6e54a82019-05-08 10:18:05 +0100206android::sp<V1_0::IPreparedModel> PrepareModel(const HalModel& model,
Nikhil Raj77605822018-09-03 11:25:56 +0100207 armnn_driver::ArmnnDriver& driver)
208{
209 ErrorStatus prepareStatus = ErrorStatus::NONE;
210 return PrepareModelWithStatus(model, driver, prepareStatus);
211}
surmeh0149b9e102018-05-17 14:11:25 +0100212
Sadik Armagane6e54a82019-05-08 10:18:05 +0100213ErrorStatus Execute(android::sp<V1_0::IPreparedModel> preparedModel,
surmeh0149b9e102018-05-17 14:11:25 +0100214 const Request& request,
Nikhil Raj77605822018-09-03 11:25:56 +0100215 ErrorStatus expectedStatus = ErrorStatus::NONE);
surmeh0149b9e102018-05-17 14:11:25 +0100216
Sadik Armagane6e54a82019-05-08 10:18:05 +0100217android::sp<ExecutionCallback> ExecuteNoWait(android::sp<V1_0::IPreparedModel> preparedModel,
surmeh0149b9e102018-05-17 14:11:25 +0100218 const Request& request);
219
220} // namespace driverTestHelpers