blob: 040bec6bdff753ebe009ad57cce7d356a13bbfff [file] [log] [blame]
telsoa015307bc12018-03-09 13:51:08 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5
6#pragma once
7
surmeh01deb3bdb2018-07-05 12:06:04 +01008#include "ArmnnDriver.hpp"
9
telsoa01ce3e84a2018-08-31 09:31:35 +010010#include <NeuralNetworks.h>
11#include <ActivationFunctor.h>
12
telsoa015307bc12018-03-09 13:51:08 +000013#include <armnn/ArmNN.hpp>
14#include <armnn/INetwork.hpp>
15#include <CpuExecutor.h>
16
17#include "Utils.hpp"
18
19#include <memory>
20#include <vector>
21#include <set>
22
23namespace armnn_driver
24{
25
26class ConstTensorPin;
27class LayerInputHandle;
28
29enum class ConversionResult
30{
31 Success,
32 ErrorMappingPools,
33 UnsupportedFeature
34};
35
kevmay01bc5f7842018-08-30 12:34:39 +010036struct HalVersion_1_0
37{
38 using Model = ::android::hardware::neuralnetworks::V1_0::Model;
39};
40
41#if defined(ARMNN_ANDROID_NN_V1_1)
42struct HalVersion_1_1
43{
44 using Model = ::android::hardware::neuralnetworks::V1_1::Model;
45};
46#endif
47
telsoa015307bc12018-03-09 13:51:08 +000048// A helper performing the conversion from an AndroidNN driver Model representation,
49// to an armnn::INetwork object
arovir01a15dc112018-09-03 17:12:56 +010050template<typename HalVersion>
telsoa015307bc12018-03-09 13:51:08 +000051class ModelToINetworkConverter
52{
53public:
kevmay01bc5f7842018-08-30 12:34:39 +010054 using HalModel = typename HalVersion::Model;
55
telsoa01ce3e84a2018-08-31 09:31:35 +010056 ModelToINetworkConverter(armnn::Compute compute,
kevmay01bc5f7842018-08-30 12:34:39 +010057 const HalModel& model,
telsoa015307bc12018-03-09 13:51:08 +000058 const std::set<unsigned int>& forcedUnsupportedOperations);
59
60 ConversionResult GetConversionResult() const { return m_ConversionResult; }
61
62 // Returns the ArmNN INetwork corresponding to the input model, if preparation went smoothly, nullptr otherwise.
63 armnn::INetwork* GetINetwork() const { return m_Network.get(); }
64
65 bool IsOperationSupported(uint32_t operationIndex) const;
66
67private:
68 void Convert();
69
kevmay01bc5f7842018-08-30 12:34:39 +010070#if defined(ARMNN_ANDROID_NN_V1_1)
71 bool ConvertOperation(const ::android::hardware::neuralnetworks::V1_1::Operation& operation);
arovir01a15dc112018-09-03 17:12:56 +010072
73 bool ConvertDiv(const ::android::hardware::neuralnetworks::V1_1::Operation& operation);
kevmay01bc5f7842018-08-30 12:34:39 +010074#endif
75
telsoa01ce3e84a2018-08-31 09:31:35 +010076 bool ConvertOperation(const ::android::hardware::neuralnetworks::V1_0::Operation& operation);
telsoa015307bc12018-03-09 13:51:08 +000077
telsoa01ce3e84a2018-08-31 09:31:35 +010078 bool ConvertAdd(const ::android::hardware::neuralnetworks::V1_0::Operation& operation);
telsoa015307bc12018-03-09 13:51:08 +000079
telsoa01ce3e84a2018-08-31 09:31:35 +010080 bool ConvertAveragePool2d(const ::android::hardware::neuralnetworks::V1_0::Operation& operation);
telsoa015307bc12018-03-09 13:51:08 +000081
telsoa01ce3e84a2018-08-31 09:31:35 +010082 bool ConvertConcatenation(const ::android::hardware::neuralnetworks::V1_0::Operation& operation);
telsoa015307bc12018-03-09 13:51:08 +000083
telsoa01ce3e84a2018-08-31 09:31:35 +010084 bool ConvertConv2d(const ::android::hardware::neuralnetworks::V1_0::Operation& operation);
telsoa015307bc12018-03-09 13:51:08 +000085
telsoa01ce3e84a2018-08-31 09:31:35 +010086 bool ConvertDepthwiseConv2d(const ::android::hardware::neuralnetworks::V1_0::Operation& operation);
telsoa015307bc12018-03-09 13:51:08 +000087
telsoa01ce3e84a2018-08-31 09:31:35 +010088 bool ConvertFloor(const ::android::hardware::neuralnetworks::V1_0::Operation& operation);
telsoa015307bc12018-03-09 13:51:08 +000089
telsoa01ce3e84a2018-08-31 09:31:35 +010090 bool ConvertFullyConnected(const ::android::hardware::neuralnetworks::V1_0::Operation& operation);
telsoa015307bc12018-03-09 13:51:08 +000091
telsoa01ce3e84a2018-08-31 09:31:35 +010092 bool ConvertLogistic(const ::android::hardware::neuralnetworks::V1_0::Operation& operation);
telsoa015307bc12018-03-09 13:51:08 +000093
telsoa01ce3e84a2018-08-31 09:31:35 +010094 bool ConvertLocalResponseNormalization(const ::android::hardware::neuralnetworks::V1_0::Operation& operation);
telsoa015307bc12018-03-09 13:51:08 +000095
telsoa01ce3e84a2018-08-31 09:31:35 +010096 bool ConvertL2Normalization(const ::android::hardware::neuralnetworks::V1_0::Operation& operation);
telsoa015307bc12018-03-09 13:51:08 +000097
telsoa01ce3e84a2018-08-31 09:31:35 +010098 bool ConvertL2Pool2d(const ::android::hardware::neuralnetworks::V1_0::Operation& operation);
telsoa015307bc12018-03-09 13:51:08 +000099
telsoa01ce3e84a2018-08-31 09:31:35 +0100100 bool ConvertMaxPool2d(const ::android::hardware::neuralnetworks::V1_0::Operation& operation);
telsoa015307bc12018-03-09 13:51:08 +0000101
telsoa01ce3e84a2018-08-31 09:31:35 +0100102 bool ConvertMul(const ::android::hardware::neuralnetworks::V1_0::Operation& operation);
telsoa015307bc12018-03-09 13:51:08 +0000103
telsoa01ce3e84a2018-08-31 09:31:35 +0100104 bool ConvertReLu(const ::android::hardware::neuralnetworks::V1_0::Operation& operation);
telsoa015307bc12018-03-09 13:51:08 +0000105
telsoa01ce3e84a2018-08-31 09:31:35 +0100106 bool ConvertReLu1(const ::android::hardware::neuralnetworks::V1_0::Operation& operation);
telsoa015307bc12018-03-09 13:51:08 +0000107
telsoa01ce3e84a2018-08-31 09:31:35 +0100108 bool ConvertReLu6(const ::android::hardware::neuralnetworks::V1_0::Operation& operation);
telsoa015307bc12018-03-09 13:51:08 +0000109
telsoa01ce3e84a2018-08-31 09:31:35 +0100110 bool ConvertSoftmax(const ::android::hardware::neuralnetworks::V1_0::Operation& operation);
telsoa015307bc12018-03-09 13:51:08 +0000111
telsoa01ce3e84a2018-08-31 09:31:35 +0100112 bool ConvertTanH(const ::android::hardware::neuralnetworks::V1_0::Operation& operation);
telsoa015307bc12018-03-09 13:51:08 +0000113
telsoa01ce3e84a2018-08-31 09:31:35 +0100114 bool ConvertReshape(const ::android::hardware::neuralnetworks::V1_0::Operation& operation);
telsoa015307bc12018-03-09 13:51:08 +0000115
telsoa01ce3e84a2018-08-31 09:31:35 +0100116 bool ConvertResizeBilinear(const ::android::hardware::neuralnetworks::V1_0::Operation& operation);
telsoa015307bc12018-03-09 13:51:08 +0000117
telsoa01ce3e84a2018-08-31 09:31:35 +0100118 bool ConvertLstm(const ::android::hardware::neuralnetworks::V1_0::Operation& operation);
119
120 bool ConvertToActivation(const ::android::hardware::neuralnetworks::V1_0::Operation& operation,
121 const char* operationName,
arovir01a15dc112018-09-03 17:12:56 +0100122 const armnn::ActivationDescriptor& activationDesc);
telsoa015307bc12018-03-09 13:51:08 +0000123
telsoa01ce3e84a2018-08-31 09:31:35 +0100124 bool ConvertPooling2d(const ::android::hardware::neuralnetworks::V1_0::Operation& operation,
125 const char* name, armnn::PoolingAlgorithm poolType);
telsoa015307bc12018-03-09 13:51:08 +0000126
telsoa015307bc12018-03-09 13:51:08 +0000127 const void* GetOperandValueReadOnlyAddress(const Operand& operand) const;
128
arovir01a15dc112018-09-03 17:12:56 +0100129 template<typename HalOperation>
130 const Operand* GetInputOperand(const HalOperation& operation, uint32_t inputIndex) const;
telsoa015307bc12018-03-09 13:51:08 +0000131
arovir01a15dc112018-09-03 17:12:56 +0100132 template<typename HalOperation>
133 const Operand* GetOutputOperand(const HalOperation& operation, uint32_t outputIndex) const;
telsoa015307bc12018-03-09 13:51:08 +0000134
arovir01a15dc112018-09-03 17:12:56 +0100135 template<typename HalOperation, typename T>
136 bool GetInputScalar(const HalOperation& operation, uint32_t inputIndex, OperandType type, T& outValue) const;
telsoa015307bc12018-03-09 13:51:08 +0000137
arovir01a15dc112018-09-03 17:12:56 +0100138 template<typename HalOperation>
139 bool GetInputInt32(const HalOperation& operation, uint32_t inputIndex, int32_t& outValue) const;
telsoa015307bc12018-03-09 13:51:08 +0000140
arovir01a15dc112018-09-03 17:12:56 +0100141 template<typename HalOperation>
142 bool GetInputFloat32(const HalOperation& operation, uint32_t inputIndex, float& outValue) const;
telsoa015307bc12018-03-09 13:51:08 +0000143
arovir01a15dc112018-09-03 17:12:56 +0100144 template<typename HalOperation>
145 bool GetInputActivationFunctionImpl(const HalOperation& operation,
telsoa01ce3e84a2018-08-31 09:31:35 +0100146 uint32_t inputIndex,
147 OperandType type,
148 ActivationFn& outActivationFunction) const;
telsoa015307bc12018-03-09 13:51:08 +0000149
arovir01a15dc112018-09-03 17:12:56 +0100150 template<typename HalOperation>
151 bool GetInputActivationFunction(const HalOperation& operation,
telsoa01ce3e84a2018-08-31 09:31:35 +0100152 uint32_t inputIndex,
153 ActivationFn& outActivationFunction) const;
telsoa015307bc12018-03-09 13:51:08 +0000154
arovir01a15dc112018-09-03 17:12:56 +0100155 template<typename HalOperation>
156 bool GetInputActivationFunctionFromTensor(const HalOperation& operation,
telsoa01ce3e84a2018-08-31 09:31:35 +0100157 uint32_t inputIndex,
158 ActivationFn& outActivationFunction) const;
telsoa015307bc12018-03-09 13:51:08 +0000159
arovir01a15dc112018-09-03 17:12:56 +0100160 template<typename HalOperation>
161 bool GetOptionalInputActivation(const HalOperation& operation,
telsoa01ce3e84a2018-08-31 09:31:35 +0100162 uint32_t inputIndex,
163 ActivationFn& activationFunction) const;
164
arovir01a15dc112018-09-03 17:12:56 +0100165 template<typename HalOperation>
166 bool GetInputPaddingScheme(const HalOperation& operation,
telsoa01ce3e84a2018-08-31 09:31:35 +0100167 uint32_t inputIndex,
168 android::nn::PaddingScheme& outPaddingScheme) const;
169
arovir01a15dc112018-09-03 17:12:56 +0100170 template<typename HalOperation>
171 LayerInputHandle ConvertToLayerInputHandle(const HalOperation& operation, uint32_t inputIndex);
telsoa01ce3e84a2018-08-31 09:31:35 +0100172
arovir01a15dc112018-09-03 17:12:56 +0100173 template<typename HalOperation>
telsoa01ce3e84a2018-08-31 09:31:35 +0100174 ConstTensorPin ConvertOperationInputToConstTensorPin(
arovir01a15dc112018-09-03 17:12:56 +0100175 const HalOperation& operation,
176 uint32_t inputIndex,
telsoa015307bc12018-03-09 13:51:08 +0000177 const armnn::PermutationVector& dimensionMappings = g_DontPermute,
arovir01a15dc112018-09-03 17:12:56 +0100178 const armnn::TensorShape* overrideTensorShape = nullptr,
179 bool optional = false);
180
181 ConstTensorPin ConvertOperandToConstTensorPin(
182 const Operand& operand,
183 const armnn::PermutationVector& dimensionMappings = g_DontPermute,
184 const armnn::TensorShape* overrideTensorShape = nullptr,
185 bool optional = false);
telsoa015307bc12018-03-09 13:51:08 +0000186
187 bool GetTensorInt32Values(const Operand& operand, std::vector<int32_t>& outValues) const;
188
arovir01a15dc112018-09-03 17:12:56 +0100189 armnn::IConnectableLayer* ProcessActivation(const armnn::TensorInfo& tensorInfo,
190 ActivationFn activation,
telsoa015307bc12018-03-09 13:51:08 +0000191 armnn::IConnectableLayer* prevLayer);
192
arovir01a15dc112018-09-03 17:12:56 +0100193 template<typename HalOperation>
194 bool SetupAndTrackLayerOutputSlot(const HalOperation& operation,
telsoa01ce3e84a2018-08-31 09:31:35 +0100195 uint32_t operationOutputIndex,
196 armnn::IConnectableLayer& layer,
197 uint32_t layerOutputIndex);
telsoa015307bc12018-03-09 13:51:08 +0000198
arovir01a15dc112018-09-03 17:12:56 +0100199 template<typename HalOperation>
200 bool SetupAndTrackLayerOutputSlot(const HalOperation& operation,
telsoa01ce3e84a2018-08-31 09:31:35 +0100201 uint32_t outputIndex,
telsoa015307bc12018-03-09 13:51:08 +0000202 armnn::IConnectableLayer& layer);
203
telsoa015307bc12018-03-09 13:51:08 +0000204 // Input data
kevmay01bc5f7842018-08-30 12:34:39 +0100205 armnn::Compute m_Compute;
206 const HalModel& m_Model;
207 const std::set<unsigned int>& m_ForcedUnsupportedOperations;
telsoa015307bc12018-03-09 13:51:08 +0000208
209 // Output data
telsoa01ce3e84a2018-08-31 09:31:35 +0100210 armnn::INetworkPtr m_Network;
211 ConversionResult m_ConversionResult;
212 std::map<uint32_t, bool> m_OperationSupported;
telsoa015307bc12018-03-09 13:51:08 +0000213
214 // Working/intermediate data
arovir01a15dc112018-09-03 17:12:56 +0100215 std::vector<armnn::IOutputSlot*> m_OutputSlotForOperand;
telsoa015307bc12018-03-09 13:51:08 +0000216 std::vector<android::nn::RunTimePoolInfo> m_MemPools;
217};
218
arovir01a15dc112018-09-03 17:12:56 +0100219} // armnn_driver