blob: 101e99ff8d40abe8ed6cfe8415453c0f6d4dab8a [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5#pragma once
6
7#include "armnnOnnxParser/IOnnxParser.hpp"
8#include "google/protobuf/repeated_field.h"
9#include <unordered_map>
10
11#include <onnx/onnx.pb.h>
12
13
14namespace armnn
15{
16class TensorInfo;
Tee Jung7ff9a602019-11-01 07:04:42 +000017enum class ActivationFunction;
telsoa01c577f2c2018-08-31 09:22:23 +010018}
19
20namespace armnnOnnxParser
21{
22
telsoa01c577f2c2018-08-31 09:22:23 +010023using ModelPtr = std::unique_ptr<onnx::ModelProto>;
24
Kevin Mayef33cb12021-01-29 14:24:57 +000025class OnnxParserImpl
telsoa01c577f2c2018-08-31 09:22:23 +010026{
27
Kevin Mayef33cb12021-01-29 14:24:57 +000028using OperationParsingFunction = void(OnnxParserImpl::*)(const onnx::NodeProto& NodeProto);
telsoa01c577f2c2018-08-31 09:22:23 +010029
30public:
31
32 using GraphPtr = std::unique_ptr<onnx::GraphProto>;
33
34 /// Create the network from a protobuf binary file on disk
Kevin Mayef33cb12021-01-29 14:24:57 +000035 armnn::INetworkPtr CreateNetworkFromBinaryFile(const char* graphFile);
telsoa01c577f2c2018-08-31 09:22:23 +010036
37 /// Create the network from a protobuf text file on disk
Kevin Mayef33cb12021-01-29 14:24:57 +000038 armnn::INetworkPtr CreateNetworkFromTextFile(const char* graphFile);
telsoa01c577f2c2018-08-31 09:22:23 +010039
40 /// Create the network directly from protobuf text in a string. Useful for debugging/testing
Kevin Mayef33cb12021-01-29 14:24:57 +000041 armnn::INetworkPtr CreateNetworkFromString(const std::string& protoText);
telsoa01c577f2c2018-08-31 09:22:23 +010042
43 /// Retrieve binding info (layer id and tensor info) for the network input identified by the given layer name
Kevin Mayef33cb12021-01-29 14:24:57 +000044 BindingPointInfo GetNetworkInputBindingInfo(const std::string& name) const;
telsoa01c577f2c2018-08-31 09:22:23 +010045
46 /// Retrieve binding info (layer id and tensor info) for the network output identified by the given layer name
Kevin Mayef33cb12021-01-29 14:24:57 +000047 BindingPointInfo GetNetworkOutputBindingInfo(const std::string& name) const;
telsoa01c577f2c2018-08-31 09:22:23 +010048
49public:
50
Kevin Mayef33cb12021-01-29 14:24:57 +000051 OnnxParserImpl();
52 ~OnnxParserImpl() = default;
telsoa01c577f2c2018-08-31 09:22:23 +010053
54 static ModelPtr LoadModelFromBinaryFile(const char * fileName);
55 static ModelPtr LoadModelFromTextFile(const char * fileName);
56 static ModelPtr LoadModelFromString(const std::string& inputString);
57
Ryan OShea337c17f2020-02-21 12:33:17 +000058 /// Retrieve inputs names
telsoa01c577f2c2018-08-31 09:22:23 +010059 static std::vector<std::string> GetInputs(ModelPtr& model);
60
Ryan OShea337c17f2020-02-21 12:33:17 +000061 /// Retrieve outputs names
telsoa01c577f2c2018-08-31 09:22:23 +010062 static std::vector<std::string> GetOutputs(ModelPtr& model);
63
Matthew Sloyanac001ee2021-02-03 10:43:04 +000064 /// Retrieve version in X.Y.Z form
65 static const std::string GetVersion();
66
telsoa01c577f2c2018-08-31 09:22:23 +010067private:
68
69 /// Parses a ModelProto loaded into memory from one of the other CreateNetwork*
70 armnn::INetworkPtr CreateNetworkFromModel(onnx::ModelProto& model);
71
Ryan OShea337c17f2020-02-21 12:33:17 +000072 /// Parse every node and make the connection between the resulting tensors
telsoa01c577f2c2018-08-31 09:22:23 +010073 void LoadGraph();
74
75 void SetupInfo(const google::protobuf::RepeatedPtrField<onnx::ValueInfoProto >* list);
76
77 std::vector<armnn::TensorInfo> ComputeOutputInfo(std::vector<std::string> outNames,
78 const armnn::IConnectableLayer* layer,
79 std::vector<armnn::TensorShape> inputShapes);
80
81 void DetectFullyConnected();
82
83 template <typename Location>
84 void GetInputAndParam(const onnx::NodeProto& node,
85 std::string* inputName,
86 std::string* constName,
87 const Location& location);
88
89 template <typename Location>
90 void To1DTensor(const std::string &name, const Location& location);
91
92 //Broadcast Preparation functions
93 std::pair<std::string, std::string> AddPrepareBroadcast(const std::string& input0, const std::string& input1);
94 void PrependForBroadcast(const std::string& outputName, const std::string& input0, const std::string& input1);
95
Ryan OSheaed27ee72020-04-22 16:37:29 +010096 void AddConvLayerWithDepthwiseConv(const onnx::NodeProto& node, const armnn::Convolution2dDescriptor& convDesc);
97 void AddFullyConnected(const onnx::NodeProto& matmulNode, const onnx::NodeProto* addNode = nullptr);
98 void AddPoolingLayer(const onnx::NodeProto& nodeProto, armnn::Pooling2dDescriptor& desc);
99
telsoa01c577f2c2018-08-31 09:22:23 +0100100 void CreateConstantLayer(const std::string& tensorName, const std::string& layerName);
101 void CreateReshapeLayer(const std::string& inputName,
102 const std::string& outputName,
103 const std::string& layerName);
104
Tee Jung7ff9a602019-11-01 07:04:42 +0000105 void ParseActivation(const onnx::NodeProto& nodeProto, const armnn::ActivationFunction func);
Finn Williams7ee5d2c2020-03-27 11:11:50 +0000106 void ParseClip(const onnx::NodeProto& nodeProto);
Tee Jung7ff9a602019-11-01 07:04:42 +0000107 void ParseSigmoid(const onnx::NodeProto& nodeProto);
108 void ParseTanh(const onnx::NodeProto& nodeProto);
telsoa01c577f2c2018-08-31 09:22:23 +0100109 void ParseRelu(const onnx::NodeProto& nodeProto);
Tee Jung7ff9a602019-11-01 07:04:42 +0000110 void ParseLeakyRelu(const onnx::NodeProto& nodeProto);
telsoa01c577f2c2018-08-31 09:22:23 +0100111
telsoa01c577f2c2018-08-31 09:22:23 +0100112 void ParseAdd(const onnx::NodeProto& nodeProto);
Ryan OSheaed27ee72020-04-22 16:37:29 +0100113 void ParseAveragePool(const onnx::NodeProto& nodeProto);
114 void ParseBatchNormalization(const onnx::NodeProto& node);
115 void ParseConstant(const onnx::NodeProto& nodeProto);
116 void ParseConv(const onnx::NodeProto& nodeProto);
117 void ParseFlatten(const onnx::NodeProto& node);
118 void ParseGlobalAveragePool(const onnx::NodeProto& node);
119 void ParseMaxPool(const onnx::NodeProto& nodeProto);
Narumol Prangnawaratcdc495e2021-09-16 18:13:39 +0100120 void ParseShape(const onnx::NodeProto& node);
Ryan OSheaed27ee72020-04-22 16:37:29 +0100121 void ParseReshape(const onnx::NodeProto& nodeProto);
telsoa01c577f2c2018-08-31 09:22:23 +0100122
123 void RegisterInputSlots(armnn::IConnectableLayer* layer, const std::vector<std::string>& tensorIndexes);
124 void RegisterOutputSlots(armnn::IConnectableLayer* layer, const std::vector<std::string>& tensorIndexes);
125
126 void SetupInputLayers();
127 void SetupOutputLayers();
128
129 void ResetParser();
130 void Cleanup();
131
Jan Eilers53ef7952021-06-02 12:01:25 +0100132 std::pair<armnn::ConstTensor, std::unique_ptr<float[]>>
133 CreateConstTensor(const std::string name,
134 armnn::Optional<armnn::PermutationVector&> permutationVector = armnn::EmptyOptional());
telsoa01c577f2c2018-08-31 09:22:23 +0100135
136 template <typename TypeList, typename Location>
137 void ValidateInputs(const onnx::NodeProto& node,
138 TypeList validInputs,
139 const Location& location);
140
141 /// The network we're building. Gets cleared after it is passed to the user
142 armnn::INetworkPtr m_Network;
143
Ryan OShea337c17f2020-02-21 12:33:17 +0000144 /// Ptr to the graph we're building the network from
telsoa01c577f2c2018-08-31 09:22:23 +0100145 GraphPtr m_Graph;
146
Ryan OShea337c17f2020-02-21 12:33:17 +0000147 /// Map of the information for every tensor
telsoa01c577f2c2018-08-31 09:22:23 +0100148 struct OnnxTensor
149 {
150 std::unique_ptr<armnn::TensorInfo> m_info;
151 std::unique_ptr<const onnx::TensorProto> m_tensor;
152 onnx::TensorProto::DataType m_dtype;
153
154 OnnxTensor() : m_info(nullptr), m_tensor(nullptr), m_dtype(onnx::TensorProto::FLOAT) { }
155 bool isConstant() { return m_tensor != nullptr; }
telsoa01c577f2c2018-08-31 09:22:23 +0100156 };
157
158 std::unordered_map<std::string, OnnxTensor> m_TensorsInfo;
159
160 /// map of onnx operation names to parsing member functions
161 static const std::map<std::string, OperationParsingFunction> m_ParserFunctions;
162
163 /// A mapping of an output slot to each of the input slots it should be connected to
164 /// The outputSlot is from the layer that creates this tensor as one of its ouputs
165 /// The inputSlots are from the layers that use this tensor as one of their inputs
166 struct TensorSlots
167 {
168 armnn::IOutputSlot* outputSlot;
169 std::vector<armnn::IInputSlot*> inputSlots;
170
171 TensorSlots() : outputSlot(nullptr) { }
172 };
Ryan OShea337c17f2020-02-21 12:33:17 +0000173 /// Map of the tensor names to their connections for the connections of the layers of the graph
telsoa01c577f2c2018-08-31 09:22:23 +0100174 std::unordered_map<std::string, TensorSlots> m_TensorConnections;
175
Ryan OShea337c17f2020-02-21 12:33:17 +0000176 /// Map of the tensor names to their node and index in graph.node()
telsoa01c577f2c2018-08-31 09:22:23 +0100177 std::unordered_map<std::string, std::pair<const onnx::NodeProto*, int>> m_OutputsMap;
178
179 /// Number of times a specific node (identified by his index number) was used as input
180 /// and list of the nodes it was fused with
181 struct UsageSummary
182 {
183 std::vector<size_t> fusedWithNodes;
184 size_t inputForNodes;
185
186 UsageSummary() : fusedWithNodes({}), inputForNodes(0) { }
187
188 };
189
190 std::vector<UsageSummary> m_OutputsFusedAndUsed;
Ryan OSheaed27ee72020-04-22 16:37:29 +0100191
telsoa01c577f2c2018-08-31 09:22:23 +0100192};
193}