blob: f1b7205ff1e22aadfd8bc4ab9c9f1b45c5bcb2b4 [file] [log] [blame]
surmeh01bceff2f2018-03-29 16:29:27 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
surmeh01bceff2f2018-03-29 16:29:27 +01004//
5#pragma once
6
7#include "armnnTfParser/ITfParser.hpp"
8
9#include "armnn/Types.hpp"
10#include "armnn/Tensor.hpp"
11#include "armnn/INetwork.hpp"
12
narpra016f37f832018-12-21 18:30:00 +000013#include <list>
surmeh01bceff2f2018-03-29 16:29:27 +010014#include <map>
15#include <memory>
16#include <unordered_map>
jimfly0184c70e62018-12-19 13:14:46 +000017#include <utility>
surmeh01bceff2f2018-03-29 16:29:27 +010018#include <vector>
19
20namespace armnn
21{
22class TensorInfo;
23}
24
25namespace tensorflow
26{
27class GraphDef;
28class NodeDef;
29}
30
31namespace armnnTfParser
32{
33
34using BindingPointInfo = std::pair<armnn::LayerBindingId, armnn::TensorInfo>;
35
36class ParsedTfOperation;
37using ParsedTfOperationPtr = std::unique_ptr<ParsedTfOperation>;
38
39///
40/// WithOutputTensorIndex wraps a value and an index. The purpose of
telsoa01c577f2c2018-08-31 09:22:23 +010041/// this template is to signify that, in Tensorflow, the input name of
42/// a layer has the convention of 'inputTensorName:#index', where the
43/// #index can be omitted and it implicitly means the 0 output of
surmeh01bceff2f2018-03-29 16:29:27 +010044/// the referenced layer. By supporting this notation we can handle
45/// layers with multiple outputs, such as Split.
46///
47template <typename T>
48struct WithOutputTensorIndex
49{
50 T m_IndexedValue;
51 unsigned int m_Index;
52
53 WithOutputTensorIndex(const T & value, unsigned int index)
54 : m_IndexedValue{value}
55 , m_Index{index} {}
56
57 WithOutputTensorIndex(T && value, unsigned int index)
58 : m_IndexedValue{value}
59 , m_Index{index} {}
60};
61
62using OutputOfParsedTfOperation = WithOutputTensorIndex<ParsedTfOperation *>;
63using OutputOfConstNodeDef = WithOutputTensorIndex<const tensorflow::NodeDef*>;
64using OutputId = WithOutputTensorIndex<std::string>;
65
66class TfParser : public ITfParser
67{
68public:
telsoa01c577f2c2018-08-31 09:22:23 +010069 /// Creates the network from a protobuf text file on the disk.
surmeh01bceff2f2018-03-29 16:29:27 +010070 virtual armnn::INetworkPtr CreateNetworkFromTextFile(
71 const char* graphFile,
72 const std::map<std::string, armnn::TensorShape>& inputShapes,
73 const std::vector<std::string>& requestedOutputs) override;
74
telsoa01c577f2c2018-08-31 09:22:23 +010075 /// Creates the network from a protobuf binary file on the disk.
surmeh01bceff2f2018-03-29 16:29:27 +010076 virtual armnn::INetworkPtr CreateNetworkFromBinaryFile(
77 const char* graphFile,
78 const std::map<std::string, armnn::TensorShape>& inputShapes,
79 const std::vector<std::string>& requestedOutputs) override;
80
telsoa01c577f2c2018-08-31 09:22:23 +010081 /// Creates the network directly from protobuf text in a string. Useful for debugging/testing.
surmeh01bceff2f2018-03-29 16:29:27 +010082 virtual armnn::INetworkPtr CreateNetworkFromString(
83 const char* protoText,
84 const std::map<std::string, armnn::TensorShape>& inputShapes,
85 const std::vector<std::string>& requestedOutputs) override;
86
telsoa01c577f2c2018-08-31 09:22:23 +010087 /// Retrieves binding info (layer id and tensor info) for the network input identified by the given layer name.
surmeh01bceff2f2018-03-29 16:29:27 +010088 virtual BindingPointInfo GetNetworkInputBindingInfo(const std::string& name) const override;
89
telsoa01c577f2c2018-08-31 09:22:23 +010090 /// Retrieves binding info (layer id and tensor info) for the network output identified by the given layer name.
surmeh01bceff2f2018-03-29 16:29:27 +010091 virtual BindingPointInfo GetNetworkOutputBindingInfo(const std::string& name) const override;
92
93public:
94 TfParser();
95
96private:
97 template <typename T>
98 friend class ParsedConstTfOperation;
99 friend class ParsedMatMulTfOperation;
telsoa01c577f2c2018-08-31 09:22:23 +0100100 friend class ParsedMulTfOperation;
surmeh01bceff2f2018-03-29 16:29:27 +0100101
telsoa01c577f2c2018-08-31 09:22:23 +0100102 /// Parses a GraphDef loaded into memory from one of the other CreateNetwork*.
surmeh01bceff2f2018-03-29 16:29:27 +0100103 armnn::INetworkPtr CreateNetworkFromGraphDef(const tensorflow::GraphDef& graphDef,
104 const std::map<std::string, armnn::TensorShape>& inputShapes,
105 const std::vector<std::string>& requestedOutputs);
106
telsoa01c577f2c2018-08-31 09:22:23 +0100107 /// Sets up variables and then performs BFS to parse all nodes.
surmeh01bceff2f2018-03-29 16:29:27 +0100108 void LoadGraphDef(const tensorflow::GraphDef& graphDef);
109
telsoa01c577f2c2018-08-31 09:22:23 +0100110 /// Parses a given node, assuming nodes before it in the graph have been done.
surmeh01bceff2f2018-03-29 16:29:27 +0100111 void LoadNodeDef(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
112
telsoa01c577f2c2018-08-31 09:22:23 +0100113 /// Handling identity layers as the input for Conv2D layer.
surmeh01bceff2f2018-03-29 16:29:27 +0100114 const tensorflow::NodeDef* ResolveIdentityNode(const tensorflow::NodeDef* nodeDef);
115 /// Finds the nodes connected as inputs of the given node in the graph.
116 std::vector<OutputOfConstNodeDef> GetTfInputNodes(const tensorflow::NodeDef& nodeDef) const;
117 /// Finds the IParsedTfOperations for the nodes connected as inputs of the given node in the graph,
118 /// and throws an exception if the number of inputs does not match the expected one.
119 /// This will automatically resolve any identity nodes. The result vector contains the parsed operation
120 /// together with the output tensor index to make the connection unambiguous.
121 std::vector<OutputOfParsedTfOperation> GetInputParsedTfOperationsChecked(const tensorflow::NodeDef& nodeDef,
122 std::size_t expectedNumInputs);
123
124 ParsedTfOperationPtr ParseConst(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
125
telsoa01c577f2c2018-08-31 09:22:23 +0100126 /// Checks if there is a pre-parsed const tensor available with the given name and Type.
surmeh01bceff2f2018-03-29 16:29:27 +0100127 template<typename Type>
128 bool HasParsedConstTensor(const std::string & nodeName) const;
jimfly01f6ba7472018-12-04 10:09:52 +0000129 template<typename Type>
130 bool HasParsedConstTensor(ParsedTfOperation* parsedTfOpPtr) const;
surmeh01bceff2f2018-03-29 16:29:27 +0100131
132 ParsedTfOperationPtr ParseAdd(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
Ferran Balaguerfbdad032018-12-28 18:15:24 +0000133 ParsedTfOperationPtr ParseAddN(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
surmeh01bceff2f2018-03-29 16:29:27 +0100134 ParsedTfOperationPtr ParseBiasAdd(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
135 ParsedTfOperationPtr ParseConv2D(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
136 ParsedTfOperationPtr ParseDepthwiseConv2D(const tensorflow::NodeDef& nodeDef,const tensorflow::GraphDef& graphDef);
Conor Kennedyc2130a02018-12-05 11:05:54 +0000137 ParsedTfOperationPtr ParseExpandDims(const tensorflow::NodeDef& nodeDef,const tensorflow::GraphDef& graphDef);
surmeh01bceff2f2018-03-29 16:29:27 +0100138 ParsedTfOperationPtr ParseFusedBatchNorm(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
139 ParsedTfOperationPtr ParseConcat(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
140 ParsedTfOperationPtr ParseIdentity(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
141 ParsedTfOperationPtr ParseLrn(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
142 ParsedTfOperationPtr ParseMatMul(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
Ferran Balaguer51dd62f2019-01-11 19:29:18 +0000143 ParsedTfOperationPtr ParseMean(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
surmeh01bceff2f2018-03-29 16:29:27 +0100144 ParsedTfOperationPtr ParseMul(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
145 ParsedTfOperationPtr ParsePlaceholder(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
saoste01bbd40612018-08-28 15:41:51 +0100146 ParsedTfOperationPtr ParseRealDiv(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
surmeh01bceff2f2018-03-29 16:29:27 +0100147 ParsedTfOperationPtr ParseRelu(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
148 ParsedTfOperationPtr ParseRelu6(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
149 ParsedTfOperationPtr ParseReshape(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
150 ParsedTfOperationPtr ParseResizeBilinear(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
Mohamed Nour Abouelseoud7a8892f2019-01-09 14:19:58 +0000151 ParsedTfOperationPtr ParseRsqrt(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
surmeh01bceff2f2018-03-29 16:29:27 +0100152 ParsedTfOperationPtr ParseShape(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
153 ParsedTfOperationPtr ParseSqueeze(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
154 ParsedTfOperationPtr ParseSigmoid(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
155 ParsedTfOperationPtr ParseSoftmax(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
156 ParsedTfOperationPtr ParseSoftplus(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
Sadik Armagan2ad6cb42018-12-27 11:23:44 +0000157 ParsedTfOperationPtr ParseSplit(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
surmeh01bceff2f2018-03-29 16:29:27 +0100158 ParsedTfOperationPtr ParseTanh(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
159 ParsedTfOperationPtr ParseMaxPool(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
160 ParsedTfOperationPtr ParseAvgPool(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
161 ParsedTfOperationPtr ParsePooling2d(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef,
162 armnn::PoolingAlgorithm pooltype);
jimfly0184c70e62018-12-19 13:14:46 +0000163 ParsedTfOperationPtr ParseEqual(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
telsoa01c577f2c2018-08-31 09:22:23 +0100164 ParsedTfOperationPtr ParseMaximum(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
Nattapat Chaimanowong24df8222018-12-04 13:47:02 +0000165 ParsedTfOperationPtr ParseMinimum(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
jimfly01a06bf312018-12-18 16:24:51 +0000166 ParsedTfOperationPtr ParseGreater(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
jimfly01f6ba7472018-12-04 10:09:52 +0000167 ParsedTfOperationPtr ParsePad(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
jimfly0123be07e2018-12-04 17:47:22 +0000168 ParsedTfOperationPtr ParseSub(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
surmeh01bceff2f2018-03-29 16:29:27 +0100169 ParsedTfOperationPtr AddActivationLayer(const tensorflow::NodeDef& nodeDef, armnn::ActivationDescriptor& desc);
170 ParsedTfOperationPtr AddAdditionLayer(const tensorflow::NodeDef& nodeDef, bool isBiasAdd = false);
saoste01bbd40612018-08-28 15:41:51 +0100171 ParsedTfOperationPtr AddRealDivLayer(const tensorflow::NodeDef& nodeDef);
Sadik Armagan975c09a2018-12-04 10:02:08 +0000172 ParsedTfOperationPtr AddMaximumLayer(const tensorflow::NodeDef& nodeDef);
telsoa01c577f2c2018-08-31 09:22:23 +0100173
174private:
175 armnn::IConnectableLayer* AddMultiplicationLayer(const tensorflow::NodeDef& nodeDef);
176
surmeh01bceff2f2018-03-29 16:29:27 +0100177 armnn::IConnectableLayer* AddFullyConnectedLayer(const tensorflow::NodeDef& matMulNodeDef,
178 const tensorflow::NodeDef* addNodeDef, const char* armnnLayerName);
179
telsoa01c577f2c2018-08-31 09:22:23 +0100180 bool IsSupportedLeakyReluPattern(const tensorflow::NodeDef& mulNodeDef,
181 size_t alphaLayerIndex,
182 const OutputOfParsedTfOperation& otherOp,
183 armnn::IOutputSlot** outputOfLeakyRelu,
184 armnn::ActivationDescriptor & desc);
185
jimfly0184c70e62018-12-19 13:14:46 +0000186 std::pair<armnn::IOutputSlot*, armnn::IOutputSlot*> ProcessElementwiseInputSlots(
187 const tensorflow::NodeDef& nodeDef, const std::string& layerName);
188
189 ParsedTfOperationPtr ProcessElementwiseLayer(
190 armnn::IOutputSlot* input0Slot,
191 armnn::IOutputSlot* input1Slot,
192 armnn::IConnectableLayer* const layer,
193 const tensorflow::NodeDef& nodeDef);
194
Ferran Balaguerfbdad032018-12-28 18:15:24 +0000195 armnn::IConnectableLayer* CreateAdditionLayer(
196 const tensorflow::NodeDef& nodeDef,
197 armnn::IOutputSlot* input0Slot,
198 armnn::IOutputSlot* input1Slot,
199 const std::string& layerName);
200
201 armnn::IConnectableLayer* CreateAdditionLayer(
202 const tensorflow::NodeDef& nodeDef,
203 const OutputOfParsedTfOperation& opOne,
204 const OutputOfParsedTfOperation& opTwo,
205 unsigned int numberOfAddition);
206
207 armnn::IConnectableLayer* CreateAdditionLayer(
208 const tensorflow::NodeDef& nodeDef,
209 armnn::IConnectableLayer* layerOne,
210 armnn::IConnectableLayer* layerTwo,
211 unsigned int numberOfAddition,
212 unsigned long numberOfLayersToConnect,
213 bool isOdd);
214
215 armnn::IConnectableLayer* CreateAdditionLayer(
216 const tensorflow::NodeDef& nodeDef,
217 const OutputOfParsedTfOperation& op,
218 armnn::IConnectableLayer* layer);
219
surmeh01bceff2f2018-03-29 16:29:27 +0100220 static std::pair<armnn::LayerBindingId, armnn::TensorInfo> GetBindingInfo(const std::string& layerName,
221 const char* bindingPointDesc,
222 const std::unordered_map<std::string, BindingPointInfo>& nameToBindingInfo);
223
224 void TrackInputBinding(armnn::IConnectableLayer* layer,
225 armnn::LayerBindingId id,
226 const armnn::TensorInfo& tensorInfo);
227
228 void TrackOutputBinding(armnn::IConnectableLayer* layer,
229 armnn::LayerBindingId id,
230 const armnn::TensorInfo& tensorInfo);
231
232 static void TrackBindingPoint(armnn::IConnectableLayer* layer, armnn::LayerBindingId id,
233 const armnn::TensorInfo& tensorInfo,
234 const char* bindingPointDesc,
235 std::unordered_map<std::string, BindingPointInfo>& nameToBindingInfo);
236
237 void Cleanup();
238
telsoa01c577f2c2018-08-31 09:22:23 +0100239 /// The network we're building. Gets cleared after it is passed to the user.
surmeh01bceff2f2018-03-29 16:29:27 +0100240 armnn::INetworkPtr m_Network;
241
242 using OperationParsingFunction = ParsedTfOperationPtr(TfParser::*)(const tensorflow::NodeDef& nodeDef,
243 const tensorflow::GraphDef& graphDef);
244
telsoa01c577f2c2018-08-31 09:22:23 +0100245 /// Map of TensorFlow operation names to parsing member functions.
surmeh01bceff2f2018-03-29 16:29:27 +0100246 static const std::map<std::string, OperationParsingFunction> ms_OperationNameToParsingFunctions;
247
narpra016f37f832018-12-21 18:30:00 +0000248 static const std::list<std::string> m_ControlInputs;
249
surmeh01bceff2f2018-03-29 16:29:27 +0100250 std::map<std::string, armnn::TensorShape> m_InputShapes;
251 std::vector<std::string> m_RequestedOutputs;
252
telsoa01c577f2c2018-08-31 09:22:23 +0100253 /// Map of nodes extracted from the GraphDef to speed up parsing.
surmeh01bceff2f2018-03-29 16:29:27 +0100254 std::unordered_map<std::string, const tensorflow::NodeDef*> m_NodesByName;
255
256 std::unordered_map<std::string, ParsedTfOperationPtr> m_ParsedTfOperations;
257
telsoa01c577f2c2018-08-31 09:22:23 +0100258 /// Maps input layer names to their corresponding ids and tensor info.
surmeh01bceff2f2018-03-29 16:29:27 +0100259 std::unordered_map<std::string, BindingPointInfo> m_NetworkInputsBindingInfo;
260
telsoa01c577f2c2018-08-31 09:22:23 +0100261 /// Maps output layer names to their corresponding ids and tensor info.
surmeh01bceff2f2018-03-29 16:29:27 +0100262 std::unordered_map<std::string, BindingPointInfo> m_NetworkOutputsBindingInfo;
263};
Ferran Balaguer51dd62f2019-01-11 19:29:18 +0000264
surmeh01bceff2f2018-03-29 16:29:27 +0100265}