blob: 5ca867c0f766566753195c3cd8abd32e7c6260f7 [file] [log] [blame]
surmeh01bceff2f2018-03-29 16:29:27 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
surmeh01bceff2f2018-03-29 16:29:27 +01004//
5#pragma once
6
7#include "armnnTfParser/ITfParser.hpp"
8
9#include "armnn/Types.hpp"
10#include "armnn/Tensor.hpp"
11#include "armnn/INetwork.hpp"
12
13#include <map>
14#include <memory>
15#include <unordered_map>
16#include <vector>
17
18namespace armnn
19{
20class TensorInfo;
21}
22
23namespace tensorflow
24{
25class GraphDef;
26class NodeDef;
27}
28
29namespace armnnTfParser
30{
31
32using BindingPointInfo = std::pair<armnn::LayerBindingId, armnn::TensorInfo>;
33
34class ParsedTfOperation;
35using ParsedTfOperationPtr = std::unique_ptr<ParsedTfOperation>;
36
37///
38/// WithOutputTensorIndex wraps a value and an index. The purpose of
telsoa01c577f2c2018-08-31 09:22:23 +010039/// this template is to signify that, in Tensorflow, the input name of
40/// a layer has the convention of 'inputTensorName:#index', where the
41/// #index can be omitted and it implicitly means the 0 output of
surmeh01bceff2f2018-03-29 16:29:27 +010042/// the referenced layer. By supporting this notation we can handle
43/// layers with multiple outputs, such as Split.
44///
45template <typename T>
46struct WithOutputTensorIndex
47{
48 T m_IndexedValue;
49 unsigned int m_Index;
50
51 WithOutputTensorIndex(const T & value, unsigned int index)
52 : m_IndexedValue{value}
53 , m_Index{index} {}
54
55 WithOutputTensorIndex(T && value, unsigned int index)
56 : m_IndexedValue{value}
57 , m_Index{index} {}
58};
59
60using OutputOfParsedTfOperation = WithOutputTensorIndex<ParsedTfOperation *>;
61using OutputOfConstNodeDef = WithOutputTensorIndex<const tensorflow::NodeDef*>;
62using OutputId = WithOutputTensorIndex<std::string>;
63
64class TfParser : public ITfParser
65{
66public:
telsoa01c577f2c2018-08-31 09:22:23 +010067 /// Creates the network from a protobuf text file on the disk.
surmeh01bceff2f2018-03-29 16:29:27 +010068 virtual armnn::INetworkPtr CreateNetworkFromTextFile(
69 const char* graphFile,
70 const std::map<std::string, armnn::TensorShape>& inputShapes,
71 const std::vector<std::string>& requestedOutputs) override;
72
telsoa01c577f2c2018-08-31 09:22:23 +010073 /// Creates the network from a protobuf binary file on the disk.
surmeh01bceff2f2018-03-29 16:29:27 +010074 virtual armnn::INetworkPtr CreateNetworkFromBinaryFile(
75 const char* graphFile,
76 const std::map<std::string, armnn::TensorShape>& inputShapes,
77 const std::vector<std::string>& requestedOutputs) override;
78
telsoa01c577f2c2018-08-31 09:22:23 +010079 /// Creates the network directly from protobuf text in a string. Useful for debugging/testing.
surmeh01bceff2f2018-03-29 16:29:27 +010080 virtual armnn::INetworkPtr CreateNetworkFromString(
81 const char* protoText,
82 const std::map<std::string, armnn::TensorShape>& inputShapes,
83 const std::vector<std::string>& requestedOutputs) override;
84
telsoa01c577f2c2018-08-31 09:22:23 +010085 /// Retrieves binding info (layer id and tensor info) for the network input identified by the given layer name.
surmeh01bceff2f2018-03-29 16:29:27 +010086 virtual BindingPointInfo GetNetworkInputBindingInfo(const std::string& name) const override;
87
telsoa01c577f2c2018-08-31 09:22:23 +010088 /// Retrieves binding info (layer id and tensor info) for the network output identified by the given layer name.
surmeh01bceff2f2018-03-29 16:29:27 +010089 virtual BindingPointInfo GetNetworkOutputBindingInfo(const std::string& name) const override;
90
91public:
92 TfParser();
93
94private:
95 template <typename T>
96 friend class ParsedConstTfOperation;
97 friend class ParsedMatMulTfOperation;
telsoa01c577f2c2018-08-31 09:22:23 +010098 friend class ParsedMulTfOperation;
surmeh01bceff2f2018-03-29 16:29:27 +010099
telsoa01c577f2c2018-08-31 09:22:23 +0100100 /// Parses a GraphDef loaded into memory from one of the other CreateNetwork*.
surmeh01bceff2f2018-03-29 16:29:27 +0100101 armnn::INetworkPtr CreateNetworkFromGraphDef(const tensorflow::GraphDef& graphDef,
102 const std::map<std::string, armnn::TensorShape>& inputShapes,
103 const std::vector<std::string>& requestedOutputs);
104
telsoa01c577f2c2018-08-31 09:22:23 +0100105 /// Sets up variables and then performs BFS to parse all nodes.
surmeh01bceff2f2018-03-29 16:29:27 +0100106 void LoadGraphDef(const tensorflow::GraphDef& graphDef);
107
telsoa01c577f2c2018-08-31 09:22:23 +0100108 /// Parses a given node, assuming nodes before it in the graph have been done.
surmeh01bceff2f2018-03-29 16:29:27 +0100109 void LoadNodeDef(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
110
telsoa01c577f2c2018-08-31 09:22:23 +0100111 /// Handling identity layers as the input for Conv2D layer.
surmeh01bceff2f2018-03-29 16:29:27 +0100112 const tensorflow::NodeDef* ResolveIdentityNode(const tensorflow::NodeDef* nodeDef);
113 /// Finds the nodes connected as inputs of the given node in the graph.
114 std::vector<OutputOfConstNodeDef> GetTfInputNodes(const tensorflow::NodeDef& nodeDef) const;
115 /// Finds the IParsedTfOperations for the nodes connected as inputs of the given node in the graph,
116 /// and throws an exception if the number of inputs does not match the expected one.
117 /// This will automatically resolve any identity nodes. The result vector contains the parsed operation
118 /// together with the output tensor index to make the connection unambiguous.
119 std::vector<OutputOfParsedTfOperation> GetInputParsedTfOperationsChecked(const tensorflow::NodeDef& nodeDef,
120 std::size_t expectedNumInputs);
121
122 ParsedTfOperationPtr ParseConst(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
123
telsoa01c577f2c2018-08-31 09:22:23 +0100124 /// Checks if there is a pre-parsed const tensor available with the given name and Type.
surmeh01bceff2f2018-03-29 16:29:27 +0100125 template<typename Type>
126 bool HasParsedConstTensor(const std::string & nodeName) const;
jimfly01f6ba7472018-12-04 10:09:52 +0000127 template<typename Type>
128 bool HasParsedConstTensor(ParsedTfOperation* parsedTfOpPtr) const;
surmeh01bceff2f2018-03-29 16:29:27 +0100129
130 ParsedTfOperationPtr ParseAdd(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
131 ParsedTfOperationPtr ParseBiasAdd(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
132 ParsedTfOperationPtr ParseConv2D(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
133 ParsedTfOperationPtr ParseDepthwiseConv2D(const tensorflow::NodeDef& nodeDef,const tensorflow::GraphDef& graphDef);
Conor Kennedyc2130a02018-12-05 11:05:54 +0000134 ParsedTfOperationPtr ParseExpandDims(const tensorflow::NodeDef& nodeDef,const tensorflow::GraphDef& graphDef);
surmeh01bceff2f2018-03-29 16:29:27 +0100135 ParsedTfOperationPtr ParseFusedBatchNorm(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
136 ParsedTfOperationPtr ParseConcat(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
137 ParsedTfOperationPtr ParseIdentity(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
138 ParsedTfOperationPtr ParseLrn(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
139 ParsedTfOperationPtr ParseMatMul(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
140 ParsedTfOperationPtr ParseMul(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
141 ParsedTfOperationPtr ParsePlaceholder(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
saoste01bbd40612018-08-28 15:41:51 +0100142 ParsedTfOperationPtr ParseRealDiv(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
surmeh01bceff2f2018-03-29 16:29:27 +0100143 ParsedTfOperationPtr ParseRelu(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
144 ParsedTfOperationPtr ParseRelu6(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
145 ParsedTfOperationPtr ParseReshape(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
146 ParsedTfOperationPtr ParseResizeBilinear(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
147 ParsedTfOperationPtr ParseShape(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
148 ParsedTfOperationPtr ParseSqueeze(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
149 ParsedTfOperationPtr ParseSigmoid(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
150 ParsedTfOperationPtr ParseSoftmax(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
151 ParsedTfOperationPtr ParseSoftplus(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
152 ParsedTfOperationPtr ParseTanh(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
153 ParsedTfOperationPtr ParseMaxPool(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
154 ParsedTfOperationPtr ParseAvgPool(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
155 ParsedTfOperationPtr ParsePooling2d(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef,
156 armnn::PoolingAlgorithm pooltype);
telsoa01c577f2c2018-08-31 09:22:23 +0100157 ParsedTfOperationPtr ParseMaximum(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
Nattapat Chaimanowong24df8222018-12-04 13:47:02 +0000158 ParsedTfOperationPtr ParseMinimum(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
jimfly01f6ba7472018-12-04 10:09:52 +0000159 ParsedTfOperationPtr ParsePad(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
jimfly0123be07e2018-12-04 17:47:22 +0000160 ParsedTfOperationPtr ParseSub(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
surmeh01bceff2f2018-03-29 16:29:27 +0100161 ParsedTfOperationPtr AddActivationLayer(const tensorflow::NodeDef& nodeDef, armnn::ActivationDescriptor& desc);
162 ParsedTfOperationPtr AddAdditionLayer(const tensorflow::NodeDef& nodeDef, bool isBiasAdd = false);
saoste01bbd40612018-08-28 15:41:51 +0100163 ParsedTfOperationPtr AddRealDivLayer(const tensorflow::NodeDef& nodeDef);
Sadik Armagan975c09a2018-12-04 10:02:08 +0000164 ParsedTfOperationPtr AddMaximumLayer(const tensorflow::NodeDef& nodeDef);
telsoa01c577f2c2018-08-31 09:22:23 +0100165
166private:
167 armnn::IConnectableLayer* AddMultiplicationLayer(const tensorflow::NodeDef& nodeDef);
168
surmeh01bceff2f2018-03-29 16:29:27 +0100169 armnn::IConnectableLayer* AddFullyConnectedLayer(const tensorflow::NodeDef& matMulNodeDef,
170 const tensorflow::NodeDef* addNodeDef, const char* armnnLayerName);
171
telsoa01c577f2c2018-08-31 09:22:23 +0100172 bool IsSupportedLeakyReluPattern(const tensorflow::NodeDef& mulNodeDef,
173 size_t alphaLayerIndex,
174 const OutputOfParsedTfOperation& otherOp,
175 armnn::IOutputSlot** outputOfLeakyRelu,
176 armnn::ActivationDescriptor & desc);
177
surmeh01bceff2f2018-03-29 16:29:27 +0100178 static std::pair<armnn::LayerBindingId, armnn::TensorInfo> GetBindingInfo(const std::string& layerName,
179 const char* bindingPointDesc,
180 const std::unordered_map<std::string, BindingPointInfo>& nameToBindingInfo);
181
182 void TrackInputBinding(armnn::IConnectableLayer* layer,
183 armnn::LayerBindingId id,
184 const armnn::TensorInfo& tensorInfo);
185
186 void TrackOutputBinding(armnn::IConnectableLayer* layer,
187 armnn::LayerBindingId id,
188 const armnn::TensorInfo& tensorInfo);
189
190 static void TrackBindingPoint(armnn::IConnectableLayer* layer, armnn::LayerBindingId id,
191 const armnn::TensorInfo& tensorInfo,
192 const char* bindingPointDesc,
193 std::unordered_map<std::string, BindingPointInfo>& nameToBindingInfo);
194
195 void Cleanup();
196
telsoa01c577f2c2018-08-31 09:22:23 +0100197 /// The network we're building. Gets cleared after it is passed to the user.
surmeh01bceff2f2018-03-29 16:29:27 +0100198 armnn::INetworkPtr m_Network;
199
200 using OperationParsingFunction = ParsedTfOperationPtr(TfParser::*)(const tensorflow::NodeDef& nodeDef,
201 const tensorflow::GraphDef& graphDef);
202
telsoa01c577f2c2018-08-31 09:22:23 +0100203 /// Map of TensorFlow operation names to parsing member functions.
surmeh01bceff2f2018-03-29 16:29:27 +0100204 static const std::map<std::string, OperationParsingFunction> ms_OperationNameToParsingFunctions;
205
206 std::map<std::string, armnn::TensorShape> m_InputShapes;
207 std::vector<std::string> m_RequestedOutputs;
208
telsoa01c577f2c2018-08-31 09:22:23 +0100209 /// Map of nodes extracted from the GraphDef to speed up parsing.
surmeh01bceff2f2018-03-29 16:29:27 +0100210 std::unordered_map<std::string, const tensorflow::NodeDef*> m_NodesByName;
211
212 std::unordered_map<std::string, ParsedTfOperationPtr> m_ParsedTfOperations;
213
telsoa01c577f2c2018-08-31 09:22:23 +0100214 /// Maps input layer names to their corresponding ids and tensor info.
surmeh01bceff2f2018-03-29 16:29:27 +0100215 std::unordered_map<std::string, BindingPointInfo> m_NetworkInputsBindingInfo;
216
telsoa01c577f2c2018-08-31 09:22:23 +0100217 /// Maps output layer names to their corresponding ids and tensor info.
surmeh01bceff2f2018-03-29 16:29:27 +0100218 std::unordered_map<std::string, BindingPointInfo> m_NetworkOutputsBindingInfo;
219};
220}