blob: 2895487214dfaf3e7abb4cfecefb3c89823c6ef5 [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5#pragma once
6
7#include "armnn/INetwork.hpp"
8#include "armnnTfLiteParser/ITfLiteParser.hpp"
Nattapat Chaimanowongb66504b2018-10-17 15:19:14 +01009#include "armnn/Types.hpp"
telsoa01c577f2c2018-08-31 09:22:23 +010010
11#include <schema_generated.h>
12#include <functional>
13#include <vector>
14
15namespace armnnTfLiteParser
16{
17
18class TfLiteParser : public ITfLiteParser
19{
20public:
21 // Shorthands for TfLite types
22 using ModelPtr = std::unique_ptr<tflite::ModelT>;
23 using SubGraphPtr = std::unique_ptr<tflite::SubGraphT>;
24 using OperatorPtr = std::unique_ptr<tflite::OperatorT>;
25 using OperatorCodePtr = std::unique_ptr<tflite::OperatorCodeT>;
26 using TensorPtr = std::unique_ptr<tflite::TensorT>;
27 using TensorRawPtr = const tflite::TensorT *;
28 using TensorRawPtrVector = std::vector<TensorRawPtr>;
29 using TensorIdRawPtr = std::pair<size_t, TensorRawPtr>;
30 using TensorIdRawPtrVector = std::vector<TensorIdRawPtr>;
31 using BufferPtr = std::unique_ptr<tflite::BufferT>;
32 using BufferRawPtr = const tflite::BufferT *;
33
34public:
35 /// Create the network from a flatbuffers binary file on disk
36 virtual armnn::INetworkPtr CreateNetworkFromBinaryFile(const char* graphFile) override;
37
38 /// Create the network from a flatbuffers binary
39 virtual armnn::INetworkPtr CreateNetworkFromBinary(const std::vector<uint8_t> & binaryContent) override;
40
41
42 /// Retrieve binding info (layer id and tensor info) for the network input identified by
43 /// the given layer name and subgraph id
44 virtual BindingPointInfo GetNetworkInputBindingInfo(size_t subgraphId,
45 const std::string& name) const override;
46
47 /// Retrieve binding info (layer id and tensor info) for the network output identified by
48 /// the given layer name and subgraph id
49 virtual BindingPointInfo GetNetworkOutputBindingInfo(size_t subgraphId,
50 const std::string& name) const override;
51
52 /// Return the number of subgraphs in the parsed model
53 virtual size_t GetSubgraphCount() const override;
54
55 /// Return the input tensor names for a given subgraph
56 virtual std::vector<std::string> GetSubgraphInputTensorNames(size_t subgraphId) const override;
57
58 /// Return the output tensor names for a given subgraph
59 virtual std::vector<std::string> GetSubgraphOutputTensorNames(size_t subgraphId) const override;
60
61 TfLiteParser();
62 virtual ~TfLiteParser() {}
63
64public:
65 // testable helpers
66 static ModelPtr LoadModelFromFile(const char * fileName);
67 static ModelPtr LoadModelFromBinary(const uint8_t * binaryContent, size_t len);
68 static TensorRawPtrVector GetInputs(const ModelPtr & model, size_t subgraphIndex, size_t operatorIndex);
69 static TensorRawPtrVector GetOutputs(const ModelPtr & model, size_t subgraphIndex, size_t operatorIndex);
70 static TensorIdRawPtrVector GetSubgraphInputs(const ModelPtr & model, size_t subgraphIndex);
71 static TensorIdRawPtrVector GetSubgraphOutputs(const ModelPtr & model, size_t subgraphIndex);
72 static std::vector<int32_t>& GetInputTensorIds(const ModelPtr& model, size_t subgraphIndex, size_t operatorIndex);
73 static std::vector<int32_t>& GetOutputTensorIds(const ModelPtr& model, size_t subgraphIndex, size_t operatorIndex);
74
75 static BufferRawPtr GetBuffer(const ModelPtr& model, size_t bufferIndex);
76 static armnn::TensorInfo OutputShapeOfSqueeze(const std::vector<uint32_t> & squeezeDims,
77 const armnn::TensorInfo & inputTensorInfo);
Sadikb94967b2018-09-19 15:30:00 +010078 static armnn::TensorInfo OutputShapeOfReshape(const armnn::TensorInfo & inputTensorInfo,
79 const std::vector<int32_t> & targetDimsIn);
telsoa01c577f2c2018-08-31 09:22:23 +010080
81private:
82 // No copying allowed until it is wanted and properly implemented
83 TfLiteParser(const TfLiteParser &) = delete;
84 TfLiteParser & operator=(const TfLiteParser &) = delete;
85
86 /// Create the network from an already loaded flatbuffers model
87 armnn::INetworkPtr CreateNetworkFromModel();
88
89 // signature for the parser functions
90 using OperatorParsingFunction = void(TfLiteParser::*)(size_t subgraphIndex, size_t operatorIndex);
91
92 void ParseUnsupportedOperator(size_t subgraphIndex, size_t operatorIndex);
Finn Williamsc42c3842019-01-22 14:18:11 +000093 void ParseActivation(size_t subgraphIndex, size_t operatorIndex, armnn::ActivationFunction activationType);
telsoa01c577f2c2018-08-31 09:22:23 +010094 void ParseAveragePool2D(size_t subgraphIndex, size_t operatorIndex);
Bruno Goncalvesdb947e22019-02-08 18:52:21 -020095 void ParseBatchToSpaceND(size_t subgraphIndex, size_t operatorIndex);
Sadik Armagan479045b2018-10-01 11:51:37 +010096 void ParseConcatenation(size_t subgraphIndex, size_t operatorIndex);
telsoa01c577f2c2018-08-31 09:22:23 +010097 void ParseConv2D(size_t subgraphIndex, size_t operatorIndex);
98 void ParseDepthwiseConv2D(size_t subgraphIndex, size_t operatorIndex);
keidav011b3e2ea2019-02-21 10:07:37 +000099 void ParseDetectionPostProcess(size_t subgraphIndex, size_t operatorIndex);
Sadik Armagan8853c1f2018-10-22 09:04:18 +0100100 void ParseFullyConnected(size_t subgraphIndex, size_t operatorIndex);
Finn Williamsc42c3842019-01-22 14:18:11 +0000101 void ParseLogistic(size_t subgraphIndex, size_t operatorIndex);
Nattapat Chaimanowongb66504b2018-10-17 15:19:14 +0100102 void ParseMaxPool2D(size_t subgraphIndex, size_t operatorIndex);
Bruno Goncalvesb8d805e2019-02-12 22:57:13 -0200103 void ParseMaximum(size_t subgraphIndex, size_t operatorIndex);
Bruno Goncalves8f6d7a72019-02-12 22:58:18 -0200104 void ParseMinimum(size_t subgraphIndex, size_t operatorIndex);
Sadik Armagan58f39192018-09-17 14:14:39 +0100105 void ParseRelu(size_t subgraphIndex, size_t operatorIndex);
106 void ParseRelu6(size_t subgraphIndex, size_t operatorIndex);
Sadikb94967b2018-09-19 15:30:00 +0100107 void ParseReshape(size_t subgraphIndex, size_t operatorIndex);
Bruno Goncalves3f58ddb2019-02-07 18:40:11 -0200108 void ParseResizeBilinear(size_t subgraphIndex, size_t operatorIndex);
Sadik Armagan479045b2018-10-01 11:51:37 +0100109 void ParseSoftmax(size_t subgraphIndex, size_t operatorIndex);
Bruno Goncalvesbaded142019-02-08 19:02:48 -0200110 void ParseSpaceToBatchND(size_t subgraphIndex, size_t operatorIndex);
Sadik Armagan479045b2018-10-01 11:51:37 +0100111 void ParseSqueeze(size_t subgraphIndex, size_t operatorIndex);
Bruno Goncalves451d95b2019-02-12 22:59:22 -0200112 void ParseStridedSlice(size_t subgraphIndex, size_t operatorIndex);
Bruno Goncalvesbbeae262019-02-07 18:37:39 -0200113 void ParseSub(size_t subgraphIndex, size_t operatorIndex);
Bruno Goncalvesd4ac6a42018-12-18 12:56:22 -0200114 void ParseAdd(size_t subgraphIndex, size_t operatorIndex);
Bruno Goncalvesf803f782018-12-18 13:40:30 -0200115 void ParseMul(size_t subgraphIndex, size_t operatorIndex);
Bruno Goncalves2235cee2018-12-19 12:51:45 -0200116 void ParseMean(size_t subgraphIndex, size_t operatorIndex);
Bruno Goncalves6c2355b2018-12-19 12:52:01 -0200117 void ParsePad(size_t subgraphIndex, size_t operatorIndex);
telsoa01c577f2c2018-08-31 09:22:23 +0100118
Nattapat Chaimanowongb66504b2018-10-17 15:19:14 +0100119 void ParsePool(size_t subgraphIndex, size_t operatorIndex, armnn::PoolingAlgorithm algorithm);
120
telsoa01c577f2c2018-08-31 09:22:23 +0100121 void RegisterProducerOfTensor(size_t subgraphIndex, size_t tensorIndex, armnn::IOutputSlot* slot);
122 void RegisterConsumerOfTensor(size_t subgraphIndex, size_t tensorIndex, armnn::IInputSlot* slot);
123 void RegisterInputSlots(size_t subgraphIndex,
124 size_t operatorIndex,
125 armnn::IConnectableLayer* layer,
126 const std::vector<unsigned int>& tensorIndexes);
127 void RegisterOutputSlots(size_t subgraphIndex,
128 size_t operatorIndex,
129 armnn::IConnectableLayer* layer,
130 const std::vector<unsigned int>& tensorIndexes);
131
132 void SetupInputLayers(size_t subgraphIndex);
133 void SetupOutputLayers(size_t subgraphIndex);
Bruno Goncalves3d7efe92018-12-27 14:21:43 -0200134 void SetupConstantLayers(size_t subgraphIndex);
telsoa01c577f2c2018-08-31 09:22:23 +0100135
136 void ResetParser();
137
Bruno Goncalves9c761a62018-12-27 14:20:35 -0200138 void AddBroadcastReshapeLayer(size_t subgraphIndex,
139 size_t operatorIndex,
140 armnn::IConnectableLayer* layer);
141
telsoa01c577f2c2018-08-31 09:22:23 +0100142 /// Attach an activation layer to the one passed as a parameter
Sadik Armagan58f39192018-09-17 14:14:39 +0100143 armnn::IConnectableLayer* AddFusedActivationLayer(armnn::IConnectableLayer* layer,
144 unsigned int outputSlot,
145 tflite::ActivationFunctionType activationType);
telsoa01c577f2c2018-08-31 09:22:23 +0100146
147 // SupportedDataStorage's purpose is to hold data till we pass over to the network.
148 // We don't care about the content, and we want a single datatype to simplify the code.
149 struct SupportedDataStorage
150 {
Matteo Martincigh747ef822018-12-18 09:26:39 +0000151 public:
152 // Convenience constructors
153 SupportedDataStorage(std::unique_ptr<float[]>&& data);
154 SupportedDataStorage(std::unique_ptr<uint8_t[]>&& data);
155 SupportedDataStorage(std::unique_ptr<int32_t[]>&& data);
telsoa01c577f2c2018-08-31 09:22:23 +0100156
Matteo Martincigh747ef822018-12-18 09:26:39 +0000157 private:
158 // Pointers to the data buffers
159 std::unique_ptr<float[]> m_FloatData;
160 std::unique_ptr<uint8_t[]> m_Uint8Data;
161 std::unique_ptr<int32_t[]> m_Int32Data;
telsoa01c577f2c2018-08-31 09:22:23 +0100162 };
163
Matteo Martincigh747ef822018-12-18 09:26:39 +0000164
165 template<typename T>
166 std::pair<armnn::ConstTensor, TfLiteParser::SupportedDataStorage>
167 CreateConstTensorAndStoreData(TfLiteParser::BufferRawPtr bufferPtr,
168 TfLiteParser::TensorRawPtr tensorPtr,
169 armnn::TensorInfo& tensorInfo,
170 armnn::Optional<armnn::PermutationVector&> permutationVector);
171
172 std::pair<armnn::ConstTensor, SupportedDataStorage>
173 CreateConstTensor(TensorRawPtr tensorPtr,
174 armnn::TensorInfo& tensorInfo,
175 armnn::Optional<armnn::PermutationVector&> permutationVector);
telsoa01c577f2c2018-08-31 09:22:23 +0100176
177 /// The network we're building. Gets cleared after it is passed to the user
178 armnn::INetworkPtr m_Network;
179 std::vector<OperatorParsingFunction> m_ParserFunctions;
180 ModelPtr m_Model;
181
182 /// A mapping of an output slot to each of the input slots it should be connected to
183 /// The outputSlot is from the layer that creates this tensor as one of its ouputs
184 /// The inputSlots are from the layers that use this tensor as one of their inputs
185 struct TensorSlots
186 {
187 armnn::IOutputSlot* outputSlot;
188 std::vector<armnn::IInputSlot*> inputSlots;
189
190 TensorSlots() : outputSlot(nullptr) { }
191 };
192 typedef std::vector<TensorSlots> TensorConnections;
193 /// Connections for tensors in each subgraph
194 /// The first index is the subgraph ID, the second index is the tensor ID
195 std::vector<TensorConnections> m_SubgraphConnections;
Narumol Prangnawarat4628d052019-02-25 17:26:05 +0000196
197 /// This is used in case that the model does not speciry the output.
198 /// The shape can be calculated from the options.
199 std::vector<std::vector<unsigned int>> m_OverridenOutputShapes;
telsoa01c577f2c2018-08-31 09:22:23 +0100200};
201
202}