blob: 91927c24a87652ec55b834616321e323aeb70bb8 [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5#pragma once
6
7#include "armnnOnnxParser/IOnnxParser.hpp"
8#include "google/protobuf/repeated_field.h"
9#include <unordered_map>
10
11#include <onnx/onnx.pb.h>
12
13
14namespace armnn
15{
16class TensorInfo;
17}
18
19namespace armnnOnnxParser
20{
21
telsoa01c577f2c2018-08-31 09:22:23 +010022using ModelPtr = std::unique_ptr<onnx::ModelProto>;
23
24class OnnxParser : public IOnnxParser
25{
26
27using OperationParsingFunction = void(OnnxParser::*)(const onnx::NodeProto& NodeProto);
28
29public:
30
31 using GraphPtr = std::unique_ptr<onnx::GraphProto>;
32
33 /// Create the network from a protobuf binary file on disk
34 virtual armnn::INetworkPtr CreateNetworkFromBinaryFile(const char* graphFile) override;
35
36 /// Create the network from a protobuf text file on disk
37 virtual armnn::INetworkPtr CreateNetworkFromTextFile(const char* graphFile) override;
38
39 /// Create the network directly from protobuf text in a string. Useful for debugging/testing
40 virtual armnn::INetworkPtr CreateNetworkFromString(const std::string& protoText) override;
41
42 /// Retrieve binding info (layer id and tensor info) for the network input identified by the given layer name
43 virtual BindingPointInfo GetNetworkInputBindingInfo(const std::string& name) const override;
44
45 /// Retrieve binding info (layer id and tensor info) for the network output identified by the given layer name
46 virtual BindingPointInfo GetNetworkOutputBindingInfo(const std::string& name) const override;
47
48public:
49
50 OnnxParser();
51
52 static ModelPtr LoadModelFromBinaryFile(const char * fileName);
53 static ModelPtr LoadModelFromTextFile(const char * fileName);
54 static ModelPtr LoadModelFromString(const std::string& inputString);
55
56 ///Retrieve inputs names
57 static std::vector<std::string> GetInputs(ModelPtr& model);
58
59 ///Retrieve outputs names
60 static std::vector<std::string> GetOutputs(ModelPtr& model);
61
62private:
63
64 /// Parses a ModelProto loaded into memory from one of the other CreateNetwork*
65 armnn::INetworkPtr CreateNetworkFromModel(onnx::ModelProto& model);
66
67 ///Parse every node and make the connection between the resulting tensors
68 void LoadGraph();
69
70 void SetupInfo(const google::protobuf::RepeatedPtrField<onnx::ValueInfoProto >* list);
71
72 std::vector<armnn::TensorInfo> ComputeOutputInfo(std::vector<std::string> outNames,
73 const armnn::IConnectableLayer* layer,
74 std::vector<armnn::TensorShape> inputShapes);
75
76 void DetectFullyConnected();
77
78 template <typename Location>
79 void GetInputAndParam(const onnx::NodeProto& node,
80 std::string* inputName,
81 std::string* constName,
82 const Location& location);
83
84 template <typename Location>
85 void To1DTensor(const std::string &name, const Location& location);
86
87 //Broadcast Preparation functions
88 std::pair<std::string, std::string> AddPrepareBroadcast(const std::string& input0, const std::string& input1);
89 void PrependForBroadcast(const std::string& outputName, const std::string& input0, const std::string& input1);
90
91 void CreateConstantLayer(const std::string& tensorName, const std::string& layerName);
92 void CreateReshapeLayer(const std::string& inputName,
93 const std::string& outputName,
94 const std::string& layerName);
95
96 void ParseBatchNormalization(const onnx::NodeProto& node);
97 void ParseConstant(const onnx::NodeProto& nodeProto);
98
99 void ParseMaxPool(const onnx::NodeProto& nodeProto);
100 void ParseAveragePool(const onnx::NodeProto& nodeProto);
101 void ParseGlobalAveragePool(const onnx::NodeProto& node);
102
103 void AddPoolingLayer(const onnx::NodeProto& nodeProto, armnn::Pooling2dDescriptor& desc);
104
105 void ParseReshape(const onnx::NodeProto& nodeProto);
106 void ParseRelu(const onnx::NodeProto& nodeProto);
107
108 void AddConvLayerWithDepthwiseConv(const onnx::NodeProto& node, const armnn::Convolution2dDescriptor& convDesc);
109 void ParseConv(const onnx::NodeProto& nodeProto);
110
111 void ParseAdd(const onnx::NodeProto& nodeProto);
112 void AddFullyConnected(const onnx::NodeProto& matmulNode, const onnx::NodeProto* addNode = nullptr);
113
114 void RegisterInputSlots(armnn::IConnectableLayer* layer, const std::vector<std::string>& tensorIndexes);
115 void RegisterOutputSlots(armnn::IConnectableLayer* layer, const std::vector<std::string>& tensorIndexes);
116
117 void SetupInputLayers();
118 void SetupOutputLayers();
119
120 void ResetParser();
121 void Cleanup();
122
123 std::pair<armnn::ConstTensor, std::unique_ptr<float[]>> CreateConstTensor(const std::string name);
124
125 template <typename TypeList, typename Location>
126 void ValidateInputs(const onnx::NodeProto& node,
127 TypeList validInputs,
128 const Location& location);
129
130 /// The network we're building. Gets cleared after it is passed to the user
131 armnn::INetworkPtr m_Network;
132
133 ///Ptr to the graph we're building the network from
134 GraphPtr m_Graph;
135
136 ///Map of the information for every tensor
137 struct OnnxTensor
138 {
139 std::unique_ptr<armnn::TensorInfo> m_info;
140 std::unique_ptr<const onnx::TensorProto> m_tensor;
141 onnx::TensorProto::DataType m_dtype;
142
143 OnnxTensor() : m_info(nullptr), m_tensor(nullptr), m_dtype(onnx::TensorProto::FLOAT) { }
144 bool isConstant() { return m_tensor != nullptr; }
145
146 };
147
148 std::unordered_map<std::string, OnnxTensor> m_TensorsInfo;
149
150 /// map of onnx operation names to parsing member functions
151 static const std::map<std::string, OperationParsingFunction> m_ParserFunctions;
152
153 /// A mapping of an output slot to each of the input slots it should be connected to
154 /// The outputSlot is from the layer that creates this tensor as one of its ouputs
155 /// The inputSlots are from the layers that use this tensor as one of their inputs
156 struct TensorSlots
157 {
158 armnn::IOutputSlot* outputSlot;
159 std::vector<armnn::IInputSlot*> inputSlots;
160
161 TensorSlots() : outputSlot(nullptr) { }
162 };
163 ///Map of the tensor names to their connections for the connections of the layers of the graph
164 std::unordered_map<std::string, TensorSlots> m_TensorConnections;
165
166 //Map of the tensor names to their node and index in graph.node()
167 std::unordered_map<std::string, std::pair<const onnx::NodeProto*, int>> m_OutputsMap;
168
169 /// Number of times a specific node (identified by his index number) was used as input
170 /// and list of the nodes it was fused with
171 struct UsageSummary
172 {
173 std::vector<size_t> fusedWithNodes;
174 size_t inputForNodes;
175
176 UsageSummary() : fusedWithNodes({}), inputForNodes(0) { }
177
178 };
179
180 std::vector<UsageSummary> m_OutputsFusedAndUsed;
181};
182}