blob: c28ebdcd10be4ea827ef7733deb323bd4274ffe0 [file] [log] [blame]
telsoa015307bc12018-03-09 13:51:08 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// See LICENSE file in the project root for full license information.
4//
5
6#pragma once
7
surmeh01deb3bdb2018-07-05 12:06:04 +01008#include "ArmnnDriver.hpp"
Matteo Martincighe48bdff2018-09-03 13:50:50 +01009#include "ArmnnDriverImpl.hpp"
surmeh01deb3bdb2018-07-05 12:06:04 +010010
telsoa01ce3e84a2018-08-31 09:31:35 +010011#include <NeuralNetworks.h>
12#include <ActivationFunctor.h>
13
telsoa015307bc12018-03-09 13:51:08 +000014#include <armnn/ArmNN.hpp>
15#include <armnn/INetwork.hpp>
16#include <CpuExecutor.h>
17
18#include "Utils.hpp"
19
20#include <memory>
21#include <vector>
22#include <set>
23
24namespace armnn_driver
25{
26
27class ConstTensorPin;
28class LayerInputHandle;
29
30enum class ConversionResult
31{
32 Success,
33 ErrorMappingPools,
34 UnsupportedFeature
35};
36
37// A helper performing the conversion from an AndroidNN driver Model representation,
38// to an armnn::INetwork object
arovir01a15dc112018-09-03 17:12:56 +010039template<typename HalVersion>
telsoa015307bc12018-03-09 13:51:08 +000040class ModelToINetworkConverter
41{
42public:
kevmay01bc5f7842018-08-30 12:34:39 +010043 using HalModel = typename HalVersion::Model;
44
telsoa01ce3e84a2018-08-31 09:31:35 +010045 ModelToINetworkConverter(armnn::Compute compute,
Matteo Martincighe48bdff2018-09-03 13:50:50 +010046 const HalModel& model,
47 const std::set<unsigned int>& forcedUnsupportedOperations);
telsoa015307bc12018-03-09 13:51:08 +000048
49 ConversionResult GetConversionResult() const { return m_ConversionResult; }
50
51 // Returns the ArmNN INetwork corresponding to the input model, if preparation went smoothly, nullptr otherwise.
52 armnn::INetwork* GetINetwork() const { return m_Network.get(); }
53
54 bool IsOperationSupported(uint32_t operationIndex) const;
55
56private:
57 void Convert();
58
kevmay01bc5f7842018-08-30 12:34:39 +010059#if defined(ARMNN_ANDROID_NN_V1_1)
60 bool ConvertOperation(const ::android::hardware::neuralnetworks::V1_1::Operation& operation);
arovir01a15dc112018-09-03 17:12:56 +010061
62 bool ConvertDiv(const ::android::hardware::neuralnetworks::V1_1::Operation& operation);
kevmay01bc5f7842018-08-30 12:34:39 +010063#endif
64
telsoa01ce3e84a2018-08-31 09:31:35 +010065 bool ConvertOperation(const ::android::hardware::neuralnetworks::V1_0::Operation& operation);
telsoa015307bc12018-03-09 13:51:08 +000066
telsoa01ce3e84a2018-08-31 09:31:35 +010067 bool ConvertAdd(const ::android::hardware::neuralnetworks::V1_0::Operation& operation);
telsoa015307bc12018-03-09 13:51:08 +000068
telsoa01ce3e84a2018-08-31 09:31:35 +010069 bool ConvertAveragePool2d(const ::android::hardware::neuralnetworks::V1_0::Operation& operation);
telsoa015307bc12018-03-09 13:51:08 +000070
telsoa01ce3e84a2018-08-31 09:31:35 +010071 bool ConvertConcatenation(const ::android::hardware::neuralnetworks::V1_0::Operation& operation);
telsoa015307bc12018-03-09 13:51:08 +000072
telsoa01ce3e84a2018-08-31 09:31:35 +010073 bool ConvertConv2d(const ::android::hardware::neuralnetworks::V1_0::Operation& operation);
telsoa015307bc12018-03-09 13:51:08 +000074
telsoa01ce3e84a2018-08-31 09:31:35 +010075 bool ConvertDepthwiseConv2d(const ::android::hardware::neuralnetworks::V1_0::Operation& operation);
telsoa015307bc12018-03-09 13:51:08 +000076
telsoa01ce3e84a2018-08-31 09:31:35 +010077 bool ConvertFloor(const ::android::hardware::neuralnetworks::V1_0::Operation& operation);
telsoa015307bc12018-03-09 13:51:08 +000078
telsoa01ce3e84a2018-08-31 09:31:35 +010079 bool ConvertFullyConnected(const ::android::hardware::neuralnetworks::V1_0::Operation& operation);
telsoa015307bc12018-03-09 13:51:08 +000080
telsoa01ce3e84a2018-08-31 09:31:35 +010081 bool ConvertLogistic(const ::android::hardware::neuralnetworks::V1_0::Operation& operation);
telsoa015307bc12018-03-09 13:51:08 +000082
telsoa01ce3e84a2018-08-31 09:31:35 +010083 bool ConvertLocalResponseNormalization(const ::android::hardware::neuralnetworks::V1_0::Operation& operation);
telsoa015307bc12018-03-09 13:51:08 +000084
telsoa01ce3e84a2018-08-31 09:31:35 +010085 bool ConvertL2Normalization(const ::android::hardware::neuralnetworks::V1_0::Operation& operation);
telsoa015307bc12018-03-09 13:51:08 +000086
telsoa01ce3e84a2018-08-31 09:31:35 +010087 bool ConvertL2Pool2d(const ::android::hardware::neuralnetworks::V1_0::Operation& operation);
telsoa015307bc12018-03-09 13:51:08 +000088
telsoa01ce3e84a2018-08-31 09:31:35 +010089 bool ConvertMaxPool2d(const ::android::hardware::neuralnetworks::V1_0::Operation& operation);
telsoa015307bc12018-03-09 13:51:08 +000090
telsoa01ce3e84a2018-08-31 09:31:35 +010091 bool ConvertMul(const ::android::hardware::neuralnetworks::V1_0::Operation& operation);
telsoa015307bc12018-03-09 13:51:08 +000092
telsoa01ce3e84a2018-08-31 09:31:35 +010093 bool ConvertReLu(const ::android::hardware::neuralnetworks::V1_0::Operation& operation);
telsoa015307bc12018-03-09 13:51:08 +000094
telsoa01ce3e84a2018-08-31 09:31:35 +010095 bool ConvertReLu1(const ::android::hardware::neuralnetworks::V1_0::Operation& operation);
telsoa015307bc12018-03-09 13:51:08 +000096
telsoa01ce3e84a2018-08-31 09:31:35 +010097 bool ConvertReLu6(const ::android::hardware::neuralnetworks::V1_0::Operation& operation);
telsoa015307bc12018-03-09 13:51:08 +000098
telsoa01ce3e84a2018-08-31 09:31:35 +010099 bool ConvertSoftmax(const ::android::hardware::neuralnetworks::V1_0::Operation& operation);
telsoa015307bc12018-03-09 13:51:08 +0000100
telsoa01ce3e84a2018-08-31 09:31:35 +0100101 bool ConvertTanH(const ::android::hardware::neuralnetworks::V1_0::Operation& operation);
telsoa015307bc12018-03-09 13:51:08 +0000102
telsoa01ce3e84a2018-08-31 09:31:35 +0100103 bool ConvertReshape(const ::android::hardware::neuralnetworks::V1_0::Operation& operation);
telsoa015307bc12018-03-09 13:51:08 +0000104
telsoa01ce3e84a2018-08-31 09:31:35 +0100105 bool ConvertResizeBilinear(const ::android::hardware::neuralnetworks::V1_0::Operation& operation);
telsoa015307bc12018-03-09 13:51:08 +0000106
telsoa01ce3e84a2018-08-31 09:31:35 +0100107 bool ConvertLstm(const ::android::hardware::neuralnetworks::V1_0::Operation& operation);
108
109 bool ConvertToActivation(const ::android::hardware::neuralnetworks::V1_0::Operation& operation,
110 const char* operationName,
arovir01a15dc112018-09-03 17:12:56 +0100111 const armnn::ActivationDescriptor& activationDesc);
telsoa015307bc12018-03-09 13:51:08 +0000112
telsoa01ce3e84a2018-08-31 09:31:35 +0100113 bool ConvertPooling2d(const ::android::hardware::neuralnetworks::V1_0::Operation& operation,
114 const char* name, armnn::PoolingAlgorithm poolType);
telsoa015307bc12018-03-09 13:51:08 +0000115
telsoa015307bc12018-03-09 13:51:08 +0000116 const void* GetOperandValueReadOnlyAddress(const Operand& operand) const;
117
arovir01a15dc112018-09-03 17:12:56 +0100118 template<typename HalOperation>
119 const Operand* GetInputOperand(const HalOperation& operation, uint32_t inputIndex) const;
telsoa015307bc12018-03-09 13:51:08 +0000120
arovir01a15dc112018-09-03 17:12:56 +0100121 template<typename HalOperation>
122 const Operand* GetOutputOperand(const HalOperation& operation, uint32_t outputIndex) const;
telsoa015307bc12018-03-09 13:51:08 +0000123
arovir01a15dc112018-09-03 17:12:56 +0100124 template<typename HalOperation, typename T>
125 bool GetInputScalar(const HalOperation& operation, uint32_t inputIndex, OperandType type, T& outValue) const;
telsoa015307bc12018-03-09 13:51:08 +0000126
arovir01a15dc112018-09-03 17:12:56 +0100127 template<typename HalOperation>
128 bool GetInputInt32(const HalOperation& operation, uint32_t inputIndex, int32_t& outValue) const;
telsoa015307bc12018-03-09 13:51:08 +0000129
arovir01a15dc112018-09-03 17:12:56 +0100130 template<typename HalOperation>
131 bool GetInputFloat32(const HalOperation& operation, uint32_t inputIndex, float& outValue) const;
telsoa015307bc12018-03-09 13:51:08 +0000132
arovir01a15dc112018-09-03 17:12:56 +0100133 template<typename HalOperation>
134 bool GetInputActivationFunctionImpl(const HalOperation& operation,
telsoa01ce3e84a2018-08-31 09:31:35 +0100135 uint32_t inputIndex,
136 OperandType type,
137 ActivationFn& outActivationFunction) const;
telsoa015307bc12018-03-09 13:51:08 +0000138
arovir01a15dc112018-09-03 17:12:56 +0100139 template<typename HalOperation>
140 bool GetInputActivationFunction(const HalOperation& operation,
telsoa01ce3e84a2018-08-31 09:31:35 +0100141 uint32_t inputIndex,
142 ActivationFn& outActivationFunction) const;
telsoa015307bc12018-03-09 13:51:08 +0000143
arovir01a15dc112018-09-03 17:12:56 +0100144 template<typename HalOperation>
145 bool GetInputActivationFunctionFromTensor(const HalOperation& operation,
telsoa01ce3e84a2018-08-31 09:31:35 +0100146 uint32_t inputIndex,
147 ActivationFn& outActivationFunction) const;
telsoa015307bc12018-03-09 13:51:08 +0000148
arovir01a15dc112018-09-03 17:12:56 +0100149 template<typename HalOperation>
150 bool GetOptionalInputActivation(const HalOperation& operation,
telsoa01ce3e84a2018-08-31 09:31:35 +0100151 uint32_t inputIndex,
152 ActivationFn& activationFunction) const;
153
arovir01a15dc112018-09-03 17:12:56 +0100154 template<typename HalOperation>
155 bool GetInputPaddingScheme(const HalOperation& operation,
telsoa01ce3e84a2018-08-31 09:31:35 +0100156 uint32_t inputIndex,
157 android::nn::PaddingScheme& outPaddingScheme) const;
158
arovir01a15dc112018-09-03 17:12:56 +0100159 template<typename HalOperation>
160 LayerInputHandle ConvertToLayerInputHandle(const HalOperation& operation, uint32_t inputIndex);
telsoa01ce3e84a2018-08-31 09:31:35 +0100161
arovir01a15dc112018-09-03 17:12:56 +0100162 template<typename HalOperation>
telsoa01ce3e84a2018-08-31 09:31:35 +0100163 ConstTensorPin ConvertOperationInputToConstTensorPin(
arovir01a15dc112018-09-03 17:12:56 +0100164 const HalOperation& operation,
165 uint32_t inputIndex,
telsoa015307bc12018-03-09 13:51:08 +0000166 const armnn::PermutationVector& dimensionMappings = g_DontPermute,
arovir01a15dc112018-09-03 17:12:56 +0100167 const armnn::TensorShape* overrideTensorShape = nullptr,
168 bool optional = false);
169
170 ConstTensorPin ConvertOperandToConstTensorPin(
171 const Operand& operand,
172 const armnn::PermutationVector& dimensionMappings = g_DontPermute,
173 const armnn::TensorShape* overrideTensorShape = nullptr,
174 bool optional = false);
telsoa015307bc12018-03-09 13:51:08 +0000175
176 bool GetTensorInt32Values(const Operand& operand, std::vector<int32_t>& outValues) const;
177
arovir01a15dc112018-09-03 17:12:56 +0100178 armnn::IConnectableLayer* ProcessActivation(const armnn::TensorInfo& tensorInfo,
179 ActivationFn activation,
telsoa015307bc12018-03-09 13:51:08 +0000180 armnn::IConnectableLayer* prevLayer);
181
arovir01a15dc112018-09-03 17:12:56 +0100182 template<typename HalOperation>
183 bool SetupAndTrackLayerOutputSlot(const HalOperation& operation,
telsoa01ce3e84a2018-08-31 09:31:35 +0100184 uint32_t operationOutputIndex,
185 armnn::IConnectableLayer& layer,
186 uint32_t layerOutputIndex);
telsoa015307bc12018-03-09 13:51:08 +0000187
arovir01a15dc112018-09-03 17:12:56 +0100188 template<typename HalOperation>
189 bool SetupAndTrackLayerOutputSlot(const HalOperation& operation,
telsoa01ce3e84a2018-08-31 09:31:35 +0100190 uint32_t outputIndex,
telsoa015307bc12018-03-09 13:51:08 +0000191 armnn::IConnectableLayer& layer);
192
telsoa015307bc12018-03-09 13:51:08 +0000193 // Input data
kevmay01bc5f7842018-08-30 12:34:39 +0100194 armnn::Compute m_Compute;
195 const HalModel& m_Model;
196 const std::set<unsigned int>& m_ForcedUnsupportedOperations;
telsoa015307bc12018-03-09 13:51:08 +0000197
198 // Output data
telsoa01ce3e84a2018-08-31 09:31:35 +0100199 armnn::INetworkPtr m_Network;
200 ConversionResult m_ConversionResult;
201 std::map<uint32_t, bool> m_OperationSupported;
telsoa015307bc12018-03-09 13:51:08 +0000202
203 // Working/intermediate data
arovir01a15dc112018-09-03 17:12:56 +0100204 std::vector<armnn::IOutputSlot*> m_OutputSlotForOperand;
telsoa015307bc12018-03-09 13:51:08 +0000205 std::vector<android::nn::RunTimePoolInfo> m_MemPools;
206};
207
arovir01a15dc112018-09-03 17:12:56 +0100208} // armnn_driver