blob: 26e2ceecc1e030a78dfeb061821eb55038001af1 [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
Matthew Sloyanca361232023-02-16 14:50:22 +00002// Copyright © 2017,2022-2023 Arm Ltd and Contributors. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5#include "OnnxParser.hpp"
6
Matthew Sloyanac001ee2021-02-03 10:43:04 +00007#include "armnnOnnxParser/Version.hpp"
8
Matthew Bentham39ef3e52020-01-20 10:09:09 +00009#include <armnn/Descriptors.hpp>
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010010#include <armnn/utility/Assert.hpp>
Matthew Sloyan589e3e82020-09-11 16:17:48 +010011#include <armnn/utility/NumericCast.hpp>
Narumol Prangnawaratbc3bb622021-09-24 16:08:34 +010012#include <ParserHelper.hpp>
telsoa01c577f2c2018-08-31 09:22:23 +010013#include <VerificationHelpers.hpp>
14
James Ward58dec6b2020-09-11 17:32:44 +010015#include <fmt/format.h>
Aron Virginas-Tard4f0fea2019-04-09 14:08:06 +010016
telsoa01c577f2c2018-08-31 09:22:23 +010017#include <google/protobuf/text_format.h>
18#include <google/protobuf/io/zero_copy_stream_impl.h>
19
Matthew Sloyanac001ee2021-02-03 10:43:04 +000020#include <iostream>
telsoa01c577f2c2018-08-31 09:22:23 +010021#include <numeric>
Jan Eilers53ef7952021-06-02 12:01:25 +010022#include <armnnUtils/Permute.hpp>
telsoa01c577f2c2018-08-31 09:22:23 +010023
24using namespace armnn;
25
26namespace armnnOnnxParser
27{
Kevin Mayef33cb12021-01-29 14:24:57 +000028
29IOnnxParser::IOnnxParser() : pOnnxParserImpl(new OnnxParserImpl()) {}
30
31IOnnxParser::~IOnnxParser() = default;
32
33IOnnxParser* IOnnxParser::CreateRaw()
34{
35 return new IOnnxParser();
36}
37
38IOnnxParserPtr IOnnxParser::Create()
39{
40 return IOnnxParserPtr(CreateRaw(), &IOnnxParser::Destroy);
41}
42
43void IOnnxParser::Destroy(IOnnxParser* parser)
44{
45 delete parser;
46}
47
48armnn::INetworkPtr IOnnxParser::CreateNetworkFromBinaryFile(const char* graphFile)
49{
50 return pOnnxParserImpl->CreateNetworkFromBinaryFile(graphFile);
51}
52
Mike Kelly2ae32242022-11-25 13:55:24 +000053armnn::INetworkPtr IOnnxParser::CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent)
54{
55 return pOnnxParserImpl->CreateNetworkFromBinary(binaryContent);
56}
57
58armnn::INetworkPtr IOnnxParser::CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent,
59 const std::map<std::string, armnn::TensorShape>& inputShapes)
60{
61 return pOnnxParserImpl->CreateNetworkFromBinary(binaryContent, inputShapes);
62}
63
Kevin Mayef33cb12021-01-29 14:24:57 +000064armnn::INetworkPtr IOnnxParser::CreateNetworkFromTextFile(const char* graphFile)
65{
66 return pOnnxParserImpl->CreateNetworkFromTextFile(graphFile);
67}
68
69armnn::INetworkPtr IOnnxParser::CreateNetworkFromString(const std::string& protoText)
70{
71 return pOnnxParserImpl->CreateNetworkFromString(protoText);
72}
73
Narumol Prangnawarat1b11f322021-10-13 11:44:50 +010074armnn::INetworkPtr IOnnxParser::CreateNetworkFromBinaryFile(
75 const char* graphFile,
76 const std::map<std::string, armnn::TensorShape>& inputShapes)
77{
78 return pOnnxParserImpl->CreateNetworkFromBinaryFile(graphFile, inputShapes);
79}
80
81armnn::INetworkPtr IOnnxParser::CreateNetworkFromTextFile(const char* graphFile,
82 const std::map<std::string, armnn::TensorShape>& inputShapes)
83{
84 return pOnnxParserImpl->CreateNetworkFromTextFile(graphFile, inputShapes);
85}
86
87armnn::INetworkPtr IOnnxParser::CreateNetworkFromString(const std::string& protoText,
88 const std::map<std::string, armnn::TensorShape>& inputShapes)
89{
90 return pOnnxParserImpl->CreateNetworkFromString(protoText, inputShapes);
91}
92
Kevin Mayef33cb12021-01-29 14:24:57 +000093BindingPointInfo IOnnxParser::GetNetworkInputBindingInfo(const std::string& name) const
94{
95 return pOnnxParserImpl->GetNetworkInputBindingInfo(name);
96}
97
98BindingPointInfo IOnnxParser::GetNetworkOutputBindingInfo(const std::string& name) const
99{
100 return pOnnxParserImpl->GetNetworkOutputBindingInfo(name);
101}
102
telsoa01c577f2c2018-08-31 09:22:23 +0100103namespace
104{
105void CheckValidDataType(std::initializer_list<onnx::TensorProto::DataType> validInputTypes,
106 const onnx::TensorProto::DataType actualValue,
107 const char* validExpr,
108 std::string nodeName,
109 std::string tensorName,
110 const armnn::CheckLocation& location)
111{
112 bool isValid = std::any_of(validInputTypes.begin(),
113 validInputTypes.end(),
114 [&actualValue](onnx::TensorProto::DataType x) { return x == actualValue; } );
115 if (!isValid)
116 {
117 throw ParseException(
James Ward58dec6b2020-09-11 17:32:44 +0100118 fmt::format("Datatype {} is not valid for tensor '{}' of node '{}', not in {{{}}}. {}",
119 onnx::TensorProto::DataType_Name(actualValue),
120 tensorName,
121 nodeName,
122 validExpr,
123 location.AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100124 }
125}
126
127#define CHECK_VALID_DATATYPE(NODE, TENSOR, ACTUAL, ...) \
128CheckValidDataType({__VA_ARGS__}, ACTUAL, #__VA_ARGS__, NODE, TENSOR, CHECK_LOCATION())
129
130using StrTypeListPair = std::pair<const char*, std::initializer_list<onnx::TensorProto::DataType>>;
131#define STR_LIST(...) StrTypeListPair(#__VA_ARGS__, {__VA_ARGS__})
132
133template <typename Callable>
134void ReadMandatoryNodeAttributeImpl(const onnx::NodeProto& node,
135 const std::string& attribName,
136 onnx::AttributeProto::AttributeType expectedType,
137 Callable callable)
138{
139 auto attribs = node.attribute();
140 int attriNum = 0;
141 while (attriNum < node.attribute_size())
142 {
143 if (attribs.Get(attriNum).name() == attribName)
144 {
145 if (attribs.Get(attriNum).type() == expectedType)
146 {
147 callable(attribs.Get(attriNum));
148 }
149 else
150 {
James Ward58dec6b2020-09-11 17:32:44 +0100151 throw ParseException(fmt::format("Attribute {} of node {} expected to have {} as "
152 "onnx::AttributeProto::AttributeType, but found {} instead {}",
153 attribName,
154 node.name(),
155 onnx::AttributeProto::AttributeType_Name(expectedType),
156 onnx::AttributeProto::AttributeType_Name(attribs.Get(attriNum).type()),
157 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100158 }
159 break;
160 }
161 ++attriNum;
162 }
163 if (attriNum == node.attribute_size())
164 {
James Ward58dec6b2020-09-11 17:32:44 +0100165 throw ParseException(fmt::format("Could not find required attribute {} in node {} {}",
166 attribName, node.name(), CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100167 }
168}
169
170template <typename Callable>
171void ReadOptionalNodeAttributeImpl(const onnx::NodeProto& node,
172 const std::string& attribName,
173 onnx::AttributeProto::AttributeType expectedType,
174 Callable callable)
175{
176 auto attribs = node.attribute();
177 for (int attriNum = 0; attriNum < node.attribute_size(); ++attriNum)
178 {
179 if (attribs.Get(attriNum).name() == attribName)
180 {
181 if (attribs.Get(attriNum).type() == expectedType)
182 {
183 callable(attribs.Get(attriNum));
184 }
185 else
186 {
James Ward58dec6b2020-09-11 17:32:44 +0100187 throw ParseException(
188 fmt::format("Attribute {} of node {} expected to have {} as onnx::AttributeProto::AttributeType, "
189 "but found {} instead {}",
190 attribName,
191 node.name(),
192 onnx::AttributeProto::AttributeType_Name(expectedType),
193 onnx::AttributeProto::AttributeType_Name(attribs.Get(attriNum).type()),
194 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100195 }
196 }
197 }
198}
199
Narumol Prangnawaratbc3bb622021-09-24 16:08:34 +0100200int ReadMandatoryNodeIntAttribute(const onnx::NodeProto& node,
201 const std::string& name)
202{
203 int attribValue = 0;
204 ReadMandatoryNodeAttributeImpl(node, name, onnx::AttributeProto::INT,
205 [&attribValue](const onnx::AttributeProto& attrValue)
206 {
207 attribValue = CHECKED_INT32(attrValue.i());
208 });
209 return attribValue;
210}
211
Ryan OSheaed27ee72020-04-22 16:37:29 +0100212int64_t ReadOptionalNodeInt64Attribute(const onnx::NodeProto& node,
213 const std::string& name,
214 const int64_t defaultValue = 0)
215{
216 int64_t attribValue = defaultValue;
217 ReadOptionalNodeAttributeImpl(node, name, onnx::AttributeProto::INT,
218 [&attribValue](const onnx::AttributeProto& attrValue)
219 {
220 attribValue = attrValue.i();
221 });
222 return attribValue;
223}
224
telsoa01c577f2c2018-08-31 09:22:23 +0100225std::vector<uint32_t> ReadMandatoryNodeUint32ListAttribute(const onnx::NodeProto& node,
226 const std::string& name)
227{
228 std::vector<uint32_t> attriList;
229 ReadMandatoryNodeAttributeImpl(node, name, onnx::AttributeProto::INTS,
230 [&attriList](const onnx::AttributeProto& attrValue)
231 {
232 for (int attriNum = 0; attriNum < attrValue.ints_size(); ++attriNum)
233 {
234 attriList.push_back(CHECKED_NON_NEGATIVE(CHECKED_INT32(attrValue.ints().Get(attriNum))));
235 }
236 });
237 return attriList;
238}
239
240uint32_t ReadOptionalNodeUint32Attribute(const onnx::NodeProto& node,
241 const std::string& name,
242 const uint32_t defaultVal = 0u)
243{
244 uint32_t attribValue = defaultVal;
245 ReadOptionalNodeAttributeImpl(node, name, onnx::AttributeProto::INT,
246 [&attribValue](const onnx::AttributeProto& attrValue)
247 {
248 attribValue = CHECKED_NON_NEGATIVE(CHECKED_INT32((attrValue.i())));
249 });
250 return attribValue;
251}
252
253std::vector<uint32_t> ReadOptionalNodeUint32ListAttribute(const onnx::NodeProto& node,
254 const std::string& name)
255{
256 std::vector<uint32_t> attriList;
257 ReadOptionalNodeAttributeImpl(node, name, onnx::AttributeProto::INTS,
258 [&attriList](const onnx::AttributeProto& attrValue)
259 {
260 for (int attriNum = 0; attriNum < attrValue.ints_size(); ++attriNum)
261 {
262 attriList.push_back(CHECKED_NON_NEGATIVE(CHECKED_INT32(attrValue.ints().Get(attriNum))));
263 }
264 });
265
266 return attriList;
267}
268
269float ReadOptionalNodeFloatAttribute(const onnx::NodeProto& node,
270 const std::string& name,
271 const float defaultValue = 0.0f)
272{
273 float attribValue = defaultValue;
274 ReadOptionalNodeAttributeImpl(node, name, onnx::AttributeProto::FLOAT,
275 [&attribValue](const onnx::AttributeProto& attrValue)
276 {
277 attribValue = attrValue.f();
278 });
279 return attribValue;
280}
281
282std::string ReadOptionalNodeStringAttribute(const onnx::NodeProto& node, const std::string& name)
283{
284 std::string attribValue = "";
285 ReadOptionalNodeAttributeImpl(node, name, onnx::AttributeProto::STRING,
286 [&attribValue](const onnx::AttributeProto& attrValue)
287 {
288 attribValue = attrValue.s();
289 });
290 return attribValue;
291}
292
Tee Jungfcf6fd52019-11-01 05:27:28 +0000293armnn::TensorInfo ToTensorInfo(const std::string& name, std::vector<unsigned int>& shape, int data_type)
telsoa01c577f2c2018-08-31 09:22:23 +0100294{
Narumol Prangnawarat452274c2021-09-23 16:12:19 +0100295 DataType type;
296 switch(data_type)
297 {
298 case onnx::TensorProto::FLOAT:
299 {
300 type = DataType::Float32;
telsoa01c577f2c2018-08-31 09:22:23 +0100301 break;
Narumol Prangnawarat452274c2021-09-23 16:12:19 +0100302 }
303 case onnx::TensorProto::INT32:
304 case onnx::TensorProto::INT64:
305 {
306 type = DataType::Signed32;
307 break;
308 }
309 default:
310 {
311 throw ParseException(
312 fmt::format("'{}' is not a currently supported datatype for tensor {}."
313 " Supported dataTypes are FLOAT, INT32 and INT64. {}",
314 onnx::TensorProto::DataType_Name(static_cast<onnx::TensorProto::DataType>(data_type)),
315 name,
316 CHECK_LOCATION().AsString() ));
317 }
318 }
Tee Jungcaf2bdd2019-11-13 07:23:14 +0000319
Narumol Prangnawarat1b11f322021-10-13 11:44:50 +0100320 // Scalar Tensor
Narumol Prangnawarat452274c2021-09-23 16:12:19 +0100321 if (shape.empty())
322 {
323 return TensorInfo(TensorShape(Dimensionality::Scalar), type);
324 }
Tee Jungcaf2bdd2019-11-13 07:23:14 +0000325
Narumol Prangnawarat1b11f322021-10-13 11:44:50 +0100326 // Dynamic Tensor
327 if(std::find(shape.begin(), shape.end(), 0) != shape.end())
328 {
329 return TensorInfo(TensorShape(Dimensionality::NotSpecified), type);
330 }
331
Narumol Prangnawarat452274c2021-09-23 16:12:19 +0100332 return TensorInfo(TensorShape(static_cast<unsigned int>(shape.size()), shape.data()), type);
Tee Jungfcf6fd52019-11-01 05:27:28 +0000333}
334
335armnn::TensorInfo ToTensorInfo(const onnx::ValueInfoProto& info)
336{
337 const onnx::TensorShapeProto onnxShape = info.type().tensor_type().shape();
338 std::vector<unsigned int> shapeDims;
339 for (int i = 0; i < onnxShape.dim_size(); ++i)
340 {
341 shapeDims.push_back(CHECKED_NON_NEGATIVE(CHECKED_INT32(onnxShape.dim(i).dim_value())));
342 }
343
344 return ToTensorInfo(info.name(), shapeDims, info.type().tensor_type().elem_type());
345}
346
347armnn::TensorInfo ToTensorInfo(const onnx::TensorProto& tensor)
348{
349 std::vector<unsigned int> shapeDims;
Ryan OShea337c17f2020-02-21 12:33:17 +0000350
Tee Jungfcf6fd52019-11-01 05:27:28 +0000351 for (auto dim: tensor.dims())
352 {
353 shapeDims.push_back(CHECKED_NON_NEGATIVE(CHECKED_INT32(dim)));
354 }
355
356 return ToTensorInfo(tensor.name(), shapeDims, tensor.data_type());
telsoa01c577f2c2018-08-31 09:22:23 +0100357}
358
359std::string TensorInfoAsString(const TensorInfo& info,
360 const std::string& name,
361 const onnx::TensorProto::DataType& type)
362{
363 const TensorShape shape = info.GetShape();
364 std::stringstream ss;
365 ss << "tensor '" << name << "' contains "
366 << onnx::TensorProto::DataType_Name(type)
367 << " and has shape [";
368
369 for (uint32_t i = 0; i < shape.GetNumDimensions() - 1; ++i)
370 {
371 ss << shape[i] << ", ";
372 }
373 ss << shape[shape.GetNumDimensions() - 1] << "]";
374 return ss.str();
375}
376
Sadik Armagan60bb9d82021-01-11 15:15:01 +0000377void CalcPadding(uint32_t inputSize,
378 uint32_t filterSize,
379 uint32_t stride,
380 uint32_t dilation,
381 uint32_t* paddingFront,
382 uint32_t* paddingBack,
383 bool isUpper)
telsoa01c577f2c2018-08-31 09:22:23 +0100384{
385 uint32_t outputSize = (inputSize + stride - 1) / stride;
Sadik Armagan60bb9d82021-01-11 15:15:01 +0000386 uint32_t dilatedSize = filterSize + (dilation - 1) * (filterSize - 1);
387 uint32_t temp = (outputSize - 1) * stride + dilatedSize;
telsoa01c577f2c2018-08-31 09:22:23 +0100388 *paddingFront = (temp - inputSize) / 2;
389 *paddingBack = *paddingFront;
390 if((temp - inputSize) % 2 == 1)
391 {
392 if (isUpper)
393 {
Sadik Armagan60bb9d82021-01-11 15:15:01 +0000394 *paddingBack += 1;
telsoa01c577f2c2018-08-31 09:22:23 +0100395 }
396 else
397 {
Sadik Armagan60bb9d82021-01-11 15:15:01 +0000398 *paddingFront += 1;
telsoa01c577f2c2018-08-31 09:22:23 +0100399 }
400 }
401}
402
Ryan OSheaed27ee72020-04-22 16:37:29 +0100403TensorInfo ComputeReshapeInfo(const TensorShape& targetShapeTensor,
telsoa01c577f2c2018-08-31 09:22:23 +0100404 const TensorShape& inShape,
Narumol Prangnawarat452274c2021-09-23 16:12:19 +0100405 const std::string& outName,
406 DataType dataType = DataType::Float32)
telsoa01c577f2c2018-08-31 09:22:23 +0100407{
408 std::vector<int> targetDims;
Ryan OSheaed27ee72020-04-22 16:37:29 +0100409 for(uint i = 0; i < targetShapeTensor.GetNumDimensions(); ++i)
telsoa01c577f2c2018-08-31 09:22:23 +0100410 {
Ryan OSheaed27ee72020-04-22 16:37:29 +0100411 int val = CHECKED_INT32(targetShapeTensor[i]);
telsoa01c577f2c2018-08-31 09:22:23 +0100412 if(val == 0)
413 {
414 targetDims.push_back(static_cast<int>(inShape[static_cast<uint>(i)]));
415 }
416 else
417 {
418 targetDims.push_back(val);
419 }
420 }
421
422 std::vector<unsigned int> outDims(targetDims.begin(), targetDims.end());
423 const auto stretchDim = std::find(targetDims.begin(), targetDims.end(), -1);
424 if (stretchDim != targetDims.end())
425 {
426 if (std::find(std::next(stretchDim), targetDims.end(), -1) != targetDims.end())
427 {
428 std::stringstream ss;
429 ss << "[ ";
430 for(uint i = 0; i < targetDims.size() - 1; ++i)
431 {
432 ss << targetDims[i] << ", ";
433 }
434 ss << targetDims[targetDims.size() - 1] << " ]";
435
James Ward58dec6b2020-09-11 17:32:44 +0100436 throw ParseException(
437 fmt::format("Error during creation of reshaped tensor '{}'. At most one component of shape can be "
438 " -1 and here, shape is {} {}",
439 outName,
440 ss.str(),
441 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100442 }
443
Matthew Sloyan589e3e82020-09-11 16:17:48 +0100444 auto targetNumElements = armnn::numeric_cast<unsigned int>(std::accumulate(targetDims.begin(), targetDims.end(),
telsoa01c577f2c2018-08-31 09:22:23 +0100445 -1, std::multiplies<int32_t>()));
446 auto stretchIndex = static_cast<size_t>(std::distance(targetDims.begin(), stretchDim));
447 outDims[stretchIndex] = inShape.GetNumElements() / targetNumElements;
448 }
449 TensorShape outShape = TensorShape{static_cast<unsigned int>(outDims.size()), outDims.data()};
Narumol Prangnawarat452274c2021-09-23 16:12:19 +0100450 return TensorInfo(outShape, dataType);
telsoa01c577f2c2018-08-31 09:22:23 +0100451}
452
453} //namespace
454
Kevin Mayef33cb12021-01-29 14:24:57 +0000455const std::map<std::string, OnnxParserImpl::OperationParsingFunction> OnnxParserImpl::m_ParserFunctions = {
456 { "BatchNormalization", &OnnxParserImpl::ParseBatchNormalization},
457 { "GlobalAveragePool", &OnnxParserImpl::ParseGlobalAveragePool},
458 { "AveragePool", &OnnxParserImpl::ParseAveragePool },
459 { "Clip", &OnnxParserImpl::ParseClip },
460 { "Constant", &OnnxParserImpl::ParseConstant },
461 { "MaxPool", &OnnxParserImpl::ParseMaxPool },
462 { "Reshape", &OnnxParserImpl::ParseReshape },
463 { "Sigmoid", &OnnxParserImpl::ParseSigmoid },
464 { "Tanh", &OnnxParserImpl::ParseTanh },
465 { "Relu", &OnnxParserImpl::ParseRelu },
466 { "LeakyRelu", &OnnxParserImpl::ParseLeakyRelu },
467 { "Conv", &OnnxParserImpl::ParseConv },
468 { "Add", &OnnxParserImpl::ParseAdd },
Narumol Prangnawaratcdc495e2021-09-16 18:13:39 +0100469 { "Flatten", &OnnxParserImpl::ParseFlatten },
Narumol Prangnawaratf10b15a2021-09-17 21:08:57 +0100470 { "Shape", &OnnxParserImpl::ParseShape },
471 { "Gather", &OnnxParserImpl::ParseGather },
Narumol Prangnawaratbc3bb622021-09-24 16:08:34 +0100472 { "Unsqueeze", &OnnxParserImpl::ParseUnsqueeze },
Narumol Prangnawarat1112b012021-09-30 12:10:50 +0100473 { "Concat", &OnnxParserImpl::ParseConcat },
474 { "Gemm", &OnnxParserImpl::ParseGemm }
telsoa01c577f2c2018-08-31 09:22:23 +0100475};
476
477template<typename TypePair, typename Location>
Kevin Mayef33cb12021-01-29 14:24:57 +0000478void OnnxParserImpl::ValidateInputs(const onnx::NodeProto& node,
telsoa01c577f2c2018-08-31 09:22:23 +0100479 TypePair validInputs,
480 const Location& location)
481{
482 for(auto input : node.input())
483 {
484 CheckValidDataType(validInputs.second,
485 m_TensorsInfo[input].m_dtype,
486 validInputs.first,
487 node.name(),
488 input,
489 location);
490 }
491}
492
493#define VALID_INPUTS(NODE, VALID_INPUTS) \
Kevin Mayef33cb12021-01-29 14:24:57 +0000494 OnnxParserImpl::ValidateInputs(NODE, \
telsoa01c577f2c2018-08-31 09:22:23 +0100495 VALID_INPUTS, \
496 CHECK_LOCATION())
497
Kevin Mayef33cb12021-01-29 14:24:57 +0000498std::vector<TensorInfo> OnnxParserImpl::ComputeOutputInfo(std::vector<std::string> outNames,
499 const IConnectableLayer* layer,
Narumol Prangnawarat452274c2021-09-23 16:12:19 +0100500 std::vector<TensorShape> inputShapes,
501 const onnx::TensorProto::DataType& dataType)
telsoa01c577f2c2018-08-31 09:22:23 +0100502{
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100503 ARMNN_ASSERT(! outNames.empty());
telsoa01c577f2c2018-08-31 09:22:23 +0100504 bool needCompute = std::any_of(outNames.begin(),
505 outNames.end(),
506 [this](std::string name)
507 {
Matthew Sloyanca361232023-02-16 14:50:22 +0000508 return (m_TensorsInfo.count(name) == 0 ||
509 m_TensorsInfo[name].m_info == nullptr ||
510 m_TensorsInfo[name].m_info->GetShape().GetDimensionality() ==
511 Dimensionality::NotSpecified);
telsoa01c577f2c2018-08-31 09:22:23 +0100512 });
Narumol Prangnawarat452274c2021-09-23 16:12:19 +0100513 std::vector<TensorInfo> outInfo;
514 //if the output info(s) are not here, we need to compute them
515 std::vector<TensorShape> inferredShapes;
516 DataType armnnType = DataType::Float32;
517 if(needCompute) {
518 inferredShapes = layer->InferOutputShapes(inputShapes);
519 ARMNN_ASSERT(inferredShapes.size() == outNames.size());
520 switch (dataType) {
521 case onnx::TensorProto::FLOAT: {
522 armnnType = DataType::Float32;
523 break;
524 }
525 case onnx::TensorProto::INT32:
526 case onnx::TensorProto::INT64: {
527 armnnType = DataType::Signed32;
528 break;
529 }
530 default: {
531 throw ParseException(
532 fmt::format("'{}' is not a currently supported datatype for {}."
533 " Supported dataTypes are FLOAT, INT32 and INT64. {}",
534 onnx::TensorProto::DataType_Name(static_cast<onnx::TensorProto::DataType>(dataType)),
535 layer->GetName(),
536 CHECK_LOCATION().AsString()));
537 }
538 }
539 }
540 for (uint i = 0; i < outNames.size(); ++i)
541 {
542 if(needCompute)
543 {
544 m_TensorsInfo[outNames[i]] = OnnxTensor();
545 m_TensorsInfo[outNames[i]].m_info = std::make_unique<TensorInfo>(
546 TensorInfo(inferredShapes[i], armnnType));
547 m_TensorsInfo[outNames[i]].m_dtype = dataType;
548 }
telsoa01c577f2c2018-08-31 09:22:23 +0100549 outInfo.push_back(*m_TensorsInfo[outNames[i]].m_info);
Narumol Prangnawarat452274c2021-09-23 16:12:19 +0100550 }
551 return outInfo;
telsoa01c577f2c2018-08-31 09:22:23 +0100552}
553
Kevin Mayef33cb12021-01-29 14:24:57 +0000554OnnxParserImpl::OnnxParserImpl()
telsoa01c577f2c2018-08-31 09:22:23 +0100555 : m_Network(nullptr, nullptr)
556{
557}
558
Kevin Mayef33cb12021-01-29 14:24:57 +0000559void OnnxParserImpl::ResetParser()
telsoa01c577f2c2018-08-31 09:22:23 +0100560{
561 m_Network = armnn::INetworkPtr(nullptr, nullptr);
562 m_Graph = nullptr;
Narumol Prangnawarat1b11f322021-10-13 11:44:50 +0100563 m_InputInfos.clear();
564 m_OutputInfos.clear();
telsoa01c577f2c2018-08-31 09:22:23 +0100565}
566
Kevin Mayef33cb12021-01-29 14:24:57 +0000567void OnnxParserImpl::Cleanup()
telsoa01c577f2c2018-08-31 09:22:23 +0100568{
569 m_TensorConnections.clear();
570 m_TensorsInfo.clear();
571 m_OutputsMap.clear();
572 m_OutputsFusedAndUsed.clear();
Narumol Prangnawarat1b11f322021-10-13 11:44:50 +0100573 m_InputShapes.clear();
telsoa01c577f2c2018-08-31 09:22:23 +0100574}
575
Jan Eilers53ef7952021-06-02 12:01:25 +0100576template<typename T>
577std::pair<armnn::ConstTensor, std::unique_ptr<T[]>>
578CreateConstTensorImpl(const T* bufferPtr,
579 armnn::TensorInfo& tensorInfo,
580 const armnn::Optional<armnn::PermutationVector&> permutationVector)
telsoa01c577f2c2018-08-31 09:22:23 +0100581{
Jan Eilers53ef7952021-06-02 12:01:25 +0100582 ARMNN_ASSERT_MSG(bufferPtr != nullptr, fmt::format("Buffer for permutation is null").c_str());
583
584 std::unique_ptr<T[]> data(new T[tensorInfo.GetNumElements()]);
585
586 if (permutationVector.has_value() && permutationVector.value().GetSize() > 0)
587 {
588 tensorInfo = armnnUtils::Permuted(tensorInfo, permutationVector.value());
589 armnnUtils::Permute(tensorInfo.GetShape(), permutationVector.value(),
590 reinterpret_cast<const T*>(bufferPtr), data.get(), sizeof(T));
591 }
592 else
593 {
594 ::memcpy(data.get(), bufferPtr, tensorInfo.GetNumBytes());
595 }
596
597 return std::make_pair(ConstTensor(tensorInfo, data.get()), std::move(data));
598}
599
600std::pair<ConstTensor, std::unique_ptr<float[]>>
601OnnxParserImpl::CreateConstTensor(const std::string name,
602 armnn::Optional<armnn::PermutationVector&> permutationVector)
603{
604 TensorInfo tensorInfo = *m_TensorsInfo[name].m_info;
telsoa01c577f2c2018-08-31 09:22:23 +0100605 onnx::TensorProto onnxTensor = *m_TensorsInfo[name].m_tensor;
606
Narumol Prangnawaratf10b15a2021-09-17 21:08:57 +0100607 //ONNX can have Float16 and double constant nodes but ArmNN only supports float32
608 CHECK_VALID_DATATYPE(name, onnxTensor.name(),
609 static_cast<onnx::TensorProto::DataType>(onnxTensor.data_type()), onnx::TensorProto::FLOAT);
610
Matthew Sloyan81beae32021-07-13 19:46:11 +0100611 // Makes sure IsConstant flag is set.
612 tensorInfo.SetConstant();
613
Jan Eilers53ef7952021-06-02 12:01:25 +0100614 // Const tensors requires at least a list of values
615 if (tensorInfo.GetNumElements() == 0)
616 {
617 throw ParseException(fmt::format("No tensor data found for Const tensor '{}' {}",
618 name,
619 CHECK_LOCATION().AsString()));
620 }
621
telsoa01c577f2c2018-08-31 09:22:23 +0100622 auto srcData = onnxTensor.float_data().data();
Pablo Tello3dcc1c62019-04-24 14:20:21 +0100623 // Copy the value list entries into the destination
624 if (!onnxTensor.has_raw_data())
telsoa01c577f2c2018-08-31 09:22:23 +0100625 {
Pablo Tello3dcc1c62019-04-24 14:20:21 +0100626 if(tensorInfo.GetNumElements() != static_cast<uint>(onnxTensor.float_data_size()))
627 {
James Ward58dec6b2020-09-11 17:32:44 +0100628 throw ParseException(
629 fmt::format("The number of data provided ({}) does not match the tensor '{}' number of "
630 "elements ({}) {}",
631 onnxTensor.float_data_size(),
632 name,
633 tensorInfo.GetNumElements(),
634 CHECK_LOCATION().AsString()));
Pablo Tello3dcc1c62019-04-24 14:20:21 +0100635 }
Jan Eilers53ef7952021-06-02 12:01:25 +0100636 return CreateConstTensorImpl<float>(srcData, tensorInfo, permutationVector);
telsoa01c577f2c2018-08-31 09:22:23 +0100637 }
Pablo Tello3dcc1c62019-04-24 14:20:21 +0100638 else
639 {
Jan Eilers53ef7952021-06-02 12:01:25 +0100640 return CreateConstTensorImpl<float>(reinterpret_cast<const float*>(onnxTensor.raw_data().c_str()),
641 tensorInfo,
642 permutationVector);
Pablo Tello3dcc1c62019-04-24 14:20:21 +0100643 }
telsoa01c577f2c2018-08-31 09:22:23 +0100644}
645
Narumol Prangnawaratf10b15a2021-09-17 21:08:57 +0100646std::pair<ConstTensor, std::unique_ptr<int32_t[]>>
647OnnxParserImpl::CreateInt64ConstTensor(const std::string name,
648 armnn::Optional<armnn::PermutationVector&> permutationVector)
649{
650 TensorInfo tensorInfo = *m_TensorsInfo[name].m_info;
651 onnx::TensorProto onnxTensor = *m_TensorsInfo[name].m_tensor;
652
653 CHECK_VALID_DATATYPE(name, onnxTensor.name(),
654 static_cast<onnx::TensorProto::DataType>(onnxTensor.data_type()), onnx::TensorProto::INT64);
655
656 // Makes sure IsConstant flag is set.
657 tensorInfo.SetConstant();
658 uint numElements = tensorInfo.GetNumElements();
659
660 // Const tensors requires at least a list of values
661 if (numElements == 0)
662 {
663 throw ParseException(fmt::format("No tensor data found for Const tensor '{}' {}",
664 name,
665 CHECK_LOCATION().AsString()));
666 }
667
668 // Copy the value list entries into the destination
669 if (!onnxTensor.has_raw_data())
670 {
671 auto srcData = onnxTensor.int64_data().data();
672 if(numElements != static_cast<uint>(onnxTensor.int64_data_size()))
673 {
674 throw ParseException(
675 fmt::format("The number of data provided ({}) does not match the tensor '{}' number of "
676 "elements ({}) {}",
677 onnxTensor.int64_data_size(),
678 name,
679 tensorInfo.GetNumElements(),
680 CHECK_LOCATION().AsString()));
681 }
682
683 std::vector<int32_t> int32Data;
684 for(uint i = 0; i < numElements; i++)
685 {
686 int32_t int32Value = CHECKED_INT32(srcData[i]);
687 int32Data.push_back(int32Value);
688 }
689
690 return CreateConstTensorImpl<int32_t>(int32Data.data(), tensorInfo, permutationVector);
691 }
692 else
693 {
694 auto srcData = reinterpret_cast<const int64_t*>(onnxTensor.raw_data().c_str());
695 std::vector<int32_t> int32Data;
696 for(uint i = 0; i < numElements; i++)
697 {
698 int32_t int32Value = CHECKED_INT32(srcData[i]);
699 int32Data.push_back(int32Value);
700 }
701 return CreateConstTensorImpl<int32_t>(int32Data.data(), tensorInfo, permutationVector);
702 }
703}
704
Kevin Mayef33cb12021-01-29 14:24:57 +0000705ModelPtr OnnxParserImpl::LoadModelFromTextFile(const char* graphFile)
telsoa01c577f2c2018-08-31 09:22:23 +0100706{
707 FILE* fd = fopen(graphFile, "r");
708
709 if (fd == nullptr)
710 {
James Ward58dec6b2020-09-11 17:32:44 +0100711 throw FileNotFoundException(fmt::format("Invalid (null) filename {}", CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100712 }
713
714 // Parse the file into a message
715 ModelPtr modelProto = std::make_unique<onnx::ModelProto>();
716 using google::protobuf::io::FileInputStream;
717 std::unique_ptr<FileInputStream> input = std::make_unique<FileInputStream>(fileno(fd));
718 bool success = google::protobuf::TextFormat::Parse(input.get(), modelProto.get());
719 fclose(fd);
720
721 if (!success)
722 {
723 std::stringstream error;
724 error << "Failed to parse graph file";
James Ward58dec6b2020-09-11 17:32:44 +0100725 throw ParseException(fmt::format("{} {}", error.str(), CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100726 }
727 return modelProto;
728}
729
Kevin Mayef33cb12021-01-29 14:24:57 +0000730INetworkPtr OnnxParserImpl::CreateNetworkFromTextFile(const char* graphFile)
telsoa01c577f2c2018-08-31 09:22:23 +0100731{
732 ResetParser();
733 ModelPtr modelProto = LoadModelFromTextFile(graphFile);
734 return CreateNetworkFromModel(*modelProto);
735}
736
Narumol Prangnawarat1b11f322021-10-13 11:44:50 +0100737INetworkPtr OnnxParserImpl::CreateNetworkFromTextFile(const char* graphFile,
738 const std::map<std::string, armnn::TensorShape>& inputShapes)
739{
740 ResetParser();
741 m_InputShapes = inputShapes;
742 ModelPtr modelProto = LoadModelFromTextFile(graphFile);
743 return CreateNetworkFromModel(*modelProto);
744}
telsoa01c577f2c2018-08-31 09:22:23 +0100745
Mike Kelly2ae32242022-11-25 13:55:24 +0000746INetworkPtr OnnxParserImpl::CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent)
747{
748 ResetParser();
749 ModelPtr modelProto = LoadModelFromBinary(binaryContent);
750 return CreateNetworkFromModel(*modelProto);
751}
752
753INetworkPtr OnnxParserImpl::CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent,
754 const std::map<std::string, armnn::TensorShape>& inputShapes)
755{
756 ResetParser();
757 m_InputShapes = inputShapes;
758 ModelPtr modelProto = LoadModelFromBinary(binaryContent);
759 return CreateNetworkFromModel(*modelProto);
760}
761
762ModelPtr OnnxParserImpl::LoadModelFromBinary(const std::vector<uint8_t>& binaryContent)
763{
764 if (binaryContent.size() == 0)
765 {
766 throw ParseException(fmt::format("Missing binary content", CHECK_LOCATION().AsString()));
767 }
768 // Parse the file into a message
769 ModelPtr modelProto = std::make_unique<onnx::ModelProto>();
770
771 google::protobuf::io::CodedInputStream codedStream(binaryContent.data(), static_cast<int>(binaryContent.size()));
772 codedStream.SetTotalBytesLimit(INT_MAX);
773 bool success = modelProto.get()->ParseFromCodedStream(&codedStream);
774
775 if (!success)
776 {
777 std::stringstream error;
778 error << "Failed to parse graph";
779 throw ParseException(fmt::format("{} {}", error.str(), CHECK_LOCATION().AsString()));
780 }
781 return modelProto;
782}
783
Kevin Mayef33cb12021-01-29 14:24:57 +0000784ModelPtr OnnxParserImpl::LoadModelFromBinaryFile(const char* graphFile)
telsoa01c577f2c2018-08-31 09:22:23 +0100785{
786 FILE* fd = fopen(graphFile, "rb");
787
788 if (fd == nullptr)
789 {
James Ward58dec6b2020-09-11 17:32:44 +0100790 throw FileNotFoundException(fmt::format("Invalid (null) filename {}", CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100791 }
792
793 // Parse the file into a message
794 ModelPtr modelProto = std::make_unique<onnx::ModelProto>();
795
796 google::protobuf::io::FileInputStream inStream(fileno(fd));
797 google::protobuf::io::CodedInputStream codedStream(&inStream);
Nikhil Raje5181532020-10-09 14:52:25 +0100798 codedStream.SetTotalBytesLimit(INT_MAX);
telsoa01c577f2c2018-08-31 09:22:23 +0100799 bool success = modelProto.get()->ParseFromCodedStream(&codedStream);
800 fclose(fd);
801
802 if (!success)
803 {
804 std::stringstream error;
805 error << "Failed to parse graph file";
James Ward58dec6b2020-09-11 17:32:44 +0100806 throw ParseException(fmt::format("{} {}", error.str(), CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100807 }
808 return modelProto;
809
810}
811
Kevin Mayef33cb12021-01-29 14:24:57 +0000812INetworkPtr OnnxParserImpl::CreateNetworkFromBinaryFile(const char* graphFile)
telsoa01c577f2c2018-08-31 09:22:23 +0100813{
814 ResetParser();
815 ModelPtr modelProto = LoadModelFromBinaryFile(graphFile);
816 return CreateNetworkFromModel(*modelProto);
817}
818
Narumol Prangnawarat1b11f322021-10-13 11:44:50 +0100819INetworkPtr OnnxParserImpl::CreateNetworkFromBinaryFile(const char* graphFile,
820 const std::map<std::string, armnn::TensorShape>& inputShapes)
821{
822 ResetParser();
823 m_InputShapes = inputShapes;
824 ModelPtr modelProto = LoadModelFromBinaryFile(graphFile);
825 return CreateNetworkFromModel(*modelProto);
826}
827
Kevin Mayef33cb12021-01-29 14:24:57 +0000828ModelPtr OnnxParserImpl::LoadModelFromString(const std::string& protoText)
telsoa01c577f2c2018-08-31 09:22:23 +0100829{
830 if (protoText == "")
831 {
James Ward58dec6b2020-09-11 17:32:44 +0100832 throw InvalidArgumentException(fmt::format("Invalid (empty) string for model parameter {}",
833 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100834 }
835 // Parse the string into a message
836 ModelPtr modelProto = std::make_unique<onnx::ModelProto>();
837 bool success = google::protobuf::TextFormat::ParseFromString(protoText, modelProto.get());
838 if (!success)
839 {
840 std::stringstream error;
841 error << "Failed to parse graph file";
James Ward58dec6b2020-09-11 17:32:44 +0100842 throw ParseException(fmt::format("{} {}", error.str(), CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100843 }
844 return modelProto;
845}
846
Kevin Mayef33cb12021-01-29 14:24:57 +0000847INetworkPtr OnnxParserImpl::CreateNetworkFromString(const std::string& protoText)
telsoa01c577f2c2018-08-31 09:22:23 +0100848{
849 ResetParser();
850 ModelPtr modelProto = LoadModelFromString(protoText);
851 return CreateNetworkFromModel(*modelProto);
852}
853
Narumol Prangnawarat1b11f322021-10-13 11:44:50 +0100854INetworkPtr OnnxParserImpl::CreateNetworkFromString(const std::string& protoText,
855 const std::map<std::string, armnn::TensorShape>& inputShapes)
856{
857 ResetParser();
858 m_InputShapes = inputShapes;
859 ModelPtr modelProto = LoadModelFromString(protoText);
860 return CreateNetworkFromModel(*modelProto);
861}
862
Kevin Mayef33cb12021-01-29 14:24:57 +0000863INetworkPtr OnnxParserImpl::CreateNetworkFromModel(onnx::ModelProto& model)
telsoa01c577f2c2018-08-31 09:22:23 +0100864{
865 m_Network = INetwork::Create();
866 try
867 {
868 m_Graph = std::make_unique<onnx::GraphProto>(*model.mutable_graph());
869 LoadGraph();
870 }
871 catch (const ParseException& e)
872 {
873 Cleanup();
874 throw e;
875 }
876 Cleanup();
877 return std::move(m_Network);
878}
879
Kevin Mayef33cb12021-01-29 14:24:57 +0000880void OnnxParserImpl::LoadGraph()
telsoa01c577f2c2018-08-31 09:22:23 +0100881{
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100882 ARMNN_ASSERT(m_Graph.get() != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +0100883
884 //Fill m_TensorsInfo with the shapes and value of every tensor
885 SetupInfo(m_Graph->mutable_output());
886 SetupInfo(m_Graph->mutable_input());
887 SetupInfo(m_Graph->mutable_value_info());
888
889 for (auto tensor : m_Graph->initializer())
890 {
891 m_TensorsInfo[tensor.name()].m_tensor = std::make_unique<const onnx::TensorProto>(tensor);
Tee Jungfcf6fd52019-11-01 05:27:28 +0000892 m_TensorsInfo[tensor.name()].m_info = std::make_unique<TensorInfo>(ToTensorInfo(tensor));
893 m_TensorsInfo[tensor.name()].m_dtype =
894 static_cast<onnx::TensorProto::DataType>(tensor.data_type());
telsoa01c577f2c2018-08-31 09:22:23 +0100895 }
896
897 SetupInputLayers();
898 SetupOutputLayers();
899
900 //Detect FullyConnected layers with bias and update the FusedAndUsed map acccordingly
901 DetectFullyConnected();
902
903 //Parsing the graph
904 for(size_t nodeIndex = 0; nodeIndex < static_cast<size_t>(m_Graph->node_size()); nodeIndex++)
905 {
906 auto node = m_Graph->node(static_cast<int>(nodeIndex));
907 const std::string& operation = node.op_type();
908
909 // check which layers we handled already (add and matmul fused as FC)
Ryan OShea337c17f2020-02-21 12:33:17 +0000910 if (operation == "MatMul" )
telsoa01c577f2c2018-08-31 09:22:23 +0100911 {
912 if(m_OutputsFusedAndUsed[nodeIndex].inputForNodes != m_OutputsFusedAndUsed[nodeIndex].fusedWithNodes.size())
913 {
914 //Node which can not be fused as a FullyConnected layer (used in layers as a simple matmul output)
915 AddFullyConnected(node);
916 }
917 }
918 else if (!(m_OutputsFusedAndUsed[nodeIndex].fusedWithNodes.empty()) && operation == "Add")
919 {
920 int matmulIndex = static_cast<int> (m_OutputsFusedAndUsed[nodeIndex].fusedWithNodes[0]);
921 AddFullyConnected(m_Graph->node(matmulIndex), &node);
922 }
923 else if (m_OutputsFusedAndUsed[nodeIndex].fusedWithNodes.empty()) //node is not part of a fused layer
924 {
925 auto it = m_ParserFunctions.find(operation);
926 if (it != m_ParserFunctions.end())
927 {
928 auto func = it->second;
929 (this->*func)(node);
930 }
931 else
932 {
James Ward58dec6b2020-09-11 17:32:44 +0100933 throw ParseException(fmt::format("Unsupported operation {} for node '{}' {}",
934 operation,
935 node.name(),
936 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100937 }
938 }
939 }
940
941 //Making the connections between outputs and inputs of each layers
942 for (const auto& tensorCon : m_TensorConnections)
943 {
944 if (tensorCon.second.outputSlot != nullptr)
945 {
946 for (size_t inputSlotIdx = 0; inputSlotIdx < tensorCon.second.inputSlots.size(); ++inputSlotIdx)
947 {
948 tensorCon.second.outputSlot->Connect(*(tensorCon.second.inputSlots[inputSlotIdx]));
949 }
950 }
951 }
Narumol Prangnawarat1b11f322021-10-13 11:44:50 +0100952
953 // Get output info.
954 for(int outputIndex = 0; outputIndex < m_Graph->output_size(); ++outputIndex)
955 {
956 auto output = m_Graph->output(outputIndex);
957 m_OutputInfos[output.name()] = *m_TensorsInfo[output.name()].m_info;
958 }
telsoa01c577f2c2018-08-31 09:22:23 +0100959}
960
Kevin Mayef33cb12021-01-29 14:24:57 +0000961void OnnxParserImpl::SetupInfo(const google::protobuf::RepeatedPtrField<onnx::ValueInfoProto >* list)
telsoa01c577f2c2018-08-31 09:22:23 +0100962{
963 for (auto tensor : *list)
964 {
965 m_TensorsInfo[tensor.name()] = OnnxTensor();
966 m_TensorsInfo[tensor.name()].m_info = std::make_unique<TensorInfo>(ToTensorInfo(tensor));
Matteo Martincighe355dc22018-12-10 13:45:27 +0000967 m_TensorsInfo[tensor.name()].m_dtype =
968 static_cast<onnx::TensorProto::DataType>(tensor.type().tensor_type().elem_type());
telsoa01c577f2c2018-08-31 09:22:23 +0100969 }
970}
971
Kevin Mayef33cb12021-01-29 14:24:57 +0000972void OnnxParserImpl::DetectFullyConnected()
telsoa01c577f2c2018-08-31 09:22:23 +0100973{
974 m_OutputsFusedAndUsed = std::vector<UsageSummary> (static_cast<size_t>(m_Graph->node_size()), UsageSummary());
975 auto matmulAndConstant = [&](const std::string& constInput,
976 const std::string& matmulInput,
977 int& nodeIndex)
978 {
979 auto matmulIt = m_OutputsMap.find(matmulInput);
980 if(matmulIt != m_OutputsMap.end() && matmulIt->second.first->op_type() == "MatMul"
981 && m_TensorsInfo[constInput].isConstant())
982 {
983 nodeIndex = matmulIt->second.second;
984 return true;
985 }
986 return false;
987 };
988
989 for(int nodeIndex = 0; nodeIndex < m_Graph->node_size(); nodeIndex++)
990 {
991 const onnx::NodeProto* node = &m_Graph->node(nodeIndex);
992 for (const std::string& output : node->output())
993 {
994 m_OutputsMap[output] = std::make_pair(node, nodeIndex);
995 }
996
997 for (const std::string& input : node->input()) //count how many time a node is used as input
998 {
999 auto matmulIt = m_OutputsMap.find(input);
1000 if(matmulIt != m_OutputsMap.end()){
1001 ++m_OutputsFusedAndUsed[static_cast<size_t>(matmulIt->second.second)].inputForNodes; //node used
1002 }
1003 }
1004
1005 if (node->op_type() == "Add")
1006 {
1007 int matmulIndex = 0;
1008 if (matmulAndConstant(node->input(0), node->input(1), matmulIndex) ||
1009 matmulAndConstant(node->input(1), node->input(0), matmulIndex))
1010 {
1011 //matmul and add were fused
1012 m_OutputsFusedAndUsed[static_cast<size_t>(matmulIndex)].fusedWithNodes
1013 .push_back(static_cast<size_t>(nodeIndex));
1014
1015 m_OutputsFusedAndUsed[static_cast<size_t>(nodeIndex)].fusedWithNodes
1016 .push_back(static_cast<size_t>(matmulIndex));
1017 }
1018 }
1019 }
1020
1021 for (auto output: m_Graph->output()) { //Add usages as output of the graph in count of usages
1022 auto matmulIt = m_OutputsMap.find(output.name());
1023 if(matmulIt != m_OutputsMap.end()){
1024 ++m_OutputsFusedAndUsed[static_cast<size_t>(matmulIt->second.second)].inputForNodes;
1025 }
1026 }
1027}
1028
1029template<typename Location>
Kevin Mayef33cb12021-01-29 14:24:57 +00001030void OnnxParserImpl::GetInputAndParam(const onnx::NodeProto& node,
1031 std::string* inputName,
1032 std::string* constName,
1033 const Location& location)
telsoa01c577f2c2018-08-31 09:22:23 +01001034{
1035 int cstIndex;
1036 if (m_TensorsInfo[node.input(0)].isConstant())
1037 {
1038 cstIndex = 0;
1039 }
1040 else if (m_TensorsInfo[node.input(1)].isConstant())
1041 {
1042 cstIndex = 1;
1043 }
1044 else
1045 {
James Ward58dec6b2020-09-11 17:32:44 +01001046 throw ParseException(fmt::format("One of the input tensors ('{}' or '{}') should be constant in node '{}' {}",
1047 node.input(0),
1048 node.input(1),
1049 node.name(),
1050 location.AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001051 }
1052 if(constName)
1053 {
1054 *constName = node.input(cstIndex);
1055 }
1056 if(inputName)
1057 {
1058 *inputName = node.input(!cstIndex);
1059 }
1060}
1061
1062template<typename Location>
Kevin Mayef33cb12021-01-29 14:24:57 +00001063void OnnxParserImpl::To1DTensor(const std::string& name, const Location& location)
telsoa01c577f2c2018-08-31 09:22:23 +01001064{
1065 TensorShape shape = m_TensorsInfo[name].m_info->GetShape();
1066 std::vector<uint32_t> newShape;
1067 for(uint i = 0; i < shape.GetNumDimensions() - 1; ++i)
1068 {
1069 if(shape[i] != 1)
1070 {
James Ward58dec6b2020-09-11 17:32:44 +01001071 throw ParseException(
1072 fmt::format("Only tensors with shape [1, ..., 1, X] can be converted to 1D and {} {}",
1073 TensorInfoAsString(*m_TensorsInfo[name].m_info, name, m_TensorsInfo[name].m_dtype),
1074 location.AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001075 }
1076 }
1077 newShape.push_back(shape[shape.GetNumDimensions() - 1]);
1078
1079 m_TensorsInfo[name].m_info->SetShape(TensorShape(static_cast<unsigned int>(newShape.size()), newShape.data()));
1080}
1081
Kevin Mayef33cb12021-01-29 14:24:57 +00001082void OnnxParserImpl::AddConvLayerWithDepthwiseConv(const onnx::NodeProto& node, const Convolution2dDescriptor& convDesc)
Ryan OSheaed27ee72020-04-22 16:37:29 +01001083{
1084 ARMNN_ASSERT(node.op_type() == "Conv");
1085
1086 DepthwiseConvolution2dDescriptor desc;
1087 desc.m_PadLeft = convDesc.m_PadLeft;
1088 desc.m_PadRight = convDesc.m_PadRight;
1089 desc.m_PadTop = convDesc.m_PadTop;
1090 desc.m_PadBottom = convDesc.m_PadBottom;
1091 desc.m_StrideX = convDesc.m_StrideX;
1092 desc.m_StrideY = convDesc.m_StrideY;
1093 desc.m_BiasEnabled = convDesc.m_BiasEnabled;
1094
Cathal Corbett06902652022-04-14 17:55:11 +01001095 armnn::IConnectableLayer* layer = m_Network->AddDepthwiseConvolution2dLayer(desc, node.name().c_str());
Cathal Corbett541880f2022-05-16 15:20:56 +01001096 std::string permuteStr = "permute_" + node.input(1);
1097 std::vector<std::string> tensorIndexes= {node.input(0), permuteStr};
Jan Eilers53ef7952021-06-02 12:01:25 +01001098
Cathal Corbett541880f2022-05-16 15:20:56 +01001099 auto weightTensor = CreateConstTensor(node.input(1));
Cathal Corbett06902652022-04-14 17:55:11 +01001100 IConnectableLayer* weightsLayer = m_Network->AddConstantLayer(weightTensor.first);
Cathal Corbett541880f2022-05-16 15:20:56 +01001101
1102 // weights come in as [O,1,H,W] from ONNX and need to be converted to ArmNNs depthwise weights layout [1,H,W,O]
1103 armnn::PermutationVector perVec {3, 0, 1, 2};
1104 TensorInfo weightsPermuted = armnnUtils::Permuted(weightTensor.first.GetInfo(), perVec);
1105
1106 // Inserts NewLayer so layers don't need to be re-sorted.
1107 IConnectableLayer* permuteLayer = m_Network->AddPermuteLayer(PermuteDescriptor(perVec),
1108 "permute_layer");
1109 permuteLayer->GetOutputSlot(0).SetTensorInfo(weightsPermuted);
1110 permuteLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1u));
1111
Cathal Corbett06902652022-04-14 17:55:11 +01001112 weightsLayer->GetOutputSlot(0).SetTensorInfo(weightTensor.first.GetInfo());
Cathal Corbett541880f2022-05-16 15:20:56 +01001113 weightsLayer->GetOutputSlot(0).Connect(permuteLayer->GetInputSlot(0u));
Cathal Corbett06902652022-04-14 17:55:11 +01001114
Ryan OSheaed27ee72020-04-22 16:37:29 +01001115 if (node.input_size() == 3)
1116 {
1117 if(!m_TensorsInfo[node.input(2)].isConstant())
1118 {
James Ward58dec6b2020-09-11 17:32:44 +01001119 throw ParseException(fmt::format("Bias '{}' should be constant in Conv layer '{}' {}",
1120 node.input(2),
1121 node.name(),
1122 CHECK_LOCATION().AsString()));
Ryan OSheaed27ee72020-04-22 16:37:29 +01001123 }
Cathal Corbett06902652022-04-14 17:55:11 +01001124
Ryan OSheaed27ee72020-04-22 16:37:29 +01001125 desc.m_BiasEnabled = true;
1126 auto biasTensor = CreateConstTensor(node.input(2));
Cathal Corbett06902652022-04-14 17:55:11 +01001127 tensorIndexes.emplace_back(node.input(2));
1128
1129 IConnectableLayer* biasLayer = m_Network->AddConstantLayer(biasTensor.first);
1130 biasLayer->GetOutputSlot(0).SetTensorInfo(biasTensor.first.GetInfo());
1131 biasLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(2u));
Ryan OSheaed27ee72020-04-22 16:37:29 +01001132 }
Cathal Corbett06902652022-04-14 17:55:11 +01001133
Ryan OSheaed27ee72020-04-22 16:37:29 +01001134 ARMNN_ASSERT(layer != nullptr);
1135
1136 auto outputInfo = ComputeOutputInfo({ node.output(0) }, layer,
1137 { m_TensorsInfo[node.input(0)].m_info->GetShape(),
Cathal Corbett541880f2022-05-16 15:20:56 +01001138 weightsPermuted.GetShape() });
Ryan OSheaed27ee72020-04-22 16:37:29 +01001139
1140 layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
1141
1142 // register the input connection slots for the layer, connections are made after all layers have been created
1143 // only the tensors for the inputs are relevant, exclude the const tensors
Cathal Corbett06902652022-04-14 17:55:11 +01001144 RegisterInputSlots(layer, tensorIndexes);
Ryan OSheaed27ee72020-04-22 16:37:29 +01001145
1146 // register the output connection slots for the layer, connections are made after all layers have been created
1147 RegisterOutputSlots(layer, {node.output(0)});
1148}
1149
Kevin Mayef33cb12021-01-29 14:24:57 +00001150void OnnxParserImpl::AddFullyConnected(const onnx::NodeProto& matmulNode, const onnx::NodeProto* addNode)
telsoa01c577f2c2018-08-31 09:22:23 +01001151{
telsoa01c577f2c2018-08-31 09:22:23 +01001152 // find matmul inputs
telsoa01c577f2c2018-08-31 09:22:23 +01001153 std::string inputName;
Matthew Sloyanca361232023-02-16 14:50:22 +00001154 std::string weightName;
1155 std::string biasName;
1156 std::string outputName;
telsoa01c577f2c2018-08-31 09:22:23 +01001157 CHECK_VALID_SIZE(static_cast<size_t>(matmulNode.input_size()), 2);
1158 CHECK_VALID_SIZE(static_cast<size_t>(matmulNode.output_size()), 1);
1159 VALID_INPUTS(matmulNode, STR_LIST(onnx::TensorProto::FLOAT));
1160
1161 GetInputAndParam(matmulNode, &inputName, &weightName, CHECK_LOCATION());
1162
Matthew Sloyanca361232023-02-16 14:50:22 +00001163 TensorInfo inputInfo = *m_TensorsInfo[inputName].m_info;
1164 TensorInfo weightInfo = *m_TensorsInfo[weightName].m_info;
1165 TensorInfo biasInfo;
1166
1167 std::vector<std::string> inputNames;
1168
telsoa01c577f2c2018-08-31 09:22:23 +01001169 FullyConnectedDescriptor desc;
1170 desc.m_BiasEnabled = addNode != nullptr;
1171
1172 IConnectableLayer* layer = nullptr;
1173 if(desc.m_BiasEnabled)
1174 {
1175 // find bias const
telsoa01c577f2c2018-08-31 09:22:23 +01001176 CHECK_VALID_SIZE(static_cast<size_t>(addNode->input_size()), 2);
1177 CHECK_VALID_SIZE(static_cast<size_t>(addNode->output_size()), 1);
1178 VALID_INPUTS(*addNode, STR_LIST(onnx::TensorProto::FLOAT));
1179
1180 GetInputAndParam(*addNode, nullptr, &biasName, CHECK_LOCATION());
1181
1182 //Output shape is [1, weights[1]] and 1d vec in ONNX can be [1,X] so we convert biases to "armnn" 1D
1183 To1DTensor(biasName, CHECK_LOCATION());
Matthew Sloyanca361232023-02-16 14:50:22 +00001184 biasInfo = *m_TensorsInfo[biasName].m_info;
telsoa01c577f2c2018-08-31 09:22:23 +01001185
1186 if (weightInfo.GetShape()[1] != biasInfo.GetShape()[0])
1187 {
James Ward58dec6b2020-09-11 17:32:44 +01001188 throw ParseException(
1189 fmt::format("Shape of weights '{}' and bias of following Add node '{}' do not match : {}"
1190 " and {} ( /!\\ bias should be a 1D tensor) {}",
1191 weightName,
1192 addNode->name(),
1193 TensorInfoAsString(*m_TensorsInfo[weightName].m_info, weightName,
1194 m_TensorsInfo[weightName].m_dtype),
1195 TensorInfoAsString(*m_TensorsInfo[biasName].m_info, biasName,
1196 m_TensorsInfo[biasName].m_dtype ),
1197 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001198 }
Matthew Sloyan81beae32021-07-13 19:46:11 +01001199
Matthew Sloyanca361232023-02-16 14:50:22 +00001200 inputNames = { inputName, weightName, biasName };
1201 outputName = addNode->output(0);
telsoa01c577f2c2018-08-31 09:22:23 +01001202 }
1203 else
1204 {
Matthew Sloyanca361232023-02-16 14:50:22 +00001205 inputNames = { inputName, weightName };
1206 outputName = matmulNode.output(0);
1207 }
telsoa01c577f2c2018-08-31 09:22:23 +01001208
Matthew Sloyanca361232023-02-16 14:50:22 +00001209 // Just add a FullyConnected layer, weights and biases are handled as inputs now.
1210 layer = m_Network->AddFullyConnectedLayer(desc, matmulNode.name().c_str());
1211 ARMNN_ASSERT(layer != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +01001212
Matthew Sloyanca361232023-02-16 14:50:22 +00001213 if (inputInfo.GetNumDimensions() > 2)
1214 {
1215 // Add reshape to flatten to 2D [batch_size, input_size],
1216 // where "input_size" corresponds to the number of inputs to the layer,
1217 // matching the second dimension of weights,
1218 // and "batch_size" is calculated by dividing the number of elements by "input_size".
1219 std::vector<unsigned int> reshapedDimensions(2);
1220 reshapedDimensions[1] = weightInfo.GetShape()[0];
1221 reshapedDimensions[0] = inputInfo.GetNumElements() / reshapedDimensions[1];
1222
1223 if (inputInfo.GetNumElements() % reshapedDimensions[1] != 0)
Matthew Sloyan81beae32021-07-13 19:46:11 +01001224 {
Matthew Sloyanca361232023-02-16 14:50:22 +00001225 throw ParseException(
1226 fmt::format("Failed to deduce input tensor shape from filter size {} {}",
1227 reshapedDimensions[1],
1228 CHECK_LOCATION().AsString()));
Matthew Sloyan81beae32021-07-13 19:46:11 +01001229 }
1230
Matthew Sloyanca361232023-02-16 14:50:22 +00001231 TensorInfo reshapedTensorInfo = inputInfo;
1232 reshapedTensorInfo.SetShape(armnn::TensorShape{ 2, reshapedDimensions.data() });
1233 inputInfo = reshapedTensorInfo;
1234
1235 ReshapeDescriptor reshapeDescriptor;
1236 reshapeDescriptor.m_TargetShape = reshapedTensorInfo.GetShape();
1237
1238 std::string reshapeLayerName = fmt::format("Reshape_for:{}", layer->GetName());
1239 IConnectableLayer* reshapeLayer = m_Network->AddReshapeLayer(reshapeDescriptor, reshapeLayerName.c_str());
1240
1241 reshapeLayer->GetOutputSlot(0).SetTensorInfo(reshapedTensorInfo);
1242 reshapeLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(0));
1243
1244 RegisterInputSlots(reshapeLayer, {inputName});
1245 inputNames[0] = reshapeLayerName;
telsoa01c577f2c2018-08-31 09:22:23 +01001246 }
Matthew Sloyanca361232023-02-16 14:50:22 +00001247
1248 auto outputInfo = ComputeOutputInfo({ outputName },
1249 layer,
1250 { inputInfo.GetShape(),
1251 weightInfo.GetShape() });
1252 layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
1253
1254 RegisterInputSlots(layer, inputNames);
1255
1256 // Add constant layer to store weights/biases and connect to FullyConnected layer..
1257 if(m_TensorsInfo[weightName].isConstant())
1258 {
1259 IConnectableLayer* weightsLayer = m_Network->AddConstantLayer(CreateConstTensor(weightName).first);
1260
1261 weightInfo.SetConstant();
1262 weightsLayer->GetOutputSlot(0).SetTensorInfo(weightInfo);
1263 weightsLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1u));
1264 }
1265
1266 if(desc.m_BiasEnabled && m_TensorsInfo[biasName].isConstant())
1267 {
1268 IConnectableLayer* biasLayer = m_Network->AddConstantLayer(CreateConstTensor(biasName).first);
1269
1270 biasInfo.SetConstant();
1271 biasLayer->GetOutputSlot(0).SetTensorInfo(biasInfo);
1272 biasLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(2u));
1273 }
1274
1275 if (outputInfo[0].GetNumDimensions() > 2)
1276 {
1277 // Calculate reshape to flatten to 2D [batch_size, input_size]
1278 std::vector<unsigned int> reshapedDimensions(2);
1279 reshapedDimensions[1] = weightInfo.GetShape()[1];
1280 reshapedDimensions[0] = outputInfo[0].GetNumElements() / reshapedDimensions[1];
1281
1282 if (outputInfo[0].GetNumElements() % reshapedDimensions[1] != 0)
1283 {
1284 throw ParseException(
1285 fmt::format("Failed to deduce output tensor shape from filter size {} {}",
1286 reshapedDimensions[1],
1287 CHECK_LOCATION().AsString()));
1288 }
1289
1290 armnn::TensorInfo reshapedOutputTensorInfo = outputInfo[0];
1291 reshapedOutputTensorInfo.SetShape(armnn::TensorShape{ 2, reshapedDimensions.data() });
1292 layer->GetOutputSlot(0).SetTensorInfo(reshapedOutputTensorInfo);
1293
1294 ReshapeDescriptor desc;
1295 desc.m_TargetShape = outputInfo[0].GetShape();
1296
1297 std::string reshapeLayerName = fmt::format("ExpandDims_for:{}", layer->GetName());
1298 IConnectableLayer* reshapeLayer = m_Network->AddReshapeLayer(desc, reshapeLayerName.c_str());
1299
1300 layer->GetOutputSlot(0).Connect(reshapeLayer->GetInputSlot(0));
1301 reshapeLayer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
1302
1303 RegisterInputSlots(reshapeLayer, {layer->GetName()});
1304 layer = reshapeLayer;
1305 }
1306
1307 RegisterOutputSlots(layer, { outputName });
telsoa01c577f2c2018-08-31 09:22:23 +01001308}
1309
Kevin Mayef33cb12021-01-29 14:24:57 +00001310void OnnxParserImpl::AddPoolingLayer(const onnx::NodeProto& node, Pooling2dDescriptor& desc)
telsoa01c577f2c2018-08-31 09:22:23 +01001311{
1312
1313 CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 1);
1314 CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
1315
1316 VALID_INPUTS(node, STR_LIST(onnx::TensorProto::FLOAT));
1317
1318 std::vector<uint32_t> kernel_shape = ReadMandatoryNodeUint32ListAttribute(node, "kernel_shape"); //size of pool win
1319 std::vector<uint32_t> strides = ReadOptionalNodeUint32ListAttribute(node, "strides");
1320 std::vector<uint32_t> pads = ReadOptionalNodeUint32ListAttribute(node, "pads");
1321
1322 desc.m_OutputShapeRounding = OutputShapeRounding::Floor;
1323 desc.m_PoolWidth = kernel_shape[1];
1324 desc.m_PoolHeight = kernel_shape[0];
1325
1326 if(strides.empty())
1327 {
1328 desc.m_StrideX = 1;
1329 desc.m_StrideY = 1;
1330 }
1331 else
1332 {
1333 desc.m_StrideX = strides[1];
1334 desc.m_StrideY = strides[0];
1335 }
1336
1337 //Check new padding version first
1338 if(pads.empty())
1339 {
1340 //Check deprecated version
1341 std::string paddingString = ReadOptionalNodeStringAttribute(node, "auto_pad");
1342 if(paddingString != "VALID" && paddingString != "" && paddingString != "NOTSET")
1343 {
1344 bool isUpper;
1345 if( paddingString == "SAME_LOWER")
1346 {
1347 isUpper = false;
1348 }
1349 else if (paddingString == "SAME_UPPER")
1350 {
1351 isUpper = true;
1352 }
1353 else
1354 {
James Ward58dec6b2020-09-11 17:32:44 +01001355 throw ParseException(fmt::format("Invalid auto_pad attribute for node {}. "
1356 "Only SAME_UPPER, SAME_LOWER or VALID supported and found {} {}",
1357 node.name(),
1358 paddingString,
1359 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001360 }
1361 auto inputInfo = *m_TensorsInfo[node.input(0)].m_info;
1362 uint32_t inputHeight = inputInfo.GetShape()[2];
1363 uint32_t inputWidth = inputInfo.GetShape()[3];
Sadik Armagan60bb9d82021-01-11 15:15:01 +00001364 CalcPadding(inputHeight,
1365 desc.m_PoolHeight,
1366 desc.m_StrideY,
1367 1u,
1368 &desc.m_PadTop,
1369 &desc.m_PadBottom,
1370 isUpper);
1371 CalcPadding(inputWidth,
1372 desc.m_PoolWidth,
1373 desc.m_StrideX,
1374 1u,
1375 &desc.m_PadLeft,
1376 &desc.m_PadRight,
1377 isUpper);
telsoa01c577f2c2018-08-31 09:22:23 +01001378 }
1379 }
1380 else
1381 {
1382 desc.m_PadTop = pads[0];
1383 desc.m_PadLeft = pads[1];
1384 desc.m_PadBottom = pads[2];
1385 desc.m_PadRight = pads[3];
1386 }
1387
1388 IConnectableLayer* layer = m_Network->AddPooling2dLayer(desc, node.name().c_str());
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001389 ARMNN_ASSERT(layer != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +01001390
1391 auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, {m_TensorsInfo[node.input(0)].m_info->GetShape()});
1392 layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
1393
1394 // register the input connection slots for the layer, connections are made after all layers have been created
1395 // only the tensors for the inputs are relevant, exclude the const tensors
1396 RegisterInputSlots(layer, {node.input(0)});
1397
1398 // register the output connection slots for the layer, connections are made after all layers have been created
1399 RegisterOutputSlots(layer, {node.output(0)});
1400}
1401
Kevin Mayef33cb12021-01-29 14:24:57 +00001402std::pair<std::string, std::string> OnnxParserImpl::AddPrepareBroadcast(const std::string& input0,
1403 const std::string& input1)
Ryan OSheaed27ee72020-04-22 16:37:29 +01001404{
1405 std::pair<std::string, std::string> inputs = std::make_pair(input0, input1);
1406
1407 TensorShape input0Shape = m_TensorsInfo[input0].m_info->GetShape();
1408 TensorShape input1Shape = m_TensorsInfo[input1].m_info->GetShape();
1409
1410 if(input1Shape.GetNumDimensions() < input0Shape.GetNumDimensions())
1411 {
James Ward58dec6b2020-09-11 17:32:44 +01001412 auto outputName = fmt::format("reshape_output_{}", input1);
Ryan OSheaed27ee72020-04-22 16:37:29 +01001413 PrependForBroadcast(outputName, input1, input0);
1414 inputs.second = outputName;
1415 }
1416 else if(input0Shape.GetNumDimensions() < input1Shape.GetNumDimensions())
1417 {
James Ward58dec6b2020-09-11 17:32:44 +01001418 auto outputName = fmt::format("reshape_output_{}", input0);
Ryan OSheaed27ee72020-04-22 16:37:29 +01001419 PrependForBroadcast(outputName, input0, input1);
1420 inputs.first = outputName;
1421 }
1422 return inputs;
1423}
1424
Kevin Mayef33cb12021-01-29 14:24:57 +00001425void OnnxParserImpl::CreateConstantLayer(const std::string& tensorName, const std::string& layerName)
Ryan OSheaed27ee72020-04-22 16:37:29 +01001426{
1427 auto armnnTensor = CreateConstTensor(tensorName);
Narumol Prangnawaratf10b15a2021-09-17 21:08:57 +01001428 IConnectableLayer* layer = m_Network->AddConstantLayer(armnnTensor.first, layerName.c_str());
1429 layer->GetOutputSlot(0).SetTensorInfo(armnnTensor.first.GetInfo());
1430 RegisterOutputSlots(layer, {tensorName});
1431}
Ryan OSheaed27ee72020-04-22 16:37:29 +01001432
Narumol Prangnawaratf10b15a2021-09-17 21:08:57 +01001433void OnnxParserImpl::CreateInt64ConstantLayer(const std::string& tensorName, const std::string& layerName)
1434{
1435 auto armnnTensor = CreateInt64ConstTensor(tensorName);
Ryan OSheaed27ee72020-04-22 16:37:29 +01001436 IConnectableLayer* layer = m_Network->AddConstantLayer(armnnTensor.first, layerName.c_str());
1437 layer->GetOutputSlot(0).SetTensorInfo(armnnTensor.first.GetInfo());
1438 RegisterOutputSlots(layer, {tensorName});
1439}
1440
Kevin Mayef33cb12021-01-29 14:24:57 +00001441void OnnxParserImpl::CreateReshapeLayer(const std::string& inputName,
1442 const std::string& outputName,
1443 const std::string& layerName)
telsoa01c577f2c2018-08-31 09:22:23 +01001444{
1445 const TensorInfo outputTensorInfo = *m_TensorsInfo[outputName].m_info;
1446 ReshapeDescriptor reshapeDesc;
1447 reshapeDesc.m_TargetShape = outputTensorInfo.GetShape();
1448
1449 IConnectableLayer* layer = m_Network->AddReshapeLayer(reshapeDesc, layerName.c_str());
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001450 ARMNN_ASSERT(layer != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +01001451 layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
1452
1453 // register the input connection slots for the layer, connections are made after all layers have been created
1454 // only the tensors for the inputs are relevant, exclude the const tensors
1455 RegisterInputSlots(layer, {inputName});
1456
1457 // register the output connection slots for the layer, connections are made after all layers have been created
1458 RegisterOutputSlots(layer, {outputName});
1459}
1460
Kevin Mayef33cb12021-01-29 14:24:57 +00001461void OnnxParserImpl::ParseActivation(const onnx::NodeProto& node, const armnn::ActivationFunction func)
telsoa01c577f2c2018-08-31 09:22:23 +01001462{
Finn Williams7ee5d2c2020-03-27 11:11:50 +00001463 CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 1, 3);
telsoa01c577f2c2018-08-31 09:22:23 +01001464 CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
1465
1466 VALID_INPUTS(node, STR_LIST(onnx::TensorProto::FLOAT));
1467
1468 ActivationDescriptor desc;
Tee Jung7ff9a602019-11-01 07:04:42 +00001469 desc.m_Function = func;
telsoa01c577f2c2018-08-31 09:22:23 +01001470
Finn Williams7ee5d2c2020-03-27 11:11:50 +00001471 if (func == ActivationFunction::BoundedReLu)
1472 {
Narumol Prangnawaratf106ab72021-09-15 17:30:37 +01001473 if (node.input_size() == 1 && node.attribute_size() > 0)
1474 {
1475 desc.m_A = ReadOptionalNodeFloatAttribute(node, "max", std::numeric_limits<float>::max());
1476 desc.m_B = ReadOptionalNodeFloatAttribute(node, "min", std::numeric_limits<float>::lowest());
1477 }
1478 else
1479 {
1480 desc.m_A = node.input(2).empty() ? std::numeric_limits<float>::max() : std::stof(node.input(2));
1481 desc.m_B = node.input(1).empty() ? std::numeric_limits<float>::lowest() : std::stof(node.input(1));
1482 }
Finn Williams7ee5d2c2020-03-27 11:11:50 +00001483 }
1484
telsoa01c577f2c2018-08-31 09:22:23 +01001485 IConnectableLayer* const layer = m_Network->AddActivationLayer(desc, node.name().c_str());
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001486 ARMNN_ASSERT(layer != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +01001487
1488 auto outputInfo = ComputeOutputInfo({ node.output(0)}, layer, {m_TensorsInfo[node.input(0)].m_info->GetShape()});
1489 layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
1490
1491 // register the input connection slots for the layer, connections are made after all layers have been created
1492 // only the tensors for the inputs are relevant, exclude the const tensors
1493 RegisterInputSlots(layer, {node.input(0)});
1494
1495 // register the output connection slots for the layer, connections are made after all layers have been created
1496 RegisterOutputSlots(layer, {node.output(0)});
1497}
1498
Kevin Mayef33cb12021-01-29 14:24:57 +00001499void OnnxParserImpl::ParseClip(const onnx::NodeProto& node)
Finn Williams7ee5d2c2020-03-27 11:11:50 +00001500{
1501 ParseActivation(node, ActivationFunction::BoundedReLu);
1502}
1503
Kevin Mayef33cb12021-01-29 14:24:57 +00001504void OnnxParserImpl::ParseSigmoid(const onnx::NodeProto& node)
Tee Jung7ff9a602019-11-01 07:04:42 +00001505{
1506 ParseActivation(node, ActivationFunction::Sigmoid);
1507}
1508
Kevin Mayef33cb12021-01-29 14:24:57 +00001509void OnnxParserImpl::ParseTanh(const onnx::NodeProto& node)
Tee Jung7ff9a602019-11-01 07:04:42 +00001510{
1511 ParseActivation(node, ActivationFunction::TanH);
1512}
1513
Kevin Mayef33cb12021-01-29 14:24:57 +00001514void OnnxParserImpl::ParseRelu(const onnx::NodeProto& node)
Tee Jung7ff9a602019-11-01 07:04:42 +00001515{
1516 ParseActivation(node, ActivationFunction::ReLu);
1517}
1518
Kevin Mayef33cb12021-01-29 14:24:57 +00001519void OnnxParserImpl::ParseLeakyRelu(const onnx::NodeProto& node)
Tee Jung7ff9a602019-11-01 07:04:42 +00001520{
1521 ParseActivation(node, ActivationFunction::LeakyReLu);
1522}
telsoa01c577f2c2018-08-31 09:22:23 +01001523
Kevin Mayef33cb12021-01-29 14:24:57 +00001524void OnnxParserImpl::ParseAdd(const onnx::NodeProto& node)
telsoa01c577f2c2018-08-31 09:22:23 +01001525{
Ryan OSheaed27ee72020-04-22 16:37:29 +01001526 CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 2);
1527 CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
telsoa01c577f2c2018-08-31 09:22:23 +01001528
Ryan OSheaed27ee72020-04-22 16:37:29 +01001529 VALID_INPUTS(node, STR_LIST(onnx::TensorProto::FLOAT));
telsoa01c577f2c2018-08-31 09:22:23 +01001530
Ryan OSheaed27ee72020-04-22 16:37:29 +01001531 // TODO: unify broadcast validation code across layers
1532 // tracked by: IVGCVSW-1576
telsoa01c577f2c2018-08-31 09:22:23 +01001533
Ryan OSheaed27ee72020-04-22 16:37:29 +01001534 // Checking broadcast compatibility : only scalar or 1D tensors
1535 auto inputs = AddPrepareBroadcast(node.input(0), node.input(1));
1536 auto input0 = *m_TensorsInfo[inputs.first].m_info;
1537 auto input1 = *m_TensorsInfo[inputs.second].m_info;
1538 ARMNN_ASSERT(input0.GetNumDimensions() == input1.GetNumDimensions());
1539
1540 unsigned int numDims = input0.GetNumDimensions();
1541 for (unsigned int i = 0; i < numDims; i++)
telsoa01c577f2c2018-08-31 09:22:23 +01001542 {
Ryan OSheaed27ee72020-04-22 16:37:29 +01001543 unsigned int dim0 = input0.GetShape()[i];
1544 unsigned int dim1 = input1.GetShape()[i];
1545 if (dim0 != dim1 && dim0 != 1 && dim1 != 1)
telsoa01c577f2c2018-08-31 09:22:23 +01001546 {
James Ward58dec6b2020-09-11 17:32:44 +01001547 throw ParseException(
1548 fmt::format("Broadcast is only supported for scalar or 1D tensors in Add node '{}'. "
1549 "Input dimensions should either match or one should be of size 1 and here, "
1550 "{} and {} {}",
1551 node.name(),
1552 TensorInfoAsString(*m_TensorsInfo[inputs.first].m_info, inputs.first,
1553 m_TensorsInfo[inputs.first].m_dtype),
1554 TensorInfoAsString(*m_TensorsInfo[inputs.second].m_info, inputs.second,
1555 m_TensorsInfo[inputs.second].m_dtype),
1556 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001557 }
telsoa01c577f2c2018-08-31 09:22:23 +01001558 }
Ryan OSheaed27ee72020-04-22 16:37:29 +01001559
1560
Mike Kelly3ec30772023-03-08 13:47:17 +00001561 IConnectableLayer* layer = m_Network->AddElementwiseBinaryLayer(BinaryOperation::Add, node.name().c_str());
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001562 ARMNN_ASSERT(layer != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +01001563
1564 auto outputInfo = ComputeOutputInfo({ node.output(0) }, layer,
Ryan OSheaed27ee72020-04-22 16:37:29 +01001565 { m_TensorsInfo[inputs.first].m_info->GetShape(),
1566 m_TensorsInfo[inputs.second].m_info->GetShape() });
telsoa01c577f2c2018-08-31 09:22:23 +01001567 layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
1568
Ryan OSheaed27ee72020-04-22 16:37:29 +01001569 // register the input connection -> for constant inputs, we need to make a newDim constant layer
1570 if(m_TensorsInfo[inputs.first].isConstant()) {
James Ward58dec6b2020-09-11 17:32:44 +01001571 CreateConstantLayer(inputs.first, fmt::format("Add:constant_of_{}", node.input(0)));
Ryan OSheaed27ee72020-04-22 16:37:29 +01001572 }
1573 if(m_TensorsInfo[inputs.second].isConstant()) {
James Ward58dec6b2020-09-11 17:32:44 +01001574 CreateConstantLayer(inputs.second, fmt::format("Add:constant_of_{}", node.input(1)));
Ryan OSheaed27ee72020-04-22 16:37:29 +01001575 }
1576 RegisterInputSlots(layer, {inputs.first, inputs.second});
telsoa01c577f2c2018-08-31 09:22:23 +01001577
Ryan OSheaed27ee72020-04-22 16:37:29 +01001578 // register the output connection
telsoa01c577f2c2018-08-31 09:22:23 +01001579 RegisterOutputSlots(layer, {node.output(0)});
1580}
1581
Kevin Mayef33cb12021-01-29 14:24:57 +00001582void OnnxParserImpl::ParseAveragePool(const onnx::NodeProto& node)
Ryan OSheaed27ee72020-04-22 16:37:29 +01001583{
1584 Pooling2dDescriptor desc;
1585 desc.m_PoolType = PoolingAlgorithm::Average;
1586
1587 uint32_t count_include_pad = 0;
1588 count_include_pad = ReadOptionalNodeUint32Attribute(node, "count_include_pad");
1589 if(count_include_pad) {
1590 desc.m_PaddingMethod = PaddingMethod::IgnoreValue;
1591 }
1592 AddPoolingLayer(node, desc);
1593}
1594
Kevin Mayef33cb12021-01-29 14:24:57 +00001595void OnnxParserImpl::ParseBatchNormalization(const onnx::NodeProto& node)
Ryan OSheaed27ee72020-04-22 16:37:29 +01001596{
1597 //IGNORE momentum parameter and spatial parameters
1598
1599 CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 5);
1600 CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
1601
1602 VALID_INPUTS(node, STR_LIST(onnx::TensorProto::FLOAT));
1603 for(int ind = 1; ind < node.input_size(); ++ind)
1604 {
1605 auto tensor = node.input(ind);
1606 if(! m_TensorsInfo[tensor].isConstant())
1607 {
James Ward58dec6b2020-09-11 17:32:44 +01001608 throw ParseException(
1609 fmt::format("Input tensor '{}' should be constant in BatchNormalization node '{}' {}",
1610 tensor,
1611 node.name(),
1612 CHECK_LOCATION().AsString()));
Ryan OSheaed27ee72020-04-22 16:37:29 +01001613 }
1614 }
1615
1616 float epsilon = ReadOptionalNodeFloatAttribute(node, "epsilon", 1e-5f);
1617 BatchNormalizationDescriptor desc;
1618 desc.m_Eps = epsilon;
1619
1620 auto scaleTensor = CreateConstTensor(node.input(1));
1621 auto biasTensor = CreateConstTensor(node.input(2));
1622 auto meanTensor = CreateConstTensor(node.input(3));
1623 auto varTensor = CreateConstTensor(node.input(4));
1624
1625 IConnectableLayer* layer = m_Network->AddBatchNormalizationLayer(desc,
1626 meanTensor.first,
1627 varTensor.first,
1628 biasTensor.first,
1629 scaleTensor.first,
1630 node.name().c_str());
1631 ARMNN_ASSERT(layer != nullptr);
1632
1633 auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, {m_TensorsInfo[node.input(0)].m_info->GetShape()});
1634 layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
1635
1636 RegisterInputSlots(layer, {node.input(0)}); //don't register constant inputs
1637
1638 // register the output connection
1639 RegisterOutputSlots(layer, {node.output(0)});
1640}
1641
Narumol Prangnawaratbc3bb622021-09-24 16:08:34 +01001642void OnnxParserImpl::ParseConcat(const onnx::NodeProto& node)
1643{
1644 CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
1645
1646 uint32_t numConcatView = static_cast<uint32_t>(node.input_size());
1647 uint32_t inputRank = m_TensorsInfo[node.input(0)].m_info->GetNumDimensions();
1648
1649 int axisInt = ReadMandatoryNodeIntAttribute(node, "axis");
1650
1651 unsigned int concatDimInput = static_cast<unsigned int>(
1652 (static_cast<int>(inputRank) + axisInt) % static_cast<int>(inputRank));
1653
1654 OriginsDescriptor concatDescriptor(numConcatView, inputRank);
1655 concatDescriptor.SetConcatAxis(concatDimInput);
1656
1657 unsigned int mergeDimOrigin = 0;
1658
1659 std::vector<TensorShape> inputShapes;
1660 std::vector<std::string> tensorIds;
1661
1662 for (unsigned int viewIndex = 0; viewIndex < numConcatView; ++viewIndex)
1663 {
1664 std::string nodeName = node.input(static_cast<int>(viewIndex));
1665 auto inputTensorInfo = *m_TensorsInfo[nodeName].m_info;
1666 inputShapes.push_back(inputTensorInfo.GetShape());
1667 tensorIds.push_back(nodeName);
1668
1669 // Set up concatDescriptor view origin
1670 armnnUtils::ProcessConcatInputTensorInfo(
1671 inputTensorInfo, concatDescriptor, concatDimInput, viewIndex, mergeDimOrigin);
1672 }
1673
1674 IConnectableLayer* layer = m_Network->AddConcatLayer(concatDescriptor, node.name().c_str());
1675 ARMNN_ASSERT(layer != nullptr);
1676
Narumol Prangnawarat452274c2021-09-23 16:12:19 +01001677 auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, inputShapes,
1678 m_TensorsInfo[node.input(0)].m_dtype);
Narumol Prangnawaratbc3bb622021-09-24 16:08:34 +01001679
1680 layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
1681
1682 // register the input connection slots for the layer, connections are made after all layers have been created
1683 RegisterInputSlots(layer, tensorIds);
1684
1685 // register the output connection slots for the layer, connections are made after all layers have been created
1686 RegisterOutputSlots(layer, { node.output(0) });
1687}
1688
Kevin Mayef33cb12021-01-29 14:24:57 +00001689void OnnxParserImpl::ParseConstant(const onnx::NodeProto& node)
Ryan OSheaed27ee72020-04-22 16:37:29 +01001690{
1691 CHECK_VALID_SIZE(static_cast<size_t>(node.attribute_size()), 1);
1692 if (!node.attribute(0).has_t())
1693 {
James Ward58dec6b2020-09-11 17:32:44 +01001694 throw ParseException(fmt::format("Value not found for Constant node '{}' {}",
1695 node.name(),
1696 CHECK_LOCATION().AsString()));
Ryan OSheaed27ee72020-04-22 16:37:29 +01001697 }
1698 const onnx::TensorProto& onnxTensor = node.attribute(0).t();
1699
Ryan OSheaed27ee72020-04-22 16:37:29 +01001700 //Register this as a m_ConstParam so we know we can use it as a constant param in future layers.
1701 m_TensorsInfo[node.output(0)].m_tensor = std::make_unique<const onnx::TensorProto>(onnxTensor);
1702 m_TensorsInfo[node.output(0)].m_info = std::make_unique<TensorInfo>(ToTensorInfo(onnxTensor));
1703 m_TensorsInfo[node.output(0)].m_dtype = static_cast<onnx::TensorProto::DataType>(onnxTensor.data_type());
1704
Narumol Prangnawaratf10b15a2021-09-17 21:08:57 +01001705 if (m_TensorsInfo[node.output(0)].m_dtype == onnx::TensorProto_DataType_FLOAT)
1706 {
1707 CreateConstantLayer(node.output(0), node.name());
1708 }
1709 else if (m_TensorsInfo[node.output(0)].m_dtype == onnx::TensorProto_DataType_INT64)
1710 {
1711 CreateInt64ConstantLayer(node.output(0), node.name());
1712 }
1713 else
1714 {
1715 throw ParseException(fmt::format("Data type not support for Constant node '{}' {}",
1716 node.name(),
1717 CHECK_LOCATION().AsString()));
1718 }
Ryan OSheaed27ee72020-04-22 16:37:29 +01001719}
1720
Kevin Mayef33cb12021-01-29 14:24:57 +00001721void OnnxParserImpl::ParseConv(const onnx::NodeProto& node)
telsoa01c577f2c2018-08-31 09:22:23 +01001722{
1723 CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 2, 3); //input, weight, (bias)
1724 CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
1725
1726 VALID_INPUTS(node, STR_LIST(onnx::TensorProto::FLOAT));
1727
1728 if(m_TensorsInfo[node.input(0)].m_info->GetNumDimensions() != 4)
1729 {
James Ward58dec6b2020-09-11 17:32:44 +01001730 throw ParseException(
1731 fmt::format("ArmNN only supports 2D convolution and Conv layer '{}' input {} {}",
1732 node.name(),
1733 TensorInfoAsString(*m_TensorsInfo[node.input(0)].m_info, node.input(0),
1734 m_TensorsInfo[node.input(0)].m_dtype),
1735 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001736 }
1737
1738 if(!m_TensorsInfo[node.input(1)].isConstant())
1739 {
James Ward58dec6b2020-09-11 17:32:44 +01001740 throw ParseException(
1741 fmt::format("Weights '{}' should be constant in Conv layer '{}' {}",
1742 node.input(1),
1743 node.name(),
1744 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001745 }
1746
1747 auto inputInfo = *m_TensorsInfo[node.input(0)].m_info;
1748
telsoa01c577f2c2018-08-31 09:22:23 +01001749 Convolution2dDescriptor desc;
1750 desc.m_BiasEnabled = false;
1751
1752 std::vector<uint32_t> strides = ReadOptionalNodeUint32ListAttribute(node, "strides");
1753 if(strides.empty())
1754 {
1755 desc.m_StrideX = 1;
1756 desc.m_StrideY = 1;
1757 }
1758 else
1759 {
1760 desc.m_StrideX = strides[1];
1761 desc.m_StrideY = strides[0];
1762 }
1763
Sadik Armagan60bb9d82021-01-11 15:15:01 +00001764 std::vector<uint32_t> dilations = ReadOptionalNodeUint32ListAttribute(node, "dilations");
1765 if(!dilations.empty())
1766 {
1767 desc.m_DilationX = dilations[1];
1768 desc.m_DilationY = dilations[0];
1769 }
1770
telsoa01c577f2c2018-08-31 09:22:23 +01001771 std::vector<uint32_t> pads = ReadOptionalNodeUint32ListAttribute(node, "pads");
1772 //Check new padding version first
1773 if(pads.empty())
1774 {
1775 //Check deprecated version
1776 std::string paddingString = ReadOptionalNodeStringAttribute(node, "auto_pad");
1777 if(paddingString != "VALID" && paddingString != "" && paddingString != "NOTSET")
1778 {
1779 bool isUpper;
1780 if( paddingString == "SAME_LOWER")
1781 {
1782 isUpper = false;
1783 }
1784 else if (paddingString == "SAME_UPPER")
1785 {
1786 isUpper = true;
1787 }
1788 else
1789 {
James Ward58dec6b2020-09-11 17:32:44 +01001790 throw ParseException(
1791 fmt::format("Invalid auto_pad attribute for node {}. Only SAME_UPPER, SAME_LOWER or VALID "
1792 "supported and found {} {}",
1793 node.name(),
1794 paddingString,
1795 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001796 }
1797 uint32_t inputHeight = inputInfo.GetShape()[2];
1798 uint32_t inputWidth = inputInfo.GetShape()[3];
1799
1800 uint32_t weightHeight;
1801 uint32_t weightWidth;
1802 std::vector<uint32_t> kernel_shape = ReadOptionalNodeUint32ListAttribute(node, "kernel_shape");
1803 if (kernel_shape.empty())
1804 {
1805 const TensorInfo weightTensorInfo = *m_TensorsInfo[node.input(1)].m_info;
1806 weightHeight = weightTensorInfo.GetShape()[2];
1807 weightWidth = weightTensorInfo.GetShape()[3];
1808 }
1809 else
1810 {
1811 weightHeight = kernel_shape[0];
1812 weightWidth = kernel_shape[1];
1813 }
Sadik Armagan60bb9d82021-01-11 15:15:01 +00001814 CalcPadding(inputHeight,
1815 weightHeight,
1816 desc.m_StrideY,
1817 desc.m_DilationY,
1818 &desc.m_PadTop,
1819 &desc.m_PadBottom,
1820 isUpper);
1821 CalcPadding(inputWidth,
1822 weightWidth,
1823 desc.m_StrideX,
1824 desc.m_DilationX,
1825 &desc.m_PadLeft,
1826 &desc.m_PadRight,
1827 isUpper);
telsoa01c577f2c2018-08-31 09:22:23 +01001828 }
1829 }
1830 else
1831 {
1832 desc.m_PadTop = pads[0];
1833 desc.m_PadLeft = pads[1];
1834 desc.m_PadBottom = pads[2];
1835 desc.m_PadRight = pads[3];
1836 }
1837
1838 uint32_t group = ReadOptionalNodeUint32Attribute(node, "group", 1);
1839 if(group > 1)
1840 {
1841 if (group > inputInfo.GetShape()[1])
1842 {
1843 throw ParseException(
James Ward58dec6b2020-09-11 17:32:44 +01001844 fmt::format("Error parsing Convolution node: {}. "
1845 "The 'group'={} parameter cannot be larger than the "
1846 "channel of the input shape={} (in NCHW format). {}",
1847 node.name(),
1848 group,
1849 inputInfo.GetShape()[1],
1850 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001851 }
1852 else if (group == inputInfo.GetShape()[1])
1853 {
1854 // we use a depthwise convolution here, because the number of groups equals to the
1855 // input channels
1856 AddConvLayerWithDepthwiseConv(node, desc);
1857 return;
1858 }
1859 else
1860 {
1861 // TODO: split the input by channels into channels/groups separate convolutions
Jim Flynne242f2d2019-05-22 14:24:13 +01001862 // and concatenate the results afterwards
James Ward58dec6b2020-09-11 17:32:44 +01001863 throw ParseException(fmt::format("Error parsing Convolution node: {}. "
1864 "The 'group'={} parameter should be 1 or be equal to the "
1865 "channel of the input shape={} (in NCHW format). {}",
1866 node.name(),
1867 group,
1868 inputInfo.GetShape()[1],
1869 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001870 }
1871 }
1872
Keith Davis721e6292022-05-17 10:06:53 +01001873 node.input_size() == 3 ? desc.m_BiasEnabled = true : desc.m_BiasEnabled = false;
1874 armnn::IConnectableLayer* layer = m_Network->AddConvolution2dLayer(desc, node.name().c_str());
Keith Davisb4dd5cc2022-04-07 11:32:00 +01001875 std::vector<std::string> tensorIndexes= {node.input(0), node.input(1)};
1876
telsoa01c577f2c2018-08-31 09:22:23 +01001877 auto weightTensor = CreateConstTensor(node.input(1));
1878
Keith Davis721e6292022-05-17 10:06:53 +01001879 IConnectableLayer* weightsLayer = m_Network->AddConstantLayer(weightTensor.first);
1880 weightsLayer->GetOutputSlot(0).SetTensorInfo(weightTensor.first.GetInfo());
1881 weightsLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1u));
1882
telsoa01c577f2c2018-08-31 09:22:23 +01001883 if (node.input_size() == 3)
1884 {
1885 if(!m_TensorsInfo[node.input(2)].isConstant())
1886 {
James Ward58dec6b2020-09-11 17:32:44 +01001887 throw ParseException(fmt::format("Bias '{}' should be constant in Conv layer '{}' {}",
1888 node.input(2),
1889 node.name(),
1890 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001891 }
1892 desc.m_BiasEnabled = true;
1893 auto biasTensor = CreateConstTensor(node.input(2));
Keith Davis721e6292022-05-17 10:06:53 +01001894
1895 IConnectableLayer* biasLayer = m_Network->AddConstantLayer(biasTensor.first);
1896 biasLayer->GetOutputSlot(0).SetTensorInfo(biasTensor.first.GetInfo());
1897 biasLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(2u));
1898
1899 tensorIndexes.emplace_back(node.input(2));
telsoa01c577f2c2018-08-31 09:22:23 +01001900 }
Keith Davis721e6292022-05-17 10:06:53 +01001901
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001902 ARMNN_ASSERT(layer != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +01001903
1904 auto outputInfo = ComputeOutputInfo({ node.output(0) }, layer,
1905 { m_TensorsInfo[node.input(0)].m_info->GetShape(),
1906 m_TensorsInfo[node.input(1)].m_info->GetShape() });
1907 layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
1908
1909 // register the input connection slots for the layer, connections are made after all layers have been created
1910 // only the tensors for the inputs are relevant, exclude the const tensors
Keith Davisb4dd5cc2022-04-07 11:32:00 +01001911 RegisterInputSlots(layer, tensorIndexes);
telsoa01c577f2c2018-08-31 09:22:23 +01001912
1913 // register the output connection slots for the layer, connections are made after all layers have been created
1914 RegisterOutputSlots(layer, {node.output(0)});
1915}
1916
Kevin Mayef33cb12021-01-29 14:24:57 +00001917void OnnxParserImpl::ParseFlatten(const onnx::NodeProto& node)
Ryan OSheaed27ee72020-04-22 16:37:29 +01001918{
1919 CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 1);
1920 CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
1921
1922 CHECK_VALID_DATATYPE(node.name(), node.input(0),
1923 m_TensorsInfo[node.input(0)].m_dtype,
1924 onnx::TensorProto::FLOAT);
1925
1926 int64_t axis = ReadOptionalNodeInt64Attribute(node, "axis", 1);
1927 TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape();
1928
1929 /// Negative axis conversion
1930 if (axis < 0)
1931 {
1932 axis += inputShape.GetNumDimensions();
1933 }
1934
1935 /// Check Axis is within dimensions
1936 if (axis < 0 || axis >= inputShape.GetNumDimensions())
1937 {
James Ward58dec6b2020-09-11 17:32:44 +01001938 throw ParseException(fmt::format("Axis '{}' invalid. Tensor has '{}' dimensions in FlattenLayer '{}'",
1939 axis, inputShape.GetNumDimensions(), node.name()));
Ryan OSheaed27ee72020-04-22 16:37:29 +01001940 }
1941
1942 /// If axis chosen is 0 dimension1 will always be 1 in output , default dimension2 to 1 because 0 is invalid
1943 uint dimension1{1};
1944 uint dimension2{1};
1945 uint i{0};
1946
1947 /// dimension1 = (d_0 * d_1 ... d_(axis-1))
1948 for (i = 0; i < axis; i++){
1949 dimension1 *= inputShape[i];
1950 }
1951
1952 /// dimension2 = (d_axis * d_(axis+1) ... d_n)
1953 for (i = static_cast<uint>(axis); i < inputShape.GetNumDimensions(); i++){
1954 dimension2 *= inputShape[i];
1955 }
1956
1957 TensorShape outputShape{dimension1, dimension2};
1958
1959 auto outInfo = ComputeReshapeInfo(outputShape, inputShape, node.output(0));
1960 m_TensorsInfo[node.output(0)].m_info = std::make_unique<TensorInfo>(outInfo);
1961 CreateReshapeLayer(node.input(0), node.output(0), node.name());
1962}
1963
Narumol Prangnawaratf10b15a2021-09-17 21:08:57 +01001964void OnnxParserImpl::ParseGather(const onnx::NodeProto& node)
1965{
1966 CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 2);
1967 CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
1968
1969 armnn::GatherDescriptor gatherDescriptor;
1970 gatherDescriptor.m_Axis = static_cast<int>(ReadOptionalNodeInt64Attribute(node, "axis", 0));
1971
1972 IConnectableLayer* layer = m_Network->AddGatherLayer(gatherDescriptor, node.name().c_str());
1973 ARMNN_ASSERT(layer != nullptr);
1974
Narumol Prangnawarat452274c2021-09-23 16:12:19 +01001975 const TensorShape& inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape();
1976 const TensorShape& indicesShape = m_TensorsInfo[node.input(1)].m_info->GetShape();
1977 auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, { inputShape, indicesShape },
1978 m_TensorsInfo[node.input(0)].m_dtype);
Narumol Prangnawaratf10b15a2021-09-17 21:08:57 +01001979 layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
1980
1981 // register the input connection slots for the layer, connections are made after all layers have been created
1982 RegisterInputSlots(layer, { node.input(0), node.input(1) });
1983
1984 // register the output connection slots for the layer, connections are made after all layers have been created
1985 RegisterOutputSlots(layer, { node.output(0) });
1986}
1987
Narumol Prangnawarat1112b012021-09-30 12:10:50 +01001988void OnnxParserImpl::ParseGemm(const onnx::NodeProto& node)
1989{
1990 CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 2, 3);
1991 CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
1992
1993 int transA = static_cast<int>(ReadOptionalNodeUint32Attribute(node, "transA", 0));
1994 int transB = static_cast<int>(ReadOptionalNodeUint32Attribute(node, "transB", 0));
1995 float alpha = ReadOptionalNodeFloatAttribute(node, "alpha", 1.0);
1996 float beta = ReadOptionalNodeFloatAttribute(node, "beta", 1.0);
1997 bool biasEnabled = node.input_size() == 3;
1998
1999 TensorShape input0Shape = m_TensorsInfo[node.input(0)].m_info->GetShape();
2000 TensorShape input1Shape = m_TensorsInfo[node.input(1)].m_info->GetShape();
2001
2002 // if transB != 0, add transpose to the input1 (tanspose weight matrix in FullyConnected)
2003 armnn::FullyConnectedDescriptor fullyConnectedDescriptor;
2004 fullyConnectedDescriptor.m_BiasEnabled = biasEnabled;
2005 fullyConnectedDescriptor.m_TransposeWeightMatrix = transB;
2006
2007 IConnectableLayer* layer = nullptr;
2008
2009 // Just add a FullyConnected layer, weights and biases are handled as inputs now.
2010 layer = m_Network->AddFullyConnectedLayer(fullyConnectedDescriptor, node.name().c_str());
2011 ARMNN_ASSERT(layer != nullptr);
2012
2013 // if transA != 0, add transpose to the input0
2014 if (transA != 0)
2015 {
2016 std::string transAName = "transpose_" + node.input(0);
2017 armnn::TransposeDescriptor transposeADescriptor;
2018 transposeADescriptor.m_DimMappings = { 1, 0 };
2019 IConnectableLayer* transALayer = m_Network->AddTransposeLayer(transposeADescriptor, transAName.c_str());
2020 ARMNN_ASSERT(transALayer != nullptr);
2021 auto transAInfo = ComputeOutputInfo({ transAName }, transALayer, { input0Shape });
2022 transALayer->GetOutputSlot(0).SetTensorInfo(transAInfo[0]);
2023 transALayer->GetOutputSlot(0).Connect(layer->GetInputSlot(0u));
2024 // register the input connection slots for the layer, connections are made after all layers have been created
2025 RegisterInputSlot(transALayer, node.input(0), 0);
2026 input0Shape = transAInfo[0].GetShape();
2027 }
2028 else
2029 {
2030 RegisterInputSlot(layer, node.input(0), 0);
2031 }
2032
2033 // Add constant layer to store weights/biases and connect to FullyConnected layer.
2034 if(m_TensorsInfo[node.input(1)].isConstant())
2035 {
2036 IConnectableLayer* weightsLayer = m_Network->AddConstantLayer(CreateConstTensor(node.input(1)).first);
2037 TensorInfo weightInfo = *m_TensorsInfo[node.input(1)].m_info;
2038 weightInfo.SetConstant();
2039 weightsLayer->GetOutputSlot(0).SetTensorInfo(weightInfo);
2040
2041 // if alpha != 1, multiply to the weight
2042 if (alpha != 1)
2043 {
2044 std::string activationName = "activation_" + node.input(1);
2045 armnn::ActivationDescriptor activationDescriptor;
2046 activationDescriptor.m_A = alpha;
2047 activationDescriptor.m_Function = ActivationFunction::Linear;
2048 IConnectableLayer* actLayer = m_Network->AddActivationLayer(activationDescriptor, activationName.c_str());
2049 ARMNN_ASSERT(actLayer != nullptr);
2050
2051 auto actInfo = ComputeOutputInfo({ activationName }, actLayer, { weightInfo.GetShape() });
2052 actLayer->GetOutputSlot(0).SetTensorInfo(actInfo[0]);
2053 actLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1u));
2054 weightsLayer->GetOutputSlot(0).Connect(actLayer->GetInputSlot(0u));
2055 input1Shape = actInfo[0].GetShape();
2056 }
2057 else
2058 {
2059 weightsLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1u));
2060 input1Shape = weightInfo.GetShape();
2061 }
2062 }
2063 else
2064 {
2065 // if alpha != 1, multiply to the weight
2066 if (alpha != 1)
2067 {
2068 std::string activationName = "activation_" + node.input(1);
2069 armnn::ActivationDescriptor activationDescriptor;
2070 activationDescriptor.m_A = alpha;
2071 activationDescriptor.m_Function = ActivationFunction::Linear;
2072 IConnectableLayer* actLayer = m_Network->AddActivationLayer(activationDescriptor, activationName.c_str());
2073 ARMNN_ASSERT(actLayer != nullptr);
2074
2075 auto actInfo = ComputeOutputInfo({ activationName }, actLayer, { input1Shape });
2076 actLayer->GetOutputSlot(0).SetTensorInfo(actInfo[0]);
2077 actLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1u));
2078 RegisterInputSlot(actLayer, node.input(1), 0);
2079 input1Shape = actInfo[0].GetShape();
2080 }
2081 else
2082 {
2083 RegisterInputSlot(layer, node.input(1), 1);
2084 }
2085 }
2086
2087 if(biasEnabled && m_TensorsInfo[node.input(2)].isConstant())
2088 {
2089 To1DTensor(node.input(2), CHECK_LOCATION());
2090 IConnectableLayer* biasLayer = m_Network->AddConstantLayer(CreateConstTensor(node.input(2)).first);
2091 TensorInfo biasInfo = *m_TensorsInfo[node.input(2)].m_info;
2092 biasInfo.SetConstant();
2093 biasLayer->GetOutputSlot(0).SetTensorInfo(biasInfo);
2094
2095 // if beta != 1, multiply to the bias
2096 if (beta != 1)
2097 {
2098 std::string activationName = "activation_" + node.input(2);
2099 armnn::ActivationDescriptor activationDescriptor;
2100 activationDescriptor.m_A = beta;
2101 activationDescriptor.m_Function = ActivationFunction::Linear;
2102 IConnectableLayer* actLayer = m_Network->AddActivationLayer(activationDescriptor, activationName.c_str());
2103 ARMNN_ASSERT(actLayer != nullptr);
2104
2105 auto actInfo = ComputeOutputInfo({ activationName }, actLayer, { biasInfo.GetShape() });
2106 actLayer->GetOutputSlot(0).SetTensorInfo(actInfo[0]);
2107 actLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(2u));
2108 biasLayer->GetOutputSlot(0).Connect(actLayer->GetInputSlot(0u));
2109 }
2110 else
2111 {
2112 biasLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(2u));
2113 }
2114 }
2115 else if (biasEnabled)
2116 {
2117 // Currently we support non-constant tensor of input C (bias) of Gemm when the dimension is 1
2118 if (m_TensorsInfo[node.input(2)].m_info->GetNumDimensions() != 1)
2119 {
2120 throw ParseException(fmt::format("The parser supports constant or non-constant with 1 dimension for "
2121 "Input C of Gemm. Input '{}' in '{}' is not supported '{}'",
2122 node.input(2),
2123 node.name(),
2124 CHECK_LOCATION().AsString()));
2125 }
2126 // if beta != 1, multiply to the bias
2127 if (beta != 1)
2128 {
2129 std::string activationName = "activation_" + node.input(2);
2130 armnn::ActivationDescriptor activationDescriptor;
2131 activationDescriptor.m_A = beta;
2132 activationDescriptor.m_Function = ActivationFunction::Linear;
2133 IConnectableLayer* actLayer = m_Network->AddActivationLayer(activationDescriptor, activationName.c_str());
2134 ARMNN_ASSERT(actLayer != nullptr);
2135
2136 auto actInfo = ComputeOutputInfo({ activationName },
2137 actLayer,
2138 { m_TensorsInfo[node.input(2)].m_info->GetShape() });
2139 actLayer->GetOutputSlot(0).SetTensorInfo(actInfo[0]);
2140 actLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(2u));
2141 RegisterInputSlot(actLayer, node.input(2), 0);
2142 }
2143 else
2144 {
2145 RegisterInputSlot(layer, node.input(2), 2);
2146 }
2147 }
2148
2149 // Set final output of the FullyConnected layer
2150 auto outputInfo = ComputeOutputInfo({ node.output(0) }, layer,
2151 { input0Shape, input1Shape });
2152 layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
2153
2154 RegisterOutputSlots(layer, {node.output(0)});
2155}
2156
Kevin Mayef33cb12021-01-29 14:24:57 +00002157void OnnxParserImpl::ParseGlobalAveragePool(const onnx::NodeProto& node)
Ryan OSheaed27ee72020-04-22 16:37:29 +01002158{
2159 Pooling2dDescriptor desc = Pooling2dDescriptor();
2160 desc.m_PoolType = PoolingAlgorithm::Average;
2161
2162 //kernel size is the same as input
2163 TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape();
2164 desc.m_PoolWidth = inputShape[3];
2165 desc.m_PoolHeight = inputShape[2];
2166
2167 IConnectableLayer* layer = m_Network->AddPooling2dLayer(desc, node.name().c_str());
2168 ARMNN_ASSERT(layer != nullptr);
2169
2170 auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, {inputShape});
2171 layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
2172
2173 // register the input connection slots for the layer, connections are made after all layers have been created
2174 // only the tensors for the inputs are relevant, exclude the const tensors
2175 RegisterInputSlots(layer, {node.input(0)});
2176
2177 // register the output connection slots for the layer, connections are made after all layers have been created
2178 RegisterOutputSlots(layer, {node.output(0)});
2179}
2180
Kevin Mayef33cb12021-01-29 14:24:57 +00002181void OnnxParserImpl::ParseMaxPool(const onnx::NodeProto& node)
Ryan OSheaed27ee72020-04-22 16:37:29 +01002182{
2183 Pooling2dDescriptor desc;
2184 desc.m_PoolType = PoolingAlgorithm::Max;
2185 desc.m_PaddingMethod = PaddingMethod::Exclude;
2186 AddPoolingLayer(node, desc);
2187}
2188
Narumol Prangnawaratcdc495e2021-09-16 18:13:39 +01002189void OnnxParserImpl::ParseShape(const onnx::NodeProto& node)
2190{
2191 CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 1);
2192 CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
2193
Narumol Prangnawaratcdc495e2021-09-16 18:13:39 +01002194 IConnectableLayer* layer = m_Network->AddShapeLayer(node.name().c_str());
2195 ARMNN_ASSERT(layer != nullptr);
2196
2197 TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape();
Narumol Prangnawarat452274c2021-09-23 16:12:19 +01002198 auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, {inputShape}, onnx::TensorProto::INT64);
Narumol Prangnawaratcdc495e2021-09-16 18:13:39 +01002199 layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
2200
2201 // register the input connection slots for the layer, connections are made after all layers have been created
2202 RegisterInputSlots(layer, {node.input(0)});
2203
2204 // register the output connection slots for the layer, connections are made after all layers have been created
2205 RegisterOutputSlots(layer, {node.output(0)});
2206}
2207
Kevin Mayef33cb12021-01-29 14:24:57 +00002208void OnnxParserImpl::ParseReshape(const onnx::NodeProto& node)
Ryan OSheaed27ee72020-04-22 16:37:29 +01002209{
2210 CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 2);
2211 CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
2212
2213 CHECK_VALID_DATATYPE(node.name(), node.input(0),
2214 m_TensorsInfo[node.input(0)].m_dtype,
2215 onnx::TensorProto::FLOAT); //input
2216 CHECK_VALID_DATATYPE(node.name(), node.input(1),
2217 m_TensorsInfo[node.input(1)].m_dtype,
2218 onnx::TensorProto::INT64); //shape
2219
Narumol Prangnawarat4b536e32021-10-18 12:35:19 +01002220 TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape();
2221
2222 std::vector<unsigned int> targetShape;
2223 if(m_TensorsInfo[node.input(1)].isConstant())
Ryan OSheaed27ee72020-04-22 16:37:29 +01002224 {
Narumol Prangnawarat4b536e32021-10-18 12:35:19 +01002225 unsigned int dims = static_cast<unsigned int>(m_TensorsInfo[node.input(1)].m_tensor->int64_data_size());
2226 targetShape.reserve(dims);
2227
2228 for(uint i = 0; i < dims; i++)
2229 {
2230 int val = CHECKED_INT32(m_TensorsInfo[node.input(1)].m_tensor->int64_data(static_cast<int>(i)));
2231 targetShape[i]= static_cast<unsigned int>(val);
2232 }
2233 }
2234 else
2235 {
2236 // The parser only supports shape (batch, -1) or (-1) for non-constant shape input.
2237 unsigned int dims = m_TensorsInfo[node.input(1)].m_info->GetNumDimensions();
2238 TensorShape shapes = m_TensorsInfo[node.input(1)].m_info->GetShape();
2239 if (dims != 1 || shapes[0] > 2)
2240 {
2241 throw ParseException(fmt::format("Invalid input shape '{}' in Reshape layer '{}' {}",
2242 node.input(1),
2243 node.name(),
2244 CHECK_LOCATION().AsString()));
2245 }
2246
2247 unsigned int numInputElements = m_TensorsInfo[node.input(0)].m_info->GetNumElements();
2248 if (shapes[0] == 1)
2249 {
2250 targetShape = { numInputElements };
2251 }
2252 else if (shapes[0] == 2)
2253 {
2254 targetShape = { inputShape[0] , numInputElements / inputShape[0] };
2255 }
Ryan OSheaed27ee72020-04-22 16:37:29 +01002256 }
2257
2258 if(m_TensorsInfo[node.input(0)].isConstant())
2259 {
2260 //make a new cst tensor -> move the data to the output tensor (the shape is already good in the output tensor)
2261 if(m_TensorsInfo.count(node.output(0)) == 0)
2262 {
2263 m_TensorsInfo[node.output(0)] = OnnxTensor();
2264 }
2265 m_TensorsInfo[node.output(0)].m_tensor =
2266 std::make_unique<onnx::TensorProto>(*m_TensorsInfo[node.input(0)].m_tensor);
2267 }
2268 else
2269 {
Ryan OSheaed27ee72020-04-22 16:37:29 +01002270 if(m_TensorsInfo.count(node.output(0)) == 0 || m_TensorsInfo[node.output(0)].m_info == nullptr)
2271 {
Narumol Prangnawarat4b536e32021-10-18 12:35:19 +01002272 auto outInfo = ComputeReshapeInfo(
2273 TensorShape(static_cast<unsigned int>(targetShape.size()), targetShape.data()),
2274 inputShape, node.output(0));
Ryan OSheaed27ee72020-04-22 16:37:29 +01002275 m_TensorsInfo[node.output(0)].m_info = std::make_unique<TensorInfo>(outInfo);
2276 }
2277
2278 CreateReshapeLayer(node.input(0), node.output(0), node.name());
2279 }
2280}
2281
Narumol Prangnawaratfe6aa2f2021-09-23 16:11:17 +01002282void OnnxParserImpl::ParseUnsqueeze(const onnx::NodeProto& node)
2283{
2284 CHECK_VALID_SIZE(armnn::numeric_cast<size_t>(node.input_size()), 1, 2);
2285 CHECK_VALID_SIZE(armnn::numeric_cast<size_t>(node.output_size()), 1);
2286
Narumol Prangnawaratfe6aa2f2021-09-23 16:11:17 +01002287 TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape();
2288 std::vector<uint32_t> dims;
2289 if (node.input_size() == 1 && node.attribute_size() > 0)
2290 {
2291 dims = ReadMandatoryNodeUint32ListAttribute(node, "axes");
2292 }
2293 else
2294 {
2295 CHECK_VALID_DATATYPE(node.name(), node.input(1),
2296 m_TensorsInfo[node.input(1)].m_dtype,
2297 onnx::TensorProto::INT64); //axes
2298
2299 auto int64Axes = m_TensorsInfo[node.input(1)].m_tensor->int64_data().data();
2300 uint numDim = armnn::numeric_cast<uint>(m_TensorsInfo[node.input(1)].m_tensor->int64_data_size());
2301
2302 for(uint i = 0; i < numDim; i++)
2303 {
2304 uint32_t uint32Value = CHECKED_NON_NEGATIVE(CHECKED_INT32(int64Axes[i]));
2305 dims.push_back(uint32Value);
2306 }
2307 }
2308
2309 // Ensure that the axes are sorted
2310 std::sort(dims.begin(), dims.end());
2311
2312 std::vector<unsigned int> targetShape;
2313
Narumol Prangnawarat452274c2021-09-23 16:12:19 +01002314 if (inputShape.GetDimensionality() != Dimensionality::Scalar)
Narumol Prangnawaratfe6aa2f2021-09-23 16:11:17 +01002315 {
Narumol Prangnawarat452274c2021-09-23 16:12:19 +01002316 for(uint i = 0; i < inputShape.GetNumDimensions(); i++)
2317 {
2318 targetShape.push_back(inputShape[i]);
2319 }
Narumol Prangnawaratfe6aa2f2021-09-23 16:11:17 +01002320 }
2321
2322 for(uint i = 0; i < dims.size(); i++)
2323 {
2324 targetShape.insert(targetShape.begin() + armnn::numeric_cast<int>(dims[i]), 1);
2325 }
2326
Narumol Prangnawarat452274c2021-09-23 16:12:19 +01002327 auto outInfo = ComputeReshapeInfo(TensorShape(static_cast<unsigned int>(targetShape.size()), targetShape.data()),
2328 inputShape, node.output(0), m_TensorsInfo[node.input(0)].m_info->GetDataType());
Narumol Prangnawaratfe6aa2f2021-09-23 16:11:17 +01002329 m_TensorsInfo[node.output(0)].m_info = std::make_unique<TensorInfo>(outInfo);
Narumol Prangnawarat452274c2021-09-23 16:12:19 +01002330 m_TensorsInfo[node.output(0)].m_dtype = m_TensorsInfo[node.input(0)].m_dtype;
Narumol Prangnawaratfe6aa2f2021-09-23 16:11:17 +01002331
2332 CreateReshapeLayer(node.input(0), node.output(0), node.name());
2333}
2334
Kevin Mayef33cb12021-01-29 14:24:57 +00002335void OnnxParserImpl::PrependForBroadcast(const std::string& outputName,
2336 const std::string& input0,
2337 const std::string& input1)
telsoa01c577f2c2018-08-31 09:22:23 +01002338{
2339 //input0 should be reshaped to have same number of dim as input1
2340 TensorInfo outputTensorInfo = TensorInfo(*m_TensorsInfo[input0].m_info);
2341
2342 TensorShape input0Shape = m_TensorsInfo[input0].m_info->GetShape();
2343 TensorShape input1Shape = m_TensorsInfo[input1].m_info->GetShape();
2344
2345 uint32_t diff = input1Shape.GetNumDimensions() - input0Shape.GetNumDimensions();
2346 std::vector<uint32_t> newShape;
2347 while(diff > 0)
2348 {
2349 newShape.push_back(1);
2350 diff--;
2351 }
2352 for (uint dim = 0; dim < input0Shape.GetNumDimensions(); ++dim)
2353 {
2354 newShape.push_back(input0Shape[dim]);
2355 }
2356 outputTensorInfo.SetShape(TensorShape(static_cast<unsigned int>(newShape.size()), newShape.data()));
2357
2358 //add the new tensor to m_TensorsInfo
2359 m_TensorsInfo[outputName] = OnnxTensor();
2360 m_TensorsInfo[outputName].m_info = std::make_unique<TensorInfo>(outputTensorInfo);
2361
2362 //add reshape layer if the parent was not constant...
2363 if( ! m_TensorsInfo[input0].isConstant())
2364 {
James Ward58dec6b2020-09-11 17:32:44 +01002365 CreateReshapeLayer(input0, outputName, fmt::format("Add:reshapeOf{}", input0));
telsoa01c577f2c2018-08-31 09:22:23 +01002366 }
2367 else //make it constant and it will be create in Add
2368 {
2369 m_TensorsInfo[outputName].m_tensor = std::make_unique<onnx::TensorProto>(*m_TensorsInfo[input0].m_tensor);
2370
2371 }
2372}
2373
Kevin Mayef33cb12021-01-29 14:24:57 +00002374void OnnxParserImpl::SetupInputLayers()
telsoa01c577f2c2018-08-31 09:22:23 +01002375{
2376 //Find user input and add their layers
2377 for(int inputIndex = 0; inputIndex < m_Graph->input_size(); ++inputIndex)
2378 {
2379 auto input = m_Graph->input(inputIndex);
Narumol Prangnawarat1b11f322021-10-13 11:44:50 +01002380 if (!m_TensorsInfo[input.name()].isConstant())
telsoa01c577f2c2018-08-31 09:22:23 +01002381 {
2382 IConnectableLayer* layer =
Narumol Prangnawarat1b11f322021-10-13 11:44:50 +01002383 m_Network->AddInputLayer(static_cast<armnn::LayerBindingId>(inputIndex), input.name().c_str());
2384 TensorInfo tensorInfo = *m_TensorsInfo[input.name()].m_info;
2385 if (tensorInfo.GetShape().GetDimensionality() == Dimensionality::NotSpecified)
2386 {
2387 if (m_InputShapes.find(input.name()) == m_InputShapes.end())
2388 {
2389 throw ParseException(fmt::format("The parser does not support dynamic tensor, "
2390 "please specify input shape for {}. {}",
2391 input.name(),
2392 CHECK_LOCATION().AsString()));
2393 }
2394 else
2395 {
2396 tensorInfo.SetShape(m_InputShapes[input.name()]);
2397 m_TensorsInfo[input.name()].m_info = std::make_unique<TensorInfo>(tensorInfo);
2398 }
2399
2400 }
telsoa01c577f2c2018-08-31 09:22:23 +01002401 layer->GetOutputSlot(0).SetTensorInfo(tensorInfo);
2402
Narumol Prangnawarat1b11f322021-10-13 11:44:50 +01002403 m_InputInfos[input.name()] = tensorInfo;
2404
telsoa01c577f2c2018-08-31 09:22:23 +01002405 RegisterOutputSlots(layer,{ input.name() });
2406 }
2407 }
2408}
2409
Kevin Mayef33cb12021-01-29 14:24:57 +00002410void OnnxParserImpl::SetupOutputLayers()
telsoa01c577f2c2018-08-31 09:22:23 +01002411{
2412 if(m_Graph->output_size() == 0)
2413 {
James Ward58dec6b2020-09-11 17:32:44 +01002414 throw ParseException(fmt::format("The given model does not have any outputs {}", CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01002415 }
2416
2417 for(int outputIndex = 0; outputIndex < m_Graph->output_size(); ++outputIndex)
2418 {
2419 IConnectableLayer* layer =
2420 m_Network->AddOutputLayer(static_cast<armnn::LayerBindingId>(outputIndex),
2421 m_Graph->output(outputIndex).name().c_str());
2422
2423 RegisterInputSlots(layer, { m_Graph->output(outputIndex).name() });
2424 }
2425}
2426
Narumol Prangnawarat1112b012021-09-30 12:10:50 +01002427void OnnxParserImpl::RegisterInputSlot(IConnectableLayer* layer,
2428 const std::string& tensorId,
2429 unsigned int slotIndex)
2430{
2431 armnn::IInputSlot* slot = &(layer->GetInputSlot(slotIndex));
2432
2433 auto it = m_TensorConnections.find(tensorId);
2434
2435 if (it == m_TensorConnections.end())
2436 {
Narumol Prangnawarat1b11f322021-10-13 11:44:50 +01002437 //First time seeing this tensor, we need to map it
Narumol Prangnawarat1112b012021-09-30 12:10:50 +01002438 m_TensorConnections[tensorId] = TensorSlots();
2439 }
2440 m_TensorConnections[tensorId].inputSlots.push_back(slot);
2441}
2442
Kevin Mayef33cb12021-01-29 14:24:57 +00002443void OnnxParserImpl::RegisterInputSlots(IConnectableLayer* layer, const std::vector<std::string>& tensorIds)
telsoa01c577f2c2018-08-31 09:22:23 +01002444{
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01002445 ARMNN_ASSERT(layer != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +01002446 if (tensorIds.size() != layer->GetNumInputSlots())
2447 {
2448 throw ParseException(
James Ward58dec6b2020-09-11 17:32:44 +01002449 fmt::format("The number of tensor inputs ({}) does not match the number expected ({}) {}",
2450 tensorIds.size(),
2451 layer->GetNumInputSlots(),
2452 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01002453 }
Matthew Sloyan81beae32021-07-13 19:46:11 +01002454
telsoa01c577f2c2018-08-31 09:22:23 +01002455 for (unsigned int slotIndex = 0; slotIndex < layer->GetNumInputSlots(); ++slotIndex)
2456 {
2457 std::string tensorId = tensorIds[slotIndex];
2458 armnn::IInputSlot* slot = &(layer->GetInputSlot(slotIndex));
2459
2460 auto it = m_TensorConnections.find(tensorId);
2461
2462 if (it == m_TensorConnections.end())
2463 {
Narumol Prangnawarat1b11f322021-10-13 11:44:50 +01002464 // First time seing this tensor, we need to map it
telsoa01c577f2c2018-08-31 09:22:23 +01002465 m_TensorConnections[tensorId] = TensorSlots();
2466 }
2467 m_TensorConnections[tensorId].inputSlots.push_back(slot);
2468 }
2469}
2470
Kevin Mayef33cb12021-01-29 14:24:57 +00002471void OnnxParserImpl::RegisterOutputSlots(IConnectableLayer* layer, const std::vector<std::string>& tensorIds)
telsoa01c577f2c2018-08-31 09:22:23 +01002472{
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01002473 ARMNN_ASSERT(layer != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +01002474 if (tensorIds.size() != layer->GetNumOutputSlots())
2475 {
2476 throw ParseException(
James Ward58dec6b2020-09-11 17:32:44 +01002477 fmt::format("The number of tensor outputs ({}) does not match the number expected ({}) {} ",
2478 tensorIds.size(),
2479 layer->GetNumOutputSlots(),
2480 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01002481 }
2482
2483 for (unsigned int slotIndex = 0; slotIndex < layer->GetNumOutputSlots(); ++slotIndex)
2484 {
2485 std::string tensorId = tensorIds[slotIndex];
2486 armnn::IOutputSlot* slot = &(layer->GetOutputSlot(slotIndex));
2487
2488 auto it = m_TensorConnections.find(tensorId);
2489
2490 if (it == m_TensorConnections.end())
2491 {
2492 //First time seing this tensor, we need to map it
2493 m_TensorConnections[tensorId] = TensorSlots();
2494 }
2495
Ryan OShea337c17f2020-02-21 12:33:17 +00002496 TensorSlots& tensorSlots = m_TensorConnections[tensorId];
telsoa01c577f2c2018-08-31 09:22:23 +01002497
2498 // assuming there is only one producer for that tensor
2499 if (tensorSlots.outputSlot != nullptr)
2500 {
James Ward58dec6b2020-09-11 17:32:44 +01002501 throw ParseException(fmt::format("Another layer has already registered itself as the producer of "
2502 "tensor:{} {}",
2503 tensorId,
2504 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01002505 }
2506 tensorSlots.outputSlot = slot;
2507 }
Narumol Prangnawarat1b11f322021-10-13 11:44:50 +01002508
telsoa01c577f2c2018-08-31 09:22:23 +01002509}
2510
Kevin Mayef33cb12021-01-29 14:24:57 +00002511BindingPointInfo OnnxParserImpl::GetNetworkInputBindingInfo(const std::string& name) const
telsoa01c577f2c2018-08-31 09:22:23 +01002512{
2513 for(int i = 0; i < m_Graph->input_size(); ++i)
2514 {
2515 auto input = m_Graph->input(i);
2516 if(input.name() == name)
2517 {
Narumol Prangnawarat1b11f322021-10-13 11:44:50 +01002518 auto it = m_InputInfos.find(name);
2519
2520 if (it != m_InputInfos.end())
2521 {
2522 return std::make_pair(static_cast<armnn::LayerBindingId>(i), it->second);
2523 }
telsoa01c577f2c2018-08-31 09:22:23 +01002524 }
2525 }
James Ward58dec6b2020-09-11 17:32:44 +01002526 throw InvalidArgumentException(fmt::format("The input layer '{}' does not exist {}",
2527 name, CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01002528}
2529
Kevin Mayef33cb12021-01-29 14:24:57 +00002530BindingPointInfo OnnxParserImpl::GetNetworkOutputBindingInfo(const std::string& name) const
telsoa01c577f2c2018-08-31 09:22:23 +01002531{
2532 for(int i = 0; i < m_Graph->output_size(); ++i)
2533 {
2534 auto output = m_Graph->output(i);
2535 if(output.name() == name)
2536 {
Narumol Prangnawarat1b11f322021-10-13 11:44:50 +01002537 auto it = m_OutputInfos.find(name);
2538
2539 if (it != m_OutputInfos.end())
2540 {
2541 return std::make_pair(static_cast<armnn::LayerBindingId>(i), it->second);
2542 }
telsoa01c577f2c2018-08-31 09:22:23 +01002543 }
2544 }
James Ward58dec6b2020-09-11 17:32:44 +01002545 throw InvalidArgumentException(fmt::format("The output layer '{}' does not exist {}",
2546 name, CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01002547}
2548
Kevin Mayef33cb12021-01-29 14:24:57 +00002549std::vector<std::string> OnnxParserImpl::GetInputs(ModelPtr& model)
telsoa01c577f2c2018-08-31 09:22:23 +01002550{
2551 if(model == nullptr) {
James Ward58dec6b2020-09-11 17:32:44 +01002552 throw InvalidArgumentException(fmt::format("The given model cannot be null {}",
2553 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01002554 }
2555
2556 std::vector<std::string> inputNames;
2557 std::map<std::string, bool> isConstant;
2558 for(auto tensor : model->graph().initializer())
2559 {
2560 isConstant[tensor.name()] = true;
2561 }
2562 for(auto input : model->graph().input())
2563 {
2564 auto it = isConstant.find(input.name());
2565 if(it == isConstant.end())
2566 {
2567 inputNames.push_back(input.name());
2568 }
2569 }
2570 return inputNames;
2571}
2572
Kevin Mayef33cb12021-01-29 14:24:57 +00002573std::vector<std::string> OnnxParserImpl::GetOutputs(ModelPtr& model)
telsoa01c577f2c2018-08-31 09:22:23 +01002574{
2575 if(model == nullptr) {
James Ward58dec6b2020-09-11 17:32:44 +01002576 throw InvalidArgumentException(fmt::format("The given model cannot be null {}",
2577 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01002578 }
2579
2580 std::vector<std::string> outputNames;
2581 for(auto output : model->graph().output())
2582 {
2583 outputNames.push_back(output.name());
2584 }
2585 return outputNames;
2586}
2587
Matthew Sloyanac001ee2021-02-03 10:43:04 +00002588const std::string OnnxParserImpl::GetVersion()
2589{
2590 return ONNX_PARSER_VERSION;
2591}
2592
telsoa01c577f2c2018-08-31 09:22:23 +01002593} // namespace armnnOnnxParser