blob: 370936fec23c649e4cf4b051c3390a40ea3fd2a9 [file] [log] [blame]
surmeh0149b9e102018-05-17 14:11:25 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beck93e48982018-09-05 13:05:09 +01003// SPDX-License-Identifier: MIT
surmeh0149b9e102018-05-17 14:11:25 +01004//
5#pragma once
6
7#ifndef LOG_TAG
8#define LOG_TAG "ArmnnDriverTests"
9#endif // LOG_TAG
10
11#include "../ArmnnDriver.hpp"
12#include <iosfwd>
Nikhil Raj77605822018-09-03 11:25:56 +010013#include <boost/test/unit_test.hpp>
surmeh0149b9e102018-05-17 14:11:25 +010014
15namespace android
16{
17namespace hardware
18{
19namespace neuralnetworks
20{
21namespace V1_0
22{
23
24std::ostream& operator<<(std::ostream& os, ErrorStatus stat);
25
26} // namespace android::hardware::neuralnetworks::V1_0
27} // namespace android::hardware::neuralnetworks
28} // namespace android::hardware
29} // namespace android
30
31namespace driverTestHelpers
32{
33
Matteo Martincigh8b287c22018-09-07 09:25:10 +010034std::ostream& operator<<(std::ostream& os, V1_0::ErrorStatus stat);
surmeh0149b9e102018-05-17 14:11:25 +010035
36struct ExecutionCallback : public IExecutionCallback
37{
38 ExecutionCallback() : mNotified(false) {}
39 Return<void> notify(ErrorStatus status) override;
40 /// wait until the callback has notified us that it is done
41 Return<void> wait();
42
43private:
44 // use a mutex and a condition variable to wait for asynchronous callbacks
45 std::mutex mMutex;
46 std::condition_variable mCondition;
47 // and a flag, in case we are notified before the wait call
48 bool mNotified;
49};
50
51class PreparedModelCallback : public IPreparedModelCallback
52{
53public:
54 PreparedModelCallback()
55 : m_ErrorStatus(ErrorStatus::NONE)
56 , m_PreparedModel()
57 { }
58 ~PreparedModelCallback() override { }
59
60 Return<void> notify(ErrorStatus status,
61 const android::sp<IPreparedModel>& preparedModel) override;
62 ErrorStatus GetErrorStatus() { return m_ErrorStatus; }
63 android::sp<IPreparedModel> GetPreparedModel() { return m_PreparedModel; }
64
65private:
66 ErrorStatus m_ErrorStatus;
67 android::sp<IPreparedModel> m_PreparedModel;
68};
69
70hidl_memory allocateSharedMemory(int64_t size);
71
72android::sp<IMemory> AddPoolAndGetData(uint32_t size, Request& request);
73
74void AddPoolAndSetData(uint32_t size, Request& request, const float* data);
75
Nikhil Raj77605822018-09-03 11:25:56 +010076template<typename HalModel>
77void AddOperand(HalModel& model, const Operand& op)
78{
79 model.operands.resize(model.operands.size() + 1);
80 model.operands[model.operands.size() - 1] = op;
81}
surmeh0149b9e102018-05-17 14:11:25 +010082
Nikhil Raj77605822018-09-03 11:25:56 +010083template<typename HalModel>
84void AddIntOperand(HalModel& model, int32_t value)
85{
86 DataLocation location = {};
87 location.offset = model.operandValues.size();
88 location.length = sizeof(int32_t);
89
90 Operand op = {};
91 op.type = OperandType::INT32;
92 op.dimensions = hidl_vec<uint32_t>{};
93 op.lifetime = OperandLifeTime::CONSTANT_COPY;
94 op.location = location;
95
96 model.operandValues.resize(model.operandValues.size() + location.length);
97 *reinterpret_cast<int32_t*>(&model.operandValues[location.offset]) = value;
98
99 AddOperand<HalModel>(model, op);
100}
surmeh0149b9e102018-05-17 14:11:25 +0100101
102template<typename T>
103OperandType TypeToOperandType();
104
105template<>
106OperandType TypeToOperandType<float>();
107
108template<>
109OperandType TypeToOperandType<int32_t>();
110
Nikhil Raj77605822018-09-03 11:25:56 +0100111template<typename HalModel, typename T>
112void AddTensorOperand(HalModel& model,
telsoa01ce3e84a2018-08-31 09:31:35 +0100113 hidl_vec<uint32_t> dimensions,
114 T* values,
115 OperandType operandType = OperandType::TENSOR_FLOAT32)
surmeh0149b9e102018-05-17 14:11:25 +0100116{
117 uint32_t totalElements = 1;
118 for (uint32_t dim : dimensions)
119 {
120 totalElements *= dim;
121 }
122
123 DataLocation location = {};
124 location.offset = model.operandValues.size();
125 location.length = totalElements * sizeof(T);
126
127 Operand op = {};
telsoa01ce3e84a2018-08-31 09:31:35 +0100128 op.type = operandType;
surmeh0149b9e102018-05-17 14:11:25 +0100129 op.dimensions = dimensions;
130 op.lifetime = OperandLifeTime::CONSTANT_COPY;
131 op.location = location;
132
133 model.operandValues.resize(model.operandValues.size() + location.length);
134 for (uint32_t i = 0; i < totalElements; i++)
135 {
136 *(reinterpret_cast<T*>(&model.operandValues[location.offset]) + i) = values[i];
137 }
138
Nikhil Raj77605822018-09-03 11:25:56 +0100139 AddOperand<HalModel>(model, op);
surmeh0149b9e102018-05-17 14:11:25 +0100140}
141
Nikhil Raj77605822018-09-03 11:25:56 +0100142template<typename HalModel>
143void AddInputOperand(HalModel& model,
telsoa01ce3e84a2018-08-31 09:31:35 +0100144 hidl_vec<uint32_t> dimensions,
Nikhil Raj77605822018-09-03 11:25:56 +0100145 OperandType operandType = OperandType::TENSOR_FLOAT32)
146{
147 Operand op = {};
148 op.type = operandType;
Matteo Martincigh8d50f8f2018-10-25 15:39:33 +0100149 op.scale = operandType == OperandType::TENSOR_QUANT8_ASYMM ? 1.f / 255.f : 0.f;
Nikhil Raj77605822018-09-03 11:25:56 +0100150 op.dimensions = dimensions;
151 op.lifetime = OperandLifeTime::MODEL_INPUT;
surmeh0149b9e102018-05-17 14:11:25 +0100152
Nikhil Raj77605822018-09-03 11:25:56 +0100153 AddOperand<HalModel>(model, op);
154
155 model.inputIndexes.resize(model.inputIndexes.size() + 1);
156 model.inputIndexes[model.inputIndexes.size() - 1] = model.operands.size() - 1;
157}
158
159template<typename HalModel>
160void AddOutputOperand(HalModel& model,
telsoa01ce3e84a2018-08-31 09:31:35 +0100161 hidl_vec<uint32_t> dimensions,
Nikhil Raj77605822018-09-03 11:25:56 +0100162 OperandType operandType = OperandType::TENSOR_FLOAT32)
163{
164 Operand op = {};
165 op.type = operandType;
166 op.scale = operandType == OperandType::TENSOR_QUANT8_ASYMM ? 1.f / 255.f : 0.f;
167 op.dimensions = dimensions;
168 op.lifetime = OperandLifeTime::MODEL_OUTPUT;
surmeh0149b9e102018-05-17 14:11:25 +0100169
Nikhil Raj77605822018-09-03 11:25:56 +0100170 AddOperand<HalModel>(model, op);
171
172 model.outputIndexes.resize(model.outputIndexes.size() + 1);
173 model.outputIndexes[model.outputIndexes.size() - 1] = model.operands.size() - 1;
174}
surmeh0149b9e102018-05-17 14:11:25 +0100175
Matteo Martincigh8b287c22018-09-07 09:25:10 +0100176android::sp<IPreparedModel> PrepareModelWithStatus(const V1_0::Model& model,
surmeh0149b9e102018-05-17 14:11:25 +0100177 armnn_driver::ArmnnDriver& driver,
Nikhil Raj77605822018-09-03 11:25:56 +0100178 ErrorStatus& prepareStatus,
179 ErrorStatus expectedStatus = ErrorStatus::NONE);
180
Matteo Martincigh8b287c22018-09-07 09:25:10 +0100181#ifdef ARMNN_ANDROID_NN_V1_1
Nikhil Raj77605822018-09-03 11:25:56 +0100182
Matteo Martincigh8b287c22018-09-07 09:25:10 +0100183android::sp<IPreparedModel> PrepareModelWithStatus(const V1_1::Model& model,
Nikhil Raj77605822018-09-03 11:25:56 +0100184 armnn_driver::ArmnnDriver& driver,
185 ErrorStatus& prepareStatus,
186 ErrorStatus expectedStatus = ErrorStatus::NONE);
187
188#endif
189
190template<typename HalModel>
191android::sp<IPreparedModel> PrepareModel(const HalModel& model,
192 armnn_driver::ArmnnDriver& driver)
193{
194 ErrorStatus prepareStatus = ErrorStatus::NONE;
195 return PrepareModelWithStatus(model, driver, prepareStatus);
196}
surmeh0149b9e102018-05-17 14:11:25 +0100197
198ErrorStatus Execute(android::sp<IPreparedModel> preparedModel,
199 const Request& request,
Nikhil Raj77605822018-09-03 11:25:56 +0100200 ErrorStatus expectedStatus = ErrorStatus::NONE);
surmeh0149b9e102018-05-17 14:11:25 +0100201
202android::sp<ExecutionCallback> ExecuteNoWait(android::sp<IPreparedModel> preparedModel,
203 const Request& request);
204
205} // namespace driverTestHelpers