blob: 5c04cceb825dca1cde526eabed375f9dc83791c5 [file] [log] [blame]
surmeh01bceff2f2018-03-29 16:29:27 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
surmeh01bceff2f2018-03-29 16:29:27 +01004//
5#pragma once
6
7#include "armnnTfParser/ITfParser.hpp"
8
9#include "armnn/Types.hpp"
10#include "armnn/Tensor.hpp"
11#include "armnn/INetwork.hpp"
12
narpra016f37f832018-12-21 18:30:00 +000013#include <list>
surmeh01bceff2f2018-03-29 16:29:27 +010014#include <map>
15#include <memory>
16#include <unordered_map>
jimfly0184c70e62018-12-19 13:14:46 +000017#include <utility>
surmeh01bceff2f2018-03-29 16:29:27 +010018#include <vector>
19
20namespace armnn
21{
22class TensorInfo;
23}
24
25namespace tensorflow
26{
27class GraphDef;
28class NodeDef;
29}
30
31namespace armnnTfParser
32{
33
surmeh01bceff2f2018-03-29 16:29:27 +010034class ParsedTfOperation;
35using ParsedTfOperationPtr = std::unique_ptr<ParsedTfOperation>;
36
37///
38/// WithOutputTensorIndex wraps a value and an index. The purpose of
telsoa01c577f2c2018-08-31 09:22:23 +010039/// this template is to signify that, in Tensorflow, the input name of
40/// a layer has the convention of 'inputTensorName:#index', where the
41/// #index can be omitted and it implicitly means the 0 output of
surmeh01bceff2f2018-03-29 16:29:27 +010042/// the referenced layer. By supporting this notation we can handle
43/// layers with multiple outputs, such as Split.
44///
45template <typename T>
46struct WithOutputTensorIndex
47{
48 T m_IndexedValue;
49 unsigned int m_Index;
50
51 WithOutputTensorIndex(const T & value, unsigned int index)
52 : m_IndexedValue{value}
53 , m_Index{index} {}
54
55 WithOutputTensorIndex(T && value, unsigned int index)
56 : m_IndexedValue{value}
57 , m_Index{index} {}
58};
59
60using OutputOfParsedTfOperation = WithOutputTensorIndex<ParsedTfOperation *>;
61using OutputOfConstNodeDef = WithOutputTensorIndex<const tensorflow::NodeDef*>;
62using OutputId = WithOutputTensorIndex<std::string>;
63
Kevin May7d96b162021-02-03 17:38:41 +000064struct ITfParser::TfParserImpl
surmeh01bceff2f2018-03-29 16:29:27 +010065{
66public:
telsoa01c577f2c2018-08-31 09:22:23 +010067 /// Creates the network from a protobuf text file on the disk.
Kevin May7d96b162021-02-03 17:38:41 +000068 armnn::INetworkPtr CreateNetworkFromTextFile(
surmeh01bceff2f2018-03-29 16:29:27 +010069 const char* graphFile,
70 const std::map<std::string, armnn::TensorShape>& inputShapes,
Kevin May7d96b162021-02-03 17:38:41 +000071 const std::vector<std::string>& requestedOutputs);
surmeh01bceff2f2018-03-29 16:29:27 +010072
telsoa01c577f2c2018-08-31 09:22:23 +010073 /// Creates the network from a protobuf binary file on the disk.
Kevin May7d96b162021-02-03 17:38:41 +000074 armnn::INetworkPtr CreateNetworkFromBinaryFile(
surmeh01bceff2f2018-03-29 16:29:27 +010075 const char* graphFile,
76 const std::map<std::string, armnn::TensorShape>& inputShapes,
Kevin May7d96b162021-02-03 17:38:41 +000077 const std::vector<std::string>& requestedOutputs);
surmeh01bceff2f2018-03-29 16:29:27 +010078
telsoa01c577f2c2018-08-31 09:22:23 +010079 /// Creates the network directly from protobuf text in a string. Useful for debugging/testing.
Kevin May7d96b162021-02-03 17:38:41 +000080 armnn::INetworkPtr CreateNetworkFromString(
surmeh01bceff2f2018-03-29 16:29:27 +010081 const char* protoText,
82 const std::map<std::string, armnn::TensorShape>& inputShapes,
Kevin May7d96b162021-02-03 17:38:41 +000083 const std::vector<std::string>& requestedOutputs);
surmeh01bceff2f2018-03-29 16:29:27 +010084
telsoa01c577f2c2018-08-31 09:22:23 +010085 /// Retrieves binding info (layer id and tensor info) for the network input identified by the given layer name.
Kevin May7d96b162021-02-03 17:38:41 +000086 BindingPointInfo GetNetworkInputBindingInfo(const std::string& name) const;
surmeh01bceff2f2018-03-29 16:29:27 +010087
telsoa01c577f2c2018-08-31 09:22:23 +010088 /// Retrieves binding info (layer id and tensor info) for the network output identified by the given layer name.
Kevin May7d96b162021-02-03 17:38:41 +000089 BindingPointInfo GetNetworkOutputBindingInfo(const std::string& name) const;
surmeh01bceff2f2018-03-29 16:29:27 +010090
Kevin May7d96b162021-02-03 17:38:41 +000091 TfParserImpl();
92 ~TfParserImpl() = default;
surmeh01bceff2f2018-03-29 16:29:27 +010093
Kevin May7d96b162021-02-03 17:38:41 +000094 TfParserImpl(const TfParserImpl&) = delete;
95 TfParserImpl& operator=(const TfParserImpl&) = delete;
surmeh01bceff2f2018-03-29 16:29:27 +010096
telsoa01c577f2c2018-08-31 09:22:23 +010097 /// Parses a GraphDef loaded into memory from one of the other CreateNetwork*.
surmeh01bceff2f2018-03-29 16:29:27 +010098 armnn::INetworkPtr CreateNetworkFromGraphDef(const tensorflow::GraphDef& graphDef,
99 const std::map<std::string, armnn::TensorShape>& inputShapes,
100 const std::vector<std::string>& requestedOutputs);
101
telsoa01c577f2c2018-08-31 09:22:23 +0100102 /// Sets up variables and then performs BFS to parse all nodes.
surmeh01bceff2f2018-03-29 16:29:27 +0100103 void LoadGraphDef(const tensorflow::GraphDef& graphDef);
104
telsoa01c577f2c2018-08-31 09:22:23 +0100105 /// Parses a given node, assuming nodes before it in the graph have been done.
surmeh01bceff2f2018-03-29 16:29:27 +0100106 void LoadNodeDef(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
107
telsoa01c577f2c2018-08-31 09:22:23 +0100108 /// Handling identity layers as the input for Conv2D layer.
surmeh01bceff2f2018-03-29 16:29:27 +0100109 const tensorflow::NodeDef* ResolveIdentityNode(const tensorflow::NodeDef* nodeDef);
110 /// Finds the nodes connected as inputs of the given node in the graph.
111 std::vector<OutputOfConstNodeDef> GetTfInputNodes(const tensorflow::NodeDef& nodeDef) const;
112 /// Finds the IParsedTfOperations for the nodes connected as inputs of the given node in the graph,
113 /// and throws an exception if the number of inputs does not match the expected one.
114 /// This will automatically resolve any identity nodes. The result vector contains the parsed operation
115 /// together with the output tensor index to make the connection unambiguous.
116 std::vector<OutputOfParsedTfOperation> GetInputParsedTfOperationsChecked(const tensorflow::NodeDef& nodeDef,
117 std::size_t expectedNumInputs);
118
119 ParsedTfOperationPtr ParseConst(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
120
telsoa01c577f2c2018-08-31 09:22:23 +0100121 /// Checks if there is a pre-parsed const tensor available with the given name and Type.
surmeh01bceff2f2018-03-29 16:29:27 +0100122 template<typename Type>
123 bool HasParsedConstTensor(const std::string & nodeName) const;
jimfly01f6ba7472018-12-04 10:09:52 +0000124 template<typename Type>
125 bool HasParsedConstTensor(ParsedTfOperation* parsedTfOpPtr) const;
surmeh01bceff2f2018-03-29 16:29:27 +0100126
Saoirse Stewart91c0eff2019-02-27 11:07:57 +0000127 unsigned int GetConstInputIndex(const std::vector<OutputOfParsedTfOperation>& inputs);
128
surmeh01bceff2f2018-03-29 16:29:27 +0100129 ParsedTfOperationPtr ParseAdd(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
Ferran Balaguerfbdad032018-12-28 18:15:24 +0000130 ParsedTfOperationPtr ParseAddN(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
surmeh01bceff2f2018-03-29 16:29:27 +0100131 ParsedTfOperationPtr ParseBiasAdd(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
132 ParsedTfOperationPtr ParseConv2D(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
Jim Flynn6cde7ed2019-02-20 14:25:11 +0000133 ParsedTfOperationPtr ParseDepthwiseConv2D(const tensorflow::NodeDef& nodeDef,
134 const tensorflow::GraphDef& graphDef);
135 ParsedTfOperationPtr ParseExpandDims(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
surmeh01bceff2f2018-03-29 16:29:27 +0100136 ParsedTfOperationPtr ParseFusedBatchNorm(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
137 ParsedTfOperationPtr ParseConcat(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
138 ParsedTfOperationPtr ParseIdentity(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
139 ParsedTfOperationPtr ParseLrn(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
140 ParsedTfOperationPtr ParseMatMul(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
Ferran Balaguer51dd62f2019-01-11 19:29:18 +0000141 ParsedTfOperationPtr ParseMean(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
surmeh01bceff2f2018-03-29 16:29:27 +0100142 ParsedTfOperationPtr ParseMul(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
143 ParsedTfOperationPtr ParsePlaceholder(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
saoste01bbd40612018-08-28 15:41:51 +0100144 ParsedTfOperationPtr ParseRealDiv(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
surmeh01bceff2f2018-03-29 16:29:27 +0100145 ParsedTfOperationPtr ParseRelu(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
146 ParsedTfOperationPtr ParseRelu6(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
147 ParsedTfOperationPtr ParseReshape(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
148 ParsedTfOperationPtr ParseResizeBilinear(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
Mohamed Nour Abouelseoud7a8892f2019-01-09 14:19:58 +0000149 ParsedTfOperationPtr ParseRsqrt(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
surmeh01bceff2f2018-03-29 16:29:27 +0100150 ParsedTfOperationPtr ParseShape(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
151 ParsedTfOperationPtr ParseSqueeze(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
152 ParsedTfOperationPtr ParseSigmoid(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
153 ParsedTfOperationPtr ParseSoftmax(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
154 ParsedTfOperationPtr ParseSoftplus(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
Sadik Armagan2ad6cb42018-12-27 11:23:44 +0000155 ParsedTfOperationPtr ParseSplit(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
Georgios Pinitas5e90aab2020-02-14 14:46:51 +0000156 ParsedTfOperationPtr ParseStridedSlice(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
surmeh01bceff2f2018-03-29 16:29:27 +0100157 ParsedTfOperationPtr ParseTanh(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
158 ParsedTfOperationPtr ParseMaxPool(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
159 ParsedTfOperationPtr ParseAvgPool(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
Jim Flynn6cde7ed2019-02-20 14:25:11 +0000160 ParsedTfOperationPtr ParsePooling2d(const tensorflow::NodeDef& nodeDef,
161 const tensorflow::GraphDef& graphDef,
162 armnn::PoolingAlgorithm pooltype);
jimfly0184c70e62018-12-19 13:14:46 +0000163 ParsedTfOperationPtr ParseEqual(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
telsoa01c577f2c2018-08-31 09:22:23 +0100164 ParsedTfOperationPtr ParseMaximum(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
Nattapat Chaimanowong24df8222018-12-04 13:47:02 +0000165 ParsedTfOperationPtr ParseMinimum(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
FrancisMurtagh94412af2019-01-24 10:53:39 +0000166 ParsedTfOperationPtr ParseGather(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
jimfly01a06bf312018-12-18 16:24:51 +0000167 ParsedTfOperationPtr ParseGreater(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
jimfly01f6ba7472018-12-04 10:09:52 +0000168 ParsedTfOperationPtr ParsePad(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
jimfly0123be07e2018-12-04 17:47:22 +0000169 ParsedTfOperationPtr ParseSub(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
Sadik Armagan48d70932020-02-18 15:18:27 +0000170 ParsedTfOperationPtr ParseStack(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
Sang-Hoon Parkdd3f71b2020-02-18 11:27:35 +0000171 ParsedTfOperationPtr ParseTranspose(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
surmeh01bceff2f2018-03-29 16:29:27 +0100172 ParsedTfOperationPtr AddActivationLayer(const tensorflow::NodeDef& nodeDef, armnn::ActivationDescriptor& desc);
173 ParsedTfOperationPtr AddAdditionLayer(const tensorflow::NodeDef& nodeDef, bool isBiasAdd = false);
saoste01bbd40612018-08-28 15:41:51 +0100174 ParsedTfOperationPtr AddRealDivLayer(const tensorflow::NodeDef& nodeDef);
Sadik Armagan975c09a2018-12-04 10:02:08 +0000175 ParsedTfOperationPtr AddMaximumLayer(const tensorflow::NodeDef& nodeDef);
telsoa01c577f2c2018-08-31 09:22:23 +0100176
telsoa01c577f2c2018-08-31 09:22:23 +0100177 armnn::IConnectableLayer* AddMultiplicationLayer(const tensorflow::NodeDef& nodeDef);
178
surmeh01bceff2f2018-03-29 16:29:27 +0100179 armnn::IConnectableLayer* AddFullyConnectedLayer(const tensorflow::NodeDef& matMulNodeDef,
180 const tensorflow::NodeDef* addNodeDef, const char* armnnLayerName);
181
telsoa01c577f2c2018-08-31 09:22:23 +0100182 bool IsSupportedLeakyReluPattern(const tensorflow::NodeDef& mulNodeDef,
183 size_t alphaLayerIndex,
184 const OutputOfParsedTfOperation& otherOp,
185 armnn::IOutputSlot** outputOfLeakyRelu,
186 armnn::ActivationDescriptor & desc);
187
jimfly0184c70e62018-12-19 13:14:46 +0000188 std::pair<armnn::IOutputSlot*, armnn::IOutputSlot*> ProcessElementwiseInputSlots(
189 const tensorflow::NodeDef& nodeDef, const std::string& layerName);
190
kevmay012b4d88e2019-01-24 14:05:09 +0000191 ParsedTfOperationPtr ProcessComparisonLayer(
192 armnn::IOutputSlot* input0Slot,
193 armnn::IOutputSlot* input1Slot,
194 armnn::IConnectableLayer* const layer,
195 const tensorflow::NodeDef& nodeDef);
196
jimfly0184c70e62018-12-19 13:14:46 +0000197 ParsedTfOperationPtr ProcessElementwiseLayer(
198 armnn::IOutputSlot* input0Slot,
199 armnn::IOutputSlot* input1Slot,
200 armnn::IConnectableLayer* const layer,
201 const tensorflow::NodeDef& nodeDef);
202
Ferran Balaguerfbdad032018-12-28 18:15:24 +0000203 armnn::IConnectableLayer* CreateAdditionLayer(
204 const tensorflow::NodeDef& nodeDef,
205 armnn::IOutputSlot* input0Slot,
206 armnn::IOutputSlot* input1Slot,
207 const std::string& layerName);
208
209 armnn::IConnectableLayer* CreateAdditionLayer(
210 const tensorflow::NodeDef& nodeDef,
211 const OutputOfParsedTfOperation& opOne,
212 const OutputOfParsedTfOperation& opTwo,
213 unsigned int numberOfAddition);
214
215 armnn::IConnectableLayer* CreateAdditionLayer(
216 const tensorflow::NodeDef& nodeDef,
217 armnn::IConnectableLayer* layerOne,
218 armnn::IConnectableLayer* layerTwo,
219 unsigned int numberOfAddition,
220 unsigned long numberOfLayersToConnect,
221 bool isOdd);
222
223 armnn::IConnectableLayer* CreateAdditionLayer(
224 const tensorflow::NodeDef& nodeDef,
225 const OutputOfParsedTfOperation& op,
226 armnn::IConnectableLayer* layer);
227
surmeh01bceff2f2018-03-29 16:29:27 +0100228 static std::pair<armnn::LayerBindingId, armnn::TensorInfo> GetBindingInfo(const std::string& layerName,
229 const char* bindingPointDesc,
230 const std::unordered_map<std::string, BindingPointInfo>& nameToBindingInfo);
231
232 void TrackInputBinding(armnn::IConnectableLayer* layer,
233 armnn::LayerBindingId id,
234 const armnn::TensorInfo& tensorInfo);
235
236 void TrackOutputBinding(armnn::IConnectableLayer* layer,
237 armnn::LayerBindingId id,
238 const armnn::TensorInfo& tensorInfo);
239
240 static void TrackBindingPoint(armnn::IConnectableLayer* layer, armnn::LayerBindingId id,
241 const armnn::TensorInfo& tensorInfo,
242 const char* bindingPointDesc,
243 std::unordered_map<std::string, BindingPointInfo>& nameToBindingInfo);
244
245 void Cleanup();
246
telsoa01c577f2c2018-08-31 09:22:23 +0100247 /// The network we're building. Gets cleared after it is passed to the user.
surmeh01bceff2f2018-03-29 16:29:27 +0100248 armnn::INetworkPtr m_Network;
249
Kevin May7d96b162021-02-03 17:38:41 +0000250 using OperationParsingFunction = ParsedTfOperationPtr(TfParserImpl::*)(const tensorflow::NodeDef& nodeDef,
251 const tensorflow::GraphDef& graphDef);
surmeh01bceff2f2018-03-29 16:29:27 +0100252
telsoa01c577f2c2018-08-31 09:22:23 +0100253 /// Map of TensorFlow operation names to parsing member functions.
surmeh01bceff2f2018-03-29 16:29:27 +0100254 static const std::map<std::string, OperationParsingFunction> ms_OperationNameToParsingFunctions;
255
narpra016f37f832018-12-21 18:30:00 +0000256 static const std::list<std::string> m_ControlInputs;
257
surmeh01bceff2f2018-03-29 16:29:27 +0100258 std::map<std::string, armnn::TensorShape> m_InputShapes;
259 std::vector<std::string> m_RequestedOutputs;
260
telsoa01c577f2c2018-08-31 09:22:23 +0100261 /// Map of nodes extracted from the GraphDef to speed up parsing.
surmeh01bceff2f2018-03-29 16:29:27 +0100262 std::unordered_map<std::string, const tensorflow::NodeDef*> m_NodesByName;
263
264 std::unordered_map<std::string, ParsedTfOperationPtr> m_ParsedTfOperations;
265
telsoa01c577f2c2018-08-31 09:22:23 +0100266 /// Maps input layer names to their corresponding ids and tensor info.
surmeh01bceff2f2018-03-29 16:29:27 +0100267 std::unordered_map<std::string, BindingPointInfo> m_NetworkInputsBindingInfo;
268
telsoa01c577f2c2018-08-31 09:22:23 +0100269 /// Maps output layer names to their corresponding ids and tensor info.
surmeh01bceff2f2018-03-29 16:29:27 +0100270 std::unordered_map<std::string, BindingPointInfo> m_NetworkOutputsBindingInfo;
271};
Ferran Balaguer51dd62f2019-01-11 19:29:18 +0000272
surmeh01bceff2f2018-03-29 16:29:27 +0100273}