blob: ccb6b983158a362f5f192ea985cb6d96582213a0 [file] [log] [blame]
surmeh0149b9e102018-05-17 14:11:25 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5#pragma once
6
7#ifndef LOG_TAG
8#define LOG_TAG "ArmnnDriverTests"
9#endif // LOG_TAG
10
11#include "../ArmnnDriver.hpp"
12#include <iosfwd>
13
14namespace android
15{
16namespace hardware
17{
18namespace neuralnetworks
19{
20namespace V1_0
21{
22
23std::ostream& operator<<(std::ostream& os, ErrorStatus stat);
24
25} // namespace android::hardware::neuralnetworks::V1_0
26} // namespace android::hardware::neuralnetworks
27} // namespace android::hardware
28} // namespace android
29
30namespace driverTestHelpers
31{
32
33std::ostream& operator<<(std::ostream& os, android::hardware::neuralnetworks::V1_0::ErrorStatus stat);
34
35struct ExecutionCallback : public IExecutionCallback
36{
37 ExecutionCallback() : mNotified(false) {}
38 Return<void> notify(ErrorStatus status) override;
39 /// wait until the callback has notified us that it is done
40 Return<void> wait();
41
42private:
43 // use a mutex and a condition variable to wait for asynchronous callbacks
44 std::mutex mMutex;
45 std::condition_variable mCondition;
46 // and a flag, in case we are notified before the wait call
47 bool mNotified;
48};
49
50class PreparedModelCallback : public IPreparedModelCallback
51{
52public:
53 PreparedModelCallback()
54 : m_ErrorStatus(ErrorStatus::NONE)
55 , m_PreparedModel()
56 { }
57 ~PreparedModelCallback() override { }
58
59 Return<void> notify(ErrorStatus status,
60 const android::sp<IPreparedModel>& preparedModel) override;
61 ErrorStatus GetErrorStatus() { return m_ErrorStatus; }
62 android::sp<IPreparedModel> GetPreparedModel() { return m_PreparedModel; }
63
64private:
65 ErrorStatus m_ErrorStatus;
66 android::sp<IPreparedModel> m_PreparedModel;
67};
68
69hidl_memory allocateSharedMemory(int64_t size);
70
71android::sp<IMemory> AddPoolAndGetData(uint32_t size, Request& request);
72
73void AddPoolAndSetData(uint32_t size, Request& request, const float* data);
74
telsoa01ce3e84a2018-08-31 09:31:35 +010075void AddOperand(::android::hardware::neuralnetworks::V1_0::Model& model, const Operand& op);
surmeh0149b9e102018-05-17 14:11:25 +010076
telsoa01ce3e84a2018-08-31 09:31:35 +010077void AddIntOperand(::android::hardware::neuralnetworks::V1_0::Model& model, int32_t value);
surmeh0149b9e102018-05-17 14:11:25 +010078
79template<typename T>
80OperandType TypeToOperandType();
81
82template<>
83OperandType TypeToOperandType<float>();
84
85template<>
86OperandType TypeToOperandType<int32_t>();
87
88template<typename T>
telsoa01ce3e84a2018-08-31 09:31:35 +010089void AddTensorOperand(::android::hardware::neuralnetworks::V1_0::Model& model,
90 hidl_vec<uint32_t> dimensions,
91 T* values,
92 OperandType operandType = OperandType::TENSOR_FLOAT32)
surmeh0149b9e102018-05-17 14:11:25 +010093{
94 uint32_t totalElements = 1;
95 for (uint32_t dim : dimensions)
96 {
97 totalElements *= dim;
98 }
99
100 DataLocation location = {};
101 location.offset = model.operandValues.size();
102 location.length = totalElements * sizeof(T);
103
104 Operand op = {};
telsoa01ce3e84a2018-08-31 09:31:35 +0100105 op.type = operandType;
surmeh0149b9e102018-05-17 14:11:25 +0100106 op.dimensions = dimensions;
107 op.lifetime = OperandLifeTime::CONSTANT_COPY;
108 op.location = location;
109
110 model.operandValues.resize(model.operandValues.size() + location.length);
111 for (uint32_t i = 0; i < totalElements; i++)
112 {
113 *(reinterpret_cast<T*>(&model.operandValues[location.offset]) + i) = values[i];
114 }
115
116 AddOperand(model, op);
117}
118
telsoa01ce3e84a2018-08-31 09:31:35 +0100119void AddInputOperand(::android::hardware::neuralnetworks::V1_0::Model& model,
120 hidl_vec<uint32_t> dimensions,
121 ::android::hardware::neuralnetworks::V1_0::OperandType operandType = OperandType::TENSOR_FLOAT32);
surmeh0149b9e102018-05-17 14:11:25 +0100122
telsoa01ce3e84a2018-08-31 09:31:35 +0100123void AddOutputOperand(::android::hardware::neuralnetworks::V1_0::Model& model,
124 hidl_vec<uint32_t> dimensions,
125 ::android::hardware::neuralnetworks::V1_0::OperandType operandType = OperandType::TENSOR_FLOAT32);
surmeh0149b9e102018-05-17 14:11:25 +0100126
telsoa01ce3e84a2018-08-31 09:31:35 +0100127android::sp<IPreparedModel> PrepareModel(const ::android::hardware::neuralnetworks::V1_0::Model& model,
surmeh0149b9e102018-05-17 14:11:25 +0100128 armnn_driver::ArmnnDriver& driver);
129
telsoa01ce3e84a2018-08-31 09:31:35 +0100130android::sp<IPreparedModel> PrepareModelWithStatus(const ::android::hardware::neuralnetworks::V1_0::Model& model,
surmeh0149b9e102018-05-17 14:11:25 +0100131 armnn_driver::ArmnnDriver& driver,
132 ErrorStatus & prepareStatus,
133 ErrorStatus expectedStatus=ErrorStatus::NONE);
134
135ErrorStatus Execute(android::sp<IPreparedModel> preparedModel,
136 const Request& request,
137 ErrorStatus expectedStatus=ErrorStatus::NONE);
138
139android::sp<ExecutionCallback> ExecuteNoWait(android::sp<IPreparedModel> preparedModel,
140 const Request& request);
141
142} // namespace driverTestHelpers