blob: 552d4e4163dd82fd9bf46314e55461946ed54942 [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
Mike Kelly2ae32242022-11-25 13:55:24 +00002// Copyright © 2017,2022 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5#include "OnnxParser.hpp"
6
Matthew Sloyanac001ee2021-02-03 10:43:04 +00007#include "armnnOnnxParser/Version.hpp"
8
Matthew Bentham39ef3e52020-01-20 10:09:09 +00009#include <armnn/Descriptors.hpp>
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010010#include <armnn/utility/Assert.hpp>
Matthew Sloyan589e3e82020-09-11 16:17:48 +010011#include <armnn/utility/NumericCast.hpp>
Narumol Prangnawaratbc3bb622021-09-24 16:08:34 +010012#include <ParserHelper.hpp>
telsoa01c577f2c2018-08-31 09:22:23 +010013#include <VerificationHelpers.hpp>
14
James Ward58dec6b2020-09-11 17:32:44 +010015#include <fmt/format.h>
Aron Virginas-Tard4f0fea2019-04-09 14:08:06 +010016
telsoa01c577f2c2018-08-31 09:22:23 +010017#include <google/protobuf/text_format.h>
18#include <google/protobuf/io/zero_copy_stream_impl.h>
19
Matthew Sloyanac001ee2021-02-03 10:43:04 +000020#include <iostream>
telsoa01c577f2c2018-08-31 09:22:23 +010021#include <numeric>
Jan Eilers53ef7952021-06-02 12:01:25 +010022#include <armnnUtils/Permute.hpp>
telsoa01c577f2c2018-08-31 09:22:23 +010023
24using namespace armnn;
25
26namespace armnnOnnxParser
27{
Kevin Mayef33cb12021-01-29 14:24:57 +000028
29IOnnxParser::IOnnxParser() : pOnnxParserImpl(new OnnxParserImpl()) {}
30
31IOnnxParser::~IOnnxParser() = default;
32
33IOnnxParser* IOnnxParser::CreateRaw()
34{
35 return new IOnnxParser();
36}
37
38IOnnxParserPtr IOnnxParser::Create()
39{
40 return IOnnxParserPtr(CreateRaw(), &IOnnxParser::Destroy);
41}
42
43void IOnnxParser::Destroy(IOnnxParser* parser)
44{
45 delete parser;
46}
47
48armnn::INetworkPtr IOnnxParser::CreateNetworkFromBinaryFile(const char* graphFile)
49{
50 return pOnnxParserImpl->CreateNetworkFromBinaryFile(graphFile);
51}
52
Mike Kelly2ae32242022-11-25 13:55:24 +000053armnn::INetworkPtr IOnnxParser::CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent)
54{
55 return pOnnxParserImpl->CreateNetworkFromBinary(binaryContent);
56}
57
58armnn::INetworkPtr IOnnxParser::CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent,
59 const std::map<std::string, armnn::TensorShape>& inputShapes)
60{
61 return pOnnxParserImpl->CreateNetworkFromBinary(binaryContent, inputShapes);
62}
63
Kevin Mayef33cb12021-01-29 14:24:57 +000064armnn::INetworkPtr IOnnxParser::CreateNetworkFromTextFile(const char* graphFile)
65{
66 return pOnnxParserImpl->CreateNetworkFromTextFile(graphFile);
67}
68
69armnn::INetworkPtr IOnnxParser::CreateNetworkFromString(const std::string& protoText)
70{
71 return pOnnxParserImpl->CreateNetworkFromString(protoText);
72}
73
Narumol Prangnawarat1b11f322021-10-13 11:44:50 +010074armnn::INetworkPtr IOnnxParser::CreateNetworkFromBinaryFile(
75 const char* graphFile,
76 const std::map<std::string, armnn::TensorShape>& inputShapes)
77{
78 return pOnnxParserImpl->CreateNetworkFromBinaryFile(graphFile, inputShapes);
79}
80
81armnn::INetworkPtr IOnnxParser::CreateNetworkFromTextFile(const char* graphFile,
82 const std::map<std::string, armnn::TensorShape>& inputShapes)
83{
84 return pOnnxParserImpl->CreateNetworkFromTextFile(graphFile, inputShapes);
85}
86
87armnn::INetworkPtr IOnnxParser::CreateNetworkFromString(const std::string& protoText,
88 const std::map<std::string, armnn::TensorShape>& inputShapes)
89{
90 return pOnnxParserImpl->CreateNetworkFromString(protoText, inputShapes);
91}
92
Kevin Mayef33cb12021-01-29 14:24:57 +000093BindingPointInfo IOnnxParser::GetNetworkInputBindingInfo(const std::string& name) const
94{
95 return pOnnxParserImpl->GetNetworkInputBindingInfo(name);
96}
97
98BindingPointInfo IOnnxParser::GetNetworkOutputBindingInfo(const std::string& name) const
99{
100 return pOnnxParserImpl->GetNetworkOutputBindingInfo(name);
101}
102
telsoa01c577f2c2018-08-31 09:22:23 +0100103namespace
104{
105void CheckValidDataType(std::initializer_list<onnx::TensorProto::DataType> validInputTypes,
106 const onnx::TensorProto::DataType actualValue,
107 const char* validExpr,
108 std::string nodeName,
109 std::string tensorName,
110 const armnn::CheckLocation& location)
111{
112 bool isValid = std::any_of(validInputTypes.begin(),
113 validInputTypes.end(),
114 [&actualValue](onnx::TensorProto::DataType x) { return x == actualValue; } );
115 if (!isValid)
116 {
117 throw ParseException(
James Ward58dec6b2020-09-11 17:32:44 +0100118 fmt::format("Datatype {} is not valid for tensor '{}' of node '{}', not in {{{}}}. {}",
119 onnx::TensorProto::DataType_Name(actualValue),
120 tensorName,
121 nodeName,
122 validExpr,
123 location.AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100124 }
125}
126
127#define CHECK_VALID_DATATYPE(NODE, TENSOR, ACTUAL, ...) \
128CheckValidDataType({__VA_ARGS__}, ACTUAL, #__VA_ARGS__, NODE, TENSOR, CHECK_LOCATION())
129
130using StrTypeListPair = std::pair<const char*, std::initializer_list<onnx::TensorProto::DataType>>;
131#define STR_LIST(...) StrTypeListPair(#__VA_ARGS__, {__VA_ARGS__})
132
133template <typename Callable>
134void ReadMandatoryNodeAttributeImpl(const onnx::NodeProto& node,
135 const std::string& attribName,
136 onnx::AttributeProto::AttributeType expectedType,
137 Callable callable)
138{
139 auto attribs = node.attribute();
140 int attriNum = 0;
141 while (attriNum < node.attribute_size())
142 {
143 if (attribs.Get(attriNum).name() == attribName)
144 {
145 if (attribs.Get(attriNum).type() == expectedType)
146 {
147 callable(attribs.Get(attriNum));
148 }
149 else
150 {
James Ward58dec6b2020-09-11 17:32:44 +0100151 throw ParseException(fmt::format("Attribute {} of node {} expected to have {} as "
152 "onnx::AttributeProto::AttributeType, but found {} instead {}",
153 attribName,
154 node.name(),
155 onnx::AttributeProto::AttributeType_Name(expectedType),
156 onnx::AttributeProto::AttributeType_Name(attribs.Get(attriNum).type()),
157 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100158 }
159 break;
160 }
161 ++attriNum;
162 }
163 if (attriNum == node.attribute_size())
164 {
James Ward58dec6b2020-09-11 17:32:44 +0100165 throw ParseException(fmt::format("Could not find required attribute {} in node {} {}",
166 attribName, node.name(), CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100167 }
168}
169
170template <typename Callable>
171void ReadOptionalNodeAttributeImpl(const onnx::NodeProto& node,
172 const std::string& attribName,
173 onnx::AttributeProto::AttributeType expectedType,
174 Callable callable)
175{
176 auto attribs = node.attribute();
177 for (int attriNum = 0; attriNum < node.attribute_size(); ++attriNum)
178 {
179 if (attribs.Get(attriNum).name() == attribName)
180 {
181 if (attribs.Get(attriNum).type() == expectedType)
182 {
183 callable(attribs.Get(attriNum));
184 }
185 else
186 {
James Ward58dec6b2020-09-11 17:32:44 +0100187 throw ParseException(
188 fmt::format("Attribute {} of node {} expected to have {} as onnx::AttributeProto::AttributeType, "
189 "but found {} instead {}",
190 attribName,
191 node.name(),
192 onnx::AttributeProto::AttributeType_Name(expectedType),
193 onnx::AttributeProto::AttributeType_Name(attribs.Get(attriNum).type()),
194 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100195 }
196 }
197 }
198}
199
Narumol Prangnawaratbc3bb622021-09-24 16:08:34 +0100200int ReadMandatoryNodeIntAttribute(const onnx::NodeProto& node,
201 const std::string& name)
202{
203 int attribValue = 0;
204 ReadMandatoryNodeAttributeImpl(node, name, onnx::AttributeProto::INT,
205 [&attribValue](const onnx::AttributeProto& attrValue)
206 {
207 attribValue = CHECKED_INT32(attrValue.i());
208 });
209 return attribValue;
210}
211
Ryan OSheaed27ee72020-04-22 16:37:29 +0100212int64_t ReadOptionalNodeInt64Attribute(const onnx::NodeProto& node,
213 const std::string& name,
214 const int64_t defaultValue = 0)
215{
216 int64_t attribValue = defaultValue;
217 ReadOptionalNodeAttributeImpl(node, name, onnx::AttributeProto::INT,
218 [&attribValue](const onnx::AttributeProto& attrValue)
219 {
220 attribValue = attrValue.i();
221 });
222 return attribValue;
223}
224
telsoa01c577f2c2018-08-31 09:22:23 +0100225std::vector<uint32_t> ReadMandatoryNodeUint32ListAttribute(const onnx::NodeProto& node,
226 const std::string& name)
227{
228 std::vector<uint32_t> attriList;
229 ReadMandatoryNodeAttributeImpl(node, name, onnx::AttributeProto::INTS,
230 [&attriList](const onnx::AttributeProto& attrValue)
231 {
232 for (int attriNum = 0; attriNum < attrValue.ints_size(); ++attriNum)
233 {
234 attriList.push_back(CHECKED_NON_NEGATIVE(CHECKED_INT32(attrValue.ints().Get(attriNum))));
235 }
236 });
237 return attriList;
238}
239
240uint32_t ReadOptionalNodeUint32Attribute(const onnx::NodeProto& node,
241 const std::string& name,
242 const uint32_t defaultVal = 0u)
243{
244 uint32_t attribValue = defaultVal;
245 ReadOptionalNodeAttributeImpl(node, name, onnx::AttributeProto::INT,
246 [&attribValue](const onnx::AttributeProto& attrValue)
247 {
248 attribValue = CHECKED_NON_NEGATIVE(CHECKED_INT32((attrValue.i())));
249 });
250 return attribValue;
251}
252
253std::vector<uint32_t> ReadOptionalNodeUint32ListAttribute(const onnx::NodeProto& node,
254 const std::string& name)
255{
256 std::vector<uint32_t> attriList;
257 ReadOptionalNodeAttributeImpl(node, name, onnx::AttributeProto::INTS,
258 [&attriList](const onnx::AttributeProto& attrValue)
259 {
260 for (int attriNum = 0; attriNum < attrValue.ints_size(); ++attriNum)
261 {
262 attriList.push_back(CHECKED_NON_NEGATIVE(CHECKED_INT32(attrValue.ints().Get(attriNum))));
263 }
264 });
265
266 return attriList;
267}
268
269float ReadOptionalNodeFloatAttribute(const onnx::NodeProto& node,
270 const std::string& name,
271 const float defaultValue = 0.0f)
272{
273 float attribValue = defaultValue;
274 ReadOptionalNodeAttributeImpl(node, name, onnx::AttributeProto::FLOAT,
275 [&attribValue](const onnx::AttributeProto& attrValue)
276 {
277 attribValue = attrValue.f();
278 });
279 return attribValue;
280}
281
282std::string ReadOptionalNodeStringAttribute(const onnx::NodeProto& node, const std::string& name)
283{
284 std::string attribValue = "";
285 ReadOptionalNodeAttributeImpl(node, name, onnx::AttributeProto::STRING,
286 [&attribValue](const onnx::AttributeProto& attrValue)
287 {
288 attribValue = attrValue.s();
289 });
290 return attribValue;
291}
292
Tee Jungfcf6fd52019-11-01 05:27:28 +0000293armnn::TensorInfo ToTensorInfo(const std::string& name, std::vector<unsigned int>& shape, int data_type)
telsoa01c577f2c2018-08-31 09:22:23 +0100294{
Narumol Prangnawarat452274c2021-09-23 16:12:19 +0100295 DataType type;
296 switch(data_type)
297 {
298 case onnx::TensorProto::FLOAT:
299 {
300 type = DataType::Float32;
telsoa01c577f2c2018-08-31 09:22:23 +0100301 break;
Narumol Prangnawarat452274c2021-09-23 16:12:19 +0100302 }
303 case onnx::TensorProto::INT32:
304 case onnx::TensorProto::INT64:
305 {
306 type = DataType::Signed32;
307 break;
308 }
309 default:
310 {
311 throw ParseException(
312 fmt::format("'{}' is not a currently supported datatype for tensor {}."
313 " Supported dataTypes are FLOAT, INT32 and INT64. {}",
314 onnx::TensorProto::DataType_Name(static_cast<onnx::TensorProto::DataType>(data_type)),
315 name,
316 CHECK_LOCATION().AsString() ));
317 }
318 }
Tee Jungcaf2bdd2019-11-13 07:23:14 +0000319
Narumol Prangnawarat1b11f322021-10-13 11:44:50 +0100320 // Scalar Tensor
Narumol Prangnawarat452274c2021-09-23 16:12:19 +0100321 if (shape.empty())
322 {
323 return TensorInfo(TensorShape(Dimensionality::Scalar), type);
324 }
Tee Jungcaf2bdd2019-11-13 07:23:14 +0000325
Narumol Prangnawarat1b11f322021-10-13 11:44:50 +0100326 // Dynamic Tensor
327 if(std::find(shape.begin(), shape.end(), 0) != shape.end())
328 {
329 return TensorInfo(TensorShape(Dimensionality::NotSpecified), type);
330 }
331
Narumol Prangnawarat452274c2021-09-23 16:12:19 +0100332 return TensorInfo(TensorShape(static_cast<unsigned int>(shape.size()), shape.data()), type);
Tee Jungfcf6fd52019-11-01 05:27:28 +0000333}
334
335armnn::TensorInfo ToTensorInfo(const onnx::ValueInfoProto& info)
336{
337 const onnx::TensorShapeProto onnxShape = info.type().tensor_type().shape();
338 std::vector<unsigned int> shapeDims;
339 for (int i = 0; i < onnxShape.dim_size(); ++i)
340 {
341 shapeDims.push_back(CHECKED_NON_NEGATIVE(CHECKED_INT32(onnxShape.dim(i).dim_value())));
342 }
343
344 return ToTensorInfo(info.name(), shapeDims, info.type().tensor_type().elem_type());
345}
346
347armnn::TensorInfo ToTensorInfo(const onnx::TensorProto& tensor)
348{
349 std::vector<unsigned int> shapeDims;
Ryan OShea337c17f2020-02-21 12:33:17 +0000350
Tee Jungfcf6fd52019-11-01 05:27:28 +0000351 for (auto dim: tensor.dims())
352 {
353 shapeDims.push_back(CHECKED_NON_NEGATIVE(CHECKED_INT32(dim)));
354 }
355
356 return ToTensorInfo(tensor.name(), shapeDims, tensor.data_type());
telsoa01c577f2c2018-08-31 09:22:23 +0100357}
358
359std::string TensorInfoAsString(const TensorInfo& info,
360 const std::string& name,
361 const onnx::TensorProto::DataType& type)
362{
363 const TensorShape shape = info.GetShape();
364 std::stringstream ss;
365 ss << "tensor '" << name << "' contains "
366 << onnx::TensorProto::DataType_Name(type)
367 << " and has shape [";
368
369 for (uint32_t i = 0; i < shape.GetNumDimensions() - 1; ++i)
370 {
371 ss << shape[i] << ", ";
372 }
373 ss << shape[shape.GetNumDimensions() - 1] << "]";
374 return ss.str();
375}
376
Sadik Armagan60bb9d82021-01-11 15:15:01 +0000377void CalcPadding(uint32_t inputSize,
378 uint32_t filterSize,
379 uint32_t stride,
380 uint32_t dilation,
381 uint32_t* paddingFront,
382 uint32_t* paddingBack,
383 bool isUpper)
telsoa01c577f2c2018-08-31 09:22:23 +0100384{
385 uint32_t outputSize = (inputSize + stride - 1) / stride;
Sadik Armagan60bb9d82021-01-11 15:15:01 +0000386 uint32_t dilatedSize = filterSize + (dilation - 1) * (filterSize - 1);
387 uint32_t temp = (outputSize - 1) * stride + dilatedSize;
telsoa01c577f2c2018-08-31 09:22:23 +0100388 *paddingFront = (temp - inputSize) / 2;
389 *paddingBack = *paddingFront;
390 if((temp - inputSize) % 2 == 1)
391 {
392 if (isUpper)
393 {
Sadik Armagan60bb9d82021-01-11 15:15:01 +0000394 *paddingBack += 1;
telsoa01c577f2c2018-08-31 09:22:23 +0100395 }
396 else
397 {
Sadik Armagan60bb9d82021-01-11 15:15:01 +0000398 *paddingFront += 1;
telsoa01c577f2c2018-08-31 09:22:23 +0100399 }
400 }
401}
402
Ryan OSheaed27ee72020-04-22 16:37:29 +0100403TensorInfo ComputeReshapeInfo(const TensorShape& targetShapeTensor,
telsoa01c577f2c2018-08-31 09:22:23 +0100404 const TensorShape& inShape,
Narumol Prangnawarat452274c2021-09-23 16:12:19 +0100405 const std::string& outName,
406 DataType dataType = DataType::Float32)
telsoa01c577f2c2018-08-31 09:22:23 +0100407{
408 std::vector<int> targetDims;
Ryan OSheaed27ee72020-04-22 16:37:29 +0100409 for(uint i = 0; i < targetShapeTensor.GetNumDimensions(); ++i)
telsoa01c577f2c2018-08-31 09:22:23 +0100410 {
Ryan OSheaed27ee72020-04-22 16:37:29 +0100411 int val = CHECKED_INT32(targetShapeTensor[i]);
telsoa01c577f2c2018-08-31 09:22:23 +0100412 if(val == 0)
413 {
414 targetDims.push_back(static_cast<int>(inShape[static_cast<uint>(i)]));
415 }
416 else
417 {
418 targetDims.push_back(val);
419 }
420 }
421
422 std::vector<unsigned int> outDims(targetDims.begin(), targetDims.end());
423 const auto stretchDim = std::find(targetDims.begin(), targetDims.end(), -1);
424 if (stretchDim != targetDims.end())
425 {
426 if (std::find(std::next(stretchDim), targetDims.end(), -1) != targetDims.end())
427 {
428 std::stringstream ss;
429 ss << "[ ";
430 for(uint i = 0; i < targetDims.size() - 1; ++i)
431 {
432 ss << targetDims[i] << ", ";
433 }
434 ss << targetDims[targetDims.size() - 1] << " ]";
435
James Ward58dec6b2020-09-11 17:32:44 +0100436 throw ParseException(
437 fmt::format("Error during creation of reshaped tensor '{}'. At most one component of shape can be "
438 " -1 and here, shape is {} {}",
439 outName,
440 ss.str(),
441 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100442 }
443
Matthew Sloyan589e3e82020-09-11 16:17:48 +0100444 auto targetNumElements = armnn::numeric_cast<unsigned int>(std::accumulate(targetDims.begin(), targetDims.end(),
telsoa01c577f2c2018-08-31 09:22:23 +0100445 -1, std::multiplies<int32_t>()));
446 auto stretchIndex = static_cast<size_t>(std::distance(targetDims.begin(), stretchDim));
447 outDims[stretchIndex] = inShape.GetNumElements() / targetNumElements;
448 }
449 TensorShape outShape = TensorShape{static_cast<unsigned int>(outDims.size()), outDims.data()};
Narumol Prangnawarat452274c2021-09-23 16:12:19 +0100450 return TensorInfo(outShape, dataType);
telsoa01c577f2c2018-08-31 09:22:23 +0100451}
452
453} //namespace
454
Kevin Mayef33cb12021-01-29 14:24:57 +0000455const std::map<std::string, OnnxParserImpl::OperationParsingFunction> OnnxParserImpl::m_ParserFunctions = {
456 { "BatchNormalization", &OnnxParserImpl::ParseBatchNormalization},
457 { "GlobalAveragePool", &OnnxParserImpl::ParseGlobalAveragePool},
458 { "AveragePool", &OnnxParserImpl::ParseAveragePool },
459 { "Clip", &OnnxParserImpl::ParseClip },
460 { "Constant", &OnnxParserImpl::ParseConstant },
461 { "MaxPool", &OnnxParserImpl::ParseMaxPool },
462 { "Reshape", &OnnxParserImpl::ParseReshape },
463 { "Sigmoid", &OnnxParserImpl::ParseSigmoid },
464 { "Tanh", &OnnxParserImpl::ParseTanh },
465 { "Relu", &OnnxParserImpl::ParseRelu },
466 { "LeakyRelu", &OnnxParserImpl::ParseLeakyRelu },
467 { "Conv", &OnnxParserImpl::ParseConv },
468 { "Add", &OnnxParserImpl::ParseAdd },
Narumol Prangnawaratcdc495e2021-09-16 18:13:39 +0100469 { "Flatten", &OnnxParserImpl::ParseFlatten },
Narumol Prangnawaratf10b15a2021-09-17 21:08:57 +0100470 { "Shape", &OnnxParserImpl::ParseShape },
471 { "Gather", &OnnxParserImpl::ParseGather },
Narumol Prangnawaratbc3bb622021-09-24 16:08:34 +0100472 { "Unsqueeze", &OnnxParserImpl::ParseUnsqueeze },
Narumol Prangnawarat1112b012021-09-30 12:10:50 +0100473 { "Concat", &OnnxParserImpl::ParseConcat },
474 { "Gemm", &OnnxParserImpl::ParseGemm }
telsoa01c577f2c2018-08-31 09:22:23 +0100475};
476
477template<typename TypePair, typename Location>
Kevin Mayef33cb12021-01-29 14:24:57 +0000478void OnnxParserImpl::ValidateInputs(const onnx::NodeProto& node,
telsoa01c577f2c2018-08-31 09:22:23 +0100479 TypePair validInputs,
480 const Location& location)
481{
482 for(auto input : node.input())
483 {
484 CheckValidDataType(validInputs.second,
485 m_TensorsInfo[input].m_dtype,
486 validInputs.first,
487 node.name(),
488 input,
489 location);
490 }
491}
492
493#define VALID_INPUTS(NODE, VALID_INPUTS) \
Kevin Mayef33cb12021-01-29 14:24:57 +0000494 OnnxParserImpl::ValidateInputs(NODE, \
telsoa01c577f2c2018-08-31 09:22:23 +0100495 VALID_INPUTS, \
496 CHECK_LOCATION())
497
Kevin Mayef33cb12021-01-29 14:24:57 +0000498std::vector<TensorInfo> OnnxParserImpl::ComputeOutputInfo(std::vector<std::string> outNames,
499 const IConnectableLayer* layer,
Narumol Prangnawarat452274c2021-09-23 16:12:19 +0100500 std::vector<TensorShape> inputShapes,
501 const onnx::TensorProto::DataType& dataType)
telsoa01c577f2c2018-08-31 09:22:23 +0100502{
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100503 ARMNN_ASSERT(! outNames.empty());
telsoa01c577f2c2018-08-31 09:22:23 +0100504 bool needCompute = std::any_of(outNames.begin(),
505 outNames.end(),
506 [this](std::string name)
507 {
Narumol Prangnawarat1b11f322021-10-13 11:44:50 +0100508 return (m_TensorsInfo.count(name) == 0 || m_TensorsInfo[name].m_info == nullptr
509 || m_TensorsInfo[name].m_info->GetShape().GetDimensionality() ==
510 Dimensionality::NotSpecified);
telsoa01c577f2c2018-08-31 09:22:23 +0100511 });
Narumol Prangnawarat452274c2021-09-23 16:12:19 +0100512 std::vector<TensorInfo> outInfo;
513 //if the output info(s) are not here, we need to compute them
514 std::vector<TensorShape> inferredShapes;
515 DataType armnnType = DataType::Float32;
516 if(needCompute) {
517 inferredShapes = layer->InferOutputShapes(inputShapes);
518 ARMNN_ASSERT(inferredShapes.size() == outNames.size());
519 switch (dataType) {
520 case onnx::TensorProto::FLOAT: {
521 armnnType = DataType::Float32;
522 break;
523 }
524 case onnx::TensorProto::INT32:
525 case onnx::TensorProto::INT64: {
526 armnnType = DataType::Signed32;
527 break;
528 }
529 default: {
530 throw ParseException(
531 fmt::format("'{}' is not a currently supported datatype for {}."
532 " Supported dataTypes are FLOAT, INT32 and INT64. {}",
533 onnx::TensorProto::DataType_Name(static_cast<onnx::TensorProto::DataType>(dataType)),
534 layer->GetName(),
535 CHECK_LOCATION().AsString()));
536 }
537 }
538 }
539 for (uint i = 0; i < outNames.size(); ++i)
540 {
541 if(needCompute)
542 {
543 m_TensorsInfo[outNames[i]] = OnnxTensor();
544 m_TensorsInfo[outNames[i]].m_info = std::make_unique<TensorInfo>(
545 TensorInfo(inferredShapes[i], armnnType));
546 m_TensorsInfo[outNames[i]].m_dtype = dataType;
547 }
telsoa01c577f2c2018-08-31 09:22:23 +0100548 outInfo.push_back(*m_TensorsInfo[outNames[i]].m_info);
Narumol Prangnawarat452274c2021-09-23 16:12:19 +0100549 }
550 return outInfo;
telsoa01c577f2c2018-08-31 09:22:23 +0100551}
552
Kevin Mayef33cb12021-01-29 14:24:57 +0000553OnnxParserImpl::OnnxParserImpl()
telsoa01c577f2c2018-08-31 09:22:23 +0100554 : m_Network(nullptr, nullptr)
555{
556}
557
Kevin Mayef33cb12021-01-29 14:24:57 +0000558void OnnxParserImpl::ResetParser()
telsoa01c577f2c2018-08-31 09:22:23 +0100559{
560 m_Network = armnn::INetworkPtr(nullptr, nullptr);
561 m_Graph = nullptr;
Narumol Prangnawarat1b11f322021-10-13 11:44:50 +0100562 m_InputInfos.clear();
563 m_OutputInfos.clear();
telsoa01c577f2c2018-08-31 09:22:23 +0100564}
565
Kevin Mayef33cb12021-01-29 14:24:57 +0000566void OnnxParserImpl::Cleanup()
telsoa01c577f2c2018-08-31 09:22:23 +0100567{
568 m_TensorConnections.clear();
569 m_TensorsInfo.clear();
570 m_OutputsMap.clear();
571 m_OutputsFusedAndUsed.clear();
Narumol Prangnawarat1b11f322021-10-13 11:44:50 +0100572 m_InputShapes.clear();
telsoa01c577f2c2018-08-31 09:22:23 +0100573}
574
Jan Eilers53ef7952021-06-02 12:01:25 +0100575template<typename T>
576std::pair<armnn::ConstTensor, std::unique_ptr<T[]>>
577CreateConstTensorImpl(const T* bufferPtr,
578 armnn::TensorInfo& tensorInfo,
579 const armnn::Optional<armnn::PermutationVector&> permutationVector)
telsoa01c577f2c2018-08-31 09:22:23 +0100580{
Jan Eilers53ef7952021-06-02 12:01:25 +0100581 ARMNN_ASSERT_MSG(bufferPtr != nullptr, fmt::format("Buffer for permutation is null").c_str());
582
583 std::unique_ptr<T[]> data(new T[tensorInfo.GetNumElements()]);
584
585 if (permutationVector.has_value() && permutationVector.value().GetSize() > 0)
586 {
587 tensorInfo = armnnUtils::Permuted(tensorInfo, permutationVector.value());
588 armnnUtils::Permute(tensorInfo.GetShape(), permutationVector.value(),
589 reinterpret_cast<const T*>(bufferPtr), data.get(), sizeof(T));
590 }
591 else
592 {
593 ::memcpy(data.get(), bufferPtr, tensorInfo.GetNumBytes());
594 }
595
596 return std::make_pair(ConstTensor(tensorInfo, data.get()), std::move(data));
597}
598
599std::pair<ConstTensor, std::unique_ptr<float[]>>
600OnnxParserImpl::CreateConstTensor(const std::string name,
601 armnn::Optional<armnn::PermutationVector&> permutationVector)
602{
603 TensorInfo tensorInfo = *m_TensorsInfo[name].m_info;
telsoa01c577f2c2018-08-31 09:22:23 +0100604 onnx::TensorProto onnxTensor = *m_TensorsInfo[name].m_tensor;
605
Narumol Prangnawaratf10b15a2021-09-17 21:08:57 +0100606 //ONNX can have Float16 and double constant nodes but ArmNN only supports float32
607 CHECK_VALID_DATATYPE(name, onnxTensor.name(),
608 static_cast<onnx::TensorProto::DataType>(onnxTensor.data_type()), onnx::TensorProto::FLOAT);
609
Matthew Sloyan81beae32021-07-13 19:46:11 +0100610 // Makes sure IsConstant flag is set.
611 tensorInfo.SetConstant();
612
Jan Eilers53ef7952021-06-02 12:01:25 +0100613 // Const tensors requires at least a list of values
614 if (tensorInfo.GetNumElements() == 0)
615 {
616 throw ParseException(fmt::format("No tensor data found for Const tensor '{}' {}",
617 name,
618 CHECK_LOCATION().AsString()));
619 }
620
telsoa01c577f2c2018-08-31 09:22:23 +0100621 auto srcData = onnxTensor.float_data().data();
Pablo Tello3dcc1c62019-04-24 14:20:21 +0100622 // Copy the value list entries into the destination
623 if (!onnxTensor.has_raw_data())
telsoa01c577f2c2018-08-31 09:22:23 +0100624 {
Pablo Tello3dcc1c62019-04-24 14:20:21 +0100625 if(tensorInfo.GetNumElements() != static_cast<uint>(onnxTensor.float_data_size()))
626 {
James Ward58dec6b2020-09-11 17:32:44 +0100627 throw ParseException(
628 fmt::format("The number of data provided ({}) does not match the tensor '{}' number of "
629 "elements ({}) {}",
630 onnxTensor.float_data_size(),
631 name,
632 tensorInfo.GetNumElements(),
633 CHECK_LOCATION().AsString()));
Pablo Tello3dcc1c62019-04-24 14:20:21 +0100634 }
Jan Eilers53ef7952021-06-02 12:01:25 +0100635 return CreateConstTensorImpl<float>(srcData, tensorInfo, permutationVector);
telsoa01c577f2c2018-08-31 09:22:23 +0100636 }
Pablo Tello3dcc1c62019-04-24 14:20:21 +0100637 else
638 {
Jan Eilers53ef7952021-06-02 12:01:25 +0100639 return CreateConstTensorImpl<float>(reinterpret_cast<const float*>(onnxTensor.raw_data().c_str()),
640 tensorInfo,
641 permutationVector);
Pablo Tello3dcc1c62019-04-24 14:20:21 +0100642 }
telsoa01c577f2c2018-08-31 09:22:23 +0100643}
644
Narumol Prangnawaratf10b15a2021-09-17 21:08:57 +0100645std::pair<ConstTensor, std::unique_ptr<int32_t[]>>
646OnnxParserImpl::CreateInt64ConstTensor(const std::string name,
647 armnn::Optional<armnn::PermutationVector&> permutationVector)
648{
649 TensorInfo tensorInfo = *m_TensorsInfo[name].m_info;
650 onnx::TensorProto onnxTensor = *m_TensorsInfo[name].m_tensor;
651
652 CHECK_VALID_DATATYPE(name, onnxTensor.name(),
653 static_cast<onnx::TensorProto::DataType>(onnxTensor.data_type()), onnx::TensorProto::INT64);
654
655 // Makes sure IsConstant flag is set.
656 tensorInfo.SetConstant();
657 uint numElements = tensorInfo.GetNumElements();
658
659 // Const tensors requires at least a list of values
660 if (numElements == 0)
661 {
662 throw ParseException(fmt::format("No tensor data found for Const tensor '{}' {}",
663 name,
664 CHECK_LOCATION().AsString()));
665 }
666
667 // Copy the value list entries into the destination
668 if (!onnxTensor.has_raw_data())
669 {
670 auto srcData = onnxTensor.int64_data().data();
671 if(numElements != static_cast<uint>(onnxTensor.int64_data_size()))
672 {
673 throw ParseException(
674 fmt::format("The number of data provided ({}) does not match the tensor '{}' number of "
675 "elements ({}) {}",
676 onnxTensor.int64_data_size(),
677 name,
678 tensorInfo.GetNumElements(),
679 CHECK_LOCATION().AsString()));
680 }
681
682 std::vector<int32_t> int32Data;
683 for(uint i = 0; i < numElements; i++)
684 {
685 int32_t int32Value = CHECKED_INT32(srcData[i]);
686 int32Data.push_back(int32Value);
687 }
688
689 return CreateConstTensorImpl<int32_t>(int32Data.data(), tensorInfo, permutationVector);
690 }
691 else
692 {
693 auto srcData = reinterpret_cast<const int64_t*>(onnxTensor.raw_data().c_str());
694 std::vector<int32_t> int32Data;
695 for(uint i = 0; i < numElements; i++)
696 {
697 int32_t int32Value = CHECKED_INT32(srcData[i]);
698 int32Data.push_back(int32Value);
699 }
700 return CreateConstTensorImpl<int32_t>(int32Data.data(), tensorInfo, permutationVector);
701 }
702}
703
Kevin Mayef33cb12021-01-29 14:24:57 +0000704ModelPtr OnnxParserImpl::LoadModelFromTextFile(const char* graphFile)
telsoa01c577f2c2018-08-31 09:22:23 +0100705{
706 FILE* fd = fopen(graphFile, "r");
707
708 if (fd == nullptr)
709 {
James Ward58dec6b2020-09-11 17:32:44 +0100710 throw FileNotFoundException(fmt::format("Invalid (null) filename {}", CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100711 }
712
713 // Parse the file into a message
714 ModelPtr modelProto = std::make_unique<onnx::ModelProto>();
715 using google::protobuf::io::FileInputStream;
716 std::unique_ptr<FileInputStream> input = std::make_unique<FileInputStream>(fileno(fd));
717 bool success = google::protobuf::TextFormat::Parse(input.get(), modelProto.get());
718 fclose(fd);
719
720 if (!success)
721 {
722 std::stringstream error;
723 error << "Failed to parse graph file";
James Ward58dec6b2020-09-11 17:32:44 +0100724 throw ParseException(fmt::format("{} {}", error.str(), CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100725 }
726 return modelProto;
727}
728
Kevin Mayef33cb12021-01-29 14:24:57 +0000729INetworkPtr OnnxParserImpl::CreateNetworkFromTextFile(const char* graphFile)
telsoa01c577f2c2018-08-31 09:22:23 +0100730{
731 ResetParser();
732 ModelPtr modelProto = LoadModelFromTextFile(graphFile);
733 return CreateNetworkFromModel(*modelProto);
734}
735
Narumol Prangnawarat1b11f322021-10-13 11:44:50 +0100736INetworkPtr OnnxParserImpl::CreateNetworkFromTextFile(const char* graphFile,
737 const std::map<std::string, armnn::TensorShape>& inputShapes)
738{
739 ResetParser();
740 m_InputShapes = inputShapes;
741 ModelPtr modelProto = LoadModelFromTextFile(graphFile);
742 return CreateNetworkFromModel(*modelProto);
743}
telsoa01c577f2c2018-08-31 09:22:23 +0100744
Mike Kelly2ae32242022-11-25 13:55:24 +0000745INetworkPtr OnnxParserImpl::CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent)
746{
747 ResetParser();
748 ModelPtr modelProto = LoadModelFromBinary(binaryContent);
749 return CreateNetworkFromModel(*modelProto);
750}
751
752INetworkPtr OnnxParserImpl::CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent,
753 const std::map<std::string, armnn::TensorShape>& inputShapes)
754{
755 ResetParser();
756 m_InputShapes = inputShapes;
757 ModelPtr modelProto = LoadModelFromBinary(binaryContent);
758 return CreateNetworkFromModel(*modelProto);
759}
760
761ModelPtr OnnxParserImpl::LoadModelFromBinary(const std::vector<uint8_t>& binaryContent)
762{
763 if (binaryContent.size() == 0)
764 {
765 throw ParseException(fmt::format("Missing binary content", CHECK_LOCATION().AsString()));
766 }
767 // Parse the file into a message
768 ModelPtr modelProto = std::make_unique<onnx::ModelProto>();
769
770 google::protobuf::io::CodedInputStream codedStream(binaryContent.data(), static_cast<int>(binaryContent.size()));
771 codedStream.SetTotalBytesLimit(INT_MAX);
772 bool success = modelProto.get()->ParseFromCodedStream(&codedStream);
773
774 if (!success)
775 {
776 std::stringstream error;
777 error << "Failed to parse graph";
778 throw ParseException(fmt::format("{} {}", error.str(), CHECK_LOCATION().AsString()));
779 }
780 return modelProto;
781}
782
Kevin Mayef33cb12021-01-29 14:24:57 +0000783ModelPtr OnnxParserImpl::LoadModelFromBinaryFile(const char* graphFile)
telsoa01c577f2c2018-08-31 09:22:23 +0100784{
785 FILE* fd = fopen(graphFile, "rb");
786
787 if (fd == nullptr)
788 {
James Ward58dec6b2020-09-11 17:32:44 +0100789 throw FileNotFoundException(fmt::format("Invalid (null) filename {}", CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100790 }
791
792 // Parse the file into a message
793 ModelPtr modelProto = std::make_unique<onnx::ModelProto>();
794
795 google::protobuf::io::FileInputStream inStream(fileno(fd));
796 google::protobuf::io::CodedInputStream codedStream(&inStream);
Nikhil Raje5181532020-10-09 14:52:25 +0100797 codedStream.SetTotalBytesLimit(INT_MAX);
telsoa01c577f2c2018-08-31 09:22:23 +0100798 bool success = modelProto.get()->ParseFromCodedStream(&codedStream);
799 fclose(fd);
800
801 if (!success)
802 {
803 std::stringstream error;
804 error << "Failed to parse graph file";
James Ward58dec6b2020-09-11 17:32:44 +0100805 throw ParseException(fmt::format("{} {}", error.str(), CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100806 }
807 return modelProto;
808
809}
810
Kevin Mayef33cb12021-01-29 14:24:57 +0000811INetworkPtr OnnxParserImpl::CreateNetworkFromBinaryFile(const char* graphFile)
telsoa01c577f2c2018-08-31 09:22:23 +0100812{
813 ResetParser();
814 ModelPtr modelProto = LoadModelFromBinaryFile(graphFile);
815 return CreateNetworkFromModel(*modelProto);
816}
817
Narumol Prangnawarat1b11f322021-10-13 11:44:50 +0100818INetworkPtr OnnxParserImpl::CreateNetworkFromBinaryFile(const char* graphFile,
819 const std::map<std::string, armnn::TensorShape>& inputShapes)
820{
821 ResetParser();
822 m_InputShapes = inputShapes;
823 ModelPtr modelProto = LoadModelFromBinaryFile(graphFile);
824 return CreateNetworkFromModel(*modelProto);
825}
826
Kevin Mayef33cb12021-01-29 14:24:57 +0000827ModelPtr OnnxParserImpl::LoadModelFromString(const std::string& protoText)
telsoa01c577f2c2018-08-31 09:22:23 +0100828{
829 if (protoText == "")
830 {
James Ward58dec6b2020-09-11 17:32:44 +0100831 throw InvalidArgumentException(fmt::format("Invalid (empty) string for model parameter {}",
832 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100833 }
834 // Parse the string into a message
835 ModelPtr modelProto = std::make_unique<onnx::ModelProto>();
836 bool success = google::protobuf::TextFormat::ParseFromString(protoText, modelProto.get());
837 if (!success)
838 {
839 std::stringstream error;
840 error << "Failed to parse graph file";
James Ward58dec6b2020-09-11 17:32:44 +0100841 throw ParseException(fmt::format("{} {}", error.str(), CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100842 }
843 return modelProto;
844}
845
Kevin Mayef33cb12021-01-29 14:24:57 +0000846INetworkPtr OnnxParserImpl::CreateNetworkFromString(const std::string& protoText)
telsoa01c577f2c2018-08-31 09:22:23 +0100847{
848 ResetParser();
849 ModelPtr modelProto = LoadModelFromString(protoText);
850 return CreateNetworkFromModel(*modelProto);
851}
852
Narumol Prangnawarat1b11f322021-10-13 11:44:50 +0100853INetworkPtr OnnxParserImpl::CreateNetworkFromString(const std::string& protoText,
854 const std::map<std::string, armnn::TensorShape>& inputShapes)
855{
856 ResetParser();
857 m_InputShapes = inputShapes;
858 ModelPtr modelProto = LoadModelFromString(protoText);
859 return CreateNetworkFromModel(*modelProto);
860}
861
Kevin Mayef33cb12021-01-29 14:24:57 +0000862INetworkPtr OnnxParserImpl::CreateNetworkFromModel(onnx::ModelProto& model)
telsoa01c577f2c2018-08-31 09:22:23 +0100863{
864 m_Network = INetwork::Create();
865 try
866 {
867 m_Graph = std::make_unique<onnx::GraphProto>(*model.mutable_graph());
868 LoadGraph();
869 }
870 catch (const ParseException& e)
871 {
872 Cleanup();
873 throw e;
874 }
875 Cleanup();
876 return std::move(m_Network);
877}
878
Kevin Mayef33cb12021-01-29 14:24:57 +0000879void OnnxParserImpl::LoadGraph()
telsoa01c577f2c2018-08-31 09:22:23 +0100880{
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100881 ARMNN_ASSERT(m_Graph.get() != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +0100882
883 //Fill m_TensorsInfo with the shapes and value of every tensor
884 SetupInfo(m_Graph->mutable_output());
885 SetupInfo(m_Graph->mutable_input());
886 SetupInfo(m_Graph->mutable_value_info());
887
888 for (auto tensor : m_Graph->initializer())
889 {
890 m_TensorsInfo[tensor.name()].m_tensor = std::make_unique<const onnx::TensorProto>(tensor);
Tee Jungfcf6fd52019-11-01 05:27:28 +0000891 m_TensorsInfo[tensor.name()].m_info = std::make_unique<TensorInfo>(ToTensorInfo(tensor));
892 m_TensorsInfo[tensor.name()].m_dtype =
893 static_cast<onnx::TensorProto::DataType>(tensor.data_type());
telsoa01c577f2c2018-08-31 09:22:23 +0100894 }
895
896 SetupInputLayers();
897 SetupOutputLayers();
898
899 //Detect FullyConnected layers with bias and update the FusedAndUsed map acccordingly
900 DetectFullyConnected();
901
902 //Parsing the graph
903 for(size_t nodeIndex = 0; nodeIndex < static_cast<size_t>(m_Graph->node_size()); nodeIndex++)
904 {
905 auto node = m_Graph->node(static_cast<int>(nodeIndex));
906 const std::string& operation = node.op_type();
907
908 // check which layers we handled already (add and matmul fused as FC)
Ryan OShea337c17f2020-02-21 12:33:17 +0000909 if (operation == "MatMul" )
telsoa01c577f2c2018-08-31 09:22:23 +0100910 {
911 if(m_OutputsFusedAndUsed[nodeIndex].inputForNodes != m_OutputsFusedAndUsed[nodeIndex].fusedWithNodes.size())
912 {
913 //Node which can not be fused as a FullyConnected layer (used in layers as a simple matmul output)
914 AddFullyConnected(node);
915 }
916 }
917 else if (!(m_OutputsFusedAndUsed[nodeIndex].fusedWithNodes.empty()) && operation == "Add")
918 {
919 int matmulIndex = static_cast<int> (m_OutputsFusedAndUsed[nodeIndex].fusedWithNodes[0]);
920 AddFullyConnected(m_Graph->node(matmulIndex), &node);
921 }
922 else if (m_OutputsFusedAndUsed[nodeIndex].fusedWithNodes.empty()) //node is not part of a fused layer
923 {
924 auto it = m_ParserFunctions.find(operation);
925 if (it != m_ParserFunctions.end())
926 {
927 auto func = it->second;
928 (this->*func)(node);
929 }
930 else
931 {
James Ward58dec6b2020-09-11 17:32:44 +0100932 throw ParseException(fmt::format("Unsupported operation {} for node '{}' {}",
933 operation,
934 node.name(),
935 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100936 }
937 }
938 }
939
940 //Making the connections between outputs and inputs of each layers
941 for (const auto& tensorCon : m_TensorConnections)
942 {
943 if (tensorCon.second.outputSlot != nullptr)
944 {
945 for (size_t inputSlotIdx = 0; inputSlotIdx < tensorCon.second.inputSlots.size(); ++inputSlotIdx)
946 {
947 tensorCon.second.outputSlot->Connect(*(tensorCon.second.inputSlots[inputSlotIdx]));
948 }
949 }
950 }
Narumol Prangnawarat1b11f322021-10-13 11:44:50 +0100951
952 // Get output info.
953 for(int outputIndex = 0; outputIndex < m_Graph->output_size(); ++outputIndex)
954 {
955 auto output = m_Graph->output(outputIndex);
956 m_OutputInfos[output.name()] = *m_TensorsInfo[output.name()].m_info;
957 }
telsoa01c577f2c2018-08-31 09:22:23 +0100958}
959
Kevin Mayef33cb12021-01-29 14:24:57 +0000960void OnnxParserImpl::SetupInfo(const google::protobuf::RepeatedPtrField<onnx::ValueInfoProto >* list)
telsoa01c577f2c2018-08-31 09:22:23 +0100961{
962 for (auto tensor : *list)
963 {
964 m_TensorsInfo[tensor.name()] = OnnxTensor();
965 m_TensorsInfo[tensor.name()].m_info = std::make_unique<TensorInfo>(ToTensorInfo(tensor));
Matteo Martincighe355dc22018-12-10 13:45:27 +0000966 m_TensorsInfo[tensor.name()].m_dtype =
967 static_cast<onnx::TensorProto::DataType>(tensor.type().tensor_type().elem_type());
telsoa01c577f2c2018-08-31 09:22:23 +0100968 }
969}
970
Kevin Mayef33cb12021-01-29 14:24:57 +0000971void OnnxParserImpl::DetectFullyConnected()
telsoa01c577f2c2018-08-31 09:22:23 +0100972{
973 m_OutputsFusedAndUsed = std::vector<UsageSummary> (static_cast<size_t>(m_Graph->node_size()), UsageSummary());
974 auto matmulAndConstant = [&](const std::string& constInput,
975 const std::string& matmulInput,
976 int& nodeIndex)
977 {
978 auto matmulIt = m_OutputsMap.find(matmulInput);
979 if(matmulIt != m_OutputsMap.end() && matmulIt->second.first->op_type() == "MatMul"
980 && m_TensorsInfo[constInput].isConstant())
981 {
982 nodeIndex = matmulIt->second.second;
983 return true;
984 }
985 return false;
986 };
987
988 for(int nodeIndex = 0; nodeIndex < m_Graph->node_size(); nodeIndex++)
989 {
990 const onnx::NodeProto* node = &m_Graph->node(nodeIndex);
991 for (const std::string& output : node->output())
992 {
993 m_OutputsMap[output] = std::make_pair(node, nodeIndex);
994 }
995
996 for (const std::string& input : node->input()) //count how many time a node is used as input
997 {
998 auto matmulIt = m_OutputsMap.find(input);
999 if(matmulIt != m_OutputsMap.end()){
1000 ++m_OutputsFusedAndUsed[static_cast<size_t>(matmulIt->second.second)].inputForNodes; //node used
1001 }
1002 }
1003
1004 if (node->op_type() == "Add")
1005 {
1006 int matmulIndex = 0;
1007 if (matmulAndConstant(node->input(0), node->input(1), matmulIndex) ||
1008 matmulAndConstant(node->input(1), node->input(0), matmulIndex))
1009 {
1010 //matmul and add were fused
1011 m_OutputsFusedAndUsed[static_cast<size_t>(matmulIndex)].fusedWithNodes
1012 .push_back(static_cast<size_t>(nodeIndex));
1013
1014 m_OutputsFusedAndUsed[static_cast<size_t>(nodeIndex)].fusedWithNodes
1015 .push_back(static_cast<size_t>(matmulIndex));
1016 }
1017 }
1018 }
1019
1020 for (auto output: m_Graph->output()) { //Add usages as output of the graph in count of usages
1021 auto matmulIt = m_OutputsMap.find(output.name());
1022 if(matmulIt != m_OutputsMap.end()){
1023 ++m_OutputsFusedAndUsed[static_cast<size_t>(matmulIt->second.second)].inputForNodes;
1024 }
1025 }
1026}
1027
1028template<typename Location>
Kevin Mayef33cb12021-01-29 14:24:57 +00001029void OnnxParserImpl::GetInputAndParam(const onnx::NodeProto& node,
1030 std::string* inputName,
1031 std::string* constName,
1032 const Location& location)
telsoa01c577f2c2018-08-31 09:22:23 +01001033{
1034 int cstIndex;
1035 if (m_TensorsInfo[node.input(0)].isConstant())
1036 {
1037 cstIndex = 0;
1038 }
1039 else if (m_TensorsInfo[node.input(1)].isConstant())
1040 {
1041 cstIndex = 1;
1042 }
1043 else
1044 {
James Ward58dec6b2020-09-11 17:32:44 +01001045 throw ParseException(fmt::format("One of the input tensors ('{}' or '{}') should be constant in node '{}' {}",
1046 node.input(0),
1047 node.input(1),
1048 node.name(),
1049 location.AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001050 }
1051 if(constName)
1052 {
1053 *constName = node.input(cstIndex);
1054 }
1055 if(inputName)
1056 {
1057 *inputName = node.input(!cstIndex);
1058 }
1059}
1060
1061template<typename Location>
Kevin Mayef33cb12021-01-29 14:24:57 +00001062void OnnxParserImpl::To1DTensor(const std::string& name, const Location& location)
telsoa01c577f2c2018-08-31 09:22:23 +01001063{
1064 TensorShape shape = m_TensorsInfo[name].m_info->GetShape();
1065 std::vector<uint32_t> newShape;
1066 for(uint i = 0; i < shape.GetNumDimensions() - 1; ++i)
1067 {
1068 if(shape[i] != 1)
1069 {
James Ward58dec6b2020-09-11 17:32:44 +01001070 throw ParseException(
1071 fmt::format("Only tensors with shape [1, ..., 1, X] can be converted to 1D and {} {}",
1072 TensorInfoAsString(*m_TensorsInfo[name].m_info, name, m_TensorsInfo[name].m_dtype),
1073 location.AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001074 }
1075 }
1076 newShape.push_back(shape[shape.GetNumDimensions() - 1]);
1077
1078 m_TensorsInfo[name].m_info->SetShape(TensorShape(static_cast<unsigned int>(newShape.size()), newShape.data()));
1079}
1080
Kevin Mayef33cb12021-01-29 14:24:57 +00001081void OnnxParserImpl::AddConvLayerWithDepthwiseConv(const onnx::NodeProto& node, const Convolution2dDescriptor& convDesc)
Ryan OSheaed27ee72020-04-22 16:37:29 +01001082{
1083 ARMNN_ASSERT(node.op_type() == "Conv");
1084
1085 DepthwiseConvolution2dDescriptor desc;
1086 desc.m_PadLeft = convDesc.m_PadLeft;
1087 desc.m_PadRight = convDesc.m_PadRight;
1088 desc.m_PadTop = convDesc.m_PadTop;
1089 desc.m_PadBottom = convDesc.m_PadBottom;
1090 desc.m_StrideX = convDesc.m_StrideX;
1091 desc.m_StrideY = convDesc.m_StrideY;
1092 desc.m_BiasEnabled = convDesc.m_BiasEnabled;
1093
Cathal Corbett06902652022-04-14 17:55:11 +01001094 armnn::IConnectableLayer* layer = m_Network->AddDepthwiseConvolution2dLayer(desc, node.name().c_str());
Cathal Corbett541880f2022-05-16 15:20:56 +01001095 std::string permuteStr = "permute_" + node.input(1);
1096 std::vector<std::string> tensorIndexes= {node.input(0), permuteStr};
Jan Eilers53ef7952021-06-02 12:01:25 +01001097
Cathal Corbett541880f2022-05-16 15:20:56 +01001098 auto weightTensor = CreateConstTensor(node.input(1));
Cathal Corbett06902652022-04-14 17:55:11 +01001099 IConnectableLayer* weightsLayer = m_Network->AddConstantLayer(weightTensor.first);
Cathal Corbett541880f2022-05-16 15:20:56 +01001100
1101 // weights come in as [O,1,H,W] from ONNX and need to be converted to ArmNNs depthwise weights layout [1,H,W,O]
1102 armnn::PermutationVector perVec {3, 0, 1, 2};
1103 TensorInfo weightsPermuted = armnnUtils::Permuted(weightTensor.first.GetInfo(), perVec);
1104
1105 // Inserts NewLayer so layers don't need to be re-sorted.
1106 IConnectableLayer* permuteLayer = m_Network->AddPermuteLayer(PermuteDescriptor(perVec),
1107 "permute_layer");
1108 permuteLayer->GetOutputSlot(0).SetTensorInfo(weightsPermuted);
1109 permuteLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1u));
1110
Cathal Corbett06902652022-04-14 17:55:11 +01001111 weightsLayer->GetOutputSlot(0).SetTensorInfo(weightTensor.first.GetInfo());
Cathal Corbett541880f2022-05-16 15:20:56 +01001112 weightsLayer->GetOutputSlot(0).Connect(permuteLayer->GetInputSlot(0u));
Cathal Corbett06902652022-04-14 17:55:11 +01001113
Ryan OSheaed27ee72020-04-22 16:37:29 +01001114 if (node.input_size() == 3)
1115 {
1116 if(!m_TensorsInfo[node.input(2)].isConstant())
1117 {
James Ward58dec6b2020-09-11 17:32:44 +01001118 throw ParseException(fmt::format("Bias '{}' should be constant in Conv layer '{}' {}",
1119 node.input(2),
1120 node.name(),
1121 CHECK_LOCATION().AsString()));
Ryan OSheaed27ee72020-04-22 16:37:29 +01001122 }
Cathal Corbett06902652022-04-14 17:55:11 +01001123
Ryan OSheaed27ee72020-04-22 16:37:29 +01001124 desc.m_BiasEnabled = true;
1125 auto biasTensor = CreateConstTensor(node.input(2));
Cathal Corbett06902652022-04-14 17:55:11 +01001126 tensorIndexes.emplace_back(node.input(2));
1127
1128 IConnectableLayer* biasLayer = m_Network->AddConstantLayer(biasTensor.first);
1129 biasLayer->GetOutputSlot(0).SetTensorInfo(biasTensor.first.GetInfo());
1130 biasLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(2u));
Ryan OSheaed27ee72020-04-22 16:37:29 +01001131 }
Cathal Corbett06902652022-04-14 17:55:11 +01001132
Ryan OSheaed27ee72020-04-22 16:37:29 +01001133 ARMNN_ASSERT(layer != nullptr);
1134
1135 auto outputInfo = ComputeOutputInfo({ node.output(0) }, layer,
1136 { m_TensorsInfo[node.input(0)].m_info->GetShape(),
Cathal Corbett541880f2022-05-16 15:20:56 +01001137 weightsPermuted.GetShape() });
Ryan OSheaed27ee72020-04-22 16:37:29 +01001138
1139 layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
1140
1141 // register the input connection slots for the layer, connections are made after all layers have been created
1142 // only the tensors for the inputs are relevant, exclude the const tensors
Cathal Corbett06902652022-04-14 17:55:11 +01001143 RegisterInputSlots(layer, tensorIndexes);
Ryan OSheaed27ee72020-04-22 16:37:29 +01001144
1145 // register the output connection slots for the layer, connections are made after all layers have been created
1146 RegisterOutputSlots(layer, {node.output(0)});
1147}
1148
Kevin Mayef33cb12021-01-29 14:24:57 +00001149void OnnxParserImpl::AddFullyConnected(const onnx::NodeProto& matmulNode, const onnx::NodeProto* addNode)
telsoa01c577f2c2018-08-31 09:22:23 +01001150{
1151
1152 // find matmul inputs
1153 std::string weightName;
1154 std::string inputName;
1155 CHECK_VALID_SIZE(static_cast<size_t>(matmulNode.input_size()), 2);
1156 CHECK_VALID_SIZE(static_cast<size_t>(matmulNode.output_size()), 1);
1157 VALID_INPUTS(matmulNode, STR_LIST(onnx::TensorProto::FLOAT));
1158
1159 GetInputAndParam(matmulNode, &inputName, &weightName, CHECK_LOCATION());
1160
1161 FullyConnectedDescriptor desc;
1162 desc.m_BiasEnabled = addNode != nullptr;
1163
1164 IConnectableLayer* layer = nullptr;
1165 if(desc.m_BiasEnabled)
1166 {
1167 // find bias const
1168 std::string biasName;
1169 CHECK_VALID_SIZE(static_cast<size_t>(addNode->input_size()), 2);
1170 CHECK_VALID_SIZE(static_cast<size_t>(addNode->output_size()), 1);
1171 VALID_INPUTS(*addNode, STR_LIST(onnx::TensorProto::FLOAT));
1172
1173 GetInputAndParam(*addNode, nullptr, &biasName, CHECK_LOCATION());
1174
1175 //Output shape is [1, weights[1]] and 1d vec in ONNX can be [1,X] so we convert biases to "armnn" 1D
1176 To1DTensor(biasName, CHECK_LOCATION());
1177 TensorInfo weightInfo = *m_TensorsInfo[weightName].m_info;
1178 TensorInfo biasInfo = *m_TensorsInfo[biasName].m_info;
1179
1180 if (weightInfo.GetShape()[1] != biasInfo.GetShape()[0])
1181 {
James Ward58dec6b2020-09-11 17:32:44 +01001182 throw ParseException(
1183 fmt::format("Shape of weights '{}' and bias of following Add node '{}' do not match : {}"
1184 " and {} ( /!\\ bias should be a 1D tensor) {}",
1185 weightName,
1186 addNode->name(),
1187 TensorInfoAsString(*m_TensorsInfo[weightName].m_info, weightName,
1188 m_TensorsInfo[weightName].m_dtype),
1189 TensorInfoAsString(*m_TensorsInfo[biasName].m_info, biasName,
1190 m_TensorsInfo[biasName].m_dtype ),
1191 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001192 }
Matthew Sloyan81beae32021-07-13 19:46:11 +01001193
1194 // Just add a FullyConnected layer, weights and biases are handled as inputs now.
1195 layer = m_Network->AddFullyConnectedLayer(desc, matmulNode.name().c_str());
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001196 ARMNN_ASSERT(layer != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +01001197
1198 auto outputInfo = ComputeOutputInfo({addNode->output(0)}, layer,
1199 {m_TensorsInfo[inputName].m_info->GetShape(),
1200 m_TensorsInfo[weightName].m_info->GetShape()});
telsoa01c577f2c2018-08-31 09:22:23 +01001201 layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
1202
Matthew Sloyan81beae32021-07-13 19:46:11 +01001203 // Add constant layer to store weights/biases and connect to FullyConnected layer..
1204 if(m_TensorsInfo[weightName].isConstant())
1205 {
1206 IConnectableLayer* weightsLayer = m_Network->AddConstantLayer(CreateConstTensor(weightName).first);
1207
1208 weightInfo.SetConstant();
1209 weightsLayer->GetOutputSlot(0).SetTensorInfo(weightInfo);
1210 weightsLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1u));
1211 }
1212
1213 if(m_TensorsInfo[biasName].isConstant())
1214 {
1215 IConnectableLayer* biasLayer = m_Network->AddConstantLayer(CreateConstTensor(biasName).first);
1216
1217 biasInfo.SetConstant();
1218 biasLayer->GetOutputSlot(0).SetTensorInfo(biasInfo);
1219 biasLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(2u));
1220 }
1221
1222 RegisterInputSlots(layer, {inputName, weightName, biasName});
telsoa01c577f2c2018-08-31 09:22:23 +01001223 RegisterOutputSlots(layer, {addNode->output(0)});
1224 }
1225 else
1226 {
Matthew Sloyan81beae32021-07-13 19:46:11 +01001227 layer = m_Network->AddFullyConnectedLayer(desc, matmulNode.name().c_str());
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001228 ARMNN_ASSERT(layer != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +01001229
1230 auto outputInfo = ComputeOutputInfo({matmulNode.output(0)}, layer,
1231 {m_TensorsInfo[inputName].m_info->GetShape(),
1232 m_TensorsInfo[weightName].m_info->GetShape()});
1233 layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
1234
Matthew Sloyan81beae32021-07-13 19:46:11 +01001235 // Add constant layer to store weights and connect to FullyConnected layer.
1236 if(m_TensorsInfo[weightName].isConstant())
1237 {
1238 TensorInfo weightInfo = *m_TensorsInfo[weightName].m_info;
1239 IConnectableLayer* weightsLayer = m_Network->AddConstantLayer(CreateConstTensor(weightName).first);
1240
1241 weightInfo.SetConstant();
1242 weightsLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1u));
1243 weightsLayer->GetOutputSlot(0).SetTensorInfo(weightInfo);
1244 }
1245
1246 RegisterInputSlots(layer, {inputName, weightName});
telsoa01c577f2c2018-08-31 09:22:23 +01001247 RegisterOutputSlots(layer, {matmulNode.output(0)});
1248 }
1249}
1250
Kevin Mayef33cb12021-01-29 14:24:57 +00001251void OnnxParserImpl::AddPoolingLayer(const onnx::NodeProto& node, Pooling2dDescriptor& desc)
telsoa01c577f2c2018-08-31 09:22:23 +01001252{
1253
1254 CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 1);
1255 CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
1256
1257 VALID_INPUTS(node, STR_LIST(onnx::TensorProto::FLOAT));
1258
1259 std::vector<uint32_t> kernel_shape = ReadMandatoryNodeUint32ListAttribute(node, "kernel_shape"); //size of pool win
1260 std::vector<uint32_t> strides = ReadOptionalNodeUint32ListAttribute(node, "strides");
1261 std::vector<uint32_t> pads = ReadOptionalNodeUint32ListAttribute(node, "pads");
1262
1263 desc.m_OutputShapeRounding = OutputShapeRounding::Floor;
1264 desc.m_PoolWidth = kernel_shape[1];
1265 desc.m_PoolHeight = kernel_shape[0];
1266
1267 if(strides.empty())
1268 {
1269 desc.m_StrideX = 1;
1270 desc.m_StrideY = 1;
1271 }
1272 else
1273 {
1274 desc.m_StrideX = strides[1];
1275 desc.m_StrideY = strides[0];
1276 }
1277
1278 //Check new padding version first
1279 if(pads.empty())
1280 {
1281 //Check deprecated version
1282 std::string paddingString = ReadOptionalNodeStringAttribute(node, "auto_pad");
1283 if(paddingString != "VALID" && paddingString != "" && paddingString != "NOTSET")
1284 {
1285 bool isUpper;
1286 if( paddingString == "SAME_LOWER")
1287 {
1288 isUpper = false;
1289 }
1290 else if (paddingString == "SAME_UPPER")
1291 {
1292 isUpper = true;
1293 }
1294 else
1295 {
James Ward58dec6b2020-09-11 17:32:44 +01001296 throw ParseException(fmt::format("Invalid auto_pad attribute for node {}. "
1297 "Only SAME_UPPER, SAME_LOWER or VALID supported and found {} {}",
1298 node.name(),
1299 paddingString,
1300 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001301 }
1302 auto inputInfo = *m_TensorsInfo[node.input(0)].m_info;
1303 uint32_t inputHeight = inputInfo.GetShape()[2];
1304 uint32_t inputWidth = inputInfo.GetShape()[3];
Sadik Armagan60bb9d82021-01-11 15:15:01 +00001305 CalcPadding(inputHeight,
1306 desc.m_PoolHeight,
1307 desc.m_StrideY,
1308 1u,
1309 &desc.m_PadTop,
1310 &desc.m_PadBottom,
1311 isUpper);
1312 CalcPadding(inputWidth,
1313 desc.m_PoolWidth,
1314 desc.m_StrideX,
1315 1u,
1316 &desc.m_PadLeft,
1317 &desc.m_PadRight,
1318 isUpper);
telsoa01c577f2c2018-08-31 09:22:23 +01001319 }
1320 }
1321 else
1322 {
1323 desc.m_PadTop = pads[0];
1324 desc.m_PadLeft = pads[1];
1325 desc.m_PadBottom = pads[2];
1326 desc.m_PadRight = pads[3];
1327 }
1328
1329 IConnectableLayer* layer = m_Network->AddPooling2dLayer(desc, node.name().c_str());
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001330 ARMNN_ASSERT(layer != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +01001331
1332 auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, {m_TensorsInfo[node.input(0)].m_info->GetShape()});
1333 layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
1334
1335 // register the input connection slots for the layer, connections are made after all layers have been created
1336 // only the tensors for the inputs are relevant, exclude the const tensors
1337 RegisterInputSlots(layer, {node.input(0)});
1338
1339 // register the output connection slots for the layer, connections are made after all layers have been created
1340 RegisterOutputSlots(layer, {node.output(0)});
1341}
1342
Kevin Mayef33cb12021-01-29 14:24:57 +00001343std::pair<std::string, std::string> OnnxParserImpl::AddPrepareBroadcast(const std::string& input0,
1344 const std::string& input1)
Ryan OSheaed27ee72020-04-22 16:37:29 +01001345{
1346 std::pair<std::string, std::string> inputs = std::make_pair(input0, input1);
1347
1348 TensorShape input0Shape = m_TensorsInfo[input0].m_info->GetShape();
1349 TensorShape input1Shape = m_TensorsInfo[input1].m_info->GetShape();
1350
1351 if(input1Shape.GetNumDimensions() < input0Shape.GetNumDimensions())
1352 {
James Ward58dec6b2020-09-11 17:32:44 +01001353 auto outputName = fmt::format("reshape_output_{}", input1);
Ryan OSheaed27ee72020-04-22 16:37:29 +01001354 PrependForBroadcast(outputName, input1, input0);
1355 inputs.second = outputName;
1356 }
1357 else if(input0Shape.GetNumDimensions() < input1Shape.GetNumDimensions())
1358 {
James Ward58dec6b2020-09-11 17:32:44 +01001359 auto outputName = fmt::format("reshape_output_{}", input0);
Ryan OSheaed27ee72020-04-22 16:37:29 +01001360 PrependForBroadcast(outputName, input0, input1);
1361 inputs.first = outputName;
1362 }
1363 return inputs;
1364}
1365
Kevin Mayef33cb12021-01-29 14:24:57 +00001366void OnnxParserImpl::CreateConstantLayer(const std::string& tensorName, const std::string& layerName)
Ryan OSheaed27ee72020-04-22 16:37:29 +01001367{
1368 auto armnnTensor = CreateConstTensor(tensorName);
Narumol Prangnawaratf10b15a2021-09-17 21:08:57 +01001369 IConnectableLayer* layer = m_Network->AddConstantLayer(armnnTensor.first, layerName.c_str());
1370 layer->GetOutputSlot(0).SetTensorInfo(armnnTensor.first.GetInfo());
1371 RegisterOutputSlots(layer, {tensorName});
1372}
Ryan OSheaed27ee72020-04-22 16:37:29 +01001373
Narumol Prangnawaratf10b15a2021-09-17 21:08:57 +01001374void OnnxParserImpl::CreateInt64ConstantLayer(const std::string& tensorName, const std::string& layerName)
1375{
1376 auto armnnTensor = CreateInt64ConstTensor(tensorName);
Ryan OSheaed27ee72020-04-22 16:37:29 +01001377 IConnectableLayer* layer = m_Network->AddConstantLayer(armnnTensor.first, layerName.c_str());
1378 layer->GetOutputSlot(0).SetTensorInfo(armnnTensor.first.GetInfo());
1379 RegisterOutputSlots(layer, {tensorName});
1380}
1381
Kevin Mayef33cb12021-01-29 14:24:57 +00001382void OnnxParserImpl::CreateReshapeLayer(const std::string& inputName,
1383 const std::string& outputName,
1384 const std::string& layerName)
telsoa01c577f2c2018-08-31 09:22:23 +01001385{
1386 const TensorInfo outputTensorInfo = *m_TensorsInfo[outputName].m_info;
1387 ReshapeDescriptor reshapeDesc;
1388 reshapeDesc.m_TargetShape = outputTensorInfo.GetShape();
1389
1390 IConnectableLayer* layer = m_Network->AddReshapeLayer(reshapeDesc, layerName.c_str());
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001391 ARMNN_ASSERT(layer != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +01001392 layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
1393
1394 // register the input connection slots for the layer, connections are made after all layers have been created
1395 // only the tensors for the inputs are relevant, exclude the const tensors
1396 RegisterInputSlots(layer, {inputName});
1397
1398 // register the output connection slots for the layer, connections are made after all layers have been created
1399 RegisterOutputSlots(layer, {outputName});
1400}
1401
Kevin Mayef33cb12021-01-29 14:24:57 +00001402void OnnxParserImpl::ParseActivation(const onnx::NodeProto& node, const armnn::ActivationFunction func)
telsoa01c577f2c2018-08-31 09:22:23 +01001403{
Finn Williams7ee5d2c2020-03-27 11:11:50 +00001404 CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 1, 3);
telsoa01c577f2c2018-08-31 09:22:23 +01001405 CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
1406
1407 VALID_INPUTS(node, STR_LIST(onnx::TensorProto::FLOAT));
1408
1409 ActivationDescriptor desc;
Tee Jung7ff9a602019-11-01 07:04:42 +00001410 desc.m_Function = func;
telsoa01c577f2c2018-08-31 09:22:23 +01001411
Finn Williams7ee5d2c2020-03-27 11:11:50 +00001412 if (func == ActivationFunction::BoundedReLu)
1413 {
Narumol Prangnawaratf106ab72021-09-15 17:30:37 +01001414 if (node.input_size() == 1 && node.attribute_size() > 0)
1415 {
1416 desc.m_A = ReadOptionalNodeFloatAttribute(node, "max", std::numeric_limits<float>::max());
1417 desc.m_B = ReadOptionalNodeFloatAttribute(node, "min", std::numeric_limits<float>::lowest());
1418 }
1419 else
1420 {
1421 desc.m_A = node.input(2).empty() ? std::numeric_limits<float>::max() : std::stof(node.input(2));
1422 desc.m_B = node.input(1).empty() ? std::numeric_limits<float>::lowest() : std::stof(node.input(1));
1423 }
Finn Williams7ee5d2c2020-03-27 11:11:50 +00001424 }
1425
telsoa01c577f2c2018-08-31 09:22:23 +01001426 IConnectableLayer* const layer = m_Network->AddActivationLayer(desc, node.name().c_str());
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001427 ARMNN_ASSERT(layer != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +01001428
1429 auto outputInfo = ComputeOutputInfo({ node.output(0)}, layer, {m_TensorsInfo[node.input(0)].m_info->GetShape()});
1430 layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
1431
1432 // register the input connection slots for the layer, connections are made after all layers have been created
1433 // only the tensors for the inputs are relevant, exclude the const tensors
1434 RegisterInputSlots(layer, {node.input(0)});
1435
1436 // register the output connection slots for the layer, connections are made after all layers have been created
1437 RegisterOutputSlots(layer, {node.output(0)});
1438}
1439
Kevin Mayef33cb12021-01-29 14:24:57 +00001440void OnnxParserImpl::ParseClip(const onnx::NodeProto& node)
Finn Williams7ee5d2c2020-03-27 11:11:50 +00001441{
1442 ParseActivation(node, ActivationFunction::BoundedReLu);
1443}
1444
Kevin Mayef33cb12021-01-29 14:24:57 +00001445void OnnxParserImpl::ParseSigmoid(const onnx::NodeProto& node)
Tee Jung7ff9a602019-11-01 07:04:42 +00001446{
1447 ParseActivation(node, ActivationFunction::Sigmoid);
1448}
1449
Kevin Mayef33cb12021-01-29 14:24:57 +00001450void OnnxParserImpl::ParseTanh(const onnx::NodeProto& node)
Tee Jung7ff9a602019-11-01 07:04:42 +00001451{
1452 ParseActivation(node, ActivationFunction::TanH);
1453}
1454
Kevin Mayef33cb12021-01-29 14:24:57 +00001455void OnnxParserImpl::ParseRelu(const onnx::NodeProto& node)
Tee Jung7ff9a602019-11-01 07:04:42 +00001456{
1457 ParseActivation(node, ActivationFunction::ReLu);
1458}
1459
Kevin Mayef33cb12021-01-29 14:24:57 +00001460void OnnxParserImpl::ParseLeakyRelu(const onnx::NodeProto& node)
Tee Jung7ff9a602019-11-01 07:04:42 +00001461{
1462 ParseActivation(node, ActivationFunction::LeakyReLu);
1463}
telsoa01c577f2c2018-08-31 09:22:23 +01001464
Kevin Mayef33cb12021-01-29 14:24:57 +00001465void OnnxParserImpl::ParseAdd(const onnx::NodeProto& node)
telsoa01c577f2c2018-08-31 09:22:23 +01001466{
Ryan OSheaed27ee72020-04-22 16:37:29 +01001467 CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 2);
1468 CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
telsoa01c577f2c2018-08-31 09:22:23 +01001469
Ryan OSheaed27ee72020-04-22 16:37:29 +01001470 VALID_INPUTS(node, STR_LIST(onnx::TensorProto::FLOAT));
telsoa01c577f2c2018-08-31 09:22:23 +01001471
Ryan OSheaed27ee72020-04-22 16:37:29 +01001472 // TODO: unify broadcast validation code across layers
1473 // tracked by: IVGCVSW-1576
telsoa01c577f2c2018-08-31 09:22:23 +01001474
Ryan OSheaed27ee72020-04-22 16:37:29 +01001475 // Checking broadcast compatibility : only scalar or 1D tensors
1476 auto inputs = AddPrepareBroadcast(node.input(0), node.input(1));
1477 auto input0 = *m_TensorsInfo[inputs.first].m_info;
1478 auto input1 = *m_TensorsInfo[inputs.second].m_info;
1479 ARMNN_ASSERT(input0.GetNumDimensions() == input1.GetNumDimensions());
1480
1481 unsigned int numDims = input0.GetNumDimensions();
1482 for (unsigned int i = 0; i < numDims; i++)
telsoa01c577f2c2018-08-31 09:22:23 +01001483 {
Ryan OSheaed27ee72020-04-22 16:37:29 +01001484 unsigned int dim0 = input0.GetShape()[i];
1485 unsigned int dim1 = input1.GetShape()[i];
1486 if (dim0 != dim1 && dim0 != 1 && dim1 != 1)
telsoa01c577f2c2018-08-31 09:22:23 +01001487 {
James Ward58dec6b2020-09-11 17:32:44 +01001488 throw ParseException(
1489 fmt::format("Broadcast is only supported for scalar or 1D tensors in Add node '{}'. "
1490 "Input dimensions should either match or one should be of size 1 and here, "
1491 "{} and {} {}",
1492 node.name(),
1493 TensorInfoAsString(*m_TensorsInfo[inputs.first].m_info, inputs.first,
1494 m_TensorsInfo[inputs.first].m_dtype),
1495 TensorInfoAsString(*m_TensorsInfo[inputs.second].m_info, inputs.second,
1496 m_TensorsInfo[inputs.second].m_dtype),
1497 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001498 }
telsoa01c577f2c2018-08-31 09:22:23 +01001499 }
Ryan OSheaed27ee72020-04-22 16:37:29 +01001500
1501
1502 IConnectableLayer* layer = m_Network->AddAdditionLayer(node.name().c_str());
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001503 ARMNN_ASSERT(layer != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +01001504
1505 auto outputInfo = ComputeOutputInfo({ node.output(0) }, layer,
Ryan OSheaed27ee72020-04-22 16:37:29 +01001506 { m_TensorsInfo[inputs.first].m_info->GetShape(),
1507 m_TensorsInfo[inputs.second].m_info->GetShape() });
telsoa01c577f2c2018-08-31 09:22:23 +01001508 layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
1509
Ryan OSheaed27ee72020-04-22 16:37:29 +01001510 // register the input connection -> for constant inputs, we need to make a newDim constant layer
1511 if(m_TensorsInfo[inputs.first].isConstant()) {
James Ward58dec6b2020-09-11 17:32:44 +01001512 CreateConstantLayer(inputs.first, fmt::format("Add:constant_of_{}", node.input(0)));
Ryan OSheaed27ee72020-04-22 16:37:29 +01001513 }
1514 if(m_TensorsInfo[inputs.second].isConstant()) {
James Ward58dec6b2020-09-11 17:32:44 +01001515 CreateConstantLayer(inputs.second, fmt::format("Add:constant_of_{}", node.input(1)));
Ryan OSheaed27ee72020-04-22 16:37:29 +01001516 }
1517 RegisterInputSlots(layer, {inputs.first, inputs.second});
telsoa01c577f2c2018-08-31 09:22:23 +01001518
Ryan OSheaed27ee72020-04-22 16:37:29 +01001519 // register the output connection
telsoa01c577f2c2018-08-31 09:22:23 +01001520 RegisterOutputSlots(layer, {node.output(0)});
1521}
1522
Kevin Mayef33cb12021-01-29 14:24:57 +00001523void OnnxParserImpl::ParseAveragePool(const onnx::NodeProto& node)
Ryan OSheaed27ee72020-04-22 16:37:29 +01001524{
1525 Pooling2dDescriptor desc;
1526 desc.m_PoolType = PoolingAlgorithm::Average;
1527
1528 uint32_t count_include_pad = 0;
1529 count_include_pad = ReadOptionalNodeUint32Attribute(node, "count_include_pad");
1530 if(count_include_pad) {
1531 desc.m_PaddingMethod = PaddingMethod::IgnoreValue;
1532 }
1533 AddPoolingLayer(node, desc);
1534}
1535
Kevin Mayef33cb12021-01-29 14:24:57 +00001536void OnnxParserImpl::ParseBatchNormalization(const onnx::NodeProto& node)
Ryan OSheaed27ee72020-04-22 16:37:29 +01001537{
1538 //IGNORE momentum parameter and spatial parameters
1539
1540 CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 5);
1541 CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
1542
1543 VALID_INPUTS(node, STR_LIST(onnx::TensorProto::FLOAT));
1544 for(int ind = 1; ind < node.input_size(); ++ind)
1545 {
1546 auto tensor = node.input(ind);
1547 if(! m_TensorsInfo[tensor].isConstant())
1548 {
James Ward58dec6b2020-09-11 17:32:44 +01001549 throw ParseException(
1550 fmt::format("Input tensor '{}' should be constant in BatchNormalization node '{}' {}",
1551 tensor,
1552 node.name(),
1553 CHECK_LOCATION().AsString()));
Ryan OSheaed27ee72020-04-22 16:37:29 +01001554 }
1555 }
1556
1557 float epsilon = ReadOptionalNodeFloatAttribute(node, "epsilon", 1e-5f);
1558 BatchNormalizationDescriptor desc;
1559 desc.m_Eps = epsilon;
1560
1561 auto scaleTensor = CreateConstTensor(node.input(1));
1562 auto biasTensor = CreateConstTensor(node.input(2));
1563 auto meanTensor = CreateConstTensor(node.input(3));
1564 auto varTensor = CreateConstTensor(node.input(4));
1565
1566 IConnectableLayer* layer = m_Network->AddBatchNormalizationLayer(desc,
1567 meanTensor.first,
1568 varTensor.first,
1569 biasTensor.first,
1570 scaleTensor.first,
1571 node.name().c_str());
1572 ARMNN_ASSERT(layer != nullptr);
1573
1574 auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, {m_TensorsInfo[node.input(0)].m_info->GetShape()});
1575 layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
1576
1577 RegisterInputSlots(layer, {node.input(0)}); //don't register constant inputs
1578
1579 // register the output connection
1580 RegisterOutputSlots(layer, {node.output(0)});
1581}
1582
Narumol Prangnawaratbc3bb622021-09-24 16:08:34 +01001583void OnnxParserImpl::ParseConcat(const onnx::NodeProto& node)
1584{
1585 CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
1586
1587 uint32_t numConcatView = static_cast<uint32_t>(node.input_size());
1588 uint32_t inputRank = m_TensorsInfo[node.input(0)].m_info->GetNumDimensions();
1589
1590 int axisInt = ReadMandatoryNodeIntAttribute(node, "axis");
1591
1592 unsigned int concatDimInput = static_cast<unsigned int>(
1593 (static_cast<int>(inputRank) + axisInt) % static_cast<int>(inputRank));
1594
1595 OriginsDescriptor concatDescriptor(numConcatView, inputRank);
1596 concatDescriptor.SetConcatAxis(concatDimInput);
1597
1598 unsigned int mergeDimOrigin = 0;
1599
1600 std::vector<TensorShape> inputShapes;
1601 std::vector<std::string> tensorIds;
1602
1603 for (unsigned int viewIndex = 0; viewIndex < numConcatView; ++viewIndex)
1604 {
1605 std::string nodeName = node.input(static_cast<int>(viewIndex));
1606 auto inputTensorInfo = *m_TensorsInfo[nodeName].m_info;
1607 inputShapes.push_back(inputTensorInfo.GetShape());
1608 tensorIds.push_back(nodeName);
1609
1610 // Set up concatDescriptor view origin
1611 armnnUtils::ProcessConcatInputTensorInfo(
1612 inputTensorInfo, concatDescriptor, concatDimInput, viewIndex, mergeDimOrigin);
1613 }
1614
1615 IConnectableLayer* layer = m_Network->AddConcatLayer(concatDescriptor, node.name().c_str());
1616 ARMNN_ASSERT(layer != nullptr);
1617
Narumol Prangnawarat452274c2021-09-23 16:12:19 +01001618 auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, inputShapes,
1619 m_TensorsInfo[node.input(0)].m_dtype);
Narumol Prangnawaratbc3bb622021-09-24 16:08:34 +01001620
1621 layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
1622
1623 // register the input connection slots for the layer, connections are made after all layers have been created
1624 RegisterInputSlots(layer, tensorIds);
1625
1626 // register the output connection slots for the layer, connections are made after all layers have been created
1627 RegisterOutputSlots(layer, { node.output(0) });
1628}
1629
Kevin Mayef33cb12021-01-29 14:24:57 +00001630void OnnxParserImpl::ParseConstant(const onnx::NodeProto& node)
Ryan OSheaed27ee72020-04-22 16:37:29 +01001631{
1632 CHECK_VALID_SIZE(static_cast<size_t>(node.attribute_size()), 1);
1633 if (!node.attribute(0).has_t())
1634 {
James Ward58dec6b2020-09-11 17:32:44 +01001635 throw ParseException(fmt::format("Value not found for Constant node '{}' {}",
1636 node.name(),
1637 CHECK_LOCATION().AsString()));
Ryan OSheaed27ee72020-04-22 16:37:29 +01001638 }
1639 const onnx::TensorProto& onnxTensor = node.attribute(0).t();
1640
Ryan OSheaed27ee72020-04-22 16:37:29 +01001641 //Register this as a m_ConstParam so we know we can use it as a constant param in future layers.
1642 m_TensorsInfo[node.output(0)].m_tensor = std::make_unique<const onnx::TensorProto>(onnxTensor);
1643 m_TensorsInfo[node.output(0)].m_info = std::make_unique<TensorInfo>(ToTensorInfo(onnxTensor));
1644 m_TensorsInfo[node.output(0)].m_dtype = static_cast<onnx::TensorProto::DataType>(onnxTensor.data_type());
1645
Narumol Prangnawaratf10b15a2021-09-17 21:08:57 +01001646 if (m_TensorsInfo[node.output(0)].m_dtype == onnx::TensorProto_DataType_FLOAT)
1647 {
1648 CreateConstantLayer(node.output(0), node.name());
1649 }
1650 else if (m_TensorsInfo[node.output(0)].m_dtype == onnx::TensorProto_DataType_INT64)
1651 {
1652 CreateInt64ConstantLayer(node.output(0), node.name());
1653 }
1654 else
1655 {
1656 throw ParseException(fmt::format("Data type not support for Constant node '{}' {}",
1657 node.name(),
1658 CHECK_LOCATION().AsString()));
1659 }
Ryan OSheaed27ee72020-04-22 16:37:29 +01001660}
1661
Kevin Mayef33cb12021-01-29 14:24:57 +00001662void OnnxParserImpl::ParseConv(const onnx::NodeProto& node)
telsoa01c577f2c2018-08-31 09:22:23 +01001663{
1664 CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 2, 3); //input, weight, (bias)
1665 CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
1666
1667 VALID_INPUTS(node, STR_LIST(onnx::TensorProto::FLOAT));
1668
1669 if(m_TensorsInfo[node.input(0)].m_info->GetNumDimensions() != 4)
1670 {
James Ward58dec6b2020-09-11 17:32:44 +01001671 throw ParseException(
1672 fmt::format("ArmNN only supports 2D convolution and Conv layer '{}' input {} {}",
1673 node.name(),
1674 TensorInfoAsString(*m_TensorsInfo[node.input(0)].m_info, node.input(0),
1675 m_TensorsInfo[node.input(0)].m_dtype),
1676 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001677 }
1678
1679 if(!m_TensorsInfo[node.input(1)].isConstant())
1680 {
James Ward58dec6b2020-09-11 17:32:44 +01001681 throw ParseException(
1682 fmt::format("Weights '{}' should be constant in Conv layer '{}' {}",
1683 node.input(1),
1684 node.name(),
1685 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001686 }
1687
1688 auto inputInfo = *m_TensorsInfo[node.input(0)].m_info;
1689
telsoa01c577f2c2018-08-31 09:22:23 +01001690 Convolution2dDescriptor desc;
1691 desc.m_BiasEnabled = false;
1692
1693 std::vector<uint32_t> strides = ReadOptionalNodeUint32ListAttribute(node, "strides");
1694 if(strides.empty())
1695 {
1696 desc.m_StrideX = 1;
1697 desc.m_StrideY = 1;
1698 }
1699 else
1700 {
1701 desc.m_StrideX = strides[1];
1702 desc.m_StrideY = strides[0];
1703 }
1704
Sadik Armagan60bb9d82021-01-11 15:15:01 +00001705 std::vector<uint32_t> dilations = ReadOptionalNodeUint32ListAttribute(node, "dilations");
1706 if(!dilations.empty())
1707 {
1708 desc.m_DilationX = dilations[1];
1709 desc.m_DilationY = dilations[0];
1710 }
1711
telsoa01c577f2c2018-08-31 09:22:23 +01001712 std::vector<uint32_t> pads = ReadOptionalNodeUint32ListAttribute(node, "pads");
1713 //Check new padding version first
1714 if(pads.empty())
1715 {
1716 //Check deprecated version
1717 std::string paddingString = ReadOptionalNodeStringAttribute(node, "auto_pad");
1718 if(paddingString != "VALID" && paddingString != "" && paddingString != "NOTSET")
1719 {
1720 bool isUpper;
1721 if( paddingString == "SAME_LOWER")
1722 {
1723 isUpper = false;
1724 }
1725 else if (paddingString == "SAME_UPPER")
1726 {
1727 isUpper = true;
1728 }
1729 else
1730 {
James Ward58dec6b2020-09-11 17:32:44 +01001731 throw ParseException(
1732 fmt::format("Invalid auto_pad attribute for node {}. Only SAME_UPPER, SAME_LOWER or VALID "
1733 "supported and found {} {}",
1734 node.name(),
1735 paddingString,
1736 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001737 }
1738 uint32_t inputHeight = inputInfo.GetShape()[2];
1739 uint32_t inputWidth = inputInfo.GetShape()[3];
1740
1741 uint32_t weightHeight;
1742 uint32_t weightWidth;
1743 std::vector<uint32_t> kernel_shape = ReadOptionalNodeUint32ListAttribute(node, "kernel_shape");
1744 if (kernel_shape.empty())
1745 {
1746 const TensorInfo weightTensorInfo = *m_TensorsInfo[node.input(1)].m_info;
1747 weightHeight = weightTensorInfo.GetShape()[2];
1748 weightWidth = weightTensorInfo.GetShape()[3];
1749 }
1750 else
1751 {
1752 weightHeight = kernel_shape[0];
1753 weightWidth = kernel_shape[1];
1754 }
Sadik Armagan60bb9d82021-01-11 15:15:01 +00001755 CalcPadding(inputHeight,
1756 weightHeight,
1757 desc.m_StrideY,
1758 desc.m_DilationY,
1759 &desc.m_PadTop,
1760 &desc.m_PadBottom,
1761 isUpper);
1762 CalcPadding(inputWidth,
1763 weightWidth,
1764 desc.m_StrideX,
1765 desc.m_DilationX,
1766 &desc.m_PadLeft,
1767 &desc.m_PadRight,
1768 isUpper);
telsoa01c577f2c2018-08-31 09:22:23 +01001769 }
1770 }
1771 else
1772 {
1773 desc.m_PadTop = pads[0];
1774 desc.m_PadLeft = pads[1];
1775 desc.m_PadBottom = pads[2];
1776 desc.m_PadRight = pads[3];
1777 }
1778
1779 uint32_t group = ReadOptionalNodeUint32Attribute(node, "group", 1);
1780 if(group > 1)
1781 {
1782 if (group > inputInfo.GetShape()[1])
1783 {
1784 throw ParseException(
James Ward58dec6b2020-09-11 17:32:44 +01001785 fmt::format("Error parsing Convolution node: {}. "
1786 "The 'group'={} parameter cannot be larger than the "
1787 "channel of the input shape={} (in NCHW format). {}",
1788 node.name(),
1789 group,
1790 inputInfo.GetShape()[1],
1791 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001792 }
1793 else if (group == inputInfo.GetShape()[1])
1794 {
1795 // we use a depthwise convolution here, because the number of groups equals to the
1796 // input channels
1797 AddConvLayerWithDepthwiseConv(node, desc);
1798 return;
1799 }
1800 else
1801 {
1802 // TODO: split the input by channels into channels/groups separate convolutions
Jim Flynne242f2d2019-05-22 14:24:13 +01001803 // and concatenate the results afterwards
James Ward58dec6b2020-09-11 17:32:44 +01001804 throw ParseException(fmt::format("Error parsing Convolution node: {}. "
1805 "The 'group'={} parameter should be 1 or be equal to the "
1806 "channel of the input shape={} (in NCHW format). {}",
1807 node.name(),
1808 group,
1809 inputInfo.GetShape()[1],
1810 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001811 }
1812 }
1813
Keith Davis721e6292022-05-17 10:06:53 +01001814 node.input_size() == 3 ? desc.m_BiasEnabled = true : desc.m_BiasEnabled = false;
1815 armnn::IConnectableLayer* layer = m_Network->AddConvolution2dLayer(desc, node.name().c_str());
Keith Davisb4dd5cc2022-04-07 11:32:00 +01001816 std::vector<std::string> tensorIndexes= {node.input(0), node.input(1)};
1817
telsoa01c577f2c2018-08-31 09:22:23 +01001818 auto weightTensor = CreateConstTensor(node.input(1));
1819
Keith Davis721e6292022-05-17 10:06:53 +01001820 IConnectableLayer* weightsLayer = m_Network->AddConstantLayer(weightTensor.first);
1821 weightsLayer->GetOutputSlot(0).SetTensorInfo(weightTensor.first.GetInfo());
1822 weightsLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1u));
1823
telsoa01c577f2c2018-08-31 09:22:23 +01001824 if (node.input_size() == 3)
1825 {
1826 if(!m_TensorsInfo[node.input(2)].isConstant())
1827 {
James Ward58dec6b2020-09-11 17:32:44 +01001828 throw ParseException(fmt::format("Bias '{}' should be constant in Conv layer '{}' {}",
1829 node.input(2),
1830 node.name(),
1831 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001832 }
1833 desc.m_BiasEnabled = true;
1834 auto biasTensor = CreateConstTensor(node.input(2));
Keith Davis721e6292022-05-17 10:06:53 +01001835
1836 IConnectableLayer* biasLayer = m_Network->AddConstantLayer(biasTensor.first);
1837 biasLayer->GetOutputSlot(0).SetTensorInfo(biasTensor.first.GetInfo());
1838 biasLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(2u));
1839
1840 tensorIndexes.emplace_back(node.input(2));
telsoa01c577f2c2018-08-31 09:22:23 +01001841 }
Keith Davis721e6292022-05-17 10:06:53 +01001842
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001843 ARMNN_ASSERT(layer != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +01001844
1845 auto outputInfo = ComputeOutputInfo({ node.output(0) }, layer,
1846 { m_TensorsInfo[node.input(0)].m_info->GetShape(),
1847 m_TensorsInfo[node.input(1)].m_info->GetShape() });
1848 layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
1849
1850 // register the input connection slots for the layer, connections are made after all layers have been created
1851 // only the tensors for the inputs are relevant, exclude the const tensors
Keith Davisb4dd5cc2022-04-07 11:32:00 +01001852 RegisterInputSlots(layer, tensorIndexes);
telsoa01c577f2c2018-08-31 09:22:23 +01001853
1854 // register the output connection slots for the layer, connections are made after all layers have been created
1855 RegisterOutputSlots(layer, {node.output(0)});
1856}
1857
Kevin Mayef33cb12021-01-29 14:24:57 +00001858void OnnxParserImpl::ParseFlatten(const onnx::NodeProto& node)
Ryan OSheaed27ee72020-04-22 16:37:29 +01001859{
1860 CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 1);
1861 CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
1862
1863 CHECK_VALID_DATATYPE(node.name(), node.input(0),
1864 m_TensorsInfo[node.input(0)].m_dtype,
1865 onnx::TensorProto::FLOAT);
1866
1867 int64_t axis = ReadOptionalNodeInt64Attribute(node, "axis", 1);
1868 TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape();
1869
1870 /// Negative axis conversion
1871 if (axis < 0)
1872 {
1873 axis += inputShape.GetNumDimensions();
1874 }
1875
1876 /// Check Axis is within dimensions
1877 if (axis < 0 || axis >= inputShape.GetNumDimensions())
1878 {
James Ward58dec6b2020-09-11 17:32:44 +01001879 throw ParseException(fmt::format("Axis '{}' invalid. Tensor has '{}' dimensions in FlattenLayer '{}'",
1880 axis, inputShape.GetNumDimensions(), node.name()));
Ryan OSheaed27ee72020-04-22 16:37:29 +01001881 }
1882
1883 /// If axis chosen is 0 dimension1 will always be 1 in output , default dimension2 to 1 because 0 is invalid
1884 uint dimension1{1};
1885 uint dimension2{1};
1886 uint i{0};
1887
1888 /// dimension1 = (d_0 * d_1 ... d_(axis-1))
1889 for (i = 0; i < axis; i++){
1890 dimension1 *= inputShape[i];
1891 }
1892
1893 /// dimension2 = (d_axis * d_(axis+1) ... d_n)
1894 for (i = static_cast<uint>(axis); i < inputShape.GetNumDimensions(); i++){
1895 dimension2 *= inputShape[i];
1896 }
1897
1898 TensorShape outputShape{dimension1, dimension2};
1899
1900 auto outInfo = ComputeReshapeInfo(outputShape, inputShape, node.output(0));
1901 m_TensorsInfo[node.output(0)].m_info = std::make_unique<TensorInfo>(outInfo);
1902 CreateReshapeLayer(node.input(0), node.output(0), node.name());
1903}
1904
Narumol Prangnawaratf10b15a2021-09-17 21:08:57 +01001905void OnnxParserImpl::ParseGather(const onnx::NodeProto& node)
1906{
1907 CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 2);
1908 CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
1909
1910 armnn::GatherDescriptor gatherDescriptor;
1911 gatherDescriptor.m_Axis = static_cast<int>(ReadOptionalNodeInt64Attribute(node, "axis", 0));
1912
1913 IConnectableLayer* layer = m_Network->AddGatherLayer(gatherDescriptor, node.name().c_str());
1914 ARMNN_ASSERT(layer != nullptr);
1915
Narumol Prangnawarat452274c2021-09-23 16:12:19 +01001916 const TensorShape& inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape();
1917 const TensorShape& indicesShape = m_TensorsInfo[node.input(1)].m_info->GetShape();
1918 auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, { inputShape, indicesShape },
1919 m_TensorsInfo[node.input(0)].m_dtype);
Narumol Prangnawaratf10b15a2021-09-17 21:08:57 +01001920 layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
1921
1922 // register the input connection slots for the layer, connections are made after all layers have been created
1923 RegisterInputSlots(layer, { node.input(0), node.input(1) });
1924
1925 // register the output connection slots for the layer, connections are made after all layers have been created
1926 RegisterOutputSlots(layer, { node.output(0) });
1927}
1928
Narumol Prangnawarat1112b012021-09-30 12:10:50 +01001929void OnnxParserImpl::ParseGemm(const onnx::NodeProto& node)
1930{
1931 CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 2, 3);
1932 CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
1933
1934 int transA = static_cast<int>(ReadOptionalNodeUint32Attribute(node, "transA", 0));
1935 int transB = static_cast<int>(ReadOptionalNodeUint32Attribute(node, "transB", 0));
1936 float alpha = ReadOptionalNodeFloatAttribute(node, "alpha", 1.0);
1937 float beta = ReadOptionalNodeFloatAttribute(node, "beta", 1.0);
1938 bool biasEnabled = node.input_size() == 3;
1939
1940 TensorShape input0Shape = m_TensorsInfo[node.input(0)].m_info->GetShape();
1941 TensorShape input1Shape = m_TensorsInfo[node.input(1)].m_info->GetShape();
1942
1943 // if transB != 0, add transpose to the input1 (tanspose weight matrix in FullyConnected)
1944 armnn::FullyConnectedDescriptor fullyConnectedDescriptor;
1945 fullyConnectedDescriptor.m_BiasEnabled = biasEnabled;
1946 fullyConnectedDescriptor.m_TransposeWeightMatrix = transB;
1947
1948 IConnectableLayer* layer = nullptr;
1949
1950 // Just add a FullyConnected layer, weights and biases are handled as inputs now.
1951 layer = m_Network->AddFullyConnectedLayer(fullyConnectedDescriptor, node.name().c_str());
1952 ARMNN_ASSERT(layer != nullptr);
1953
1954 // if transA != 0, add transpose to the input0
1955 if (transA != 0)
1956 {
1957 std::string transAName = "transpose_" + node.input(0);
1958 armnn::TransposeDescriptor transposeADescriptor;
1959 transposeADescriptor.m_DimMappings = { 1, 0 };
1960 IConnectableLayer* transALayer = m_Network->AddTransposeLayer(transposeADescriptor, transAName.c_str());
1961 ARMNN_ASSERT(transALayer != nullptr);
1962 auto transAInfo = ComputeOutputInfo({ transAName }, transALayer, { input0Shape });
1963 transALayer->GetOutputSlot(0).SetTensorInfo(transAInfo[0]);
1964 transALayer->GetOutputSlot(0).Connect(layer->GetInputSlot(0u));
1965 // register the input connection slots for the layer, connections are made after all layers have been created
1966 RegisterInputSlot(transALayer, node.input(0), 0);
1967 input0Shape = transAInfo[0].GetShape();
1968 }
1969 else
1970 {
1971 RegisterInputSlot(layer, node.input(0), 0);
1972 }
1973
1974 // Add constant layer to store weights/biases and connect to FullyConnected layer.
1975 if(m_TensorsInfo[node.input(1)].isConstant())
1976 {
1977 IConnectableLayer* weightsLayer = m_Network->AddConstantLayer(CreateConstTensor(node.input(1)).first);
1978 TensorInfo weightInfo = *m_TensorsInfo[node.input(1)].m_info;
1979 weightInfo.SetConstant();
1980 weightsLayer->GetOutputSlot(0).SetTensorInfo(weightInfo);
1981
1982 // if alpha != 1, multiply to the weight
1983 if (alpha != 1)
1984 {
1985 std::string activationName = "activation_" + node.input(1);
1986 armnn::ActivationDescriptor activationDescriptor;
1987 activationDescriptor.m_A = alpha;
1988 activationDescriptor.m_Function = ActivationFunction::Linear;
1989 IConnectableLayer* actLayer = m_Network->AddActivationLayer(activationDescriptor, activationName.c_str());
1990 ARMNN_ASSERT(actLayer != nullptr);
1991
1992 auto actInfo = ComputeOutputInfo({ activationName }, actLayer, { weightInfo.GetShape() });
1993 actLayer->GetOutputSlot(0).SetTensorInfo(actInfo[0]);
1994 actLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1u));
1995 weightsLayer->GetOutputSlot(0).Connect(actLayer->GetInputSlot(0u));
1996 input1Shape = actInfo[0].GetShape();
1997 }
1998 else
1999 {
2000 weightsLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1u));
2001 input1Shape = weightInfo.GetShape();
2002 }
2003 }
2004 else
2005 {
2006 // if alpha != 1, multiply to the weight
2007 if (alpha != 1)
2008 {
2009 std::string activationName = "activation_" + node.input(1);
2010 armnn::ActivationDescriptor activationDescriptor;
2011 activationDescriptor.m_A = alpha;
2012 activationDescriptor.m_Function = ActivationFunction::Linear;
2013 IConnectableLayer* actLayer = m_Network->AddActivationLayer(activationDescriptor, activationName.c_str());
2014 ARMNN_ASSERT(actLayer != nullptr);
2015
2016 auto actInfo = ComputeOutputInfo({ activationName }, actLayer, { input1Shape });
2017 actLayer->GetOutputSlot(0).SetTensorInfo(actInfo[0]);
2018 actLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1u));
2019 RegisterInputSlot(actLayer, node.input(1), 0);
2020 input1Shape = actInfo[0].GetShape();
2021 }
2022 else
2023 {
2024 RegisterInputSlot(layer, node.input(1), 1);
2025 }
2026 }
2027
2028 if(biasEnabled && m_TensorsInfo[node.input(2)].isConstant())
2029 {
2030 To1DTensor(node.input(2), CHECK_LOCATION());
2031 IConnectableLayer* biasLayer = m_Network->AddConstantLayer(CreateConstTensor(node.input(2)).first);
2032 TensorInfo biasInfo = *m_TensorsInfo[node.input(2)].m_info;
2033 biasInfo.SetConstant();
2034 biasLayer->GetOutputSlot(0).SetTensorInfo(biasInfo);
2035
2036 // if beta != 1, multiply to the bias
2037 if (beta != 1)
2038 {
2039 std::string activationName = "activation_" + node.input(2);
2040 armnn::ActivationDescriptor activationDescriptor;
2041 activationDescriptor.m_A = beta;
2042 activationDescriptor.m_Function = ActivationFunction::Linear;
2043 IConnectableLayer* actLayer = m_Network->AddActivationLayer(activationDescriptor, activationName.c_str());
2044 ARMNN_ASSERT(actLayer != nullptr);
2045
2046 auto actInfo = ComputeOutputInfo({ activationName }, actLayer, { biasInfo.GetShape() });
2047 actLayer->GetOutputSlot(0).SetTensorInfo(actInfo[0]);
2048 actLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(2u));
2049 biasLayer->GetOutputSlot(0).Connect(actLayer->GetInputSlot(0u));
2050 }
2051 else
2052 {
2053 biasLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(2u));
2054 }
2055 }
2056 else if (biasEnabled)
2057 {
2058 // Currently we support non-constant tensor of input C (bias) of Gemm when the dimension is 1
2059 if (m_TensorsInfo[node.input(2)].m_info->GetNumDimensions() != 1)
2060 {
2061 throw ParseException(fmt::format("The parser supports constant or non-constant with 1 dimension for "
2062 "Input C of Gemm. Input '{}' in '{}' is not supported '{}'",
2063 node.input(2),
2064 node.name(),
2065 CHECK_LOCATION().AsString()));
2066 }
2067 // if beta != 1, multiply to the bias
2068 if (beta != 1)
2069 {
2070 std::string activationName = "activation_" + node.input(2);
2071 armnn::ActivationDescriptor activationDescriptor;
2072 activationDescriptor.m_A = beta;
2073 activationDescriptor.m_Function = ActivationFunction::Linear;
2074 IConnectableLayer* actLayer = m_Network->AddActivationLayer(activationDescriptor, activationName.c_str());
2075 ARMNN_ASSERT(actLayer != nullptr);
2076
2077 auto actInfo = ComputeOutputInfo({ activationName },
2078 actLayer,
2079 { m_TensorsInfo[node.input(2)].m_info->GetShape() });
2080 actLayer->GetOutputSlot(0).SetTensorInfo(actInfo[0]);
2081 actLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(2u));
2082 RegisterInputSlot(actLayer, node.input(2), 0);
2083 }
2084 else
2085 {
2086 RegisterInputSlot(layer, node.input(2), 2);
2087 }
2088 }
2089
2090 // Set final output of the FullyConnected layer
2091 auto outputInfo = ComputeOutputInfo({ node.output(0) }, layer,
2092 { input0Shape, input1Shape });
2093 layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
2094
2095 RegisterOutputSlots(layer, {node.output(0)});
2096}
2097
Kevin Mayef33cb12021-01-29 14:24:57 +00002098void OnnxParserImpl::ParseGlobalAveragePool(const onnx::NodeProto& node)
Ryan OSheaed27ee72020-04-22 16:37:29 +01002099{
2100 Pooling2dDescriptor desc = Pooling2dDescriptor();
2101 desc.m_PoolType = PoolingAlgorithm::Average;
2102
2103 //kernel size is the same as input
2104 TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape();
2105 desc.m_PoolWidth = inputShape[3];
2106 desc.m_PoolHeight = inputShape[2];
2107
2108 IConnectableLayer* layer = m_Network->AddPooling2dLayer(desc, node.name().c_str());
2109 ARMNN_ASSERT(layer != nullptr);
2110
2111 auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, {inputShape});
2112 layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
2113
2114 // register the input connection slots for the layer, connections are made after all layers have been created
2115 // only the tensors for the inputs are relevant, exclude the const tensors
2116 RegisterInputSlots(layer, {node.input(0)});
2117
2118 // register the output connection slots for the layer, connections are made after all layers have been created
2119 RegisterOutputSlots(layer, {node.output(0)});
2120}
2121
Kevin Mayef33cb12021-01-29 14:24:57 +00002122void OnnxParserImpl::ParseMaxPool(const onnx::NodeProto& node)
Ryan OSheaed27ee72020-04-22 16:37:29 +01002123{
2124 Pooling2dDescriptor desc;
2125 desc.m_PoolType = PoolingAlgorithm::Max;
2126 desc.m_PaddingMethod = PaddingMethod::Exclude;
2127 AddPoolingLayer(node, desc);
2128}
2129
Narumol Prangnawaratcdc495e2021-09-16 18:13:39 +01002130void OnnxParserImpl::ParseShape(const onnx::NodeProto& node)
2131{
2132 CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 1);
2133 CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
2134
Narumol Prangnawaratcdc495e2021-09-16 18:13:39 +01002135 IConnectableLayer* layer = m_Network->AddShapeLayer(node.name().c_str());
2136 ARMNN_ASSERT(layer != nullptr);
2137
2138 TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape();
Narumol Prangnawarat452274c2021-09-23 16:12:19 +01002139 auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, {inputShape}, onnx::TensorProto::INT64);
Narumol Prangnawaratcdc495e2021-09-16 18:13:39 +01002140 layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
2141
2142 // register the input connection slots for the layer, connections are made after all layers have been created
2143 RegisterInputSlots(layer, {node.input(0)});
2144
2145 // register the output connection slots for the layer, connections are made after all layers have been created
2146 RegisterOutputSlots(layer, {node.output(0)});
2147}
2148
Kevin Mayef33cb12021-01-29 14:24:57 +00002149void OnnxParserImpl::ParseReshape(const onnx::NodeProto& node)
Ryan OSheaed27ee72020-04-22 16:37:29 +01002150{
2151 CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 2);
2152 CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
2153
2154 CHECK_VALID_DATATYPE(node.name(), node.input(0),
2155 m_TensorsInfo[node.input(0)].m_dtype,
2156 onnx::TensorProto::FLOAT); //input
2157 CHECK_VALID_DATATYPE(node.name(), node.input(1),
2158 m_TensorsInfo[node.input(1)].m_dtype,
2159 onnx::TensorProto::INT64); //shape
2160
Narumol Prangnawarat4b536e32021-10-18 12:35:19 +01002161 TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape();
2162
2163 std::vector<unsigned int> targetShape;
2164 if(m_TensorsInfo[node.input(1)].isConstant())
Ryan OSheaed27ee72020-04-22 16:37:29 +01002165 {
Narumol Prangnawarat4b536e32021-10-18 12:35:19 +01002166 unsigned int dims = static_cast<unsigned int>(m_TensorsInfo[node.input(1)].m_tensor->int64_data_size());
2167 targetShape.reserve(dims);
2168
2169 for(uint i = 0; i < dims; i++)
2170 {
2171 int val = CHECKED_INT32(m_TensorsInfo[node.input(1)].m_tensor->int64_data(static_cast<int>(i)));
2172 targetShape[i]= static_cast<unsigned int>(val);
2173 }
2174 }
2175 else
2176 {
2177 // The parser only supports shape (batch, -1) or (-1) for non-constant shape input.
2178 unsigned int dims = m_TensorsInfo[node.input(1)].m_info->GetNumDimensions();
2179 TensorShape shapes = m_TensorsInfo[node.input(1)].m_info->GetShape();
2180 if (dims != 1 || shapes[0] > 2)
2181 {
2182 throw ParseException(fmt::format("Invalid input shape '{}' in Reshape layer '{}' {}",
2183 node.input(1),
2184 node.name(),
2185 CHECK_LOCATION().AsString()));
2186 }
2187
2188 unsigned int numInputElements = m_TensorsInfo[node.input(0)].m_info->GetNumElements();
2189 if (shapes[0] == 1)
2190 {
2191 targetShape = { numInputElements };
2192 }
2193 else if (shapes[0] == 2)
2194 {
2195 targetShape = { inputShape[0] , numInputElements / inputShape[0] };
2196 }
Ryan OSheaed27ee72020-04-22 16:37:29 +01002197 }
2198
2199 if(m_TensorsInfo[node.input(0)].isConstant())
2200 {
2201 //make a new cst tensor -> move the data to the output tensor (the shape is already good in the output tensor)
2202 if(m_TensorsInfo.count(node.output(0)) == 0)
2203 {
2204 m_TensorsInfo[node.output(0)] = OnnxTensor();
2205 }
2206 m_TensorsInfo[node.output(0)].m_tensor =
2207 std::make_unique<onnx::TensorProto>(*m_TensorsInfo[node.input(0)].m_tensor);
2208 }
2209 else
2210 {
Ryan OSheaed27ee72020-04-22 16:37:29 +01002211 if(m_TensorsInfo.count(node.output(0)) == 0 || m_TensorsInfo[node.output(0)].m_info == nullptr)
2212 {
Narumol Prangnawarat4b536e32021-10-18 12:35:19 +01002213 auto outInfo = ComputeReshapeInfo(
2214 TensorShape(static_cast<unsigned int>(targetShape.size()), targetShape.data()),
2215 inputShape, node.output(0));
Ryan OSheaed27ee72020-04-22 16:37:29 +01002216 m_TensorsInfo[node.output(0)].m_info = std::make_unique<TensorInfo>(outInfo);
2217 }
2218
2219 CreateReshapeLayer(node.input(0), node.output(0), node.name());
2220 }
2221}
2222
Narumol Prangnawaratfe6aa2f2021-09-23 16:11:17 +01002223void OnnxParserImpl::ParseUnsqueeze(const onnx::NodeProto& node)
2224{
2225 CHECK_VALID_SIZE(armnn::numeric_cast<size_t>(node.input_size()), 1, 2);
2226 CHECK_VALID_SIZE(armnn::numeric_cast<size_t>(node.output_size()), 1);
2227
Narumol Prangnawaratfe6aa2f2021-09-23 16:11:17 +01002228 TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape();
2229 std::vector<uint32_t> dims;
2230 if (node.input_size() == 1 && node.attribute_size() > 0)
2231 {
2232 dims = ReadMandatoryNodeUint32ListAttribute(node, "axes");
2233 }
2234 else
2235 {
2236 CHECK_VALID_DATATYPE(node.name(), node.input(1),
2237 m_TensorsInfo[node.input(1)].m_dtype,
2238 onnx::TensorProto::INT64); //axes
2239
2240 auto int64Axes = m_TensorsInfo[node.input(1)].m_tensor->int64_data().data();
2241 uint numDim = armnn::numeric_cast<uint>(m_TensorsInfo[node.input(1)].m_tensor->int64_data_size());
2242
2243 for(uint i = 0; i < numDim; i++)
2244 {
2245 uint32_t uint32Value = CHECKED_NON_NEGATIVE(CHECKED_INT32(int64Axes[i]));
2246 dims.push_back(uint32Value);
2247 }
2248 }
2249
2250 // Ensure that the axes are sorted
2251 std::sort(dims.begin(), dims.end());
2252
2253 std::vector<unsigned int> targetShape;
2254
Narumol Prangnawarat452274c2021-09-23 16:12:19 +01002255 if (inputShape.GetDimensionality() != Dimensionality::Scalar)
Narumol Prangnawaratfe6aa2f2021-09-23 16:11:17 +01002256 {
Narumol Prangnawarat452274c2021-09-23 16:12:19 +01002257 for(uint i = 0; i < inputShape.GetNumDimensions(); i++)
2258 {
2259 targetShape.push_back(inputShape[i]);
2260 }
Narumol Prangnawaratfe6aa2f2021-09-23 16:11:17 +01002261 }
2262
2263 for(uint i = 0; i < dims.size(); i++)
2264 {
2265 targetShape.insert(targetShape.begin() + armnn::numeric_cast<int>(dims[i]), 1);
2266 }
2267
Narumol Prangnawarat452274c2021-09-23 16:12:19 +01002268 auto outInfo = ComputeReshapeInfo(TensorShape(static_cast<unsigned int>(targetShape.size()), targetShape.data()),
2269 inputShape, node.output(0), m_TensorsInfo[node.input(0)].m_info->GetDataType());
Narumol Prangnawaratfe6aa2f2021-09-23 16:11:17 +01002270 m_TensorsInfo[node.output(0)].m_info = std::make_unique<TensorInfo>(outInfo);
Narumol Prangnawarat452274c2021-09-23 16:12:19 +01002271 m_TensorsInfo[node.output(0)].m_dtype = m_TensorsInfo[node.input(0)].m_dtype;
Narumol Prangnawaratfe6aa2f2021-09-23 16:11:17 +01002272
2273 CreateReshapeLayer(node.input(0), node.output(0), node.name());
2274}
2275
Kevin Mayef33cb12021-01-29 14:24:57 +00002276void OnnxParserImpl::PrependForBroadcast(const std::string& outputName,
2277 const std::string& input0,
2278 const std::string& input1)
telsoa01c577f2c2018-08-31 09:22:23 +01002279{
2280 //input0 should be reshaped to have same number of dim as input1
2281 TensorInfo outputTensorInfo = TensorInfo(*m_TensorsInfo[input0].m_info);
2282
2283 TensorShape input0Shape = m_TensorsInfo[input0].m_info->GetShape();
2284 TensorShape input1Shape = m_TensorsInfo[input1].m_info->GetShape();
2285
2286 uint32_t diff = input1Shape.GetNumDimensions() - input0Shape.GetNumDimensions();
2287 std::vector<uint32_t> newShape;
2288 while(diff > 0)
2289 {
2290 newShape.push_back(1);
2291 diff--;
2292 }
2293 for (uint dim = 0; dim < input0Shape.GetNumDimensions(); ++dim)
2294 {
2295 newShape.push_back(input0Shape[dim]);
2296 }
2297 outputTensorInfo.SetShape(TensorShape(static_cast<unsigned int>(newShape.size()), newShape.data()));
2298
2299 //add the new tensor to m_TensorsInfo
2300 m_TensorsInfo[outputName] = OnnxTensor();
2301 m_TensorsInfo[outputName].m_info = std::make_unique<TensorInfo>(outputTensorInfo);
2302
2303 //add reshape layer if the parent was not constant...
2304 if( ! m_TensorsInfo[input0].isConstant())
2305 {
James Ward58dec6b2020-09-11 17:32:44 +01002306 CreateReshapeLayer(input0, outputName, fmt::format("Add:reshapeOf{}", input0));
telsoa01c577f2c2018-08-31 09:22:23 +01002307 }
2308 else //make it constant and it will be create in Add
2309 {
2310 m_TensorsInfo[outputName].m_tensor = std::make_unique<onnx::TensorProto>(*m_TensorsInfo[input0].m_tensor);
2311
2312 }
2313}
2314
Kevin Mayef33cb12021-01-29 14:24:57 +00002315void OnnxParserImpl::SetupInputLayers()
telsoa01c577f2c2018-08-31 09:22:23 +01002316{
2317 //Find user input and add their layers
2318 for(int inputIndex = 0; inputIndex < m_Graph->input_size(); ++inputIndex)
2319 {
2320 auto input = m_Graph->input(inputIndex);
Narumol Prangnawarat1b11f322021-10-13 11:44:50 +01002321 if (!m_TensorsInfo[input.name()].isConstant())
telsoa01c577f2c2018-08-31 09:22:23 +01002322 {
2323 IConnectableLayer* layer =
Narumol Prangnawarat1b11f322021-10-13 11:44:50 +01002324 m_Network->AddInputLayer(static_cast<armnn::LayerBindingId>(inputIndex), input.name().c_str());
2325 TensorInfo tensorInfo = *m_TensorsInfo[input.name()].m_info;
2326 if (tensorInfo.GetShape().GetDimensionality() == Dimensionality::NotSpecified)
2327 {
2328 if (m_InputShapes.find(input.name()) == m_InputShapes.end())
2329 {
2330 throw ParseException(fmt::format("The parser does not support dynamic tensor, "
2331 "please specify input shape for {}. {}",
2332 input.name(),
2333 CHECK_LOCATION().AsString()));
2334 }
2335 else
2336 {
2337 tensorInfo.SetShape(m_InputShapes[input.name()]);
2338 m_TensorsInfo[input.name()].m_info = std::make_unique<TensorInfo>(tensorInfo);
2339 }
2340
2341 }
telsoa01c577f2c2018-08-31 09:22:23 +01002342 layer->GetOutputSlot(0).SetTensorInfo(tensorInfo);
2343
Narumol Prangnawarat1b11f322021-10-13 11:44:50 +01002344 m_InputInfos[input.name()] = tensorInfo;
2345
telsoa01c577f2c2018-08-31 09:22:23 +01002346 RegisterOutputSlots(layer,{ input.name() });
2347 }
2348 }
2349}
2350
Kevin Mayef33cb12021-01-29 14:24:57 +00002351void OnnxParserImpl::SetupOutputLayers()
telsoa01c577f2c2018-08-31 09:22:23 +01002352{
2353 if(m_Graph->output_size() == 0)
2354 {
James Ward58dec6b2020-09-11 17:32:44 +01002355 throw ParseException(fmt::format("The given model does not have any outputs {}", CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01002356 }
2357
2358 for(int outputIndex = 0; outputIndex < m_Graph->output_size(); ++outputIndex)
2359 {
2360 IConnectableLayer* layer =
2361 m_Network->AddOutputLayer(static_cast<armnn::LayerBindingId>(outputIndex),
2362 m_Graph->output(outputIndex).name().c_str());
2363
2364 RegisterInputSlots(layer, { m_Graph->output(outputIndex).name() });
2365 }
2366}
2367
Narumol Prangnawarat1112b012021-09-30 12:10:50 +01002368void OnnxParserImpl::RegisterInputSlot(IConnectableLayer* layer,
2369 const std::string& tensorId,
2370 unsigned int slotIndex)
2371{
2372 armnn::IInputSlot* slot = &(layer->GetInputSlot(slotIndex));
2373
2374 auto it = m_TensorConnections.find(tensorId);
2375
2376 if (it == m_TensorConnections.end())
2377 {
Narumol Prangnawarat1b11f322021-10-13 11:44:50 +01002378 //First time seeing this tensor, we need to map it
Narumol Prangnawarat1112b012021-09-30 12:10:50 +01002379 m_TensorConnections[tensorId] = TensorSlots();
2380 }
2381 m_TensorConnections[tensorId].inputSlots.push_back(slot);
2382}
2383
Kevin Mayef33cb12021-01-29 14:24:57 +00002384void OnnxParserImpl::RegisterInputSlots(IConnectableLayer* layer, const std::vector<std::string>& tensorIds)
telsoa01c577f2c2018-08-31 09:22:23 +01002385{
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01002386 ARMNN_ASSERT(layer != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +01002387 if (tensorIds.size() != layer->GetNumInputSlots())
2388 {
2389 throw ParseException(
James Ward58dec6b2020-09-11 17:32:44 +01002390 fmt::format("The number of tensor inputs ({}) does not match the number expected ({}) {}",
2391 tensorIds.size(),
2392 layer->GetNumInputSlots(),
2393 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01002394 }
Matthew Sloyan81beae32021-07-13 19:46:11 +01002395
telsoa01c577f2c2018-08-31 09:22:23 +01002396 for (unsigned int slotIndex = 0; slotIndex < layer->GetNumInputSlots(); ++slotIndex)
2397 {
2398 std::string tensorId = tensorIds[slotIndex];
2399 armnn::IInputSlot* slot = &(layer->GetInputSlot(slotIndex));
2400
2401 auto it = m_TensorConnections.find(tensorId);
2402
2403 if (it == m_TensorConnections.end())
2404 {
Narumol Prangnawarat1b11f322021-10-13 11:44:50 +01002405 // First time seing this tensor, we need to map it
telsoa01c577f2c2018-08-31 09:22:23 +01002406 m_TensorConnections[tensorId] = TensorSlots();
2407 }
2408 m_TensorConnections[tensorId].inputSlots.push_back(slot);
2409 }
2410}
2411
Kevin Mayef33cb12021-01-29 14:24:57 +00002412void OnnxParserImpl::RegisterOutputSlots(IConnectableLayer* layer, const std::vector<std::string>& tensorIds)
telsoa01c577f2c2018-08-31 09:22:23 +01002413{
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01002414 ARMNN_ASSERT(layer != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +01002415 if (tensorIds.size() != layer->GetNumOutputSlots())
2416 {
2417 throw ParseException(
James Ward58dec6b2020-09-11 17:32:44 +01002418 fmt::format("The number of tensor outputs ({}) does not match the number expected ({}) {} ",
2419 tensorIds.size(),
2420 layer->GetNumOutputSlots(),
2421 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01002422 }
2423
2424 for (unsigned int slotIndex = 0; slotIndex < layer->GetNumOutputSlots(); ++slotIndex)
2425 {
2426 std::string tensorId = tensorIds[slotIndex];
2427 armnn::IOutputSlot* slot = &(layer->GetOutputSlot(slotIndex));
2428
2429 auto it = m_TensorConnections.find(tensorId);
2430
2431 if (it == m_TensorConnections.end())
2432 {
2433 //First time seing this tensor, we need to map it
2434 m_TensorConnections[tensorId] = TensorSlots();
2435 }
2436
Ryan OShea337c17f2020-02-21 12:33:17 +00002437 TensorSlots& tensorSlots = m_TensorConnections[tensorId];
telsoa01c577f2c2018-08-31 09:22:23 +01002438
2439 // assuming there is only one producer for that tensor
2440 if (tensorSlots.outputSlot != nullptr)
2441 {
James Ward58dec6b2020-09-11 17:32:44 +01002442 throw ParseException(fmt::format("Another layer has already registered itself as the producer of "
2443 "tensor:{} {}",
2444 tensorId,
2445 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01002446 }
2447 tensorSlots.outputSlot = slot;
2448 }
Narumol Prangnawarat1b11f322021-10-13 11:44:50 +01002449
telsoa01c577f2c2018-08-31 09:22:23 +01002450}
2451
Kevin Mayef33cb12021-01-29 14:24:57 +00002452BindingPointInfo OnnxParserImpl::GetNetworkInputBindingInfo(const std::string& name) const
telsoa01c577f2c2018-08-31 09:22:23 +01002453{
2454 for(int i = 0; i < m_Graph->input_size(); ++i)
2455 {
2456 auto input = m_Graph->input(i);
2457 if(input.name() == name)
2458 {
Narumol Prangnawarat1b11f322021-10-13 11:44:50 +01002459 auto it = m_InputInfos.find(name);
2460
2461 if (it != m_InputInfos.end())
2462 {
2463 return std::make_pair(static_cast<armnn::LayerBindingId>(i), it->second);
2464 }
telsoa01c577f2c2018-08-31 09:22:23 +01002465 }
2466 }
James Ward58dec6b2020-09-11 17:32:44 +01002467 throw InvalidArgumentException(fmt::format("The input layer '{}' does not exist {}",
2468 name, CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01002469}
2470
Kevin Mayef33cb12021-01-29 14:24:57 +00002471BindingPointInfo OnnxParserImpl::GetNetworkOutputBindingInfo(const std::string& name) const
telsoa01c577f2c2018-08-31 09:22:23 +01002472{
2473 for(int i = 0; i < m_Graph->output_size(); ++i)
2474 {
2475 auto output = m_Graph->output(i);
2476 if(output.name() == name)
2477 {
Narumol Prangnawarat1b11f322021-10-13 11:44:50 +01002478 auto it = m_OutputInfos.find(name);
2479
2480 if (it != m_OutputInfos.end())
2481 {
2482 return std::make_pair(static_cast<armnn::LayerBindingId>(i), it->second);
2483 }
telsoa01c577f2c2018-08-31 09:22:23 +01002484 }
2485 }
James Ward58dec6b2020-09-11 17:32:44 +01002486 throw InvalidArgumentException(fmt::format("The output layer '{}' does not exist {}",
2487 name, CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01002488}
2489
Kevin Mayef33cb12021-01-29 14:24:57 +00002490std::vector<std::string> OnnxParserImpl::GetInputs(ModelPtr& model)
telsoa01c577f2c2018-08-31 09:22:23 +01002491{
2492 if(model == nullptr) {
James Ward58dec6b2020-09-11 17:32:44 +01002493 throw InvalidArgumentException(fmt::format("The given model cannot be null {}",
2494 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01002495 }
2496
2497 std::vector<std::string> inputNames;
2498 std::map<std::string, bool> isConstant;
2499 for(auto tensor : model->graph().initializer())
2500 {
2501 isConstant[tensor.name()] = true;
2502 }
2503 for(auto input : model->graph().input())
2504 {
2505 auto it = isConstant.find(input.name());
2506 if(it == isConstant.end())
2507 {
2508 inputNames.push_back(input.name());
2509 }
2510 }
2511 return inputNames;
2512}
2513
Kevin Mayef33cb12021-01-29 14:24:57 +00002514std::vector<std::string> OnnxParserImpl::GetOutputs(ModelPtr& model)
telsoa01c577f2c2018-08-31 09:22:23 +01002515{
2516 if(model == nullptr) {
James Ward58dec6b2020-09-11 17:32:44 +01002517 throw InvalidArgumentException(fmt::format("The given model cannot be null {}",
2518 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01002519 }
2520
2521 std::vector<std::string> outputNames;
2522 for(auto output : model->graph().output())
2523 {
2524 outputNames.push_back(output.name());
2525 }
2526 return outputNames;
2527}
2528
Matthew Sloyanac001ee2021-02-03 10:43:04 +00002529const std::string OnnxParserImpl::GetVersion()
2530{
2531 return ONNX_PARSER_VERSION;
2532}
2533
telsoa01c577f2c2018-08-31 09:22:23 +01002534} // namespace armnnOnnxParser