blob: 6caf690935a803dc62b6065698b8ab4247949363 [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5#include "OnnxParser.hpp"
6
Matthew Sloyanac001ee2021-02-03 10:43:04 +00007#include "armnnOnnxParser/Version.hpp"
8
Matthew Bentham39ef3e52020-01-20 10:09:09 +00009#include <armnn/Descriptors.hpp>
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010010#include <armnn/utility/Assert.hpp>
Matthew Sloyan589e3e82020-09-11 16:17:48 +010011#include <armnn/utility/NumericCast.hpp>
Narumol Prangnawaratbc3bb622021-09-24 16:08:34 +010012#include <ParserHelper.hpp>
telsoa01c577f2c2018-08-31 09:22:23 +010013#include <VerificationHelpers.hpp>
14
James Ward58dec6b2020-09-11 17:32:44 +010015#include <fmt/format.h>
Aron Virginas-Tard4f0fea2019-04-09 14:08:06 +010016
telsoa01c577f2c2018-08-31 09:22:23 +010017#include <google/protobuf/text_format.h>
18#include <google/protobuf/io/zero_copy_stream_impl.h>
19
Matthew Sloyanac001ee2021-02-03 10:43:04 +000020#include <iostream>
telsoa01c577f2c2018-08-31 09:22:23 +010021#include <numeric>
Jan Eilers53ef7952021-06-02 12:01:25 +010022#include <armnnUtils/Permute.hpp>
telsoa01c577f2c2018-08-31 09:22:23 +010023
24using namespace armnn;
25
26namespace armnnOnnxParser
27{
Kevin Mayef33cb12021-01-29 14:24:57 +000028
29IOnnxParser::IOnnxParser() : pOnnxParserImpl(new OnnxParserImpl()) {}
30
31IOnnxParser::~IOnnxParser() = default;
32
33IOnnxParser* IOnnxParser::CreateRaw()
34{
35 return new IOnnxParser();
36}
37
38IOnnxParserPtr IOnnxParser::Create()
39{
40 return IOnnxParserPtr(CreateRaw(), &IOnnxParser::Destroy);
41}
42
43void IOnnxParser::Destroy(IOnnxParser* parser)
44{
45 delete parser;
46}
47
48armnn::INetworkPtr IOnnxParser::CreateNetworkFromBinaryFile(const char* graphFile)
49{
50 return pOnnxParserImpl->CreateNetworkFromBinaryFile(graphFile);
51}
52
53armnn::INetworkPtr IOnnxParser::CreateNetworkFromTextFile(const char* graphFile)
54{
55 return pOnnxParserImpl->CreateNetworkFromTextFile(graphFile);
56}
57
58armnn::INetworkPtr IOnnxParser::CreateNetworkFromString(const std::string& protoText)
59{
60 return pOnnxParserImpl->CreateNetworkFromString(protoText);
61}
62
63BindingPointInfo IOnnxParser::GetNetworkInputBindingInfo(const std::string& name) const
64{
65 return pOnnxParserImpl->GetNetworkInputBindingInfo(name);
66}
67
68BindingPointInfo IOnnxParser::GetNetworkOutputBindingInfo(const std::string& name) const
69{
70 return pOnnxParserImpl->GetNetworkOutputBindingInfo(name);
71}
72
telsoa01c577f2c2018-08-31 09:22:23 +010073namespace
74{
75void CheckValidDataType(std::initializer_list<onnx::TensorProto::DataType> validInputTypes,
76 const onnx::TensorProto::DataType actualValue,
77 const char* validExpr,
78 std::string nodeName,
79 std::string tensorName,
80 const armnn::CheckLocation& location)
81{
82 bool isValid = std::any_of(validInputTypes.begin(),
83 validInputTypes.end(),
84 [&actualValue](onnx::TensorProto::DataType x) { return x == actualValue; } );
85 if (!isValid)
86 {
87 throw ParseException(
James Ward58dec6b2020-09-11 17:32:44 +010088 fmt::format("Datatype {} is not valid for tensor '{}' of node '{}', not in {{{}}}. {}",
89 onnx::TensorProto::DataType_Name(actualValue),
90 tensorName,
91 nodeName,
92 validExpr,
93 location.AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +010094 }
95}
96
97#define CHECK_VALID_DATATYPE(NODE, TENSOR, ACTUAL, ...) \
98CheckValidDataType({__VA_ARGS__}, ACTUAL, #__VA_ARGS__, NODE, TENSOR, CHECK_LOCATION())
99
100using StrTypeListPair = std::pair<const char*, std::initializer_list<onnx::TensorProto::DataType>>;
101#define STR_LIST(...) StrTypeListPair(#__VA_ARGS__, {__VA_ARGS__})
102
103template <typename Callable>
104void ReadMandatoryNodeAttributeImpl(const onnx::NodeProto& node,
105 const std::string& attribName,
106 onnx::AttributeProto::AttributeType expectedType,
107 Callable callable)
108{
109 auto attribs = node.attribute();
110 int attriNum = 0;
111 while (attriNum < node.attribute_size())
112 {
113 if (attribs.Get(attriNum).name() == attribName)
114 {
115 if (attribs.Get(attriNum).type() == expectedType)
116 {
117 callable(attribs.Get(attriNum));
118 }
119 else
120 {
James Ward58dec6b2020-09-11 17:32:44 +0100121 throw ParseException(fmt::format("Attribute {} of node {} expected to have {} as "
122 "onnx::AttributeProto::AttributeType, but found {} instead {}",
123 attribName,
124 node.name(),
125 onnx::AttributeProto::AttributeType_Name(expectedType),
126 onnx::AttributeProto::AttributeType_Name(attribs.Get(attriNum).type()),
127 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100128 }
129 break;
130 }
131 ++attriNum;
132 }
133 if (attriNum == node.attribute_size())
134 {
James Ward58dec6b2020-09-11 17:32:44 +0100135 throw ParseException(fmt::format("Could not find required attribute {} in node {} {}",
136 attribName, node.name(), CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100137 }
138}
139
140template <typename Callable>
141void ReadOptionalNodeAttributeImpl(const onnx::NodeProto& node,
142 const std::string& attribName,
143 onnx::AttributeProto::AttributeType expectedType,
144 Callable callable)
145{
146 auto attribs = node.attribute();
147 for (int attriNum = 0; attriNum < node.attribute_size(); ++attriNum)
148 {
149 if (attribs.Get(attriNum).name() == attribName)
150 {
151 if (attribs.Get(attriNum).type() == expectedType)
152 {
153 callable(attribs.Get(attriNum));
154 }
155 else
156 {
James Ward58dec6b2020-09-11 17:32:44 +0100157 throw ParseException(
158 fmt::format("Attribute {} of node {} expected to have {} as onnx::AttributeProto::AttributeType, "
159 "but found {} instead {}",
160 attribName,
161 node.name(),
162 onnx::AttributeProto::AttributeType_Name(expectedType),
163 onnx::AttributeProto::AttributeType_Name(attribs.Get(attriNum).type()),
164 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100165 }
166 }
167 }
168}
169
Narumol Prangnawaratbc3bb622021-09-24 16:08:34 +0100170int ReadMandatoryNodeIntAttribute(const onnx::NodeProto& node,
171 const std::string& name)
172{
173 int attribValue = 0;
174 ReadMandatoryNodeAttributeImpl(node, name, onnx::AttributeProto::INT,
175 [&attribValue](const onnx::AttributeProto& attrValue)
176 {
177 attribValue = CHECKED_INT32(attrValue.i());
178 });
179 return attribValue;
180}
181
Ryan OSheaed27ee72020-04-22 16:37:29 +0100182int64_t ReadOptionalNodeInt64Attribute(const onnx::NodeProto& node,
183 const std::string& name,
184 const int64_t defaultValue = 0)
185{
186 int64_t attribValue = defaultValue;
187 ReadOptionalNodeAttributeImpl(node, name, onnx::AttributeProto::INT,
188 [&attribValue](const onnx::AttributeProto& attrValue)
189 {
190 attribValue = attrValue.i();
191 });
192 return attribValue;
193}
194
telsoa01c577f2c2018-08-31 09:22:23 +0100195std::vector<uint32_t> ReadMandatoryNodeUint32ListAttribute(const onnx::NodeProto& node,
196 const std::string& name)
197{
198 std::vector<uint32_t> attriList;
199 ReadMandatoryNodeAttributeImpl(node, name, onnx::AttributeProto::INTS,
200 [&attriList](const onnx::AttributeProto& attrValue)
201 {
202 for (int attriNum = 0; attriNum < attrValue.ints_size(); ++attriNum)
203 {
204 attriList.push_back(CHECKED_NON_NEGATIVE(CHECKED_INT32(attrValue.ints().Get(attriNum))));
205 }
206 });
207 return attriList;
208}
209
210uint32_t ReadOptionalNodeUint32Attribute(const onnx::NodeProto& node,
211 const std::string& name,
212 const uint32_t defaultVal = 0u)
213{
214 uint32_t attribValue = defaultVal;
215 ReadOptionalNodeAttributeImpl(node, name, onnx::AttributeProto::INT,
216 [&attribValue](const onnx::AttributeProto& attrValue)
217 {
218 attribValue = CHECKED_NON_NEGATIVE(CHECKED_INT32((attrValue.i())));
219 });
220 return attribValue;
221}
222
223std::vector<uint32_t> ReadOptionalNodeUint32ListAttribute(const onnx::NodeProto& node,
224 const std::string& name)
225{
226 std::vector<uint32_t> attriList;
227 ReadOptionalNodeAttributeImpl(node, name, onnx::AttributeProto::INTS,
228 [&attriList](const onnx::AttributeProto& attrValue)
229 {
230 for (int attriNum = 0; attriNum < attrValue.ints_size(); ++attriNum)
231 {
232 attriList.push_back(CHECKED_NON_NEGATIVE(CHECKED_INT32(attrValue.ints().Get(attriNum))));
233 }
234 });
235
236 return attriList;
237}
238
239float ReadOptionalNodeFloatAttribute(const onnx::NodeProto& node,
240 const std::string& name,
241 const float defaultValue = 0.0f)
242{
243 float attribValue = defaultValue;
244 ReadOptionalNodeAttributeImpl(node, name, onnx::AttributeProto::FLOAT,
245 [&attribValue](const onnx::AttributeProto& attrValue)
246 {
247 attribValue = attrValue.f();
248 });
249 return attribValue;
250}
251
252std::string ReadOptionalNodeStringAttribute(const onnx::NodeProto& node, const std::string& name)
253{
254 std::string attribValue = "";
255 ReadOptionalNodeAttributeImpl(node, name, onnx::AttributeProto::STRING,
256 [&attribValue](const onnx::AttributeProto& attrValue)
257 {
258 attribValue = attrValue.s();
259 });
260 return attribValue;
261}
262
Tee Jungfcf6fd52019-11-01 05:27:28 +0000263armnn::TensorInfo ToTensorInfo(const std::string& name, std::vector<unsigned int>& shape, int data_type)
telsoa01c577f2c2018-08-31 09:22:23 +0100264{
Narumol Prangnawarat452274c2021-09-23 16:12:19 +0100265 DataType type;
266 switch(data_type)
267 {
268 case onnx::TensorProto::FLOAT:
269 {
270 type = DataType::Float32;
telsoa01c577f2c2018-08-31 09:22:23 +0100271 break;
Narumol Prangnawarat452274c2021-09-23 16:12:19 +0100272 }
273 case onnx::TensorProto::INT32:
274 case onnx::TensorProto::INT64:
275 {
276 type = DataType::Signed32;
277 break;
278 }
279 default:
280 {
281 throw ParseException(
282 fmt::format("'{}' is not a currently supported datatype for tensor {}."
283 " Supported dataTypes are FLOAT, INT32 and INT64. {}",
284 onnx::TensorProto::DataType_Name(static_cast<onnx::TensorProto::DataType>(data_type)),
285 name,
286 CHECK_LOCATION().AsString() ));
287 }
288 }
Tee Jungcaf2bdd2019-11-13 07:23:14 +0000289
Narumol Prangnawarat452274c2021-09-23 16:12:19 +0100290 // To avoid crashes by trivial tensors
291 if (shape.empty())
292 {
293 return TensorInfo(TensorShape(Dimensionality::Scalar), type);
294 }
Tee Jungcaf2bdd2019-11-13 07:23:14 +0000295
Narumol Prangnawarat452274c2021-09-23 16:12:19 +0100296 return TensorInfo(TensorShape(static_cast<unsigned int>(shape.size()), shape.data()), type);
Tee Jungfcf6fd52019-11-01 05:27:28 +0000297}
298
299armnn::TensorInfo ToTensorInfo(const onnx::ValueInfoProto& info)
300{
301 const onnx::TensorShapeProto onnxShape = info.type().tensor_type().shape();
302 std::vector<unsigned int> shapeDims;
303 for (int i = 0; i < onnxShape.dim_size(); ++i)
304 {
305 shapeDims.push_back(CHECKED_NON_NEGATIVE(CHECKED_INT32(onnxShape.dim(i).dim_value())));
306 }
307
308 return ToTensorInfo(info.name(), shapeDims, info.type().tensor_type().elem_type());
309}
310
311armnn::TensorInfo ToTensorInfo(const onnx::TensorProto& tensor)
312{
313 std::vector<unsigned int> shapeDims;
Ryan OShea337c17f2020-02-21 12:33:17 +0000314
Tee Jungfcf6fd52019-11-01 05:27:28 +0000315 for (auto dim: tensor.dims())
316 {
317 shapeDims.push_back(CHECKED_NON_NEGATIVE(CHECKED_INT32(dim)));
318 }
319
320 return ToTensorInfo(tensor.name(), shapeDims, tensor.data_type());
telsoa01c577f2c2018-08-31 09:22:23 +0100321}
322
323std::string TensorInfoAsString(const TensorInfo& info,
324 const std::string& name,
325 const onnx::TensorProto::DataType& type)
326{
327 const TensorShape shape = info.GetShape();
328 std::stringstream ss;
329 ss << "tensor '" << name << "' contains "
330 << onnx::TensorProto::DataType_Name(type)
331 << " and has shape [";
332
333 for (uint32_t i = 0; i < shape.GetNumDimensions() - 1; ++i)
334 {
335 ss << shape[i] << ", ";
336 }
337 ss << shape[shape.GetNumDimensions() - 1] << "]";
338 return ss.str();
339}
340
Sadik Armagan60bb9d82021-01-11 15:15:01 +0000341void CalcPadding(uint32_t inputSize,
342 uint32_t filterSize,
343 uint32_t stride,
344 uint32_t dilation,
345 uint32_t* paddingFront,
346 uint32_t* paddingBack,
347 bool isUpper)
telsoa01c577f2c2018-08-31 09:22:23 +0100348{
349 uint32_t outputSize = (inputSize + stride - 1) / stride;
Sadik Armagan60bb9d82021-01-11 15:15:01 +0000350 uint32_t dilatedSize = filterSize + (dilation - 1) * (filterSize - 1);
351 uint32_t temp = (outputSize - 1) * stride + dilatedSize;
telsoa01c577f2c2018-08-31 09:22:23 +0100352 *paddingFront = (temp - inputSize) / 2;
353 *paddingBack = *paddingFront;
354 if((temp - inputSize) % 2 == 1)
355 {
356 if (isUpper)
357 {
Sadik Armagan60bb9d82021-01-11 15:15:01 +0000358 *paddingBack += 1;
telsoa01c577f2c2018-08-31 09:22:23 +0100359 }
360 else
361 {
Sadik Armagan60bb9d82021-01-11 15:15:01 +0000362 *paddingFront += 1;
telsoa01c577f2c2018-08-31 09:22:23 +0100363 }
364 }
365}
366
Ryan OSheaed27ee72020-04-22 16:37:29 +0100367TensorInfo ComputeReshapeInfo(const TensorShape& targetShapeTensor,
telsoa01c577f2c2018-08-31 09:22:23 +0100368 const TensorShape& inShape,
Narumol Prangnawarat452274c2021-09-23 16:12:19 +0100369 const std::string& outName,
370 DataType dataType = DataType::Float32)
telsoa01c577f2c2018-08-31 09:22:23 +0100371{
372 std::vector<int> targetDims;
Ryan OSheaed27ee72020-04-22 16:37:29 +0100373 for(uint i = 0; i < targetShapeTensor.GetNumDimensions(); ++i)
telsoa01c577f2c2018-08-31 09:22:23 +0100374 {
Ryan OSheaed27ee72020-04-22 16:37:29 +0100375 int val = CHECKED_INT32(targetShapeTensor[i]);
telsoa01c577f2c2018-08-31 09:22:23 +0100376 if(val == 0)
377 {
378 targetDims.push_back(static_cast<int>(inShape[static_cast<uint>(i)]));
379 }
380 else
381 {
382 targetDims.push_back(val);
383 }
384 }
385
386 std::vector<unsigned int> outDims(targetDims.begin(), targetDims.end());
387 const auto stretchDim = std::find(targetDims.begin(), targetDims.end(), -1);
388 if (stretchDim != targetDims.end())
389 {
390 if (std::find(std::next(stretchDim), targetDims.end(), -1) != targetDims.end())
391 {
392 std::stringstream ss;
393 ss << "[ ";
394 for(uint i = 0; i < targetDims.size() - 1; ++i)
395 {
396 ss << targetDims[i] << ", ";
397 }
398 ss << targetDims[targetDims.size() - 1] << " ]";
399
James Ward58dec6b2020-09-11 17:32:44 +0100400 throw ParseException(
401 fmt::format("Error during creation of reshaped tensor '{}'. At most one component of shape can be "
402 " -1 and here, shape is {} {}",
403 outName,
404 ss.str(),
405 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100406 }
407
Matthew Sloyan589e3e82020-09-11 16:17:48 +0100408 auto targetNumElements = armnn::numeric_cast<unsigned int>(std::accumulate(targetDims.begin(), targetDims.end(),
telsoa01c577f2c2018-08-31 09:22:23 +0100409 -1, std::multiplies<int32_t>()));
410 auto stretchIndex = static_cast<size_t>(std::distance(targetDims.begin(), stretchDim));
411 outDims[stretchIndex] = inShape.GetNumElements() / targetNumElements;
412 }
413 TensorShape outShape = TensorShape{static_cast<unsigned int>(outDims.size()), outDims.data()};
Narumol Prangnawarat452274c2021-09-23 16:12:19 +0100414 return TensorInfo(outShape, dataType);
telsoa01c577f2c2018-08-31 09:22:23 +0100415}
416
417} //namespace
418
Kevin Mayef33cb12021-01-29 14:24:57 +0000419const std::map<std::string, OnnxParserImpl::OperationParsingFunction> OnnxParserImpl::m_ParserFunctions = {
420 { "BatchNormalization", &OnnxParserImpl::ParseBatchNormalization},
421 { "GlobalAveragePool", &OnnxParserImpl::ParseGlobalAveragePool},
422 { "AveragePool", &OnnxParserImpl::ParseAveragePool },
423 { "Clip", &OnnxParserImpl::ParseClip },
424 { "Constant", &OnnxParserImpl::ParseConstant },
425 { "MaxPool", &OnnxParserImpl::ParseMaxPool },
426 { "Reshape", &OnnxParserImpl::ParseReshape },
427 { "Sigmoid", &OnnxParserImpl::ParseSigmoid },
428 { "Tanh", &OnnxParserImpl::ParseTanh },
429 { "Relu", &OnnxParserImpl::ParseRelu },
430 { "LeakyRelu", &OnnxParserImpl::ParseLeakyRelu },
431 { "Conv", &OnnxParserImpl::ParseConv },
432 { "Add", &OnnxParserImpl::ParseAdd },
Narumol Prangnawaratcdc495e2021-09-16 18:13:39 +0100433 { "Flatten", &OnnxParserImpl::ParseFlatten },
Narumol Prangnawaratf10b15a2021-09-17 21:08:57 +0100434 { "Shape", &OnnxParserImpl::ParseShape },
435 { "Gather", &OnnxParserImpl::ParseGather },
Narumol Prangnawaratbc3bb622021-09-24 16:08:34 +0100436 { "Unsqueeze", &OnnxParserImpl::ParseUnsqueeze },
437 { "Concat", &OnnxParserImpl::ParseConcat }
telsoa01c577f2c2018-08-31 09:22:23 +0100438};
439
440template<typename TypePair, typename Location>
Kevin Mayef33cb12021-01-29 14:24:57 +0000441void OnnxParserImpl::ValidateInputs(const onnx::NodeProto& node,
telsoa01c577f2c2018-08-31 09:22:23 +0100442 TypePair validInputs,
443 const Location& location)
444{
445 for(auto input : node.input())
446 {
447 CheckValidDataType(validInputs.second,
448 m_TensorsInfo[input].m_dtype,
449 validInputs.first,
450 node.name(),
451 input,
452 location);
453 }
454}
455
456#define VALID_INPUTS(NODE, VALID_INPUTS) \
Kevin Mayef33cb12021-01-29 14:24:57 +0000457 OnnxParserImpl::ValidateInputs(NODE, \
telsoa01c577f2c2018-08-31 09:22:23 +0100458 VALID_INPUTS, \
459 CHECK_LOCATION())
460
Kevin Mayef33cb12021-01-29 14:24:57 +0000461std::vector<TensorInfo> OnnxParserImpl::ComputeOutputInfo(std::vector<std::string> outNames,
462 const IConnectableLayer* layer,
Narumol Prangnawarat452274c2021-09-23 16:12:19 +0100463 std::vector<TensorShape> inputShapes,
464 const onnx::TensorProto::DataType& dataType)
telsoa01c577f2c2018-08-31 09:22:23 +0100465{
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100466 ARMNN_ASSERT(! outNames.empty());
telsoa01c577f2c2018-08-31 09:22:23 +0100467 bool needCompute = std::any_of(outNames.begin(),
468 outNames.end(),
469 [this](std::string name)
470 {
471 return (m_TensorsInfo.count(name) == 0 || m_TensorsInfo[name].m_info == nullptr);
472 });
Narumol Prangnawarat452274c2021-09-23 16:12:19 +0100473 std::vector<TensorInfo> outInfo;
474 //if the output info(s) are not here, we need to compute them
475 std::vector<TensorShape> inferredShapes;
476 DataType armnnType = DataType::Float32;
477 if(needCompute) {
478 inferredShapes = layer->InferOutputShapes(inputShapes);
479 ARMNN_ASSERT(inferredShapes.size() == outNames.size());
480 switch (dataType) {
481 case onnx::TensorProto::FLOAT: {
482 armnnType = DataType::Float32;
483 break;
484 }
485 case onnx::TensorProto::INT32:
486 case onnx::TensorProto::INT64: {
487 armnnType = DataType::Signed32;
488 break;
489 }
490 default: {
491 throw ParseException(
492 fmt::format("'{}' is not a currently supported datatype for {}."
493 " Supported dataTypes are FLOAT, INT32 and INT64. {}",
494 onnx::TensorProto::DataType_Name(static_cast<onnx::TensorProto::DataType>(dataType)),
495 layer->GetName(),
496 CHECK_LOCATION().AsString()));
497 }
498 }
499 }
500 for (uint i = 0; i < outNames.size(); ++i)
501 {
502 if(needCompute)
503 {
504 m_TensorsInfo[outNames[i]] = OnnxTensor();
505 m_TensorsInfo[outNames[i]].m_info = std::make_unique<TensorInfo>(
506 TensorInfo(inferredShapes[i], armnnType));
507 m_TensorsInfo[outNames[i]].m_dtype = dataType;
508 }
telsoa01c577f2c2018-08-31 09:22:23 +0100509 outInfo.push_back(*m_TensorsInfo[outNames[i]].m_info);
Narumol Prangnawarat452274c2021-09-23 16:12:19 +0100510 }
511 return outInfo;
telsoa01c577f2c2018-08-31 09:22:23 +0100512}
513
Kevin Mayef33cb12021-01-29 14:24:57 +0000514OnnxParserImpl::OnnxParserImpl()
telsoa01c577f2c2018-08-31 09:22:23 +0100515 : m_Network(nullptr, nullptr)
516{
517}
518
Kevin Mayef33cb12021-01-29 14:24:57 +0000519void OnnxParserImpl::ResetParser()
telsoa01c577f2c2018-08-31 09:22:23 +0100520{
521 m_Network = armnn::INetworkPtr(nullptr, nullptr);
522 m_Graph = nullptr;
523}
524
Kevin Mayef33cb12021-01-29 14:24:57 +0000525void OnnxParserImpl::Cleanup()
telsoa01c577f2c2018-08-31 09:22:23 +0100526{
527 m_TensorConnections.clear();
528 m_TensorsInfo.clear();
529 m_OutputsMap.clear();
530 m_OutputsFusedAndUsed.clear();
531}
532
Jan Eilers53ef7952021-06-02 12:01:25 +0100533template<typename T>
534std::pair<armnn::ConstTensor, std::unique_ptr<T[]>>
535CreateConstTensorImpl(const T* bufferPtr,
536 armnn::TensorInfo& tensorInfo,
537 const armnn::Optional<armnn::PermutationVector&> permutationVector)
telsoa01c577f2c2018-08-31 09:22:23 +0100538{
Jan Eilers53ef7952021-06-02 12:01:25 +0100539 ARMNN_ASSERT_MSG(bufferPtr != nullptr, fmt::format("Buffer for permutation is null").c_str());
540
541 std::unique_ptr<T[]> data(new T[tensorInfo.GetNumElements()]);
542
543 if (permutationVector.has_value() && permutationVector.value().GetSize() > 0)
544 {
545 tensorInfo = armnnUtils::Permuted(tensorInfo, permutationVector.value());
546 armnnUtils::Permute(tensorInfo.GetShape(), permutationVector.value(),
547 reinterpret_cast<const T*>(bufferPtr), data.get(), sizeof(T));
548 }
549 else
550 {
551 ::memcpy(data.get(), bufferPtr, tensorInfo.GetNumBytes());
552 }
553
554 return std::make_pair(ConstTensor(tensorInfo, data.get()), std::move(data));
555}
556
557std::pair<ConstTensor, std::unique_ptr<float[]>>
558OnnxParserImpl::CreateConstTensor(const std::string name,
559 armnn::Optional<armnn::PermutationVector&> permutationVector)
560{
561 TensorInfo tensorInfo = *m_TensorsInfo[name].m_info;
telsoa01c577f2c2018-08-31 09:22:23 +0100562 onnx::TensorProto onnxTensor = *m_TensorsInfo[name].m_tensor;
563
Narumol Prangnawaratf10b15a2021-09-17 21:08:57 +0100564 //ONNX can have Float16 and double constant nodes but ArmNN only supports float32
565 CHECK_VALID_DATATYPE(name, onnxTensor.name(),
566 static_cast<onnx::TensorProto::DataType>(onnxTensor.data_type()), onnx::TensorProto::FLOAT);
567
Matthew Sloyan81beae32021-07-13 19:46:11 +0100568 // Makes sure IsConstant flag is set.
569 tensorInfo.SetConstant();
570
Jan Eilers53ef7952021-06-02 12:01:25 +0100571 // Const tensors requires at least a list of values
572 if (tensorInfo.GetNumElements() == 0)
573 {
574 throw ParseException(fmt::format("No tensor data found for Const tensor '{}' {}",
575 name,
576 CHECK_LOCATION().AsString()));
577 }
578
telsoa01c577f2c2018-08-31 09:22:23 +0100579 auto srcData = onnxTensor.float_data().data();
Pablo Tello3dcc1c62019-04-24 14:20:21 +0100580 // Copy the value list entries into the destination
581 if (!onnxTensor.has_raw_data())
telsoa01c577f2c2018-08-31 09:22:23 +0100582 {
Pablo Tello3dcc1c62019-04-24 14:20:21 +0100583 if(tensorInfo.GetNumElements() != static_cast<uint>(onnxTensor.float_data_size()))
584 {
James Ward58dec6b2020-09-11 17:32:44 +0100585 throw ParseException(
586 fmt::format("The number of data provided ({}) does not match the tensor '{}' number of "
587 "elements ({}) {}",
588 onnxTensor.float_data_size(),
589 name,
590 tensorInfo.GetNumElements(),
591 CHECK_LOCATION().AsString()));
Pablo Tello3dcc1c62019-04-24 14:20:21 +0100592 }
Jan Eilers53ef7952021-06-02 12:01:25 +0100593 return CreateConstTensorImpl<float>(srcData, tensorInfo, permutationVector);
telsoa01c577f2c2018-08-31 09:22:23 +0100594 }
Pablo Tello3dcc1c62019-04-24 14:20:21 +0100595 else
596 {
Jan Eilers53ef7952021-06-02 12:01:25 +0100597 return CreateConstTensorImpl<float>(reinterpret_cast<const float*>(onnxTensor.raw_data().c_str()),
598 tensorInfo,
599 permutationVector);
Pablo Tello3dcc1c62019-04-24 14:20:21 +0100600 }
telsoa01c577f2c2018-08-31 09:22:23 +0100601}
602
Narumol Prangnawaratf10b15a2021-09-17 21:08:57 +0100603std::pair<ConstTensor, std::unique_ptr<int32_t[]>>
604OnnxParserImpl::CreateInt64ConstTensor(const std::string name,
605 armnn::Optional<armnn::PermutationVector&> permutationVector)
606{
607 TensorInfo tensorInfo = *m_TensorsInfo[name].m_info;
608 onnx::TensorProto onnxTensor = *m_TensorsInfo[name].m_tensor;
609
610 CHECK_VALID_DATATYPE(name, onnxTensor.name(),
611 static_cast<onnx::TensorProto::DataType>(onnxTensor.data_type()), onnx::TensorProto::INT64);
612
613 // Makes sure IsConstant flag is set.
614 tensorInfo.SetConstant();
615 uint numElements = tensorInfo.GetNumElements();
616
617 // Const tensors requires at least a list of values
618 if (numElements == 0)
619 {
620 throw ParseException(fmt::format("No tensor data found for Const tensor '{}' {}",
621 name,
622 CHECK_LOCATION().AsString()));
623 }
624
625 // Copy the value list entries into the destination
626 if (!onnxTensor.has_raw_data())
627 {
628 auto srcData = onnxTensor.int64_data().data();
629 if(numElements != static_cast<uint>(onnxTensor.int64_data_size()))
630 {
631 throw ParseException(
632 fmt::format("The number of data provided ({}) does not match the tensor '{}' number of "
633 "elements ({}) {}",
634 onnxTensor.int64_data_size(),
635 name,
636 tensorInfo.GetNumElements(),
637 CHECK_LOCATION().AsString()));
638 }
639
640 std::vector<int32_t> int32Data;
641 for(uint i = 0; i < numElements; i++)
642 {
643 int32_t int32Value = CHECKED_INT32(srcData[i]);
644 int32Data.push_back(int32Value);
645 }
646
647 return CreateConstTensorImpl<int32_t>(int32Data.data(), tensorInfo, permutationVector);
648 }
649 else
650 {
651 auto srcData = reinterpret_cast<const int64_t*>(onnxTensor.raw_data().c_str());
652 std::vector<int32_t> int32Data;
653 for(uint i = 0; i < numElements; i++)
654 {
655 int32_t int32Value = CHECKED_INT32(srcData[i]);
656 int32Data.push_back(int32Value);
657 }
658 return CreateConstTensorImpl<int32_t>(int32Data.data(), tensorInfo, permutationVector);
659 }
660}
661
Kevin Mayef33cb12021-01-29 14:24:57 +0000662ModelPtr OnnxParserImpl::LoadModelFromTextFile(const char* graphFile)
telsoa01c577f2c2018-08-31 09:22:23 +0100663{
664 FILE* fd = fopen(graphFile, "r");
665
666 if (fd == nullptr)
667 {
James Ward58dec6b2020-09-11 17:32:44 +0100668 throw FileNotFoundException(fmt::format("Invalid (null) filename {}", CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100669 }
670
671 // Parse the file into a message
672 ModelPtr modelProto = std::make_unique<onnx::ModelProto>();
673 using google::protobuf::io::FileInputStream;
674 std::unique_ptr<FileInputStream> input = std::make_unique<FileInputStream>(fileno(fd));
675 bool success = google::protobuf::TextFormat::Parse(input.get(), modelProto.get());
676 fclose(fd);
677
678 if (!success)
679 {
680 std::stringstream error;
681 error << "Failed to parse graph file";
James Ward58dec6b2020-09-11 17:32:44 +0100682 throw ParseException(fmt::format("{} {}", error.str(), CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100683 }
684 return modelProto;
685}
686
Kevin Mayef33cb12021-01-29 14:24:57 +0000687INetworkPtr OnnxParserImpl::CreateNetworkFromTextFile(const char* graphFile)
telsoa01c577f2c2018-08-31 09:22:23 +0100688{
689 ResetParser();
690 ModelPtr modelProto = LoadModelFromTextFile(graphFile);
691 return CreateNetworkFromModel(*modelProto);
692}
693
694
Kevin Mayef33cb12021-01-29 14:24:57 +0000695ModelPtr OnnxParserImpl::LoadModelFromBinaryFile(const char* graphFile)
telsoa01c577f2c2018-08-31 09:22:23 +0100696{
697 FILE* fd = fopen(graphFile, "rb");
698
699 if (fd == nullptr)
700 {
James Ward58dec6b2020-09-11 17:32:44 +0100701 throw FileNotFoundException(fmt::format("Invalid (null) filename {}", CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100702 }
703
704 // Parse the file into a message
705 ModelPtr modelProto = std::make_unique<onnx::ModelProto>();
706
707 google::protobuf::io::FileInputStream inStream(fileno(fd));
708 google::protobuf::io::CodedInputStream codedStream(&inStream);
Nikhil Raje5181532020-10-09 14:52:25 +0100709 codedStream.SetTotalBytesLimit(INT_MAX);
telsoa01c577f2c2018-08-31 09:22:23 +0100710 bool success = modelProto.get()->ParseFromCodedStream(&codedStream);
711 fclose(fd);
712
713 if (!success)
714 {
715 std::stringstream error;
716 error << "Failed to parse graph file";
James Ward58dec6b2020-09-11 17:32:44 +0100717 throw ParseException(fmt::format("{} {}", error.str(), CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100718 }
719 return modelProto;
720
721}
722
Kevin Mayef33cb12021-01-29 14:24:57 +0000723INetworkPtr OnnxParserImpl::CreateNetworkFromBinaryFile(const char* graphFile)
telsoa01c577f2c2018-08-31 09:22:23 +0100724{
725 ResetParser();
726 ModelPtr modelProto = LoadModelFromBinaryFile(graphFile);
727 return CreateNetworkFromModel(*modelProto);
728}
729
Kevin Mayef33cb12021-01-29 14:24:57 +0000730ModelPtr OnnxParserImpl::LoadModelFromString(const std::string& protoText)
telsoa01c577f2c2018-08-31 09:22:23 +0100731{
732 if (protoText == "")
733 {
James Ward58dec6b2020-09-11 17:32:44 +0100734 throw InvalidArgumentException(fmt::format("Invalid (empty) string for model parameter {}",
735 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100736 }
737 // Parse the string into a message
738 ModelPtr modelProto = std::make_unique<onnx::ModelProto>();
739 bool success = google::protobuf::TextFormat::ParseFromString(protoText, modelProto.get());
740 if (!success)
741 {
742 std::stringstream error;
743 error << "Failed to parse graph file";
James Ward58dec6b2020-09-11 17:32:44 +0100744 throw ParseException(fmt::format("{} {}", error.str(), CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100745 }
746 return modelProto;
747}
748
Kevin Mayef33cb12021-01-29 14:24:57 +0000749INetworkPtr OnnxParserImpl::CreateNetworkFromString(const std::string& protoText)
telsoa01c577f2c2018-08-31 09:22:23 +0100750{
751 ResetParser();
752 ModelPtr modelProto = LoadModelFromString(protoText);
753 return CreateNetworkFromModel(*modelProto);
754}
755
Kevin Mayef33cb12021-01-29 14:24:57 +0000756INetworkPtr OnnxParserImpl::CreateNetworkFromModel(onnx::ModelProto& model)
telsoa01c577f2c2018-08-31 09:22:23 +0100757{
758 m_Network = INetwork::Create();
759 try
760 {
761 m_Graph = std::make_unique<onnx::GraphProto>(*model.mutable_graph());
762 LoadGraph();
763 }
764 catch (const ParseException& e)
765 {
766 Cleanup();
767 throw e;
768 }
769 Cleanup();
770 return std::move(m_Network);
771}
772
Kevin Mayef33cb12021-01-29 14:24:57 +0000773void OnnxParserImpl::LoadGraph()
telsoa01c577f2c2018-08-31 09:22:23 +0100774{
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100775 ARMNN_ASSERT(m_Graph.get() != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +0100776
777 //Fill m_TensorsInfo with the shapes and value of every tensor
778 SetupInfo(m_Graph->mutable_output());
779 SetupInfo(m_Graph->mutable_input());
780 SetupInfo(m_Graph->mutable_value_info());
781
782 for (auto tensor : m_Graph->initializer())
783 {
784 m_TensorsInfo[tensor.name()].m_tensor = std::make_unique<const onnx::TensorProto>(tensor);
Tee Jungfcf6fd52019-11-01 05:27:28 +0000785 m_TensorsInfo[tensor.name()].m_info = std::make_unique<TensorInfo>(ToTensorInfo(tensor));
786 m_TensorsInfo[tensor.name()].m_dtype =
787 static_cast<onnx::TensorProto::DataType>(tensor.data_type());
telsoa01c577f2c2018-08-31 09:22:23 +0100788 }
789
790 SetupInputLayers();
791 SetupOutputLayers();
792
793 //Detect FullyConnected layers with bias and update the FusedAndUsed map acccordingly
794 DetectFullyConnected();
795
796 //Parsing the graph
797 for(size_t nodeIndex = 0; nodeIndex < static_cast<size_t>(m_Graph->node_size()); nodeIndex++)
798 {
799 auto node = m_Graph->node(static_cast<int>(nodeIndex));
800 const std::string& operation = node.op_type();
801
802 // check which layers we handled already (add and matmul fused as FC)
Ryan OShea337c17f2020-02-21 12:33:17 +0000803 if (operation == "MatMul" )
telsoa01c577f2c2018-08-31 09:22:23 +0100804 {
805 if(m_OutputsFusedAndUsed[nodeIndex].inputForNodes != m_OutputsFusedAndUsed[nodeIndex].fusedWithNodes.size())
806 {
807 //Node which can not be fused as a FullyConnected layer (used in layers as a simple matmul output)
808 AddFullyConnected(node);
809 }
810 }
811 else if (!(m_OutputsFusedAndUsed[nodeIndex].fusedWithNodes.empty()) && operation == "Add")
812 {
813 int matmulIndex = static_cast<int> (m_OutputsFusedAndUsed[nodeIndex].fusedWithNodes[0]);
814 AddFullyConnected(m_Graph->node(matmulIndex), &node);
815 }
816 else if (m_OutputsFusedAndUsed[nodeIndex].fusedWithNodes.empty()) //node is not part of a fused layer
817 {
818 auto it = m_ParserFunctions.find(operation);
819 if (it != m_ParserFunctions.end())
820 {
821 auto func = it->second;
822 (this->*func)(node);
823 }
824 else
825 {
James Ward58dec6b2020-09-11 17:32:44 +0100826 throw ParseException(fmt::format("Unsupported operation {} for node '{}' {}",
827 operation,
828 node.name(),
829 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100830 }
831 }
832 }
833
834 //Making the connections between outputs and inputs of each layers
835 for (const auto& tensorCon : m_TensorConnections)
836 {
837 if (tensorCon.second.outputSlot != nullptr)
838 {
839 for (size_t inputSlotIdx = 0; inputSlotIdx < tensorCon.second.inputSlots.size(); ++inputSlotIdx)
840 {
841 tensorCon.second.outputSlot->Connect(*(tensorCon.second.inputSlots[inputSlotIdx]));
842 }
843 }
844 }
845}
846
Kevin Mayef33cb12021-01-29 14:24:57 +0000847void OnnxParserImpl::SetupInfo(const google::protobuf::RepeatedPtrField<onnx::ValueInfoProto >* list)
telsoa01c577f2c2018-08-31 09:22:23 +0100848{
849 for (auto tensor : *list)
850 {
851 m_TensorsInfo[tensor.name()] = OnnxTensor();
852 m_TensorsInfo[tensor.name()].m_info = std::make_unique<TensorInfo>(ToTensorInfo(tensor));
Matteo Martincighe355dc22018-12-10 13:45:27 +0000853 m_TensorsInfo[tensor.name()].m_dtype =
854 static_cast<onnx::TensorProto::DataType>(tensor.type().tensor_type().elem_type());
telsoa01c577f2c2018-08-31 09:22:23 +0100855 }
856}
857
Kevin Mayef33cb12021-01-29 14:24:57 +0000858void OnnxParserImpl::DetectFullyConnected()
telsoa01c577f2c2018-08-31 09:22:23 +0100859{
860 m_OutputsFusedAndUsed = std::vector<UsageSummary> (static_cast<size_t>(m_Graph->node_size()), UsageSummary());
861 auto matmulAndConstant = [&](const std::string& constInput,
862 const std::string& matmulInput,
863 int& nodeIndex)
864 {
865 auto matmulIt = m_OutputsMap.find(matmulInput);
866 if(matmulIt != m_OutputsMap.end() && matmulIt->second.first->op_type() == "MatMul"
867 && m_TensorsInfo[constInput].isConstant())
868 {
869 nodeIndex = matmulIt->second.second;
870 return true;
871 }
872 return false;
873 };
874
875 for(int nodeIndex = 0; nodeIndex < m_Graph->node_size(); nodeIndex++)
876 {
877 const onnx::NodeProto* node = &m_Graph->node(nodeIndex);
878 for (const std::string& output : node->output())
879 {
880 m_OutputsMap[output] = std::make_pair(node, nodeIndex);
881 }
882
883 for (const std::string& input : node->input()) //count how many time a node is used as input
884 {
885 auto matmulIt = m_OutputsMap.find(input);
886 if(matmulIt != m_OutputsMap.end()){
887 ++m_OutputsFusedAndUsed[static_cast<size_t>(matmulIt->second.second)].inputForNodes; //node used
888 }
889 }
890
891 if (node->op_type() == "Add")
892 {
893 int matmulIndex = 0;
894 if (matmulAndConstant(node->input(0), node->input(1), matmulIndex) ||
895 matmulAndConstant(node->input(1), node->input(0), matmulIndex))
896 {
897 //matmul and add were fused
898 m_OutputsFusedAndUsed[static_cast<size_t>(matmulIndex)].fusedWithNodes
899 .push_back(static_cast<size_t>(nodeIndex));
900
901 m_OutputsFusedAndUsed[static_cast<size_t>(nodeIndex)].fusedWithNodes
902 .push_back(static_cast<size_t>(matmulIndex));
903 }
904 }
905 }
906
907 for (auto output: m_Graph->output()) { //Add usages as output of the graph in count of usages
908 auto matmulIt = m_OutputsMap.find(output.name());
909 if(matmulIt != m_OutputsMap.end()){
910 ++m_OutputsFusedAndUsed[static_cast<size_t>(matmulIt->second.second)].inputForNodes;
911 }
912 }
913}
914
915template<typename Location>
Kevin Mayef33cb12021-01-29 14:24:57 +0000916void OnnxParserImpl::GetInputAndParam(const onnx::NodeProto& node,
917 std::string* inputName,
918 std::string* constName,
919 const Location& location)
telsoa01c577f2c2018-08-31 09:22:23 +0100920{
921 int cstIndex;
922 if (m_TensorsInfo[node.input(0)].isConstant())
923 {
924 cstIndex = 0;
925 }
926 else if (m_TensorsInfo[node.input(1)].isConstant())
927 {
928 cstIndex = 1;
929 }
930 else
931 {
James Ward58dec6b2020-09-11 17:32:44 +0100932 throw ParseException(fmt::format("One of the input tensors ('{}' or '{}') should be constant in node '{}' {}",
933 node.input(0),
934 node.input(1),
935 node.name(),
936 location.AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100937 }
938 if(constName)
939 {
940 *constName = node.input(cstIndex);
941 }
942 if(inputName)
943 {
944 *inputName = node.input(!cstIndex);
945 }
946}
947
948template<typename Location>
Kevin Mayef33cb12021-01-29 14:24:57 +0000949void OnnxParserImpl::To1DTensor(const std::string& name, const Location& location)
telsoa01c577f2c2018-08-31 09:22:23 +0100950{
951 TensorShape shape = m_TensorsInfo[name].m_info->GetShape();
952 std::vector<uint32_t> newShape;
953 for(uint i = 0; i < shape.GetNumDimensions() - 1; ++i)
954 {
955 if(shape[i] != 1)
956 {
James Ward58dec6b2020-09-11 17:32:44 +0100957 throw ParseException(
958 fmt::format("Only tensors with shape [1, ..., 1, X] can be converted to 1D and {} {}",
959 TensorInfoAsString(*m_TensorsInfo[name].m_info, name, m_TensorsInfo[name].m_dtype),
960 location.AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100961 }
962 }
963 newShape.push_back(shape[shape.GetNumDimensions() - 1]);
964
965 m_TensorsInfo[name].m_info->SetShape(TensorShape(static_cast<unsigned int>(newShape.size()), newShape.data()));
966}
967
Kevin Mayef33cb12021-01-29 14:24:57 +0000968void OnnxParserImpl::AddConvLayerWithDepthwiseConv(const onnx::NodeProto& node, const Convolution2dDescriptor& convDesc)
Ryan OSheaed27ee72020-04-22 16:37:29 +0100969{
970 ARMNN_ASSERT(node.op_type() == "Conv");
971
972 DepthwiseConvolution2dDescriptor desc;
973 desc.m_PadLeft = convDesc.m_PadLeft;
974 desc.m_PadRight = convDesc.m_PadRight;
975 desc.m_PadTop = convDesc.m_PadTop;
976 desc.m_PadBottom = convDesc.m_PadBottom;
977 desc.m_StrideX = convDesc.m_StrideX;
978 desc.m_StrideY = convDesc.m_StrideY;
979 desc.m_BiasEnabled = convDesc.m_BiasEnabled;
980
981 armnn::IConnectableLayer* layer;
Jan Eilers53ef7952021-06-02 12:01:25 +0100982
983 // weights come in as [O,1,H,W] from ONNX and need to be converted to ArmNNs dephtwise weights layout [1,H,W,O]
984 armnn::PermutationVector perVec {3,0,1,2};
985 auto weightTensor = CreateConstTensor(node.input(1), perVec);
Ryan OSheaed27ee72020-04-22 16:37:29 +0100986
987 if (node.input_size() == 3)
988 {
989 if(!m_TensorsInfo[node.input(2)].isConstant())
990 {
James Ward58dec6b2020-09-11 17:32:44 +0100991 throw ParseException(fmt::format("Bias '{}' should be constant in Conv layer '{}' {}",
992 node.input(2),
993 node.name(),
994 CHECK_LOCATION().AsString()));
Ryan OSheaed27ee72020-04-22 16:37:29 +0100995 }
996 desc.m_BiasEnabled = true;
997 auto biasTensor = CreateConstTensor(node.input(2));
998 layer = m_Network->AddDepthwiseConvolution2dLayer(desc,
999 weightTensor.first,
1000 Optional<ConstTensor>(biasTensor.first),
1001 node.name().c_str());
1002 }
1003 else
1004 {
1005 layer = m_Network->AddDepthwiseConvolution2dLayer(desc,
1006 weightTensor.first,
1007 EmptyOptional(),
1008 node.name().c_str());
1009 }
1010 ARMNN_ASSERT(layer != nullptr);
1011
1012 auto outputInfo = ComputeOutputInfo({ node.output(0) }, layer,
1013 { m_TensorsInfo[node.input(0)].m_info->GetShape(),
Jan Eilers53ef7952021-06-02 12:01:25 +01001014 weightTensor.first.GetInfo().GetShape() });
Ryan OSheaed27ee72020-04-22 16:37:29 +01001015
1016 layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
1017
1018 // register the input connection slots for the layer, connections are made after all layers have been created
1019 // only the tensors for the inputs are relevant, exclude the const tensors
1020 RegisterInputSlots(layer, {node.input(0)});
1021
1022 // register the output connection slots for the layer, connections are made after all layers have been created
1023 RegisterOutputSlots(layer, {node.output(0)});
1024}
1025
Kevin Mayef33cb12021-01-29 14:24:57 +00001026void OnnxParserImpl::AddFullyConnected(const onnx::NodeProto& matmulNode, const onnx::NodeProto* addNode)
telsoa01c577f2c2018-08-31 09:22:23 +01001027{
1028
1029 // find matmul inputs
1030 std::string weightName;
1031 std::string inputName;
1032 CHECK_VALID_SIZE(static_cast<size_t>(matmulNode.input_size()), 2);
1033 CHECK_VALID_SIZE(static_cast<size_t>(matmulNode.output_size()), 1);
1034 VALID_INPUTS(matmulNode, STR_LIST(onnx::TensorProto::FLOAT));
1035
1036 GetInputAndParam(matmulNode, &inputName, &weightName, CHECK_LOCATION());
1037
1038 FullyConnectedDescriptor desc;
1039 desc.m_BiasEnabled = addNode != nullptr;
1040
1041 IConnectableLayer* layer = nullptr;
1042 if(desc.m_BiasEnabled)
1043 {
1044 // find bias const
1045 std::string biasName;
1046 CHECK_VALID_SIZE(static_cast<size_t>(addNode->input_size()), 2);
1047 CHECK_VALID_SIZE(static_cast<size_t>(addNode->output_size()), 1);
1048 VALID_INPUTS(*addNode, STR_LIST(onnx::TensorProto::FLOAT));
1049
1050 GetInputAndParam(*addNode, nullptr, &biasName, CHECK_LOCATION());
1051
1052 //Output shape is [1, weights[1]] and 1d vec in ONNX can be [1,X] so we convert biases to "armnn" 1D
1053 To1DTensor(biasName, CHECK_LOCATION());
1054 TensorInfo weightInfo = *m_TensorsInfo[weightName].m_info;
1055 TensorInfo biasInfo = *m_TensorsInfo[biasName].m_info;
1056
1057 if (weightInfo.GetShape()[1] != biasInfo.GetShape()[0])
1058 {
James Ward58dec6b2020-09-11 17:32:44 +01001059 throw ParseException(
1060 fmt::format("Shape of weights '{}' and bias of following Add node '{}' do not match : {}"
1061 " and {} ( /!\\ bias should be a 1D tensor) {}",
1062 weightName,
1063 addNode->name(),
1064 TensorInfoAsString(*m_TensorsInfo[weightName].m_info, weightName,
1065 m_TensorsInfo[weightName].m_dtype),
1066 TensorInfoAsString(*m_TensorsInfo[biasName].m_info, biasName,
1067 m_TensorsInfo[biasName].m_dtype ),
1068 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001069 }
Matthew Sloyan81beae32021-07-13 19:46:11 +01001070
1071 // Just add a FullyConnected layer, weights and biases are handled as inputs now.
1072 layer = m_Network->AddFullyConnectedLayer(desc, matmulNode.name().c_str());
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001073 ARMNN_ASSERT(layer != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +01001074
1075 auto outputInfo = ComputeOutputInfo({addNode->output(0)}, layer,
1076 {m_TensorsInfo[inputName].m_info->GetShape(),
1077 m_TensorsInfo[weightName].m_info->GetShape()});
telsoa01c577f2c2018-08-31 09:22:23 +01001078 layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
1079
Matthew Sloyan81beae32021-07-13 19:46:11 +01001080 // Add constant layer to store weights/biases and connect to FullyConnected layer..
1081 if(m_TensorsInfo[weightName].isConstant())
1082 {
1083 IConnectableLayer* weightsLayer = m_Network->AddConstantLayer(CreateConstTensor(weightName).first);
1084
1085 weightInfo.SetConstant();
1086 weightsLayer->GetOutputSlot(0).SetTensorInfo(weightInfo);
1087 weightsLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1u));
1088 }
1089
1090 if(m_TensorsInfo[biasName].isConstant())
1091 {
1092 IConnectableLayer* biasLayer = m_Network->AddConstantLayer(CreateConstTensor(biasName).first);
1093
1094 biasInfo.SetConstant();
1095 biasLayer->GetOutputSlot(0).SetTensorInfo(biasInfo);
1096 biasLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(2u));
1097 }
1098
1099 RegisterInputSlots(layer, {inputName, weightName, biasName});
telsoa01c577f2c2018-08-31 09:22:23 +01001100 RegisterOutputSlots(layer, {addNode->output(0)});
1101 }
1102 else
1103 {
Matthew Sloyan81beae32021-07-13 19:46:11 +01001104 layer = m_Network->AddFullyConnectedLayer(desc, matmulNode.name().c_str());
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001105 ARMNN_ASSERT(layer != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +01001106
1107 auto outputInfo = ComputeOutputInfo({matmulNode.output(0)}, layer,
1108 {m_TensorsInfo[inputName].m_info->GetShape(),
1109 m_TensorsInfo[weightName].m_info->GetShape()});
1110 layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
1111
Matthew Sloyan81beae32021-07-13 19:46:11 +01001112 // Add constant layer to store weights and connect to FullyConnected layer.
1113 if(m_TensorsInfo[weightName].isConstant())
1114 {
1115 TensorInfo weightInfo = *m_TensorsInfo[weightName].m_info;
1116 IConnectableLayer* weightsLayer = m_Network->AddConstantLayer(CreateConstTensor(weightName).first);
1117
1118 weightInfo.SetConstant();
1119 weightsLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1u));
1120 weightsLayer->GetOutputSlot(0).SetTensorInfo(weightInfo);
1121 }
1122
1123 RegisterInputSlots(layer, {inputName, weightName});
telsoa01c577f2c2018-08-31 09:22:23 +01001124 RegisterOutputSlots(layer, {matmulNode.output(0)});
1125 }
1126}
1127
Kevin Mayef33cb12021-01-29 14:24:57 +00001128void OnnxParserImpl::AddPoolingLayer(const onnx::NodeProto& node, Pooling2dDescriptor& desc)
telsoa01c577f2c2018-08-31 09:22:23 +01001129{
1130
1131 CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 1);
1132 CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
1133
1134 VALID_INPUTS(node, STR_LIST(onnx::TensorProto::FLOAT));
1135
1136 std::vector<uint32_t> kernel_shape = ReadMandatoryNodeUint32ListAttribute(node, "kernel_shape"); //size of pool win
1137 std::vector<uint32_t> strides = ReadOptionalNodeUint32ListAttribute(node, "strides");
1138 std::vector<uint32_t> pads = ReadOptionalNodeUint32ListAttribute(node, "pads");
1139
1140 desc.m_OutputShapeRounding = OutputShapeRounding::Floor;
1141 desc.m_PoolWidth = kernel_shape[1];
1142 desc.m_PoolHeight = kernel_shape[0];
1143
1144 if(strides.empty())
1145 {
1146 desc.m_StrideX = 1;
1147 desc.m_StrideY = 1;
1148 }
1149 else
1150 {
1151 desc.m_StrideX = strides[1];
1152 desc.m_StrideY = strides[0];
1153 }
1154
1155 //Check new padding version first
1156 if(pads.empty())
1157 {
1158 //Check deprecated version
1159 std::string paddingString = ReadOptionalNodeStringAttribute(node, "auto_pad");
1160 if(paddingString != "VALID" && paddingString != "" && paddingString != "NOTSET")
1161 {
1162 bool isUpper;
1163 if( paddingString == "SAME_LOWER")
1164 {
1165 isUpper = false;
1166 }
1167 else if (paddingString == "SAME_UPPER")
1168 {
1169 isUpper = true;
1170 }
1171 else
1172 {
James Ward58dec6b2020-09-11 17:32:44 +01001173 throw ParseException(fmt::format("Invalid auto_pad attribute for node {}. "
1174 "Only SAME_UPPER, SAME_LOWER or VALID supported and found {} {}",
1175 node.name(),
1176 paddingString,
1177 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001178 }
1179 auto inputInfo = *m_TensorsInfo[node.input(0)].m_info;
1180 uint32_t inputHeight = inputInfo.GetShape()[2];
1181 uint32_t inputWidth = inputInfo.GetShape()[3];
Sadik Armagan60bb9d82021-01-11 15:15:01 +00001182 CalcPadding(inputHeight,
1183 desc.m_PoolHeight,
1184 desc.m_StrideY,
1185 1u,
1186 &desc.m_PadTop,
1187 &desc.m_PadBottom,
1188 isUpper);
1189 CalcPadding(inputWidth,
1190 desc.m_PoolWidth,
1191 desc.m_StrideX,
1192 1u,
1193 &desc.m_PadLeft,
1194 &desc.m_PadRight,
1195 isUpper);
telsoa01c577f2c2018-08-31 09:22:23 +01001196 }
1197 }
1198 else
1199 {
1200 desc.m_PadTop = pads[0];
1201 desc.m_PadLeft = pads[1];
1202 desc.m_PadBottom = pads[2];
1203 desc.m_PadRight = pads[3];
1204 }
1205
1206 IConnectableLayer* layer = m_Network->AddPooling2dLayer(desc, node.name().c_str());
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001207 ARMNN_ASSERT(layer != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +01001208
1209 auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, {m_TensorsInfo[node.input(0)].m_info->GetShape()});
1210 layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
1211
1212 // register the input connection slots for the layer, connections are made after all layers have been created
1213 // only the tensors for the inputs are relevant, exclude the const tensors
1214 RegisterInputSlots(layer, {node.input(0)});
1215
1216 // register the output connection slots for the layer, connections are made after all layers have been created
1217 RegisterOutputSlots(layer, {node.output(0)});
1218}
1219
Kevin Mayef33cb12021-01-29 14:24:57 +00001220std::pair<std::string, std::string> OnnxParserImpl::AddPrepareBroadcast(const std::string& input0,
1221 const std::string& input1)
Ryan OSheaed27ee72020-04-22 16:37:29 +01001222{
1223 std::pair<std::string, std::string> inputs = std::make_pair(input0, input1);
1224
1225 TensorShape input0Shape = m_TensorsInfo[input0].m_info->GetShape();
1226 TensorShape input1Shape = m_TensorsInfo[input1].m_info->GetShape();
1227
1228 if(input1Shape.GetNumDimensions() < input0Shape.GetNumDimensions())
1229 {
James Ward58dec6b2020-09-11 17:32:44 +01001230 auto outputName = fmt::format("reshape_output_{}", input1);
Ryan OSheaed27ee72020-04-22 16:37:29 +01001231 PrependForBroadcast(outputName, input1, input0);
1232 inputs.second = outputName;
1233 }
1234 else if(input0Shape.GetNumDimensions() < input1Shape.GetNumDimensions())
1235 {
James Ward58dec6b2020-09-11 17:32:44 +01001236 auto outputName = fmt::format("reshape_output_{}", input0);
Ryan OSheaed27ee72020-04-22 16:37:29 +01001237 PrependForBroadcast(outputName, input0, input1);
1238 inputs.first = outputName;
1239 }
1240 return inputs;
1241}
1242
Kevin Mayef33cb12021-01-29 14:24:57 +00001243void OnnxParserImpl::CreateConstantLayer(const std::string& tensorName, const std::string& layerName)
Ryan OSheaed27ee72020-04-22 16:37:29 +01001244{
1245 auto armnnTensor = CreateConstTensor(tensorName);
Narumol Prangnawaratf10b15a2021-09-17 21:08:57 +01001246 IConnectableLayer* layer = m_Network->AddConstantLayer(armnnTensor.first, layerName.c_str());
1247 layer->GetOutputSlot(0).SetTensorInfo(armnnTensor.first.GetInfo());
1248 RegisterOutputSlots(layer, {tensorName});
1249}
Ryan OSheaed27ee72020-04-22 16:37:29 +01001250
Narumol Prangnawaratf10b15a2021-09-17 21:08:57 +01001251void OnnxParserImpl::CreateInt64ConstantLayer(const std::string& tensorName, const std::string& layerName)
1252{
1253 auto armnnTensor = CreateInt64ConstTensor(tensorName);
Ryan OSheaed27ee72020-04-22 16:37:29 +01001254 IConnectableLayer* layer = m_Network->AddConstantLayer(armnnTensor.first, layerName.c_str());
1255 layer->GetOutputSlot(0).SetTensorInfo(armnnTensor.first.GetInfo());
1256 RegisterOutputSlots(layer, {tensorName});
1257}
1258
Kevin Mayef33cb12021-01-29 14:24:57 +00001259void OnnxParserImpl::CreateReshapeLayer(const std::string& inputName,
1260 const std::string& outputName,
1261 const std::string& layerName)
telsoa01c577f2c2018-08-31 09:22:23 +01001262{
1263 const TensorInfo outputTensorInfo = *m_TensorsInfo[outputName].m_info;
1264 ReshapeDescriptor reshapeDesc;
1265 reshapeDesc.m_TargetShape = outputTensorInfo.GetShape();
1266
1267 IConnectableLayer* layer = m_Network->AddReshapeLayer(reshapeDesc, layerName.c_str());
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001268 ARMNN_ASSERT(layer != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +01001269 layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
1270
1271 // register the input connection slots for the layer, connections are made after all layers have been created
1272 // only the tensors for the inputs are relevant, exclude the const tensors
1273 RegisterInputSlots(layer, {inputName});
1274
1275 // register the output connection slots for the layer, connections are made after all layers have been created
1276 RegisterOutputSlots(layer, {outputName});
1277}
1278
Kevin Mayef33cb12021-01-29 14:24:57 +00001279void OnnxParserImpl::ParseActivation(const onnx::NodeProto& node, const armnn::ActivationFunction func)
telsoa01c577f2c2018-08-31 09:22:23 +01001280{
Finn Williams7ee5d2c2020-03-27 11:11:50 +00001281 CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 1, 3);
telsoa01c577f2c2018-08-31 09:22:23 +01001282 CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
1283
1284 VALID_INPUTS(node, STR_LIST(onnx::TensorProto::FLOAT));
1285
1286 ActivationDescriptor desc;
Tee Jung7ff9a602019-11-01 07:04:42 +00001287 desc.m_Function = func;
telsoa01c577f2c2018-08-31 09:22:23 +01001288
Finn Williams7ee5d2c2020-03-27 11:11:50 +00001289 if (func == ActivationFunction::BoundedReLu)
1290 {
Narumol Prangnawaratf106ab72021-09-15 17:30:37 +01001291 if (node.input_size() == 1 && node.attribute_size() > 0)
1292 {
1293 desc.m_A = ReadOptionalNodeFloatAttribute(node, "max", std::numeric_limits<float>::max());
1294 desc.m_B = ReadOptionalNodeFloatAttribute(node, "min", std::numeric_limits<float>::lowest());
1295 }
1296 else
1297 {
1298 desc.m_A = node.input(2).empty() ? std::numeric_limits<float>::max() : std::stof(node.input(2));
1299 desc.m_B = node.input(1).empty() ? std::numeric_limits<float>::lowest() : std::stof(node.input(1));
1300 }
Finn Williams7ee5d2c2020-03-27 11:11:50 +00001301 }
1302
telsoa01c577f2c2018-08-31 09:22:23 +01001303 IConnectableLayer* const layer = m_Network->AddActivationLayer(desc, node.name().c_str());
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001304 ARMNN_ASSERT(layer != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +01001305
1306 auto outputInfo = ComputeOutputInfo({ node.output(0)}, layer, {m_TensorsInfo[node.input(0)].m_info->GetShape()});
1307 layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
1308
1309 // register the input connection slots for the layer, connections are made after all layers have been created
1310 // only the tensors for the inputs are relevant, exclude the const tensors
1311 RegisterInputSlots(layer, {node.input(0)});
1312
1313 // register the output connection slots for the layer, connections are made after all layers have been created
1314 RegisterOutputSlots(layer, {node.output(0)});
1315}
1316
Kevin Mayef33cb12021-01-29 14:24:57 +00001317void OnnxParserImpl::ParseClip(const onnx::NodeProto& node)
Finn Williams7ee5d2c2020-03-27 11:11:50 +00001318{
1319 ParseActivation(node, ActivationFunction::BoundedReLu);
1320}
1321
Kevin Mayef33cb12021-01-29 14:24:57 +00001322void OnnxParserImpl::ParseSigmoid(const onnx::NodeProto& node)
Tee Jung7ff9a602019-11-01 07:04:42 +00001323{
1324 ParseActivation(node, ActivationFunction::Sigmoid);
1325}
1326
Kevin Mayef33cb12021-01-29 14:24:57 +00001327void OnnxParserImpl::ParseTanh(const onnx::NodeProto& node)
Tee Jung7ff9a602019-11-01 07:04:42 +00001328{
1329 ParseActivation(node, ActivationFunction::TanH);
1330}
1331
Kevin Mayef33cb12021-01-29 14:24:57 +00001332void OnnxParserImpl::ParseRelu(const onnx::NodeProto& node)
Tee Jung7ff9a602019-11-01 07:04:42 +00001333{
1334 ParseActivation(node, ActivationFunction::ReLu);
1335}
1336
Kevin Mayef33cb12021-01-29 14:24:57 +00001337void OnnxParserImpl::ParseLeakyRelu(const onnx::NodeProto& node)
Tee Jung7ff9a602019-11-01 07:04:42 +00001338{
1339 ParseActivation(node, ActivationFunction::LeakyReLu);
1340}
telsoa01c577f2c2018-08-31 09:22:23 +01001341
Kevin Mayef33cb12021-01-29 14:24:57 +00001342void OnnxParserImpl::ParseAdd(const onnx::NodeProto& node)
telsoa01c577f2c2018-08-31 09:22:23 +01001343{
Ryan OSheaed27ee72020-04-22 16:37:29 +01001344 CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 2);
1345 CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
telsoa01c577f2c2018-08-31 09:22:23 +01001346
Ryan OSheaed27ee72020-04-22 16:37:29 +01001347 VALID_INPUTS(node, STR_LIST(onnx::TensorProto::FLOAT));
telsoa01c577f2c2018-08-31 09:22:23 +01001348
Ryan OSheaed27ee72020-04-22 16:37:29 +01001349 // TODO: unify broadcast validation code across layers
1350 // tracked by: IVGCVSW-1576
telsoa01c577f2c2018-08-31 09:22:23 +01001351
Ryan OSheaed27ee72020-04-22 16:37:29 +01001352 // Checking broadcast compatibility : only scalar or 1D tensors
1353 auto inputs = AddPrepareBroadcast(node.input(0), node.input(1));
1354 auto input0 = *m_TensorsInfo[inputs.first].m_info;
1355 auto input1 = *m_TensorsInfo[inputs.second].m_info;
1356 ARMNN_ASSERT(input0.GetNumDimensions() == input1.GetNumDimensions());
1357
1358 unsigned int numDims = input0.GetNumDimensions();
1359 for (unsigned int i = 0; i < numDims; i++)
telsoa01c577f2c2018-08-31 09:22:23 +01001360 {
Ryan OSheaed27ee72020-04-22 16:37:29 +01001361 unsigned int dim0 = input0.GetShape()[i];
1362 unsigned int dim1 = input1.GetShape()[i];
1363 if (dim0 != dim1 && dim0 != 1 && dim1 != 1)
telsoa01c577f2c2018-08-31 09:22:23 +01001364 {
James Ward58dec6b2020-09-11 17:32:44 +01001365 throw ParseException(
1366 fmt::format("Broadcast is only supported for scalar or 1D tensors in Add node '{}'. "
1367 "Input dimensions should either match or one should be of size 1 and here, "
1368 "{} and {} {}",
1369 node.name(),
1370 TensorInfoAsString(*m_TensorsInfo[inputs.first].m_info, inputs.first,
1371 m_TensorsInfo[inputs.first].m_dtype),
1372 TensorInfoAsString(*m_TensorsInfo[inputs.second].m_info, inputs.second,
1373 m_TensorsInfo[inputs.second].m_dtype),
1374 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001375 }
telsoa01c577f2c2018-08-31 09:22:23 +01001376 }
Ryan OSheaed27ee72020-04-22 16:37:29 +01001377
1378
1379 IConnectableLayer* layer = m_Network->AddAdditionLayer(node.name().c_str());
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001380 ARMNN_ASSERT(layer != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +01001381
1382 auto outputInfo = ComputeOutputInfo({ node.output(0) }, layer,
Ryan OSheaed27ee72020-04-22 16:37:29 +01001383 { m_TensorsInfo[inputs.first].m_info->GetShape(),
1384 m_TensorsInfo[inputs.second].m_info->GetShape() });
telsoa01c577f2c2018-08-31 09:22:23 +01001385 layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
1386
Ryan OSheaed27ee72020-04-22 16:37:29 +01001387 // register the input connection -> for constant inputs, we need to make a newDim constant layer
1388 if(m_TensorsInfo[inputs.first].isConstant()) {
James Ward58dec6b2020-09-11 17:32:44 +01001389 CreateConstantLayer(inputs.first, fmt::format("Add:constant_of_{}", node.input(0)));
Ryan OSheaed27ee72020-04-22 16:37:29 +01001390 }
1391 if(m_TensorsInfo[inputs.second].isConstant()) {
James Ward58dec6b2020-09-11 17:32:44 +01001392 CreateConstantLayer(inputs.second, fmt::format("Add:constant_of_{}", node.input(1)));
Ryan OSheaed27ee72020-04-22 16:37:29 +01001393 }
1394 RegisterInputSlots(layer, {inputs.first, inputs.second});
telsoa01c577f2c2018-08-31 09:22:23 +01001395
Ryan OSheaed27ee72020-04-22 16:37:29 +01001396 // register the output connection
telsoa01c577f2c2018-08-31 09:22:23 +01001397 RegisterOutputSlots(layer, {node.output(0)});
1398}
1399
Kevin Mayef33cb12021-01-29 14:24:57 +00001400void OnnxParserImpl::ParseAveragePool(const onnx::NodeProto& node)
Ryan OSheaed27ee72020-04-22 16:37:29 +01001401{
1402 Pooling2dDescriptor desc;
1403 desc.m_PoolType = PoolingAlgorithm::Average;
1404
1405 uint32_t count_include_pad = 0;
1406 count_include_pad = ReadOptionalNodeUint32Attribute(node, "count_include_pad");
1407 if(count_include_pad) {
1408 desc.m_PaddingMethod = PaddingMethod::IgnoreValue;
1409 }
1410 AddPoolingLayer(node, desc);
1411}
1412
Kevin Mayef33cb12021-01-29 14:24:57 +00001413void OnnxParserImpl::ParseBatchNormalization(const onnx::NodeProto& node)
Ryan OSheaed27ee72020-04-22 16:37:29 +01001414{
1415 //IGNORE momentum parameter and spatial parameters
1416
1417 CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 5);
1418 CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
1419
1420 VALID_INPUTS(node, STR_LIST(onnx::TensorProto::FLOAT));
1421 for(int ind = 1; ind < node.input_size(); ++ind)
1422 {
1423 auto tensor = node.input(ind);
1424 if(! m_TensorsInfo[tensor].isConstant())
1425 {
James Ward58dec6b2020-09-11 17:32:44 +01001426 throw ParseException(
1427 fmt::format("Input tensor '{}' should be constant in BatchNormalization node '{}' {}",
1428 tensor,
1429 node.name(),
1430 CHECK_LOCATION().AsString()));
Ryan OSheaed27ee72020-04-22 16:37:29 +01001431 }
1432 }
1433
1434 float epsilon = ReadOptionalNodeFloatAttribute(node, "epsilon", 1e-5f);
1435 BatchNormalizationDescriptor desc;
1436 desc.m_Eps = epsilon;
1437
1438 auto scaleTensor = CreateConstTensor(node.input(1));
1439 auto biasTensor = CreateConstTensor(node.input(2));
1440 auto meanTensor = CreateConstTensor(node.input(3));
1441 auto varTensor = CreateConstTensor(node.input(4));
1442
1443 IConnectableLayer* layer = m_Network->AddBatchNormalizationLayer(desc,
1444 meanTensor.first,
1445 varTensor.first,
1446 biasTensor.first,
1447 scaleTensor.first,
1448 node.name().c_str());
1449 ARMNN_ASSERT(layer != nullptr);
1450
1451 auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, {m_TensorsInfo[node.input(0)].m_info->GetShape()});
1452 layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
1453
1454 RegisterInputSlots(layer, {node.input(0)}); //don't register constant inputs
1455
1456 // register the output connection
1457 RegisterOutputSlots(layer, {node.output(0)});
1458}
1459
Narumol Prangnawaratbc3bb622021-09-24 16:08:34 +01001460void OnnxParserImpl::ParseConcat(const onnx::NodeProto& node)
1461{
1462 CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
1463
1464 uint32_t numConcatView = static_cast<uint32_t>(node.input_size());
1465 uint32_t inputRank = m_TensorsInfo[node.input(0)].m_info->GetNumDimensions();
1466
1467 int axisInt = ReadMandatoryNodeIntAttribute(node, "axis");
1468
1469 unsigned int concatDimInput = static_cast<unsigned int>(
1470 (static_cast<int>(inputRank) + axisInt) % static_cast<int>(inputRank));
1471
1472 OriginsDescriptor concatDescriptor(numConcatView, inputRank);
1473 concatDescriptor.SetConcatAxis(concatDimInput);
1474
1475 unsigned int mergeDimOrigin = 0;
1476
1477 std::vector<TensorShape> inputShapes;
1478 std::vector<std::string> tensorIds;
1479
1480 for (unsigned int viewIndex = 0; viewIndex < numConcatView; ++viewIndex)
1481 {
1482 std::string nodeName = node.input(static_cast<int>(viewIndex));
1483 auto inputTensorInfo = *m_TensorsInfo[nodeName].m_info;
1484 inputShapes.push_back(inputTensorInfo.GetShape());
1485 tensorIds.push_back(nodeName);
1486
1487 // Set up concatDescriptor view origin
1488 armnnUtils::ProcessConcatInputTensorInfo(
1489 inputTensorInfo, concatDescriptor, concatDimInput, viewIndex, mergeDimOrigin);
1490 }
1491
1492 IConnectableLayer* layer = m_Network->AddConcatLayer(concatDescriptor, node.name().c_str());
1493 ARMNN_ASSERT(layer != nullptr);
1494
Narumol Prangnawarat452274c2021-09-23 16:12:19 +01001495 auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, inputShapes,
1496 m_TensorsInfo[node.input(0)].m_dtype);
Narumol Prangnawaratbc3bb622021-09-24 16:08:34 +01001497
1498 layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
1499
1500 // register the input connection slots for the layer, connections are made after all layers have been created
1501 RegisterInputSlots(layer, tensorIds);
1502
1503 // register the output connection slots for the layer, connections are made after all layers have been created
1504 RegisterOutputSlots(layer, { node.output(0) });
1505}
1506
Kevin Mayef33cb12021-01-29 14:24:57 +00001507void OnnxParserImpl::ParseConstant(const onnx::NodeProto& node)
Ryan OSheaed27ee72020-04-22 16:37:29 +01001508{
1509 CHECK_VALID_SIZE(static_cast<size_t>(node.attribute_size()), 1);
1510 if (!node.attribute(0).has_t())
1511 {
James Ward58dec6b2020-09-11 17:32:44 +01001512 throw ParseException(fmt::format("Value not found for Constant node '{}' {}",
1513 node.name(),
1514 CHECK_LOCATION().AsString()));
Ryan OSheaed27ee72020-04-22 16:37:29 +01001515 }
1516 const onnx::TensorProto& onnxTensor = node.attribute(0).t();
1517
Ryan OSheaed27ee72020-04-22 16:37:29 +01001518 //Register this as a m_ConstParam so we know we can use it as a constant param in future layers.
1519 m_TensorsInfo[node.output(0)].m_tensor = std::make_unique<const onnx::TensorProto>(onnxTensor);
1520 m_TensorsInfo[node.output(0)].m_info = std::make_unique<TensorInfo>(ToTensorInfo(onnxTensor));
1521 m_TensorsInfo[node.output(0)].m_dtype = static_cast<onnx::TensorProto::DataType>(onnxTensor.data_type());
1522
Narumol Prangnawaratf10b15a2021-09-17 21:08:57 +01001523 if (m_TensorsInfo[node.output(0)].m_dtype == onnx::TensorProto_DataType_FLOAT)
1524 {
1525 CreateConstantLayer(node.output(0), node.name());
1526 }
1527 else if (m_TensorsInfo[node.output(0)].m_dtype == onnx::TensorProto_DataType_INT64)
1528 {
1529 CreateInt64ConstantLayer(node.output(0), node.name());
1530 }
1531 else
1532 {
1533 throw ParseException(fmt::format("Data type not support for Constant node '{}' {}",
1534 node.name(),
1535 CHECK_LOCATION().AsString()));
1536 }
Ryan OSheaed27ee72020-04-22 16:37:29 +01001537}
1538
Kevin Mayef33cb12021-01-29 14:24:57 +00001539void OnnxParserImpl::ParseConv(const onnx::NodeProto& node)
telsoa01c577f2c2018-08-31 09:22:23 +01001540{
1541 CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 2, 3); //input, weight, (bias)
1542 CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
1543
1544 VALID_INPUTS(node, STR_LIST(onnx::TensorProto::FLOAT));
1545
1546 if(m_TensorsInfo[node.input(0)].m_info->GetNumDimensions() != 4)
1547 {
James Ward58dec6b2020-09-11 17:32:44 +01001548 throw ParseException(
1549 fmt::format("ArmNN only supports 2D convolution and Conv layer '{}' input {} {}",
1550 node.name(),
1551 TensorInfoAsString(*m_TensorsInfo[node.input(0)].m_info, node.input(0),
1552 m_TensorsInfo[node.input(0)].m_dtype),
1553 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001554 }
1555
1556 if(!m_TensorsInfo[node.input(1)].isConstant())
1557 {
James Ward58dec6b2020-09-11 17:32:44 +01001558 throw ParseException(
1559 fmt::format("Weights '{}' should be constant in Conv layer '{}' {}",
1560 node.input(1),
1561 node.name(),
1562 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001563 }
1564
1565 auto inputInfo = *m_TensorsInfo[node.input(0)].m_info;
1566
telsoa01c577f2c2018-08-31 09:22:23 +01001567 Convolution2dDescriptor desc;
1568 desc.m_BiasEnabled = false;
1569
1570 std::vector<uint32_t> strides = ReadOptionalNodeUint32ListAttribute(node, "strides");
1571 if(strides.empty())
1572 {
1573 desc.m_StrideX = 1;
1574 desc.m_StrideY = 1;
1575 }
1576 else
1577 {
1578 desc.m_StrideX = strides[1];
1579 desc.m_StrideY = strides[0];
1580 }
1581
Sadik Armagan60bb9d82021-01-11 15:15:01 +00001582 std::vector<uint32_t> dilations = ReadOptionalNodeUint32ListAttribute(node, "dilations");
1583 if(!dilations.empty())
1584 {
1585 desc.m_DilationX = dilations[1];
1586 desc.m_DilationY = dilations[0];
1587 }
1588
telsoa01c577f2c2018-08-31 09:22:23 +01001589 std::vector<uint32_t> pads = ReadOptionalNodeUint32ListAttribute(node, "pads");
1590 //Check new padding version first
1591 if(pads.empty())
1592 {
1593 //Check deprecated version
1594 std::string paddingString = ReadOptionalNodeStringAttribute(node, "auto_pad");
1595 if(paddingString != "VALID" && paddingString != "" && paddingString != "NOTSET")
1596 {
1597 bool isUpper;
1598 if( paddingString == "SAME_LOWER")
1599 {
1600 isUpper = false;
1601 }
1602 else if (paddingString == "SAME_UPPER")
1603 {
1604 isUpper = true;
1605 }
1606 else
1607 {
James Ward58dec6b2020-09-11 17:32:44 +01001608 throw ParseException(
1609 fmt::format("Invalid auto_pad attribute for node {}. Only SAME_UPPER, SAME_LOWER or VALID "
1610 "supported and found {} {}",
1611 node.name(),
1612 paddingString,
1613 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001614 }
1615 uint32_t inputHeight = inputInfo.GetShape()[2];
1616 uint32_t inputWidth = inputInfo.GetShape()[3];
1617
1618 uint32_t weightHeight;
1619 uint32_t weightWidth;
1620 std::vector<uint32_t> kernel_shape = ReadOptionalNodeUint32ListAttribute(node, "kernel_shape");
1621 if (kernel_shape.empty())
1622 {
1623 const TensorInfo weightTensorInfo = *m_TensorsInfo[node.input(1)].m_info;
1624 weightHeight = weightTensorInfo.GetShape()[2];
1625 weightWidth = weightTensorInfo.GetShape()[3];
1626 }
1627 else
1628 {
1629 weightHeight = kernel_shape[0];
1630 weightWidth = kernel_shape[1];
1631 }
Sadik Armagan60bb9d82021-01-11 15:15:01 +00001632 CalcPadding(inputHeight,
1633 weightHeight,
1634 desc.m_StrideY,
1635 desc.m_DilationY,
1636 &desc.m_PadTop,
1637 &desc.m_PadBottom,
1638 isUpper);
1639 CalcPadding(inputWidth,
1640 weightWidth,
1641 desc.m_StrideX,
1642 desc.m_DilationX,
1643 &desc.m_PadLeft,
1644 &desc.m_PadRight,
1645 isUpper);
telsoa01c577f2c2018-08-31 09:22:23 +01001646 }
1647 }
1648 else
1649 {
1650 desc.m_PadTop = pads[0];
1651 desc.m_PadLeft = pads[1];
1652 desc.m_PadBottom = pads[2];
1653 desc.m_PadRight = pads[3];
1654 }
1655
1656 uint32_t group = ReadOptionalNodeUint32Attribute(node, "group", 1);
1657 if(group > 1)
1658 {
1659 if (group > inputInfo.GetShape()[1])
1660 {
1661 throw ParseException(
James Ward58dec6b2020-09-11 17:32:44 +01001662 fmt::format("Error parsing Convolution node: {}. "
1663 "The 'group'={} parameter cannot be larger than the "
1664 "channel of the input shape={} (in NCHW format). {}",
1665 node.name(),
1666 group,
1667 inputInfo.GetShape()[1],
1668 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001669 }
1670 else if (group == inputInfo.GetShape()[1])
1671 {
1672 // we use a depthwise convolution here, because the number of groups equals to the
1673 // input channels
1674 AddConvLayerWithDepthwiseConv(node, desc);
1675 return;
1676 }
1677 else
1678 {
1679 // TODO: split the input by channels into channels/groups separate convolutions
Jim Flynne242f2d2019-05-22 14:24:13 +01001680 // and concatenate the results afterwards
James Ward58dec6b2020-09-11 17:32:44 +01001681 throw ParseException(fmt::format("Error parsing Convolution node: {}. "
1682 "The 'group'={} parameter should be 1 or be equal to the "
1683 "channel of the input shape={} (in NCHW format). {}",
1684 node.name(),
1685 group,
1686 inputInfo.GetShape()[1],
1687 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001688 }
1689 }
1690
1691 armnn::IConnectableLayer* layer;
1692 auto weightTensor = CreateConstTensor(node.input(1));
1693
1694 if (node.input_size() == 3)
1695 {
1696 if(!m_TensorsInfo[node.input(2)].isConstant())
1697 {
James Ward58dec6b2020-09-11 17:32:44 +01001698 throw ParseException(fmt::format("Bias '{}' should be constant in Conv layer '{}' {}",
1699 node.input(2),
1700 node.name(),
1701 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001702 }
1703 desc.m_BiasEnabled = true;
1704 auto biasTensor = CreateConstTensor(node.input(2));
1705 layer = m_Network->AddConvolution2dLayer(desc,
1706 weightTensor.first,
Matteo Martincighfc598e12019-05-14 10:36:13 +01001707 Optional<ConstTensor>(biasTensor.first),
telsoa01c577f2c2018-08-31 09:22:23 +01001708 node.name().c_str());
1709 }
1710 else
1711 {
1712 layer = m_Network->AddConvolution2dLayer(desc,
1713 weightTensor.first,
Matteo Martincighfc598e12019-05-14 10:36:13 +01001714 EmptyOptional(),
telsoa01c577f2c2018-08-31 09:22:23 +01001715 node.name().c_str());
1716 }
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001717 ARMNN_ASSERT(layer != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +01001718
1719 auto outputInfo = ComputeOutputInfo({ node.output(0) }, layer,
1720 { m_TensorsInfo[node.input(0)].m_info->GetShape(),
1721 m_TensorsInfo[node.input(1)].m_info->GetShape() });
1722 layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
1723
1724 // register the input connection slots for the layer, connections are made after all layers have been created
1725 // only the tensors for the inputs are relevant, exclude the const tensors
1726 RegisterInputSlots(layer, {node.input(0)});
1727
1728 // register the output connection slots for the layer, connections are made after all layers have been created
1729 RegisterOutputSlots(layer, {node.output(0)});
1730}
1731
Kevin Mayef33cb12021-01-29 14:24:57 +00001732void OnnxParserImpl::ParseFlatten(const onnx::NodeProto& node)
Ryan OSheaed27ee72020-04-22 16:37:29 +01001733{
1734 CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 1);
1735 CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
1736
1737 CHECK_VALID_DATATYPE(node.name(), node.input(0),
1738 m_TensorsInfo[node.input(0)].m_dtype,
1739 onnx::TensorProto::FLOAT);
1740
1741 int64_t axis = ReadOptionalNodeInt64Attribute(node, "axis", 1);
1742 TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape();
1743
1744 /// Negative axis conversion
1745 if (axis < 0)
1746 {
1747 axis += inputShape.GetNumDimensions();
1748 }
1749
1750 /// Check Axis is within dimensions
1751 if (axis < 0 || axis >= inputShape.GetNumDimensions())
1752 {
James Ward58dec6b2020-09-11 17:32:44 +01001753 throw ParseException(fmt::format("Axis '{}' invalid. Tensor has '{}' dimensions in FlattenLayer '{}'",
1754 axis, inputShape.GetNumDimensions(), node.name()));
Ryan OSheaed27ee72020-04-22 16:37:29 +01001755 }
1756
1757 /// If axis chosen is 0 dimension1 will always be 1 in output , default dimension2 to 1 because 0 is invalid
1758 uint dimension1{1};
1759 uint dimension2{1};
1760 uint i{0};
1761
1762 /// dimension1 = (d_0 * d_1 ... d_(axis-1))
1763 for (i = 0; i < axis; i++){
1764 dimension1 *= inputShape[i];
1765 }
1766
1767 /// dimension2 = (d_axis * d_(axis+1) ... d_n)
1768 for (i = static_cast<uint>(axis); i < inputShape.GetNumDimensions(); i++){
1769 dimension2 *= inputShape[i];
1770 }
1771
1772 TensorShape outputShape{dimension1, dimension2};
1773
1774 auto outInfo = ComputeReshapeInfo(outputShape, inputShape, node.output(0));
1775 m_TensorsInfo[node.output(0)].m_info = std::make_unique<TensorInfo>(outInfo);
1776 CreateReshapeLayer(node.input(0), node.output(0), node.name());
1777}
1778
Narumol Prangnawaratf10b15a2021-09-17 21:08:57 +01001779void OnnxParserImpl::ParseGather(const onnx::NodeProto& node)
1780{
1781 CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 2);
1782 CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
1783
1784 armnn::GatherDescriptor gatherDescriptor;
1785 gatherDescriptor.m_Axis = static_cast<int>(ReadOptionalNodeInt64Attribute(node, "axis", 0));
1786
1787 IConnectableLayer* layer = m_Network->AddGatherLayer(gatherDescriptor, node.name().c_str());
1788 ARMNN_ASSERT(layer != nullptr);
1789
Narumol Prangnawarat452274c2021-09-23 16:12:19 +01001790 const TensorShape& inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape();
1791 const TensorShape& indicesShape = m_TensorsInfo[node.input(1)].m_info->GetShape();
1792 auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, { inputShape, indicesShape },
1793 m_TensorsInfo[node.input(0)].m_dtype);
Narumol Prangnawaratf10b15a2021-09-17 21:08:57 +01001794 layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
1795
1796 // register the input connection slots for the layer, connections are made after all layers have been created
1797 RegisterInputSlots(layer, { node.input(0), node.input(1) });
1798
1799 // register the output connection slots for the layer, connections are made after all layers have been created
1800 RegisterOutputSlots(layer, { node.output(0) });
1801}
1802
Kevin Mayef33cb12021-01-29 14:24:57 +00001803void OnnxParserImpl::ParseGlobalAveragePool(const onnx::NodeProto& node)
Ryan OSheaed27ee72020-04-22 16:37:29 +01001804{
1805 Pooling2dDescriptor desc = Pooling2dDescriptor();
1806 desc.m_PoolType = PoolingAlgorithm::Average;
1807
1808 //kernel size is the same as input
1809 TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape();
1810 desc.m_PoolWidth = inputShape[3];
1811 desc.m_PoolHeight = inputShape[2];
1812
1813 IConnectableLayer* layer = m_Network->AddPooling2dLayer(desc, node.name().c_str());
1814 ARMNN_ASSERT(layer != nullptr);
1815
1816 auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, {inputShape});
1817 layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
1818
1819 // register the input connection slots for the layer, connections are made after all layers have been created
1820 // only the tensors for the inputs are relevant, exclude the const tensors
1821 RegisterInputSlots(layer, {node.input(0)});
1822
1823 // register the output connection slots for the layer, connections are made after all layers have been created
1824 RegisterOutputSlots(layer, {node.output(0)});
1825}
1826
Kevin Mayef33cb12021-01-29 14:24:57 +00001827void OnnxParserImpl::ParseMaxPool(const onnx::NodeProto& node)
Ryan OSheaed27ee72020-04-22 16:37:29 +01001828{
1829 Pooling2dDescriptor desc;
1830 desc.m_PoolType = PoolingAlgorithm::Max;
1831 desc.m_PaddingMethod = PaddingMethod::Exclude;
1832 AddPoolingLayer(node, desc);
1833}
1834
Narumol Prangnawaratcdc495e2021-09-16 18:13:39 +01001835void OnnxParserImpl::ParseShape(const onnx::NodeProto& node)
1836{
1837 CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 1);
1838 CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
1839
Narumol Prangnawaratcdc495e2021-09-16 18:13:39 +01001840 IConnectableLayer* layer = m_Network->AddShapeLayer(node.name().c_str());
1841 ARMNN_ASSERT(layer != nullptr);
1842
1843 TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape();
Narumol Prangnawarat452274c2021-09-23 16:12:19 +01001844 auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, {inputShape}, onnx::TensorProto::INT64);
Narumol Prangnawaratcdc495e2021-09-16 18:13:39 +01001845 layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
1846
1847 // register the input connection slots for the layer, connections are made after all layers have been created
1848 RegisterInputSlots(layer, {node.input(0)});
1849
1850 // register the output connection slots for the layer, connections are made after all layers have been created
1851 RegisterOutputSlots(layer, {node.output(0)});
1852}
1853
Kevin Mayef33cb12021-01-29 14:24:57 +00001854void OnnxParserImpl::ParseReshape(const onnx::NodeProto& node)
Ryan OSheaed27ee72020-04-22 16:37:29 +01001855{
1856 CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 2);
1857 CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
1858
1859 CHECK_VALID_DATATYPE(node.name(), node.input(0),
1860 m_TensorsInfo[node.input(0)].m_dtype,
1861 onnx::TensorProto::FLOAT); //input
1862 CHECK_VALID_DATATYPE(node.name(), node.input(1),
1863 m_TensorsInfo[node.input(1)].m_dtype,
1864 onnx::TensorProto::INT64); //shape
1865
1866 if(!m_TensorsInfo[node.input(1)].isConstant())
1867 {
James Ward58dec6b2020-09-11 17:32:44 +01001868 throw ParseException(fmt::format("Shape '{}' should be constant in Reshape layer '{}' {}",
1869 node.input(1),
1870 node.name(),
1871 CHECK_LOCATION().AsString()));
Ryan OSheaed27ee72020-04-22 16:37:29 +01001872 }
1873
1874 if(m_TensorsInfo[node.input(0)].isConstant())
1875 {
1876 //make a new cst tensor -> move the data to the output tensor (the shape is already good in the output tensor)
1877 if(m_TensorsInfo.count(node.output(0)) == 0)
1878 {
1879 m_TensorsInfo[node.output(0)] = OnnxTensor();
1880 }
1881 m_TensorsInfo[node.output(0)].m_tensor =
1882 std::make_unique<onnx::TensorProto>(*m_TensorsInfo[node.input(0)].m_tensor);
1883 }
1884 else
1885 {
1886 TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape();
1887
1888 if(m_TensorsInfo.count(node.output(0)) == 0 || m_TensorsInfo[node.output(0)].m_info == nullptr)
1889 {
1890 uint64_t dims = static_cast<uint64_t>(m_TensorsInfo[node.input(1)].m_tensor->int64_data_size());
1891 TensorShape targetShape{static_cast<unsigned int>(dims), 1};
1892
1893 for(uint i = 0; i < dims; i++)
1894 {
1895 int val = CHECKED_INT32(m_TensorsInfo[node.input(1)].m_tensor->int64_data(static_cast<int>(i)));
1896 targetShape[i]= static_cast<unsigned int>(val);
1897 }
1898
1899 auto outInfo = ComputeReshapeInfo(targetShape, inputShape, node.output(0));
1900 m_TensorsInfo[node.output(0)].m_info = std::make_unique<TensorInfo>(outInfo);
1901 }
1902
1903 CreateReshapeLayer(node.input(0), node.output(0), node.name());
1904 }
1905}
1906
Narumol Prangnawaratfe6aa2f2021-09-23 16:11:17 +01001907void OnnxParserImpl::ParseUnsqueeze(const onnx::NodeProto& node)
1908{
1909 CHECK_VALID_SIZE(armnn::numeric_cast<size_t>(node.input_size()), 1, 2);
1910 CHECK_VALID_SIZE(armnn::numeric_cast<size_t>(node.output_size()), 1);
1911
Narumol Prangnawaratfe6aa2f2021-09-23 16:11:17 +01001912 TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape();
1913 std::vector<uint32_t> dims;
1914 if (node.input_size() == 1 && node.attribute_size() > 0)
1915 {
1916 dims = ReadMandatoryNodeUint32ListAttribute(node, "axes");
1917 }
1918 else
1919 {
1920 CHECK_VALID_DATATYPE(node.name(), node.input(1),
1921 m_TensorsInfo[node.input(1)].m_dtype,
1922 onnx::TensorProto::INT64); //axes
1923
1924 auto int64Axes = m_TensorsInfo[node.input(1)].m_tensor->int64_data().data();
1925 uint numDim = armnn::numeric_cast<uint>(m_TensorsInfo[node.input(1)].m_tensor->int64_data_size());
1926
1927 for(uint i = 0; i < numDim; i++)
1928 {
1929 uint32_t uint32Value = CHECKED_NON_NEGATIVE(CHECKED_INT32(int64Axes[i]));
1930 dims.push_back(uint32Value);
1931 }
1932 }
1933
1934 // Ensure that the axes are sorted
1935 std::sort(dims.begin(), dims.end());
1936
1937 std::vector<unsigned int> targetShape;
1938
Narumol Prangnawarat452274c2021-09-23 16:12:19 +01001939 if (inputShape.GetDimensionality() != Dimensionality::Scalar)
Narumol Prangnawaratfe6aa2f2021-09-23 16:11:17 +01001940 {
Narumol Prangnawarat452274c2021-09-23 16:12:19 +01001941 for(uint i = 0; i < inputShape.GetNumDimensions(); i++)
1942 {
1943 targetShape.push_back(inputShape[i]);
1944 }
Narumol Prangnawaratfe6aa2f2021-09-23 16:11:17 +01001945 }
1946
1947 for(uint i = 0; i < dims.size(); i++)
1948 {
1949 targetShape.insert(targetShape.begin() + armnn::numeric_cast<int>(dims[i]), 1);
1950 }
1951
Narumol Prangnawarat452274c2021-09-23 16:12:19 +01001952 auto outInfo = ComputeReshapeInfo(TensorShape(static_cast<unsigned int>(targetShape.size()), targetShape.data()),
1953 inputShape, node.output(0), m_TensorsInfo[node.input(0)].m_info->GetDataType());
Narumol Prangnawaratfe6aa2f2021-09-23 16:11:17 +01001954 m_TensorsInfo[node.output(0)].m_info = std::make_unique<TensorInfo>(outInfo);
Narumol Prangnawarat452274c2021-09-23 16:12:19 +01001955 m_TensorsInfo[node.output(0)].m_dtype = m_TensorsInfo[node.input(0)].m_dtype;
Narumol Prangnawaratfe6aa2f2021-09-23 16:11:17 +01001956
1957 CreateReshapeLayer(node.input(0), node.output(0), node.name());
1958}
1959
Kevin Mayef33cb12021-01-29 14:24:57 +00001960void OnnxParserImpl::PrependForBroadcast(const std::string& outputName,
1961 const std::string& input0,
1962 const std::string& input1)
telsoa01c577f2c2018-08-31 09:22:23 +01001963{
1964 //input0 should be reshaped to have same number of dim as input1
1965 TensorInfo outputTensorInfo = TensorInfo(*m_TensorsInfo[input0].m_info);
1966
1967 TensorShape input0Shape = m_TensorsInfo[input0].m_info->GetShape();
1968 TensorShape input1Shape = m_TensorsInfo[input1].m_info->GetShape();
1969
1970 uint32_t diff = input1Shape.GetNumDimensions() - input0Shape.GetNumDimensions();
1971 std::vector<uint32_t> newShape;
1972 while(diff > 0)
1973 {
1974 newShape.push_back(1);
1975 diff--;
1976 }
1977 for (uint dim = 0; dim < input0Shape.GetNumDimensions(); ++dim)
1978 {
1979 newShape.push_back(input0Shape[dim]);
1980 }
1981 outputTensorInfo.SetShape(TensorShape(static_cast<unsigned int>(newShape.size()), newShape.data()));
1982
1983 //add the new tensor to m_TensorsInfo
1984 m_TensorsInfo[outputName] = OnnxTensor();
1985 m_TensorsInfo[outputName].m_info = std::make_unique<TensorInfo>(outputTensorInfo);
1986
1987 //add reshape layer if the parent was not constant...
1988 if( ! m_TensorsInfo[input0].isConstant())
1989 {
James Ward58dec6b2020-09-11 17:32:44 +01001990 CreateReshapeLayer(input0, outputName, fmt::format("Add:reshapeOf{}", input0));
telsoa01c577f2c2018-08-31 09:22:23 +01001991 }
1992 else //make it constant and it will be create in Add
1993 {
1994 m_TensorsInfo[outputName].m_tensor = std::make_unique<onnx::TensorProto>(*m_TensorsInfo[input0].m_tensor);
1995
1996 }
1997}
1998
Kevin Mayef33cb12021-01-29 14:24:57 +00001999void OnnxParserImpl::SetupInputLayers()
telsoa01c577f2c2018-08-31 09:22:23 +01002000{
2001 //Find user input and add their layers
2002 for(int inputIndex = 0; inputIndex < m_Graph->input_size(); ++inputIndex)
2003 {
2004 auto input = m_Graph->input(inputIndex);
2005 if (! m_TensorsInfo[input.name()].isConstant())
2006 {
2007 IConnectableLayer* layer =
2008 m_Network->AddInputLayer(static_cast<armnn::LayerBindingId>(inputIndex), input.name().c_str());
2009 auto tensorInfo = ToTensorInfo(input);
2010 layer->GetOutputSlot(0).SetTensorInfo(tensorInfo);
2011
2012 RegisterOutputSlots(layer,{ input.name() });
2013 }
2014 }
2015}
2016
Kevin Mayef33cb12021-01-29 14:24:57 +00002017void OnnxParserImpl::SetupOutputLayers()
telsoa01c577f2c2018-08-31 09:22:23 +01002018{
2019 if(m_Graph->output_size() == 0)
2020 {
James Ward58dec6b2020-09-11 17:32:44 +01002021 throw ParseException(fmt::format("The given model does not have any outputs {}", CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01002022 }
2023
2024 for(int outputIndex = 0; outputIndex < m_Graph->output_size(); ++outputIndex)
2025 {
2026 IConnectableLayer* layer =
2027 m_Network->AddOutputLayer(static_cast<armnn::LayerBindingId>(outputIndex),
2028 m_Graph->output(outputIndex).name().c_str());
2029
2030 RegisterInputSlots(layer, { m_Graph->output(outputIndex).name() });
2031 }
2032}
2033
Kevin Mayef33cb12021-01-29 14:24:57 +00002034void OnnxParserImpl::RegisterInputSlots(IConnectableLayer* layer, const std::vector<std::string>& tensorIds)
telsoa01c577f2c2018-08-31 09:22:23 +01002035{
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01002036 ARMNN_ASSERT(layer != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +01002037 if (tensorIds.size() != layer->GetNumInputSlots())
2038 {
2039 throw ParseException(
James Ward58dec6b2020-09-11 17:32:44 +01002040 fmt::format("The number of tensor inputs ({}) does not match the number expected ({}) {}",
2041 tensorIds.size(),
2042 layer->GetNumInputSlots(),
2043 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01002044 }
Matthew Sloyan81beae32021-07-13 19:46:11 +01002045
telsoa01c577f2c2018-08-31 09:22:23 +01002046 for (unsigned int slotIndex = 0; slotIndex < layer->GetNumInputSlots(); ++slotIndex)
2047 {
2048 std::string tensorId = tensorIds[slotIndex];
2049 armnn::IInputSlot* slot = &(layer->GetInputSlot(slotIndex));
2050
2051 auto it = m_TensorConnections.find(tensorId);
2052
2053 if (it == m_TensorConnections.end())
2054 {
2055 //First time seing this tensor, we need to map it
2056 m_TensorConnections[tensorId] = TensorSlots();
2057 }
2058 m_TensorConnections[tensorId].inputSlots.push_back(slot);
2059 }
2060}
2061
Kevin Mayef33cb12021-01-29 14:24:57 +00002062void OnnxParserImpl::RegisterOutputSlots(IConnectableLayer* layer, const std::vector<std::string>& tensorIds)
telsoa01c577f2c2018-08-31 09:22:23 +01002063{
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01002064 ARMNN_ASSERT(layer != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +01002065 if (tensorIds.size() != layer->GetNumOutputSlots())
2066 {
2067 throw ParseException(
James Ward58dec6b2020-09-11 17:32:44 +01002068 fmt::format("The number of tensor outputs ({}) does not match the number expected ({}) {} ",
2069 tensorIds.size(),
2070 layer->GetNumOutputSlots(),
2071 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01002072 }
2073
2074 for (unsigned int slotIndex = 0; slotIndex < layer->GetNumOutputSlots(); ++slotIndex)
2075 {
2076 std::string tensorId = tensorIds[slotIndex];
2077 armnn::IOutputSlot* slot = &(layer->GetOutputSlot(slotIndex));
2078
2079 auto it = m_TensorConnections.find(tensorId);
2080
2081 if (it == m_TensorConnections.end())
2082 {
2083 //First time seing this tensor, we need to map it
2084 m_TensorConnections[tensorId] = TensorSlots();
2085 }
2086
Ryan OShea337c17f2020-02-21 12:33:17 +00002087 TensorSlots& tensorSlots = m_TensorConnections[tensorId];
telsoa01c577f2c2018-08-31 09:22:23 +01002088
2089 // assuming there is only one producer for that tensor
2090 if (tensorSlots.outputSlot != nullptr)
2091 {
James Ward58dec6b2020-09-11 17:32:44 +01002092 throw ParseException(fmt::format("Another layer has already registered itself as the producer of "
2093 "tensor:{} {}",
2094 tensorId,
2095 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01002096 }
2097 tensorSlots.outputSlot = slot;
2098 }
2099}
2100
Kevin Mayef33cb12021-01-29 14:24:57 +00002101BindingPointInfo OnnxParserImpl::GetNetworkInputBindingInfo(const std::string& name) const
telsoa01c577f2c2018-08-31 09:22:23 +01002102{
2103 for(int i = 0; i < m_Graph->input_size(); ++i)
2104 {
2105 auto input = m_Graph->input(i);
2106 if(input.name() == name)
2107 {
2108 return std::make_pair(static_cast<armnn::LayerBindingId>(i), ToTensorInfo(input));
2109 }
2110 }
James Ward58dec6b2020-09-11 17:32:44 +01002111 throw InvalidArgumentException(fmt::format("The input layer '{}' does not exist {}",
2112 name, CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01002113}
2114
Kevin Mayef33cb12021-01-29 14:24:57 +00002115BindingPointInfo OnnxParserImpl::GetNetworkOutputBindingInfo(const std::string& name) const
telsoa01c577f2c2018-08-31 09:22:23 +01002116{
2117 for(int i = 0; i < m_Graph->output_size(); ++i)
2118 {
2119 auto output = m_Graph->output(i);
2120 if(output.name() == name)
2121 {
2122 return std::make_pair(static_cast<armnn::LayerBindingId>(i), ToTensorInfo(output));
2123 }
2124 }
James Ward58dec6b2020-09-11 17:32:44 +01002125 throw InvalidArgumentException(fmt::format("The output layer '{}' does not exist {}",
2126 name, CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01002127}
2128
Kevin Mayef33cb12021-01-29 14:24:57 +00002129std::vector<std::string> OnnxParserImpl::GetInputs(ModelPtr& model)
telsoa01c577f2c2018-08-31 09:22:23 +01002130{
2131 if(model == nullptr) {
James Ward58dec6b2020-09-11 17:32:44 +01002132 throw InvalidArgumentException(fmt::format("The given model cannot be null {}",
2133 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01002134 }
2135
2136 std::vector<std::string> inputNames;
2137 std::map<std::string, bool> isConstant;
2138 for(auto tensor : model->graph().initializer())
2139 {
2140 isConstant[tensor.name()] = true;
2141 }
2142 for(auto input : model->graph().input())
2143 {
2144 auto it = isConstant.find(input.name());
2145 if(it == isConstant.end())
2146 {
2147 inputNames.push_back(input.name());
2148 }
2149 }
2150 return inputNames;
2151}
2152
Kevin Mayef33cb12021-01-29 14:24:57 +00002153std::vector<std::string> OnnxParserImpl::GetOutputs(ModelPtr& model)
telsoa01c577f2c2018-08-31 09:22:23 +01002154{
2155 if(model == nullptr) {
James Ward58dec6b2020-09-11 17:32:44 +01002156 throw InvalidArgumentException(fmt::format("The given model cannot be null {}",
2157 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01002158 }
2159
2160 std::vector<std::string> outputNames;
2161 for(auto output : model->graph().output())
2162 {
2163 outputNames.push_back(output.name());
2164 }
2165 return outputNames;
2166}
2167
Matthew Sloyanac001ee2021-02-03 10:43:04 +00002168const std::string OnnxParserImpl::GetVersion()
2169{
2170 return ONNX_PARSER_VERSION;
2171}
2172
telsoa01c577f2c2018-08-31 09:22:23 +01002173} // namespace armnnOnnxParser