blob: 12a085d6caf109515ead625c419f5cc56681c306 [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5#pragma once
6
7#include "armnn/INetwork.hpp"
8#include "armnnTfLiteParser/ITfLiteParser.hpp"
Nattapat Chaimanowongb66504b2018-10-17 15:19:14 +01009#include "armnn/Types.hpp"
telsoa01c577f2c2018-08-31 09:22:23 +010010
11#include <schema_generated.h>
12#include <functional>
Aron Virginas-Tarc975f922019-10-23 17:38:17 +010013#include <unordered_map>
telsoa01c577f2c2018-08-31 09:22:23 +010014#include <vector>
15
16namespace armnnTfLiteParser
17{
18
Kevin May7d96b162021-02-03 17:38:41 +000019class TfLiteParserImpl
telsoa01c577f2c2018-08-31 09:22:23 +010020{
21public:
22 // Shorthands for TfLite types
23 using ModelPtr = std::unique_ptr<tflite::ModelT>;
Derek Lambertiff05cc52019-04-26 13:05:17 +010024 using SubgraphPtr = std::unique_ptr<tflite::SubGraphT>;
telsoa01c577f2c2018-08-31 09:22:23 +010025 using OperatorPtr = std::unique_ptr<tflite::OperatorT>;
26 using OperatorCodePtr = std::unique_ptr<tflite::OperatorCodeT>;
27 using TensorPtr = std::unique_ptr<tflite::TensorT>;
28 using TensorRawPtr = const tflite::TensorT *;
29 using TensorRawPtrVector = std::vector<TensorRawPtr>;
30 using TensorIdRawPtr = std::pair<size_t, TensorRawPtr>;
31 using TensorIdRawPtrVector = std::vector<TensorIdRawPtr>;
32 using BufferPtr = std::unique_ptr<tflite::BufferT>;
33 using BufferRawPtr = const tflite::BufferT *;
34
35public:
36 /// Create the network from a flatbuffers binary file on disk
Kevin May7d96b162021-02-03 17:38:41 +000037 armnn::INetworkPtr CreateNetworkFromBinaryFile(const char* graphFile);
telsoa01c577f2c2018-08-31 09:22:23 +010038
39 /// Create the network from a flatbuffers binary
Kevin May7d96b162021-02-03 17:38:41 +000040 armnn::INetworkPtr CreateNetworkFromBinary(const std::vector<uint8_t> & binaryContent);
telsoa01c577f2c2018-08-31 09:22:23 +010041
42
43 /// Retrieve binding info (layer id and tensor info) for the network input identified by
44 /// the given layer name and subgraph id
Kevin May7d96b162021-02-03 17:38:41 +000045 BindingPointInfo GetNetworkInputBindingInfo(size_t subgraphId,
46 const std::string& name) const;
telsoa01c577f2c2018-08-31 09:22:23 +010047
48 /// Retrieve binding info (layer id and tensor info) for the network output identified by
49 /// the given layer name and subgraph id
Kevin May7d96b162021-02-03 17:38:41 +000050 BindingPointInfo GetNetworkOutputBindingInfo(size_t subgraphId,
51 const std::string& name) const;
telsoa01c577f2c2018-08-31 09:22:23 +010052
53 /// Return the number of subgraphs in the parsed model
Kevin May7d96b162021-02-03 17:38:41 +000054 size_t GetSubgraphCount() const;
telsoa01c577f2c2018-08-31 09:22:23 +010055
56 /// Return the input tensor names for a given subgraph
Kevin May7d96b162021-02-03 17:38:41 +000057 std::vector<std::string> GetSubgraphInputTensorNames(size_t subgraphId) const;
telsoa01c577f2c2018-08-31 09:22:23 +010058
59 /// Return the output tensor names for a given subgraph
Kevin May7d96b162021-02-03 17:38:41 +000060 std::vector<std::string> GetSubgraphOutputTensorNames(size_t subgraphId) const;
telsoa01c577f2c2018-08-31 09:22:23 +010061
Kevin May7d96b162021-02-03 17:38:41 +000062 TfLiteParserImpl(const armnn::Optional<ITfLiteParser::TfLiteParserOptions>& options = armnn::EmptyOptional());
63 ~TfLiteParserImpl() = default;
telsoa01c577f2c2018-08-31 09:22:23 +010064
65public:
66 // testable helpers
67 static ModelPtr LoadModelFromFile(const char * fileName);
68 static ModelPtr LoadModelFromBinary(const uint8_t * binaryContent, size_t len);
69 static TensorRawPtrVector GetInputs(const ModelPtr & model, size_t subgraphIndex, size_t operatorIndex);
70 static TensorRawPtrVector GetOutputs(const ModelPtr & model, size_t subgraphIndex, size_t operatorIndex);
71 static TensorIdRawPtrVector GetSubgraphInputs(const ModelPtr & model, size_t subgraphIndex);
72 static TensorIdRawPtrVector GetSubgraphOutputs(const ModelPtr & model, size_t subgraphIndex);
73 static std::vector<int32_t>& GetInputTensorIds(const ModelPtr& model, size_t subgraphIndex, size_t operatorIndex);
74 static std::vector<int32_t>& GetOutputTensorIds(const ModelPtr& model, size_t subgraphIndex, size_t operatorIndex);
75
76 static BufferRawPtr GetBuffer(const ModelPtr& model, size_t bufferIndex);
77 static armnn::TensorInfo OutputShapeOfSqueeze(const std::vector<uint32_t> & squeezeDims,
78 const armnn::TensorInfo & inputTensorInfo);
Sadikb94967b2018-09-19 15:30:00 +010079 static armnn::TensorInfo OutputShapeOfReshape(const armnn::TensorInfo & inputTensorInfo,
80 const std::vector<int32_t> & targetDimsIn);
telsoa01c577f2c2018-08-31 09:22:23 +010081
82private:
83 // No copying allowed until it is wanted and properly implemented
Kevin May7d96b162021-02-03 17:38:41 +000084 TfLiteParserImpl(const TfLiteParserImpl &) = delete;
85 TfLiteParserImpl & operator=(const TfLiteParserImpl &) = delete;
telsoa01c577f2c2018-08-31 09:22:23 +010086
87 /// Create the network from an already loaded flatbuffers model
88 armnn::INetworkPtr CreateNetworkFromModel();
89
90 // signature for the parser functions
Kevin May7d96b162021-02-03 17:38:41 +000091 using OperatorParsingFunction = void(TfLiteParserImpl::*)(size_t subgraphIndex, size_t operatorIndex);
telsoa01c577f2c2018-08-31 09:22:23 +010092
Aron Virginas-Tarc975f922019-10-23 17:38:17 +010093 void ParseCustomOperator(size_t subgraphIndex, size_t operatorIndex);
telsoa01c577f2c2018-08-31 09:22:23 +010094 void ParseUnsupportedOperator(size_t subgraphIndex, size_t operatorIndex);
Aron Virginas-Tarc975f922019-10-23 17:38:17 +010095
Finn Williamsc42c3842019-01-22 14:18:11 +000096 void ParseActivation(size_t subgraphIndex, size_t operatorIndex, armnn::ActivationFunction activationType);
Nina Drozd200e3802019-04-15 09:47:39 +010097 void ParseAdd(size_t subgraphIndex, size_t operatorIndex);
telsoa01c577f2c2018-08-31 09:22:23 +010098 void ParseAveragePool2D(size_t subgraphIndex, size_t operatorIndex);
Bruno Goncalvesdb947e22019-02-08 18:52:21 -020099 void ParseBatchToSpaceND(size_t subgraphIndex, size_t operatorIndex);
Sadik Armagan479045b2018-10-01 11:51:37 +0100100 void ParseConcatenation(size_t subgraphIndex, size_t operatorIndex);
telsoa01c577f2c2018-08-31 09:22:23 +0100101 void ParseConv2D(size_t subgraphIndex, size_t operatorIndex);
Sadik Armagan26868492021-01-22 14:25:31 +0000102 void ParseDepthToSpace(size_t subgraphIndex, size_t operatorIndex);
telsoa01c577f2c2018-08-31 09:22:23 +0100103 void ParseDepthwiseConv2D(size_t subgraphIndex, size_t operatorIndex);
Finn Williamsed66d142019-12-06 09:55:55 +0000104 void ParseDequantize(size_t subgraphIndex, size_t operatorIndex);
keidav011b3e2ea2019-02-21 10:07:37 +0000105 void ParseDetectionPostProcess(size_t subgraphIndex, size_t operatorIndex);
Matthew Sloyan7515d072020-12-16 12:50:01 +0000106 void ParseElu(size_t subgraphIndex, size_t operatorIndex);
Derek Lambertif0176992020-04-28 13:37:49 +0100107 void ParseExp(size_t subgraphIndex, size_t operatorIndex);
Sadik Armagan8853c1f2018-10-22 09:04:18 +0100108 void ParseFullyConnected(size_t subgraphIndex, size_t operatorIndex);
Sadik Armagan26868492021-01-22 14:25:31 +0000109 void ParseGather(size_t subgraphIndex, size_t operatorIndex);
Jan Eilers2f746b32020-07-28 14:00:06 +0100110 void ParseHardSwish(size_t subgraphIndex, size_t operatorIndex);
Sadik Armagan12239e72020-05-27 11:06:17 +0100111 void ParseLeakyRelu(size_t subgraphIndex, size_t operatorIndex);
Finn Williamsc42c3842019-01-22 14:18:11 +0000112 void ParseLogistic(size_t subgraphIndex, size_t operatorIndex);
Matthew Jackson28c94572019-07-18 10:47:03 +0100113 void ParseL2Normalization(size_t subgraphIndex, size_t operatorIndex);
Nattapat Chaimanowongb66504b2018-10-17 15:19:14 +0100114 void ParseMaxPool2D(size_t subgraphIndex, size_t operatorIndex);
Bruno Goncalvesb8d805e2019-02-12 22:57:13 -0200115 void ParseMaximum(size_t subgraphIndex, size_t operatorIndex);
Nina Drozd200e3802019-04-15 09:47:39 +0100116 void ParseMean(size_t subgraphIndex, size_t operatorIndex);
Bruno Goncalves8f6d7a72019-02-12 22:58:18 -0200117 void ParseMinimum(size_t subgraphIndex, size_t operatorIndex);
Nina Drozd200e3802019-04-15 09:47:39 +0100118 void ParseMul(size_t subgraphIndex, size_t operatorIndex);
Darshan Patel83fcf982020-05-26 22:22:42 +0530119 void ParseNeg(size_t subgraphIndex, size_t operatorIndex);
Matthew Jacksonbcca1f42019-07-16 11:39:21 +0100120 void ParsePack(size_t subgraphIndex, size_t operatorIndex);
Nina Drozd200e3802019-04-15 09:47:39 +0100121 void ParsePad(size_t subgraphIndex, size_t operatorIndex);
122 void ParsePool(size_t subgraphIndex, size_t operatorIndex, armnn::PoolingAlgorithm algorithm);
Sadik Armagan66dedc72019-12-10 16:32:07 +0000123 void ParseQuantize(size_t subgraphIndex, size_t operatorIndex);
Sadik Armagan58f39192018-09-17 14:14:39 +0100124 void ParseRelu(size_t subgraphIndex, size_t operatorIndex);
125 void ParseRelu6(size_t subgraphIndex, size_t operatorIndex);
Sadikb94967b2018-09-19 15:30:00 +0100126 void ParseReshape(size_t subgraphIndex, size_t operatorIndex);
Sadik Armagana3b31f02019-12-05 09:08:53 +0000127 void ParseResize(size_t subgraphIndex, size_t operatorIndex, armnn::ResizeMethod resizeMethod);
Bruno Goncalves3f58ddb2019-02-07 18:40:11 -0200128 void ParseResizeBilinear(size_t subgraphIndex, size_t operatorIndex);
Sadik Armagana3b31f02019-12-05 09:08:53 +0000129 void ParseResizeNearestNeighbor(size_t subgraphIndex, size_t operatorIndex);
josh minorba424d22019-11-13 10:55:17 -0600130 void ParseSlice(size_t subgraphIndex, size_t operatorIndex);
Sadik Armagan479045b2018-10-01 11:51:37 +0100131 void ParseSoftmax(size_t subgraphIndex, size_t operatorIndex);
Bruno Goncalvesbaded142019-02-08 19:02:48 -0200132 void ParseSpaceToBatchND(size_t subgraphIndex, size_t operatorIndex);
Nina Drozd200e3802019-04-15 09:47:39 +0100133 void ParseSplit(size_t subgraphIndex, size_t operatorIndex);
Derek Lambertif0176992020-04-28 13:37:49 +0100134 void ParseSplitV(size_t subgraphIndex, size_t operatorIndex);
Sadik Armagan479045b2018-10-01 11:51:37 +0100135 void ParseSqueeze(size_t subgraphIndex, size_t operatorIndex);
Bruno Goncalves451d95b2019-02-12 22:59:22 -0200136 void ParseStridedSlice(size_t subgraphIndex, size_t operatorIndex);
Bruno Goncalvesbbeae262019-02-07 18:37:39 -0200137 void ParseSub(size_t subgraphIndex, size_t operatorIndex);
Sadik Armagan0c3ea5b2021-02-03 09:29:30 +0000138 void ParseSum(size_t subgraphIndex, size_t operatorIndex);
Darshan Patel42b3d7d2020-05-25 22:30:07 +0530139 void ParseDiv(size_t subgraphIndex, size_t operatorIndex);
Nina Drozd99851762019-04-09 09:37:38 +0100140 void ParseTanH(size_t subgraphIndex, size_t operatorIndex);
Keith Davis4cd29a02019-09-09 14:49:20 +0100141 void ParseTranspose(size_t subgraphIndex, size_t operatorIndex);
Matthew Jackson74bf7da2019-08-16 16:51:42 +0100142 void ParseTransposeConv(size_t subgraphIndex, size_t operatorIndex);
Nina Drozd200e3802019-04-15 09:47:39 +0100143 void ParseUnpack(size_t subgraphIndex, size_t operatorIndex);
Inki Daed4619e22020-09-10 15:33:54 +0900144 void ParseArgMax(size_t subgraphIndex, size_t operatorIndex);
Nattapat Chaimanowongb66504b2018-10-17 15:19:14 +0100145
telsoa01c577f2c2018-08-31 09:22:23 +0100146 void RegisterProducerOfTensor(size_t subgraphIndex, size_t tensorIndex, armnn::IOutputSlot* slot);
147 void RegisterConsumerOfTensor(size_t subgraphIndex, size_t tensorIndex, armnn::IInputSlot* slot);
148 void RegisterInputSlots(size_t subgraphIndex,
149 size_t operatorIndex,
150 armnn::IConnectableLayer* layer,
151 const std::vector<unsigned int>& tensorIndexes);
152 void RegisterOutputSlots(size_t subgraphIndex,
153 size_t operatorIndex,
154 armnn::IConnectableLayer* layer,
155 const std::vector<unsigned int>& tensorIndexes);
156
157 void SetupInputLayers(size_t subgraphIndex);
158 void SetupOutputLayers(size_t subgraphIndex);
Bruno Goncalves3d7efe92018-12-27 14:21:43 -0200159 void SetupConstantLayers(size_t subgraphIndex);
telsoa01c577f2c2018-08-31 09:22:23 +0100160
161 void ResetParser();
162
Bruno Goncalves9c761a62018-12-27 14:20:35 -0200163 void AddBroadcastReshapeLayer(size_t subgraphIndex,
164 size_t operatorIndex,
165 armnn::IConnectableLayer* layer);
166
telsoa01c577f2c2018-08-31 09:22:23 +0100167 /// Attach an activation layer to the one passed as a parameter
Sadik Armagan58f39192018-09-17 14:14:39 +0100168 armnn::IConnectableLayer* AddFusedActivationLayer(armnn::IConnectableLayer* layer,
169 unsigned int outputSlot,
170 tflite::ActivationFunctionType activationType);
telsoa01c577f2c2018-08-31 09:22:23 +0100171
172 // SupportedDataStorage's purpose is to hold data till we pass over to the network.
173 // We don't care about the content, and we want a single datatype to simplify the code.
174 struct SupportedDataStorage
175 {
Matteo Martincigh747ef822018-12-18 09:26:39 +0000176 public:
177 // Convenience constructors
178 SupportedDataStorage(std::unique_ptr<float[]>&& data);
179 SupportedDataStorage(std::unique_ptr<uint8_t[]>&& data);
Keith Davisd305e1a2020-01-22 11:57:54 +0000180 SupportedDataStorage(std::unique_ptr<int8_t[]>&& data);
Matteo Martincigh747ef822018-12-18 09:26:39 +0000181 SupportedDataStorage(std::unique_ptr<int32_t[]>&& data);
telsoa01c577f2c2018-08-31 09:22:23 +0100182
Matteo Martincigh747ef822018-12-18 09:26:39 +0000183 private:
184 // Pointers to the data buffers
185 std::unique_ptr<float[]> m_FloatData;
186 std::unique_ptr<uint8_t[]> m_Uint8Data;
Keith Davisd305e1a2020-01-22 11:57:54 +0000187 std::unique_ptr<int8_t[]> m_Int8Data;
Matteo Martincigh747ef822018-12-18 09:26:39 +0000188 std::unique_ptr<int32_t[]> m_Int32Data;
telsoa01c577f2c2018-08-31 09:22:23 +0100189 };
190
Matteo Martincigh747ef822018-12-18 09:26:39 +0000191
192 template<typename T>
Kevin May7d96b162021-02-03 17:38:41 +0000193 std::pair<armnn::ConstTensor, TfLiteParserImpl::SupportedDataStorage>
194 CreateConstTensorAndStoreData(TfLiteParserImpl::BufferRawPtr bufferPtr,
195 TfLiteParserImpl::TensorRawPtr tensorPtr,
Matteo Martincigh747ef822018-12-18 09:26:39 +0000196 armnn::TensorInfo& tensorInfo,
197 armnn::Optional<armnn::PermutationVector&> permutationVector);
198
199 std::pair<armnn::ConstTensor, SupportedDataStorage>
200 CreateConstTensor(TensorRawPtr tensorPtr,
201 armnn::TensorInfo& tensorInfo,
202 armnn::Optional<armnn::PermutationVector&> permutationVector);
telsoa01c577f2c2018-08-31 09:22:23 +0100203
Aron Virginas-Tarc975f922019-10-23 17:38:17 +0100204 // Settings for configuring the TfLiteParser
205 armnn::Optional<ITfLiteParser::TfLiteParserOptions> m_Options;
206
telsoa01c577f2c2018-08-31 09:22:23 +0100207 /// The network we're building. Gets cleared after it is passed to the user
208 armnn::INetworkPtr m_Network;
telsoa01c577f2c2018-08-31 09:22:23 +0100209 ModelPtr m_Model;
210
Aron Virginas-Tarc975f922019-10-23 17:38:17 +0100211 std::vector<OperatorParsingFunction> m_ParserFunctions;
212 std::unordered_map<std::string, OperatorParsingFunction> m_CustomParserFunctions;
213
telsoa01c577f2c2018-08-31 09:22:23 +0100214 /// A mapping of an output slot to each of the input slots it should be connected to
215 /// The outputSlot is from the layer that creates this tensor as one of its ouputs
216 /// The inputSlots are from the layers that use this tensor as one of their inputs
217 struct TensorSlots
218 {
219 armnn::IOutputSlot* outputSlot;
220 std::vector<armnn::IInputSlot*> inputSlots;
221
222 TensorSlots() : outputSlot(nullptr) { }
223 };
224 typedef std::vector<TensorSlots> TensorConnections;
225 /// Connections for tensors in each subgraph
226 /// The first index is the subgraph ID, the second index is the tensor ID
227 std::vector<TensorConnections> m_SubgraphConnections;
Narumol Prangnawarat4628d052019-02-25 17:26:05 +0000228
229 /// This is used in case that the model does not speciry the output.
230 /// The shape can be calculated from the options.
231 std::vector<std::vector<unsigned int>> m_OverridenOutputShapes;
telsoa01c577f2c2018-08-31 09:22:23 +0100232};
233
234}