blob: 4d91ae220712774cab9f270b86020a085496d7e3 [file] [log] [blame]
surmeh0149b9e102018-05-17 14:11:25 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beck93e48982018-09-05 13:05:09 +01003// SPDX-License-Identifier: MIT
surmeh0149b9e102018-05-17 14:11:25 +01004//
5#pragma once
6
7#ifndef LOG_TAG
8#define LOG_TAG "ArmnnDriverTests"
9#endif // LOG_TAG
10
11#include "../ArmnnDriver.hpp"
12#include <iosfwd>
Nikhil Raj77605822018-09-03 11:25:56 +010013#include <boost/test/unit_test.hpp>
surmeh0149b9e102018-05-17 14:11:25 +010014
15namespace android
16{
17namespace hardware
18{
19namespace neuralnetworks
20{
21namespace V1_0
22{
23
24std::ostream& operator<<(std::ostream& os, ErrorStatus stat);
25
26} // namespace android::hardware::neuralnetworks::V1_0
27} // namespace android::hardware::neuralnetworks
28} // namespace android::hardware
29} // namespace android
30
31namespace driverTestHelpers
32{
33
Matteo Martincigh8b287c22018-09-07 09:25:10 +010034std::ostream& operator<<(std::ostream& os, V1_0::ErrorStatus stat);
surmeh0149b9e102018-05-17 14:11:25 +010035
36struct ExecutionCallback : public IExecutionCallback
37{
38 ExecutionCallback() : mNotified(false) {}
39 Return<void> notify(ErrorStatus status) override;
40 /// wait until the callback has notified us that it is done
41 Return<void> wait();
42
43private:
44 // use a mutex and a condition variable to wait for asynchronous callbacks
45 std::mutex mMutex;
46 std::condition_variable mCondition;
47 // and a flag, in case we are notified before the wait call
48 bool mNotified;
49};
50
51class PreparedModelCallback : public IPreparedModelCallback
52{
53public:
54 PreparedModelCallback()
55 : m_ErrorStatus(ErrorStatus::NONE)
56 , m_PreparedModel()
57 { }
58 ~PreparedModelCallback() override { }
59
60 Return<void> notify(ErrorStatus status,
61 const android::sp<IPreparedModel>& preparedModel) override;
62 ErrorStatus GetErrorStatus() { return m_ErrorStatus; }
63 android::sp<IPreparedModel> GetPreparedModel() { return m_PreparedModel; }
64
65private:
66 ErrorStatus m_ErrorStatus;
67 android::sp<IPreparedModel> m_PreparedModel;
68};
69
70hidl_memory allocateSharedMemory(int64_t size);
71
72android::sp<IMemory> AddPoolAndGetData(uint32_t size, Request& request);
73
74void AddPoolAndSetData(uint32_t size, Request& request, const float* data);
75
Nikhil Raj77605822018-09-03 11:25:56 +010076template<typename HalModel>
77void AddOperand(HalModel& model, const Operand& op)
78{
79 model.operands.resize(model.operands.size() + 1);
80 model.operands[model.operands.size() - 1] = op;
81}
surmeh0149b9e102018-05-17 14:11:25 +010082
Nikhil Raj77605822018-09-03 11:25:56 +010083template<typename HalModel>
84void AddIntOperand(HalModel& model, int32_t value)
85{
86 DataLocation location = {};
87 location.offset = model.operandValues.size();
88 location.length = sizeof(int32_t);
89
90 Operand op = {};
91 op.type = OperandType::INT32;
92 op.dimensions = hidl_vec<uint32_t>{};
93 op.lifetime = OperandLifeTime::CONSTANT_COPY;
94 op.location = location;
95
96 model.operandValues.resize(model.operandValues.size() + location.length);
97 *reinterpret_cast<int32_t*>(&model.operandValues[location.offset]) = value;
98
99 AddOperand<HalModel>(model, op);
100}
surmeh0149b9e102018-05-17 14:11:25 +0100101
102template<typename T>
103OperandType TypeToOperandType();
104
105template<>
106OperandType TypeToOperandType<float>();
107
108template<>
109OperandType TypeToOperandType<int32_t>();
110
Nikhil Raj77605822018-09-03 11:25:56 +0100111template<typename HalModel, typename T>
112void AddTensorOperand(HalModel& model,
Matteo Martincighc7434122018-11-14 12:27:04 +0000113 const hidl_vec<uint32_t>& dimensions,
114 const T* values,
telsoa01ce3e84a2018-08-31 09:31:35 +0100115 OperandType operandType = OperandType::TENSOR_FLOAT32)
surmeh0149b9e102018-05-17 14:11:25 +0100116{
117 uint32_t totalElements = 1;
118 for (uint32_t dim : dimensions)
119 {
120 totalElements *= dim;
121 }
122
123 DataLocation location = {};
124 location.offset = model.operandValues.size();
125 location.length = totalElements * sizeof(T);
126
127 Operand op = {};
telsoa01ce3e84a2018-08-31 09:31:35 +0100128 op.type = operandType;
surmeh0149b9e102018-05-17 14:11:25 +0100129 op.dimensions = dimensions;
130 op.lifetime = OperandLifeTime::CONSTANT_COPY;
131 op.location = location;
132
133 model.operandValues.resize(model.operandValues.size() + location.length);
134 for (uint32_t i = 0; i < totalElements; i++)
135 {
136 *(reinterpret_cast<T*>(&model.operandValues[location.offset]) + i) = values[i];
137 }
138
Nikhil Raj77605822018-09-03 11:25:56 +0100139 AddOperand<HalModel>(model, op);
surmeh0149b9e102018-05-17 14:11:25 +0100140}
141
Matteo Martincighc7434122018-11-14 12:27:04 +0000142template<typename HalModel, typename T>
143void AddTensorOperand(HalModel& model,
144 const hidl_vec<uint32_t>& dimensions,
145 const std::vector<T>& values,
146 OperandType operandType = OperandType::TENSOR_FLOAT32)
147{
148 AddTensorOperand<HalModel, T>(model, dimensions, values.data(), operandType);
149}
150
Nikhil Raj77605822018-09-03 11:25:56 +0100151template<typename HalModel>
152void AddInputOperand(HalModel& model,
Matteo Martincighc7434122018-11-14 12:27:04 +0000153 const hidl_vec<uint32_t>& dimensions,
Nikhil Raj77605822018-09-03 11:25:56 +0100154 OperandType operandType = OperandType::TENSOR_FLOAT32)
155{
156 Operand op = {};
157 op.type = operandType;
Matteo Martincigh8d50f8f2018-10-25 15:39:33 +0100158 op.scale = operandType == OperandType::TENSOR_QUANT8_ASYMM ? 1.f / 255.f : 0.f;
Nikhil Raj77605822018-09-03 11:25:56 +0100159 op.dimensions = dimensions;
160 op.lifetime = OperandLifeTime::MODEL_INPUT;
surmeh0149b9e102018-05-17 14:11:25 +0100161
Nikhil Raj77605822018-09-03 11:25:56 +0100162 AddOperand<HalModel>(model, op);
163
164 model.inputIndexes.resize(model.inputIndexes.size() + 1);
165 model.inputIndexes[model.inputIndexes.size() - 1] = model.operands.size() - 1;
166}
167
168template<typename HalModel>
169void AddOutputOperand(HalModel& model,
Matteo Martincighc7434122018-11-14 12:27:04 +0000170 const hidl_vec<uint32_t>& dimensions,
Nikhil Raj77605822018-09-03 11:25:56 +0100171 OperandType operandType = OperandType::TENSOR_FLOAT32)
172{
173 Operand op = {};
174 op.type = operandType;
175 op.scale = operandType == OperandType::TENSOR_QUANT8_ASYMM ? 1.f / 255.f : 0.f;
176 op.dimensions = dimensions;
177 op.lifetime = OperandLifeTime::MODEL_OUTPUT;
surmeh0149b9e102018-05-17 14:11:25 +0100178
Nikhil Raj77605822018-09-03 11:25:56 +0100179 AddOperand<HalModel>(model, op);
180
181 model.outputIndexes.resize(model.outputIndexes.size() + 1);
182 model.outputIndexes[model.outputIndexes.size() - 1] = model.operands.size() - 1;
183}
surmeh0149b9e102018-05-17 14:11:25 +0100184
Matteo Martincigh8b287c22018-09-07 09:25:10 +0100185android::sp<IPreparedModel> PrepareModelWithStatus(const V1_0::Model& model,
surmeh0149b9e102018-05-17 14:11:25 +0100186 armnn_driver::ArmnnDriver& driver,
Nikhil Raj77605822018-09-03 11:25:56 +0100187 ErrorStatus& prepareStatus,
188 ErrorStatus expectedStatus = ErrorStatus::NONE);
189
Matteo Martincigh8b287c22018-09-07 09:25:10 +0100190#ifdef ARMNN_ANDROID_NN_V1_1
Nikhil Raj77605822018-09-03 11:25:56 +0100191
Matteo Martincigh8b287c22018-09-07 09:25:10 +0100192android::sp<IPreparedModel> PrepareModelWithStatus(const V1_1::Model& model,
Nikhil Raj77605822018-09-03 11:25:56 +0100193 armnn_driver::ArmnnDriver& driver,
194 ErrorStatus& prepareStatus,
195 ErrorStatus expectedStatus = ErrorStatus::NONE);
196
197#endif
198
199template<typename HalModel>
200android::sp<IPreparedModel> PrepareModel(const HalModel& model,
201 armnn_driver::ArmnnDriver& driver)
202{
203 ErrorStatus prepareStatus = ErrorStatus::NONE;
204 return PrepareModelWithStatus(model, driver, prepareStatus);
205}
surmeh0149b9e102018-05-17 14:11:25 +0100206
207ErrorStatus Execute(android::sp<IPreparedModel> preparedModel,
208 const Request& request,
Nikhil Raj77605822018-09-03 11:25:56 +0100209 ErrorStatus expectedStatus = ErrorStatus::NONE);
surmeh0149b9e102018-05-17 14:11:25 +0100210
211android::sp<ExecutionCallback> ExecuteNoWait(android::sp<IPreparedModel> preparedModel,
212 const Request& request);
213
214} // namespace driverTestHelpers