blob: bb94472c6dfcc0e82f5179c18e78f3501f6ebaf9 [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5#pragma once
6
7#include "armnnOnnxParser/IOnnxParser.hpp"
8#include "google/protobuf/repeated_field.h"
9#include <unordered_map>
10
11#include <onnx/onnx.pb.h>
12
13
14namespace armnn
15{
16class TensorInfo;
Tee Jung7ff9a602019-11-01 07:04:42 +000017enum class ActivationFunction;
telsoa01c577f2c2018-08-31 09:22:23 +010018}
19
20namespace armnnOnnxParser
21{
22
telsoa01c577f2c2018-08-31 09:22:23 +010023using ModelPtr = std::unique_ptr<onnx::ModelProto>;
24
Kevin Mayef33cb12021-01-29 14:24:57 +000025class OnnxParserImpl
telsoa01c577f2c2018-08-31 09:22:23 +010026{
27
Kevin Mayef33cb12021-01-29 14:24:57 +000028using OperationParsingFunction = void(OnnxParserImpl::*)(const onnx::NodeProto& NodeProto);
telsoa01c577f2c2018-08-31 09:22:23 +010029
30public:
31
32 using GraphPtr = std::unique_ptr<onnx::GraphProto>;
33
34 /// Create the network from a protobuf binary file on disk
Kevin Mayef33cb12021-01-29 14:24:57 +000035 armnn::INetworkPtr CreateNetworkFromBinaryFile(const char* graphFile);
telsoa01c577f2c2018-08-31 09:22:23 +010036
Narumol Prangnawarat1b11f322021-10-13 11:44:50 +010037 /// Create the network from a protobuf binary file on disk, with inputShapes specified
38 armnn::INetworkPtr CreateNetworkFromBinaryFile(const char* graphFile,
39 const std::map<std::string, armnn::TensorShape>& inputShapes);
40
telsoa01c577f2c2018-08-31 09:22:23 +010041 /// Create the network from a protobuf text file on disk
Kevin Mayef33cb12021-01-29 14:24:57 +000042 armnn::INetworkPtr CreateNetworkFromTextFile(const char* graphFile);
telsoa01c577f2c2018-08-31 09:22:23 +010043
Narumol Prangnawarat1b11f322021-10-13 11:44:50 +010044 /// Create the network from a protobuf text file on disk, with inputShapes specified
45 armnn::INetworkPtr CreateNetworkFromTextFile(const char* graphFile,
46 const std::map<std::string, armnn::TensorShape>& inputShapes);
47
telsoa01c577f2c2018-08-31 09:22:23 +010048 /// Create the network directly from protobuf text in a string. Useful for debugging/testing
Kevin Mayef33cb12021-01-29 14:24:57 +000049 armnn::INetworkPtr CreateNetworkFromString(const std::string& protoText);
telsoa01c577f2c2018-08-31 09:22:23 +010050
Narumol Prangnawarat1b11f322021-10-13 11:44:50 +010051 /// Create the network directly from protobuf text in a string, with inputShapes specified.
52 /// Useful for debugging/testing
53 armnn::INetworkPtr CreateNetworkFromString(const std::string& protoText,
54 const std::map<std::string, armnn::TensorShape>& inputShapes);
55
telsoa01c577f2c2018-08-31 09:22:23 +010056 /// Retrieve binding info (layer id and tensor info) for the network input identified by the given layer name
Kevin Mayef33cb12021-01-29 14:24:57 +000057 BindingPointInfo GetNetworkInputBindingInfo(const std::string& name) const;
telsoa01c577f2c2018-08-31 09:22:23 +010058
59 /// Retrieve binding info (layer id and tensor info) for the network output identified by the given layer name
Kevin Mayef33cb12021-01-29 14:24:57 +000060 BindingPointInfo GetNetworkOutputBindingInfo(const std::string& name) const;
telsoa01c577f2c2018-08-31 09:22:23 +010061
62public:
63
Kevin Mayef33cb12021-01-29 14:24:57 +000064 OnnxParserImpl();
65 ~OnnxParserImpl() = default;
telsoa01c577f2c2018-08-31 09:22:23 +010066
67 static ModelPtr LoadModelFromBinaryFile(const char * fileName);
68 static ModelPtr LoadModelFromTextFile(const char * fileName);
69 static ModelPtr LoadModelFromString(const std::string& inputString);
70
Ryan OShea337c17f2020-02-21 12:33:17 +000071 /// Retrieve inputs names
telsoa01c577f2c2018-08-31 09:22:23 +010072 static std::vector<std::string> GetInputs(ModelPtr& model);
73
Ryan OShea337c17f2020-02-21 12:33:17 +000074 /// Retrieve outputs names
telsoa01c577f2c2018-08-31 09:22:23 +010075 static std::vector<std::string> GetOutputs(ModelPtr& model);
76
Matthew Sloyanac001ee2021-02-03 10:43:04 +000077 /// Retrieve version in X.Y.Z form
78 static const std::string GetVersion();
79
telsoa01c577f2c2018-08-31 09:22:23 +010080private:
81
82 /// Parses a ModelProto loaded into memory from one of the other CreateNetwork*
83 armnn::INetworkPtr CreateNetworkFromModel(onnx::ModelProto& model);
84
Ryan OShea337c17f2020-02-21 12:33:17 +000085 /// Parse every node and make the connection between the resulting tensors
telsoa01c577f2c2018-08-31 09:22:23 +010086 void LoadGraph();
87
88 void SetupInfo(const google::protobuf::RepeatedPtrField<onnx::ValueInfoProto >* list);
89
Narumol Prangnawarat452274c2021-09-23 16:12:19 +010090 std::vector<armnn::TensorInfo> ComputeOutputInfo(
91 std::vector<std::string> outNames,
92 const armnn::IConnectableLayer* layer,
93 std::vector<armnn::TensorShape> inputShapes,
94 const onnx::TensorProto::DataType& type = onnx::TensorProto::FLOAT);
telsoa01c577f2c2018-08-31 09:22:23 +010095
96 void DetectFullyConnected();
97
98 template <typename Location>
99 void GetInputAndParam(const onnx::NodeProto& node,
100 std::string* inputName,
101 std::string* constName,
102 const Location& location);
103
104 template <typename Location>
105 void To1DTensor(const std::string &name, const Location& location);
106
107 //Broadcast Preparation functions
108 std::pair<std::string, std::string> AddPrepareBroadcast(const std::string& input0, const std::string& input1);
109 void PrependForBroadcast(const std::string& outputName, const std::string& input0, const std::string& input1);
110
Ryan OSheaed27ee72020-04-22 16:37:29 +0100111 void AddConvLayerWithDepthwiseConv(const onnx::NodeProto& node, const armnn::Convolution2dDescriptor& convDesc);
112 void AddFullyConnected(const onnx::NodeProto& matmulNode, const onnx::NodeProto* addNode = nullptr);
113 void AddPoolingLayer(const onnx::NodeProto& nodeProto, armnn::Pooling2dDescriptor& desc);
114
telsoa01c577f2c2018-08-31 09:22:23 +0100115 void CreateConstantLayer(const std::string& tensorName, const std::string& layerName);
Narumol Prangnawaratf10b15a2021-09-17 21:08:57 +0100116 void CreateInt64ConstantLayer(const std::string& tensorName, const std::string& layerName);
telsoa01c577f2c2018-08-31 09:22:23 +0100117 void CreateReshapeLayer(const std::string& inputName,
118 const std::string& outputName,
119 const std::string& layerName);
120
Tee Jung7ff9a602019-11-01 07:04:42 +0000121 void ParseActivation(const onnx::NodeProto& nodeProto, const armnn::ActivationFunction func);
Finn Williams7ee5d2c2020-03-27 11:11:50 +0000122 void ParseClip(const onnx::NodeProto& nodeProto);
Tee Jung7ff9a602019-11-01 07:04:42 +0000123 void ParseSigmoid(const onnx::NodeProto& nodeProto);
124 void ParseTanh(const onnx::NodeProto& nodeProto);
telsoa01c577f2c2018-08-31 09:22:23 +0100125 void ParseRelu(const onnx::NodeProto& nodeProto);
Tee Jung7ff9a602019-11-01 07:04:42 +0000126 void ParseLeakyRelu(const onnx::NodeProto& nodeProto);
telsoa01c577f2c2018-08-31 09:22:23 +0100127
telsoa01c577f2c2018-08-31 09:22:23 +0100128 void ParseAdd(const onnx::NodeProto& nodeProto);
Ryan OSheaed27ee72020-04-22 16:37:29 +0100129 void ParseAveragePool(const onnx::NodeProto& nodeProto);
130 void ParseBatchNormalization(const onnx::NodeProto& node);
Narumol Prangnawaratbc3bb622021-09-24 16:08:34 +0100131 void ParseConcat(const onnx::NodeProto& nodeProto);
Ryan OSheaed27ee72020-04-22 16:37:29 +0100132 void ParseConstant(const onnx::NodeProto& nodeProto);
133 void ParseConv(const onnx::NodeProto& nodeProto);
134 void ParseFlatten(const onnx::NodeProto& node);
Narumol Prangnawaratf10b15a2021-09-17 21:08:57 +0100135 void ParseGather(const onnx::NodeProto& node);
Narumol Prangnawarat1112b012021-09-30 12:10:50 +0100136 void ParseGemm(const onnx::NodeProto& node);
Ryan OSheaed27ee72020-04-22 16:37:29 +0100137 void ParseGlobalAveragePool(const onnx::NodeProto& node);
138 void ParseMaxPool(const onnx::NodeProto& nodeProto);
Narumol Prangnawaratcdc495e2021-09-16 18:13:39 +0100139 void ParseShape(const onnx::NodeProto& node);
Ryan OSheaed27ee72020-04-22 16:37:29 +0100140 void ParseReshape(const onnx::NodeProto& nodeProto);
Narumol Prangnawaratfe6aa2f2021-09-23 16:11:17 +0100141 void ParseUnsqueeze(const onnx::NodeProto& nodeProto);
telsoa01c577f2c2018-08-31 09:22:23 +0100142
Narumol Prangnawarat1112b012021-09-30 12:10:50 +0100143 void RegisterInputSlot(armnn::IConnectableLayer* layer,
144 const std::string& tensorId,
145 unsigned int slotIndex);
telsoa01c577f2c2018-08-31 09:22:23 +0100146 void RegisterInputSlots(armnn::IConnectableLayer* layer, const std::vector<std::string>& tensorIndexes);
147 void RegisterOutputSlots(armnn::IConnectableLayer* layer, const std::vector<std::string>& tensorIndexes);
148
149 void SetupInputLayers();
150 void SetupOutputLayers();
151
152 void ResetParser();
153 void Cleanup();
154
Jan Eilers53ef7952021-06-02 12:01:25 +0100155 std::pair<armnn::ConstTensor, std::unique_ptr<float[]>>
156 CreateConstTensor(const std::string name,
157 armnn::Optional<armnn::PermutationVector&> permutationVector = armnn::EmptyOptional());
telsoa01c577f2c2018-08-31 09:22:23 +0100158
Narumol Prangnawaratf10b15a2021-09-17 21:08:57 +0100159 std::pair<armnn::ConstTensor, std::unique_ptr<int32_t[]>>
160 CreateInt64ConstTensor(const std::string name,
161 armnn::Optional<armnn::PermutationVector&> permutationVector = armnn::EmptyOptional());
162
telsoa01c577f2c2018-08-31 09:22:23 +0100163 template <typename TypeList, typename Location>
164 void ValidateInputs(const onnx::NodeProto& node,
165 TypeList validInputs,
166 const Location& location);
167
168 /// The network we're building. Gets cleared after it is passed to the user
169 armnn::INetworkPtr m_Network;
170
Ryan OShea337c17f2020-02-21 12:33:17 +0000171 /// Ptr to the graph we're building the network from
telsoa01c577f2c2018-08-31 09:22:23 +0100172 GraphPtr m_Graph;
173
Ryan OShea337c17f2020-02-21 12:33:17 +0000174 /// Map of the information for every tensor
telsoa01c577f2c2018-08-31 09:22:23 +0100175 struct OnnxTensor
176 {
177 std::unique_ptr<armnn::TensorInfo> m_info;
178 std::unique_ptr<const onnx::TensorProto> m_tensor;
179 onnx::TensorProto::DataType m_dtype;
180
181 OnnxTensor() : m_info(nullptr), m_tensor(nullptr), m_dtype(onnx::TensorProto::FLOAT) { }
182 bool isConstant() { return m_tensor != nullptr; }
telsoa01c577f2c2018-08-31 09:22:23 +0100183 };
184
185 std::unordered_map<std::string, OnnxTensor> m_TensorsInfo;
186
187 /// map of onnx operation names to parsing member functions
188 static const std::map<std::string, OperationParsingFunction> m_ParserFunctions;
189
190 /// A mapping of an output slot to each of the input slots it should be connected to
191 /// The outputSlot is from the layer that creates this tensor as one of its ouputs
192 /// The inputSlots are from the layers that use this tensor as one of their inputs
193 struct TensorSlots
194 {
195 armnn::IOutputSlot* outputSlot;
196 std::vector<armnn::IInputSlot*> inputSlots;
197
198 TensorSlots() : outputSlot(nullptr) { }
199 };
Ryan OShea337c17f2020-02-21 12:33:17 +0000200 /// Map of the tensor names to their connections for the connections of the layers of the graph
telsoa01c577f2c2018-08-31 09:22:23 +0100201 std::unordered_map<std::string, TensorSlots> m_TensorConnections;
202
Ryan OShea337c17f2020-02-21 12:33:17 +0000203 /// Map of the tensor names to their node and index in graph.node()
telsoa01c577f2c2018-08-31 09:22:23 +0100204 std::unordered_map<std::string, std::pair<const onnx::NodeProto*, int>> m_OutputsMap;
205
Teresa Charlinbc148812021-12-13 15:29:10 +0000206 /// Number of times a specific node (identified by its index number) was used as input
telsoa01c577f2c2018-08-31 09:22:23 +0100207 /// and list of the nodes it was fused with
208 struct UsageSummary
209 {
210 std::vector<size_t> fusedWithNodes;
211 size_t inputForNodes;
212
213 UsageSummary() : fusedWithNodes({}), inputForNodes(0) { }
214
215 };
216
217 std::vector<UsageSummary> m_OutputsFusedAndUsed;
Ryan OSheaed27ee72020-04-22 16:37:29 +0100218
Narumol Prangnawarat1b11f322021-10-13 11:44:50 +0100219 std::map<std::string, armnn::TensorShape> m_InputShapes;
220
221 std::unordered_map<std::string, armnn::TensorInfo> m_InputInfos;
222
223 std::unordered_map<std::string, armnn::TensorInfo> m_OutputInfos;
224
telsoa01c577f2c2018-08-31 09:22:23 +0100225};
226}