blob: e70eb640473b16e5b755041abf29b76059873e61 [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5#include "OnnxParser.hpp"
6
Matthew Sloyanac001ee2021-02-03 10:43:04 +00007#include "armnnOnnxParser/Version.hpp"
8
Matthew Bentham39ef3e52020-01-20 10:09:09 +00009#include <armnn/Descriptors.hpp>
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010010#include <armnn/utility/Assert.hpp>
Matthew Sloyan589e3e82020-09-11 16:17:48 +010011#include <armnn/utility/NumericCast.hpp>
telsoa01c577f2c2018-08-31 09:22:23 +010012#include <VerificationHelpers.hpp>
13
James Ward58dec6b2020-09-11 17:32:44 +010014#include <fmt/format.h>
Aron Virginas-Tard4f0fea2019-04-09 14:08:06 +010015
telsoa01c577f2c2018-08-31 09:22:23 +010016#include <google/protobuf/text_format.h>
17#include <google/protobuf/io/zero_copy_stream_impl.h>
18
Matthew Sloyanac001ee2021-02-03 10:43:04 +000019#include <iostream>
telsoa01c577f2c2018-08-31 09:22:23 +010020#include <numeric>
Jan Eilers53ef7952021-06-02 12:01:25 +010021#include <armnnUtils/Permute.hpp>
telsoa01c577f2c2018-08-31 09:22:23 +010022
23using namespace armnn;
24
25namespace armnnOnnxParser
26{
Kevin Mayef33cb12021-01-29 14:24:57 +000027
28IOnnxParser::IOnnxParser() : pOnnxParserImpl(new OnnxParserImpl()) {}
29
30IOnnxParser::~IOnnxParser() = default;
31
32IOnnxParser* IOnnxParser::CreateRaw()
33{
34 return new IOnnxParser();
35}
36
37IOnnxParserPtr IOnnxParser::Create()
38{
39 return IOnnxParserPtr(CreateRaw(), &IOnnxParser::Destroy);
40}
41
42void IOnnxParser::Destroy(IOnnxParser* parser)
43{
44 delete parser;
45}
46
47armnn::INetworkPtr IOnnxParser::CreateNetworkFromBinaryFile(const char* graphFile)
48{
49 return pOnnxParserImpl->CreateNetworkFromBinaryFile(graphFile);
50}
51
52armnn::INetworkPtr IOnnxParser::CreateNetworkFromTextFile(const char* graphFile)
53{
54 return pOnnxParserImpl->CreateNetworkFromTextFile(graphFile);
55}
56
57armnn::INetworkPtr IOnnxParser::CreateNetworkFromString(const std::string& protoText)
58{
59 return pOnnxParserImpl->CreateNetworkFromString(protoText);
60}
61
62BindingPointInfo IOnnxParser::GetNetworkInputBindingInfo(const std::string& name) const
63{
64 return pOnnxParserImpl->GetNetworkInputBindingInfo(name);
65}
66
67BindingPointInfo IOnnxParser::GetNetworkOutputBindingInfo(const std::string& name) const
68{
69 return pOnnxParserImpl->GetNetworkOutputBindingInfo(name);
70}
71
telsoa01c577f2c2018-08-31 09:22:23 +010072namespace
73{
74void CheckValidDataType(std::initializer_list<onnx::TensorProto::DataType> validInputTypes,
75 const onnx::TensorProto::DataType actualValue,
76 const char* validExpr,
77 std::string nodeName,
78 std::string tensorName,
79 const armnn::CheckLocation& location)
80{
81 bool isValid = std::any_of(validInputTypes.begin(),
82 validInputTypes.end(),
83 [&actualValue](onnx::TensorProto::DataType x) { return x == actualValue; } );
84 if (!isValid)
85 {
86 throw ParseException(
James Ward58dec6b2020-09-11 17:32:44 +010087 fmt::format("Datatype {} is not valid for tensor '{}' of node '{}', not in {{{}}}. {}",
88 onnx::TensorProto::DataType_Name(actualValue),
89 tensorName,
90 nodeName,
91 validExpr,
92 location.AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +010093 }
94}
95
96#define CHECK_VALID_DATATYPE(NODE, TENSOR, ACTUAL, ...) \
97CheckValidDataType({__VA_ARGS__}, ACTUAL, #__VA_ARGS__, NODE, TENSOR, CHECK_LOCATION())
98
99using StrTypeListPair = std::pair<const char*, std::initializer_list<onnx::TensorProto::DataType>>;
100#define STR_LIST(...) StrTypeListPair(#__VA_ARGS__, {__VA_ARGS__})
101
102template <typename Callable>
103void ReadMandatoryNodeAttributeImpl(const onnx::NodeProto& node,
104 const std::string& attribName,
105 onnx::AttributeProto::AttributeType expectedType,
106 Callable callable)
107{
108 auto attribs = node.attribute();
109 int attriNum = 0;
110 while (attriNum < node.attribute_size())
111 {
112 if (attribs.Get(attriNum).name() == attribName)
113 {
114 if (attribs.Get(attriNum).type() == expectedType)
115 {
116 callable(attribs.Get(attriNum));
117 }
118 else
119 {
James Ward58dec6b2020-09-11 17:32:44 +0100120 throw ParseException(fmt::format("Attribute {} of node {} expected to have {} as "
121 "onnx::AttributeProto::AttributeType, but found {} instead {}",
122 attribName,
123 node.name(),
124 onnx::AttributeProto::AttributeType_Name(expectedType),
125 onnx::AttributeProto::AttributeType_Name(attribs.Get(attriNum).type()),
126 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100127 }
128 break;
129 }
130 ++attriNum;
131 }
132 if (attriNum == node.attribute_size())
133 {
James Ward58dec6b2020-09-11 17:32:44 +0100134 throw ParseException(fmt::format("Could not find required attribute {} in node {} {}",
135 attribName, node.name(), CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100136 }
137}
138
139template <typename Callable>
140void ReadOptionalNodeAttributeImpl(const onnx::NodeProto& node,
141 const std::string& attribName,
142 onnx::AttributeProto::AttributeType expectedType,
143 Callable callable)
144{
145 auto attribs = node.attribute();
146 for (int attriNum = 0; attriNum < node.attribute_size(); ++attriNum)
147 {
148 if (attribs.Get(attriNum).name() == attribName)
149 {
150 if (attribs.Get(attriNum).type() == expectedType)
151 {
152 callable(attribs.Get(attriNum));
153 }
154 else
155 {
James Ward58dec6b2020-09-11 17:32:44 +0100156 throw ParseException(
157 fmt::format("Attribute {} of node {} expected to have {} as onnx::AttributeProto::AttributeType, "
158 "but found {} instead {}",
159 attribName,
160 node.name(),
161 onnx::AttributeProto::AttributeType_Name(expectedType),
162 onnx::AttributeProto::AttributeType_Name(attribs.Get(attriNum).type()),
163 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100164 }
165 }
166 }
167}
168
Ryan OSheaed27ee72020-04-22 16:37:29 +0100169int64_t ReadOptionalNodeInt64Attribute(const onnx::NodeProto& node,
170 const std::string& name,
171 const int64_t defaultValue = 0)
172{
173 int64_t attribValue = defaultValue;
174 ReadOptionalNodeAttributeImpl(node, name, onnx::AttributeProto::INT,
175 [&attribValue](const onnx::AttributeProto& attrValue)
176 {
177 attribValue = attrValue.i();
178 });
179 return attribValue;
180}
181
telsoa01c577f2c2018-08-31 09:22:23 +0100182std::vector<uint32_t> ReadMandatoryNodeUint32ListAttribute(const onnx::NodeProto& node,
183 const std::string& name)
184{
185 std::vector<uint32_t> attriList;
186 ReadMandatoryNodeAttributeImpl(node, name, onnx::AttributeProto::INTS,
187 [&attriList](const onnx::AttributeProto& attrValue)
188 {
189 for (int attriNum = 0; attriNum < attrValue.ints_size(); ++attriNum)
190 {
191 attriList.push_back(CHECKED_NON_NEGATIVE(CHECKED_INT32(attrValue.ints().Get(attriNum))));
192 }
193 });
194 return attriList;
195}
196
197uint32_t ReadOptionalNodeUint32Attribute(const onnx::NodeProto& node,
198 const std::string& name,
199 const uint32_t defaultVal = 0u)
200{
201 uint32_t attribValue = defaultVal;
202 ReadOptionalNodeAttributeImpl(node, name, onnx::AttributeProto::INT,
203 [&attribValue](const onnx::AttributeProto& attrValue)
204 {
205 attribValue = CHECKED_NON_NEGATIVE(CHECKED_INT32((attrValue.i())));
206 });
207 return attribValue;
208}
209
210std::vector<uint32_t> ReadOptionalNodeUint32ListAttribute(const onnx::NodeProto& node,
211 const std::string& name)
212{
213 std::vector<uint32_t> attriList;
214 ReadOptionalNodeAttributeImpl(node, name, onnx::AttributeProto::INTS,
215 [&attriList](const onnx::AttributeProto& attrValue)
216 {
217 for (int attriNum = 0; attriNum < attrValue.ints_size(); ++attriNum)
218 {
219 attriList.push_back(CHECKED_NON_NEGATIVE(CHECKED_INT32(attrValue.ints().Get(attriNum))));
220 }
221 });
222
223 return attriList;
224}
225
226float ReadOptionalNodeFloatAttribute(const onnx::NodeProto& node,
227 const std::string& name,
228 const float defaultValue = 0.0f)
229{
230 float attribValue = defaultValue;
231 ReadOptionalNodeAttributeImpl(node, name, onnx::AttributeProto::FLOAT,
232 [&attribValue](const onnx::AttributeProto& attrValue)
233 {
234 attribValue = attrValue.f();
235 });
236 return attribValue;
237}
238
239std::string ReadOptionalNodeStringAttribute(const onnx::NodeProto& node, const std::string& name)
240{
241 std::string attribValue = "";
242 ReadOptionalNodeAttributeImpl(node, name, onnx::AttributeProto::STRING,
243 [&attribValue](const onnx::AttributeProto& attrValue)
244 {
245 attribValue = attrValue.s();
246 });
247 return attribValue;
248}
249
Tee Jungfcf6fd52019-11-01 05:27:28 +0000250armnn::TensorInfo ToTensorInfo(const std::string& name, std::vector<unsigned int>& shape, int data_type)
telsoa01c577f2c2018-08-31 09:22:23 +0100251{
telsoa01c577f2c2018-08-31 09:22:23 +0100252 DataType type;
Tee Jungfcf6fd52019-11-01 05:27:28 +0000253 switch(data_type)
telsoa01c577f2c2018-08-31 09:22:23 +0100254 {
255 case onnx::TensorProto::FLOAT:
256 {
257 type = DataType::Float32;
258 break;
259 }
260 case onnx::TensorProto::INT32:
261 case onnx::TensorProto::INT64:
262 {
263 type = DataType::Signed32;
264 break;
265 }
266 default:
267 {
268 throw ParseException(
James Ward58dec6b2020-09-11 17:32:44 +0100269 fmt::format("'{}' is not a currently supported datatype for tensor {}."
270 " Supported dataTypes are FLOAT, INT32 and INT64. {}",
271 onnx::TensorProto::DataType_Name(static_cast<onnx::TensorProto::DataType>(data_type)),
272 name,
273 CHECK_LOCATION().AsString() ));
telsoa01c577f2c2018-08-31 09:22:23 +0100274 }
telsoa01c577f2c2018-08-31 09:22:23 +0100275 }
Tee Jungcaf2bdd2019-11-13 07:23:14 +0000276
277 // To avoid crashes by trivial tensors
278 if (shape.empty())
279 {
280 return TensorInfo(TensorShape(), type);
281 }
282
Tee Jungfcf6fd52019-11-01 05:27:28 +0000283 return TensorInfo(TensorShape(static_cast<unsigned int>(shape.size()), shape.data()), type);
284}
285
286armnn::TensorInfo ToTensorInfo(const onnx::ValueInfoProto& info)
287{
288 const onnx::TensorShapeProto onnxShape = info.type().tensor_type().shape();
289 std::vector<unsigned int> shapeDims;
290 for (int i = 0; i < onnxShape.dim_size(); ++i)
291 {
292 shapeDims.push_back(CHECKED_NON_NEGATIVE(CHECKED_INT32(onnxShape.dim(i).dim_value())));
293 }
294
Ryan OShea337c17f2020-02-21 12:33:17 +0000295 if (shapeDims.empty())
296 {
297 shapeDims.push_back(1);
298 }
299
Tee Jungfcf6fd52019-11-01 05:27:28 +0000300 return ToTensorInfo(info.name(), shapeDims, info.type().tensor_type().elem_type());
301}
302
303armnn::TensorInfo ToTensorInfo(const onnx::TensorProto& tensor)
304{
305 std::vector<unsigned int> shapeDims;
Ryan OShea337c17f2020-02-21 12:33:17 +0000306
Tee Jungfcf6fd52019-11-01 05:27:28 +0000307 for (auto dim: tensor.dims())
308 {
309 shapeDims.push_back(CHECKED_NON_NEGATIVE(CHECKED_INT32(dim)));
310 }
311
Ryan OShea337c17f2020-02-21 12:33:17 +0000312 if (shapeDims.empty())
313 {
314 shapeDims.push_back(1);
315 }
316
Tee Jungfcf6fd52019-11-01 05:27:28 +0000317 return ToTensorInfo(tensor.name(), shapeDims, tensor.data_type());
telsoa01c577f2c2018-08-31 09:22:23 +0100318}
319
320std::string TensorInfoAsString(const TensorInfo& info,
321 const std::string& name,
322 const onnx::TensorProto::DataType& type)
323{
324 const TensorShape shape = info.GetShape();
325 std::stringstream ss;
326 ss << "tensor '" << name << "' contains "
327 << onnx::TensorProto::DataType_Name(type)
328 << " and has shape [";
329
330 for (uint32_t i = 0; i < shape.GetNumDimensions() - 1; ++i)
331 {
332 ss << shape[i] << ", ";
333 }
334 ss << shape[shape.GetNumDimensions() - 1] << "]";
335 return ss.str();
336}
337
Sadik Armagan60bb9d82021-01-11 15:15:01 +0000338void CalcPadding(uint32_t inputSize,
339 uint32_t filterSize,
340 uint32_t stride,
341 uint32_t dilation,
342 uint32_t* paddingFront,
343 uint32_t* paddingBack,
344 bool isUpper)
telsoa01c577f2c2018-08-31 09:22:23 +0100345{
346 uint32_t outputSize = (inputSize + stride - 1) / stride;
Sadik Armagan60bb9d82021-01-11 15:15:01 +0000347 uint32_t dilatedSize = filterSize + (dilation - 1) * (filterSize - 1);
348 uint32_t temp = (outputSize - 1) * stride + dilatedSize;
telsoa01c577f2c2018-08-31 09:22:23 +0100349 *paddingFront = (temp - inputSize) / 2;
350 *paddingBack = *paddingFront;
351 if((temp - inputSize) % 2 == 1)
352 {
353 if (isUpper)
354 {
Sadik Armagan60bb9d82021-01-11 15:15:01 +0000355 *paddingBack += 1;
telsoa01c577f2c2018-08-31 09:22:23 +0100356 }
357 else
358 {
Sadik Armagan60bb9d82021-01-11 15:15:01 +0000359 *paddingFront += 1;
telsoa01c577f2c2018-08-31 09:22:23 +0100360 }
361 }
362}
363
Ryan OSheaed27ee72020-04-22 16:37:29 +0100364TensorInfo ComputeReshapeInfo(const TensorShape& targetShapeTensor,
telsoa01c577f2c2018-08-31 09:22:23 +0100365 const TensorShape& inShape,
366 const std::string& outName)
367{
368 std::vector<int> targetDims;
Ryan OSheaed27ee72020-04-22 16:37:29 +0100369 for(uint i = 0; i < targetShapeTensor.GetNumDimensions(); ++i)
telsoa01c577f2c2018-08-31 09:22:23 +0100370 {
Ryan OSheaed27ee72020-04-22 16:37:29 +0100371 int val = CHECKED_INT32(targetShapeTensor[i]);
telsoa01c577f2c2018-08-31 09:22:23 +0100372 if(val == 0)
373 {
374 targetDims.push_back(static_cast<int>(inShape[static_cast<uint>(i)]));
375 }
376 else
377 {
378 targetDims.push_back(val);
379 }
380 }
381
382 std::vector<unsigned int> outDims(targetDims.begin(), targetDims.end());
383 const auto stretchDim = std::find(targetDims.begin(), targetDims.end(), -1);
384 if (stretchDim != targetDims.end())
385 {
386 if (std::find(std::next(stretchDim), targetDims.end(), -1) != targetDims.end())
387 {
388 std::stringstream ss;
389 ss << "[ ";
390 for(uint i = 0; i < targetDims.size() - 1; ++i)
391 {
392 ss << targetDims[i] << ", ";
393 }
394 ss << targetDims[targetDims.size() - 1] << " ]";
395
James Ward58dec6b2020-09-11 17:32:44 +0100396 throw ParseException(
397 fmt::format("Error during creation of reshaped tensor '{}'. At most one component of shape can be "
398 " -1 and here, shape is {} {}",
399 outName,
400 ss.str(),
401 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100402 }
403
Matthew Sloyan589e3e82020-09-11 16:17:48 +0100404 auto targetNumElements = armnn::numeric_cast<unsigned int>(std::accumulate(targetDims.begin(), targetDims.end(),
telsoa01c577f2c2018-08-31 09:22:23 +0100405 -1, std::multiplies<int32_t>()));
406 auto stretchIndex = static_cast<size_t>(std::distance(targetDims.begin(), stretchDim));
407 outDims[stretchIndex] = inShape.GetNumElements() / targetNumElements;
408 }
409 TensorShape outShape = TensorShape{static_cast<unsigned int>(outDims.size()), outDims.data()};
410 return TensorInfo(outShape, DataType::Float32);
411}
412
413} //namespace
414
Kevin Mayef33cb12021-01-29 14:24:57 +0000415const std::map<std::string, OnnxParserImpl::OperationParsingFunction> OnnxParserImpl::m_ParserFunctions = {
416 { "BatchNormalization", &OnnxParserImpl::ParseBatchNormalization},
417 { "GlobalAveragePool", &OnnxParserImpl::ParseGlobalAveragePool},
418 { "AveragePool", &OnnxParserImpl::ParseAveragePool },
419 { "Clip", &OnnxParserImpl::ParseClip },
420 { "Constant", &OnnxParserImpl::ParseConstant },
421 { "MaxPool", &OnnxParserImpl::ParseMaxPool },
422 { "Reshape", &OnnxParserImpl::ParseReshape },
423 { "Sigmoid", &OnnxParserImpl::ParseSigmoid },
424 { "Tanh", &OnnxParserImpl::ParseTanh },
425 { "Relu", &OnnxParserImpl::ParseRelu },
426 { "LeakyRelu", &OnnxParserImpl::ParseLeakyRelu },
427 { "Conv", &OnnxParserImpl::ParseConv },
428 { "Add", &OnnxParserImpl::ParseAdd },
Narumol Prangnawaratcdc495e2021-09-16 18:13:39 +0100429 { "Flatten", &OnnxParserImpl::ParseFlatten },
Narumol Prangnawaratf10b15a2021-09-17 21:08:57 +0100430 { "Shape", &OnnxParserImpl::ParseShape },
431 { "Gather", &OnnxParserImpl::ParseGather },
telsoa01c577f2c2018-08-31 09:22:23 +0100432};
433
434template<typename TypePair, typename Location>
Kevin Mayef33cb12021-01-29 14:24:57 +0000435void OnnxParserImpl::ValidateInputs(const onnx::NodeProto& node,
telsoa01c577f2c2018-08-31 09:22:23 +0100436 TypePair validInputs,
437 const Location& location)
438{
439 for(auto input : node.input())
440 {
441 CheckValidDataType(validInputs.second,
442 m_TensorsInfo[input].m_dtype,
443 validInputs.first,
444 node.name(),
445 input,
446 location);
447 }
448}
449
450#define VALID_INPUTS(NODE, VALID_INPUTS) \
Kevin Mayef33cb12021-01-29 14:24:57 +0000451 OnnxParserImpl::ValidateInputs(NODE, \
telsoa01c577f2c2018-08-31 09:22:23 +0100452 VALID_INPUTS, \
453 CHECK_LOCATION())
454
Kevin Mayef33cb12021-01-29 14:24:57 +0000455std::vector<TensorInfo> OnnxParserImpl::ComputeOutputInfo(std::vector<std::string> outNames,
456 const IConnectableLayer* layer,
457 std::vector<TensorShape> inputShapes)
telsoa01c577f2c2018-08-31 09:22:23 +0100458{
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100459 ARMNN_ASSERT(! outNames.empty());
telsoa01c577f2c2018-08-31 09:22:23 +0100460 bool needCompute = std::any_of(outNames.begin(),
461 outNames.end(),
462 [this](std::string name)
463 {
464 return (m_TensorsInfo.count(name) == 0 || m_TensorsInfo[name].m_info == nullptr);
465 });
466 std::vector<TensorInfo> outInfo;
467 //if the output info(s) are not here, we need to compute them
468 std::vector<TensorShape> inferredShapes;
469 if(needCompute)
470 {
471 inferredShapes = layer->InferOutputShapes(inputShapes);
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100472 ARMNN_ASSERT(inferredShapes.size() == outNames.size());
telsoa01c577f2c2018-08-31 09:22:23 +0100473 }
474 for (uint i = 0; i < outNames.size(); ++i)
475 {
476 if(needCompute)
477 {
478 m_TensorsInfo[outNames[i]] = OnnxTensor();
479 m_TensorsInfo[outNames[i]].m_info = std::make_unique<TensorInfo>(
480 TensorInfo(inferredShapes[i], DataType::Float32));
481 }
482 outInfo.push_back(*m_TensorsInfo[outNames[i]].m_info);
483 }
484 return outInfo;
485}
486
Kevin Mayef33cb12021-01-29 14:24:57 +0000487OnnxParserImpl::OnnxParserImpl()
telsoa01c577f2c2018-08-31 09:22:23 +0100488 : m_Network(nullptr, nullptr)
489{
490}
491
Kevin Mayef33cb12021-01-29 14:24:57 +0000492void OnnxParserImpl::ResetParser()
telsoa01c577f2c2018-08-31 09:22:23 +0100493{
494 m_Network = armnn::INetworkPtr(nullptr, nullptr);
495 m_Graph = nullptr;
496}
497
Kevin Mayef33cb12021-01-29 14:24:57 +0000498void OnnxParserImpl::Cleanup()
telsoa01c577f2c2018-08-31 09:22:23 +0100499{
500 m_TensorConnections.clear();
501 m_TensorsInfo.clear();
502 m_OutputsMap.clear();
503 m_OutputsFusedAndUsed.clear();
504}
505
Jan Eilers53ef7952021-06-02 12:01:25 +0100506template<typename T>
507std::pair<armnn::ConstTensor, std::unique_ptr<T[]>>
508CreateConstTensorImpl(const T* bufferPtr,
509 armnn::TensorInfo& tensorInfo,
510 const armnn::Optional<armnn::PermutationVector&> permutationVector)
telsoa01c577f2c2018-08-31 09:22:23 +0100511{
Jan Eilers53ef7952021-06-02 12:01:25 +0100512 ARMNN_ASSERT_MSG(bufferPtr != nullptr, fmt::format("Buffer for permutation is null").c_str());
513
514 std::unique_ptr<T[]> data(new T[tensorInfo.GetNumElements()]);
515
516 if (permutationVector.has_value() && permutationVector.value().GetSize() > 0)
517 {
518 tensorInfo = armnnUtils::Permuted(tensorInfo, permutationVector.value());
519 armnnUtils::Permute(tensorInfo.GetShape(), permutationVector.value(),
520 reinterpret_cast<const T*>(bufferPtr), data.get(), sizeof(T));
521 }
522 else
523 {
524 ::memcpy(data.get(), bufferPtr, tensorInfo.GetNumBytes());
525 }
526
527 return std::make_pair(ConstTensor(tensorInfo, data.get()), std::move(data));
528}
529
530std::pair<ConstTensor, std::unique_ptr<float[]>>
531OnnxParserImpl::CreateConstTensor(const std::string name,
532 armnn::Optional<armnn::PermutationVector&> permutationVector)
533{
534 TensorInfo tensorInfo = *m_TensorsInfo[name].m_info;
telsoa01c577f2c2018-08-31 09:22:23 +0100535 onnx::TensorProto onnxTensor = *m_TensorsInfo[name].m_tensor;
536
Narumol Prangnawaratf10b15a2021-09-17 21:08:57 +0100537 //ONNX can have Float16 and double constant nodes but ArmNN only supports float32
538 CHECK_VALID_DATATYPE(name, onnxTensor.name(),
539 static_cast<onnx::TensorProto::DataType>(onnxTensor.data_type()), onnx::TensorProto::FLOAT);
540
Matthew Sloyan81beae32021-07-13 19:46:11 +0100541 // Makes sure IsConstant flag is set.
542 tensorInfo.SetConstant();
543
Jan Eilers53ef7952021-06-02 12:01:25 +0100544 // Const tensors requires at least a list of values
545 if (tensorInfo.GetNumElements() == 0)
546 {
547 throw ParseException(fmt::format("No tensor data found for Const tensor '{}' {}",
548 name,
549 CHECK_LOCATION().AsString()));
550 }
551
telsoa01c577f2c2018-08-31 09:22:23 +0100552 auto srcData = onnxTensor.float_data().data();
Pablo Tello3dcc1c62019-04-24 14:20:21 +0100553 // Copy the value list entries into the destination
554 if (!onnxTensor.has_raw_data())
telsoa01c577f2c2018-08-31 09:22:23 +0100555 {
Pablo Tello3dcc1c62019-04-24 14:20:21 +0100556 if(tensorInfo.GetNumElements() != static_cast<uint>(onnxTensor.float_data_size()))
557 {
James Ward58dec6b2020-09-11 17:32:44 +0100558 throw ParseException(
559 fmt::format("The number of data provided ({}) does not match the tensor '{}' number of "
560 "elements ({}) {}",
561 onnxTensor.float_data_size(),
562 name,
563 tensorInfo.GetNumElements(),
564 CHECK_LOCATION().AsString()));
Pablo Tello3dcc1c62019-04-24 14:20:21 +0100565 }
Jan Eilers53ef7952021-06-02 12:01:25 +0100566 return CreateConstTensorImpl<float>(srcData, tensorInfo, permutationVector);
telsoa01c577f2c2018-08-31 09:22:23 +0100567 }
Pablo Tello3dcc1c62019-04-24 14:20:21 +0100568 else
569 {
Jan Eilers53ef7952021-06-02 12:01:25 +0100570 return CreateConstTensorImpl<float>(reinterpret_cast<const float*>(onnxTensor.raw_data().c_str()),
571 tensorInfo,
572 permutationVector);
Pablo Tello3dcc1c62019-04-24 14:20:21 +0100573 }
telsoa01c577f2c2018-08-31 09:22:23 +0100574}
575
Narumol Prangnawaratf10b15a2021-09-17 21:08:57 +0100576std::pair<ConstTensor, std::unique_ptr<int32_t[]>>
577OnnxParserImpl::CreateInt64ConstTensor(const std::string name,
578 armnn::Optional<armnn::PermutationVector&> permutationVector)
579{
580 TensorInfo tensorInfo = *m_TensorsInfo[name].m_info;
581 onnx::TensorProto onnxTensor = *m_TensorsInfo[name].m_tensor;
582
583 CHECK_VALID_DATATYPE(name, onnxTensor.name(),
584 static_cast<onnx::TensorProto::DataType>(onnxTensor.data_type()), onnx::TensorProto::INT64);
585
586 // Makes sure IsConstant flag is set.
587 tensorInfo.SetConstant();
588 uint numElements = tensorInfo.GetNumElements();
589
590 // Const tensors requires at least a list of values
591 if (numElements == 0)
592 {
593 throw ParseException(fmt::format("No tensor data found for Const tensor '{}' {}",
594 name,
595 CHECK_LOCATION().AsString()));
596 }
597
598 // Copy the value list entries into the destination
599 if (!onnxTensor.has_raw_data())
600 {
601 auto srcData = onnxTensor.int64_data().data();
602 if(numElements != static_cast<uint>(onnxTensor.int64_data_size()))
603 {
604 throw ParseException(
605 fmt::format("The number of data provided ({}) does not match the tensor '{}' number of "
606 "elements ({}) {}",
607 onnxTensor.int64_data_size(),
608 name,
609 tensorInfo.GetNumElements(),
610 CHECK_LOCATION().AsString()));
611 }
612
613 std::vector<int32_t> int32Data;
614 for(uint i = 0; i < numElements; i++)
615 {
616 int32_t int32Value = CHECKED_INT32(srcData[i]);
617 int32Data.push_back(int32Value);
618 }
619
620 return CreateConstTensorImpl<int32_t>(int32Data.data(), tensorInfo, permutationVector);
621 }
622 else
623 {
624 auto srcData = reinterpret_cast<const int64_t*>(onnxTensor.raw_data().c_str());
625 std::vector<int32_t> int32Data;
626 for(uint i = 0; i < numElements; i++)
627 {
628 int32_t int32Value = CHECKED_INT32(srcData[i]);
629 int32Data.push_back(int32Value);
630 }
631 return CreateConstTensorImpl<int32_t>(int32Data.data(), tensorInfo, permutationVector);
632 }
633}
634
Kevin Mayef33cb12021-01-29 14:24:57 +0000635ModelPtr OnnxParserImpl::LoadModelFromTextFile(const char* graphFile)
telsoa01c577f2c2018-08-31 09:22:23 +0100636{
637 FILE* fd = fopen(graphFile, "r");
638
639 if (fd == nullptr)
640 {
James Ward58dec6b2020-09-11 17:32:44 +0100641 throw FileNotFoundException(fmt::format("Invalid (null) filename {}", CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100642 }
643
644 // Parse the file into a message
645 ModelPtr modelProto = std::make_unique<onnx::ModelProto>();
646 using google::protobuf::io::FileInputStream;
647 std::unique_ptr<FileInputStream> input = std::make_unique<FileInputStream>(fileno(fd));
648 bool success = google::protobuf::TextFormat::Parse(input.get(), modelProto.get());
649 fclose(fd);
650
651 if (!success)
652 {
653 std::stringstream error;
654 error << "Failed to parse graph file";
James Ward58dec6b2020-09-11 17:32:44 +0100655 throw ParseException(fmt::format("{} {}", error.str(), CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100656 }
657 return modelProto;
658}
659
Kevin Mayef33cb12021-01-29 14:24:57 +0000660INetworkPtr OnnxParserImpl::CreateNetworkFromTextFile(const char* graphFile)
telsoa01c577f2c2018-08-31 09:22:23 +0100661{
662 ResetParser();
663 ModelPtr modelProto = LoadModelFromTextFile(graphFile);
664 return CreateNetworkFromModel(*modelProto);
665}
666
667
Kevin Mayef33cb12021-01-29 14:24:57 +0000668ModelPtr OnnxParserImpl::LoadModelFromBinaryFile(const char* graphFile)
telsoa01c577f2c2018-08-31 09:22:23 +0100669{
670 FILE* fd = fopen(graphFile, "rb");
671
672 if (fd == nullptr)
673 {
James Ward58dec6b2020-09-11 17:32:44 +0100674 throw FileNotFoundException(fmt::format("Invalid (null) filename {}", CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100675 }
676
677 // Parse the file into a message
678 ModelPtr modelProto = std::make_unique<onnx::ModelProto>();
679
680 google::protobuf::io::FileInputStream inStream(fileno(fd));
681 google::protobuf::io::CodedInputStream codedStream(&inStream);
Nikhil Raje5181532020-10-09 14:52:25 +0100682 codedStream.SetTotalBytesLimit(INT_MAX);
telsoa01c577f2c2018-08-31 09:22:23 +0100683 bool success = modelProto.get()->ParseFromCodedStream(&codedStream);
684 fclose(fd);
685
686 if (!success)
687 {
688 std::stringstream error;
689 error << "Failed to parse graph file";
James Ward58dec6b2020-09-11 17:32:44 +0100690 throw ParseException(fmt::format("{} {}", error.str(), CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100691 }
692 return modelProto;
693
694}
695
Kevin Mayef33cb12021-01-29 14:24:57 +0000696INetworkPtr OnnxParserImpl::CreateNetworkFromBinaryFile(const char* graphFile)
telsoa01c577f2c2018-08-31 09:22:23 +0100697{
698 ResetParser();
699 ModelPtr modelProto = LoadModelFromBinaryFile(graphFile);
700 return CreateNetworkFromModel(*modelProto);
701}
702
Kevin Mayef33cb12021-01-29 14:24:57 +0000703ModelPtr OnnxParserImpl::LoadModelFromString(const std::string& protoText)
telsoa01c577f2c2018-08-31 09:22:23 +0100704{
705 if (protoText == "")
706 {
James Ward58dec6b2020-09-11 17:32:44 +0100707 throw InvalidArgumentException(fmt::format("Invalid (empty) string for model parameter {}",
708 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100709 }
710 // Parse the string into a message
711 ModelPtr modelProto = std::make_unique<onnx::ModelProto>();
712 bool success = google::protobuf::TextFormat::ParseFromString(protoText, modelProto.get());
713 if (!success)
714 {
715 std::stringstream error;
716 error << "Failed to parse graph file";
James Ward58dec6b2020-09-11 17:32:44 +0100717 throw ParseException(fmt::format("{} {}", error.str(), CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100718 }
719 return modelProto;
720}
721
Kevin Mayef33cb12021-01-29 14:24:57 +0000722INetworkPtr OnnxParserImpl::CreateNetworkFromString(const std::string& protoText)
telsoa01c577f2c2018-08-31 09:22:23 +0100723{
724 ResetParser();
725 ModelPtr modelProto = LoadModelFromString(protoText);
726 return CreateNetworkFromModel(*modelProto);
727}
728
Kevin Mayef33cb12021-01-29 14:24:57 +0000729INetworkPtr OnnxParserImpl::CreateNetworkFromModel(onnx::ModelProto& model)
telsoa01c577f2c2018-08-31 09:22:23 +0100730{
731 m_Network = INetwork::Create();
732 try
733 {
734 m_Graph = std::make_unique<onnx::GraphProto>(*model.mutable_graph());
735 LoadGraph();
736 }
737 catch (const ParseException& e)
738 {
739 Cleanup();
740 throw e;
741 }
742 Cleanup();
743 return std::move(m_Network);
744}
745
Kevin Mayef33cb12021-01-29 14:24:57 +0000746void OnnxParserImpl::LoadGraph()
telsoa01c577f2c2018-08-31 09:22:23 +0100747{
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100748 ARMNN_ASSERT(m_Graph.get() != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +0100749
750 //Fill m_TensorsInfo with the shapes and value of every tensor
751 SetupInfo(m_Graph->mutable_output());
752 SetupInfo(m_Graph->mutable_input());
753 SetupInfo(m_Graph->mutable_value_info());
754
755 for (auto tensor : m_Graph->initializer())
756 {
757 m_TensorsInfo[tensor.name()].m_tensor = std::make_unique<const onnx::TensorProto>(tensor);
Tee Jungfcf6fd52019-11-01 05:27:28 +0000758 m_TensorsInfo[tensor.name()].m_info = std::make_unique<TensorInfo>(ToTensorInfo(tensor));
759 m_TensorsInfo[tensor.name()].m_dtype =
760 static_cast<onnx::TensorProto::DataType>(tensor.data_type());
telsoa01c577f2c2018-08-31 09:22:23 +0100761 }
762
763 SetupInputLayers();
764 SetupOutputLayers();
765
766 //Detect FullyConnected layers with bias and update the FusedAndUsed map acccordingly
767 DetectFullyConnected();
768
769 //Parsing the graph
770 for(size_t nodeIndex = 0; nodeIndex < static_cast<size_t>(m_Graph->node_size()); nodeIndex++)
771 {
772 auto node = m_Graph->node(static_cast<int>(nodeIndex));
773 const std::string& operation = node.op_type();
774
775 // check which layers we handled already (add and matmul fused as FC)
Ryan OShea337c17f2020-02-21 12:33:17 +0000776 if (operation == "MatMul" )
telsoa01c577f2c2018-08-31 09:22:23 +0100777 {
778 if(m_OutputsFusedAndUsed[nodeIndex].inputForNodes != m_OutputsFusedAndUsed[nodeIndex].fusedWithNodes.size())
779 {
780 //Node which can not be fused as a FullyConnected layer (used in layers as a simple matmul output)
781 AddFullyConnected(node);
782 }
783 }
784 else if (!(m_OutputsFusedAndUsed[nodeIndex].fusedWithNodes.empty()) && operation == "Add")
785 {
786 int matmulIndex = static_cast<int> (m_OutputsFusedAndUsed[nodeIndex].fusedWithNodes[0]);
787 AddFullyConnected(m_Graph->node(matmulIndex), &node);
788 }
789 else if (m_OutputsFusedAndUsed[nodeIndex].fusedWithNodes.empty()) //node is not part of a fused layer
790 {
791 auto it = m_ParserFunctions.find(operation);
792 if (it != m_ParserFunctions.end())
793 {
794 auto func = it->second;
795 (this->*func)(node);
796 }
797 else
798 {
James Ward58dec6b2020-09-11 17:32:44 +0100799 throw ParseException(fmt::format("Unsupported operation {} for node '{}' {}",
800 operation,
801 node.name(),
802 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100803 }
804 }
805 }
806
807 //Making the connections between outputs and inputs of each layers
808 for (const auto& tensorCon : m_TensorConnections)
809 {
810 if (tensorCon.second.outputSlot != nullptr)
811 {
812 for (size_t inputSlotIdx = 0; inputSlotIdx < tensorCon.second.inputSlots.size(); ++inputSlotIdx)
813 {
814 tensorCon.second.outputSlot->Connect(*(tensorCon.second.inputSlots[inputSlotIdx]));
815 }
816 }
817 }
818}
819
Kevin Mayef33cb12021-01-29 14:24:57 +0000820void OnnxParserImpl::SetupInfo(const google::protobuf::RepeatedPtrField<onnx::ValueInfoProto >* list)
telsoa01c577f2c2018-08-31 09:22:23 +0100821{
822 for (auto tensor : *list)
823 {
824 m_TensorsInfo[tensor.name()] = OnnxTensor();
825 m_TensorsInfo[tensor.name()].m_info = std::make_unique<TensorInfo>(ToTensorInfo(tensor));
Matteo Martincighe355dc22018-12-10 13:45:27 +0000826 m_TensorsInfo[tensor.name()].m_dtype =
827 static_cast<onnx::TensorProto::DataType>(tensor.type().tensor_type().elem_type());
telsoa01c577f2c2018-08-31 09:22:23 +0100828 }
829}
830
Kevin Mayef33cb12021-01-29 14:24:57 +0000831void OnnxParserImpl::DetectFullyConnected()
telsoa01c577f2c2018-08-31 09:22:23 +0100832{
833 m_OutputsFusedAndUsed = std::vector<UsageSummary> (static_cast<size_t>(m_Graph->node_size()), UsageSummary());
834 auto matmulAndConstant = [&](const std::string& constInput,
835 const std::string& matmulInput,
836 int& nodeIndex)
837 {
838 auto matmulIt = m_OutputsMap.find(matmulInput);
839 if(matmulIt != m_OutputsMap.end() && matmulIt->second.first->op_type() == "MatMul"
840 && m_TensorsInfo[constInput].isConstant())
841 {
842 nodeIndex = matmulIt->second.second;
843 return true;
844 }
845 return false;
846 };
847
848 for(int nodeIndex = 0; nodeIndex < m_Graph->node_size(); nodeIndex++)
849 {
850 const onnx::NodeProto* node = &m_Graph->node(nodeIndex);
851 for (const std::string& output : node->output())
852 {
853 m_OutputsMap[output] = std::make_pair(node, nodeIndex);
854 }
855
856 for (const std::string& input : node->input()) //count how many time a node is used as input
857 {
858 auto matmulIt = m_OutputsMap.find(input);
859 if(matmulIt != m_OutputsMap.end()){
860 ++m_OutputsFusedAndUsed[static_cast<size_t>(matmulIt->second.second)].inputForNodes; //node used
861 }
862 }
863
864 if (node->op_type() == "Add")
865 {
866 int matmulIndex = 0;
867 if (matmulAndConstant(node->input(0), node->input(1), matmulIndex) ||
868 matmulAndConstant(node->input(1), node->input(0), matmulIndex))
869 {
870 //matmul and add were fused
871 m_OutputsFusedAndUsed[static_cast<size_t>(matmulIndex)].fusedWithNodes
872 .push_back(static_cast<size_t>(nodeIndex));
873
874 m_OutputsFusedAndUsed[static_cast<size_t>(nodeIndex)].fusedWithNodes
875 .push_back(static_cast<size_t>(matmulIndex));
876 }
877 }
878 }
879
880 for (auto output: m_Graph->output()) { //Add usages as output of the graph in count of usages
881 auto matmulIt = m_OutputsMap.find(output.name());
882 if(matmulIt != m_OutputsMap.end()){
883 ++m_OutputsFusedAndUsed[static_cast<size_t>(matmulIt->second.second)].inputForNodes;
884 }
885 }
886}
887
888template<typename Location>
Kevin Mayef33cb12021-01-29 14:24:57 +0000889void OnnxParserImpl::GetInputAndParam(const onnx::NodeProto& node,
890 std::string* inputName,
891 std::string* constName,
892 const Location& location)
telsoa01c577f2c2018-08-31 09:22:23 +0100893{
894 int cstIndex;
895 if (m_TensorsInfo[node.input(0)].isConstant())
896 {
897 cstIndex = 0;
898 }
899 else if (m_TensorsInfo[node.input(1)].isConstant())
900 {
901 cstIndex = 1;
902 }
903 else
904 {
James Ward58dec6b2020-09-11 17:32:44 +0100905 throw ParseException(fmt::format("One of the input tensors ('{}' or '{}') should be constant in node '{}' {}",
906 node.input(0),
907 node.input(1),
908 node.name(),
909 location.AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100910 }
911 if(constName)
912 {
913 *constName = node.input(cstIndex);
914 }
915 if(inputName)
916 {
917 *inputName = node.input(!cstIndex);
918 }
919}
920
921template<typename Location>
Kevin Mayef33cb12021-01-29 14:24:57 +0000922void OnnxParserImpl::To1DTensor(const std::string& name, const Location& location)
telsoa01c577f2c2018-08-31 09:22:23 +0100923{
924 TensorShape shape = m_TensorsInfo[name].m_info->GetShape();
925 std::vector<uint32_t> newShape;
926 for(uint i = 0; i < shape.GetNumDimensions() - 1; ++i)
927 {
928 if(shape[i] != 1)
929 {
James Ward58dec6b2020-09-11 17:32:44 +0100930 throw ParseException(
931 fmt::format("Only tensors with shape [1, ..., 1, X] can be converted to 1D and {} {}",
932 TensorInfoAsString(*m_TensorsInfo[name].m_info, name, m_TensorsInfo[name].m_dtype),
933 location.AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100934 }
935 }
936 newShape.push_back(shape[shape.GetNumDimensions() - 1]);
937
938 m_TensorsInfo[name].m_info->SetShape(TensorShape(static_cast<unsigned int>(newShape.size()), newShape.data()));
939}
940
Kevin Mayef33cb12021-01-29 14:24:57 +0000941void OnnxParserImpl::AddConvLayerWithDepthwiseConv(const onnx::NodeProto& node, const Convolution2dDescriptor& convDesc)
Ryan OSheaed27ee72020-04-22 16:37:29 +0100942{
943 ARMNN_ASSERT(node.op_type() == "Conv");
944
945 DepthwiseConvolution2dDescriptor desc;
946 desc.m_PadLeft = convDesc.m_PadLeft;
947 desc.m_PadRight = convDesc.m_PadRight;
948 desc.m_PadTop = convDesc.m_PadTop;
949 desc.m_PadBottom = convDesc.m_PadBottom;
950 desc.m_StrideX = convDesc.m_StrideX;
951 desc.m_StrideY = convDesc.m_StrideY;
952 desc.m_BiasEnabled = convDesc.m_BiasEnabled;
953
954 armnn::IConnectableLayer* layer;
Jan Eilers53ef7952021-06-02 12:01:25 +0100955
956 // weights come in as [O,1,H,W] from ONNX and need to be converted to ArmNNs dephtwise weights layout [1,H,W,O]
957 armnn::PermutationVector perVec {3,0,1,2};
958 auto weightTensor = CreateConstTensor(node.input(1), perVec);
Ryan OSheaed27ee72020-04-22 16:37:29 +0100959
960 if (node.input_size() == 3)
961 {
962 if(!m_TensorsInfo[node.input(2)].isConstant())
963 {
James Ward58dec6b2020-09-11 17:32:44 +0100964 throw ParseException(fmt::format("Bias '{}' should be constant in Conv layer '{}' {}",
965 node.input(2),
966 node.name(),
967 CHECK_LOCATION().AsString()));
Ryan OSheaed27ee72020-04-22 16:37:29 +0100968 }
969 desc.m_BiasEnabled = true;
970 auto biasTensor = CreateConstTensor(node.input(2));
971 layer = m_Network->AddDepthwiseConvolution2dLayer(desc,
972 weightTensor.first,
973 Optional<ConstTensor>(biasTensor.first),
974 node.name().c_str());
975 }
976 else
977 {
978 layer = m_Network->AddDepthwiseConvolution2dLayer(desc,
979 weightTensor.first,
980 EmptyOptional(),
981 node.name().c_str());
982 }
983 ARMNN_ASSERT(layer != nullptr);
984
985 auto outputInfo = ComputeOutputInfo({ node.output(0) }, layer,
986 { m_TensorsInfo[node.input(0)].m_info->GetShape(),
Jan Eilers53ef7952021-06-02 12:01:25 +0100987 weightTensor.first.GetInfo().GetShape() });
Ryan OSheaed27ee72020-04-22 16:37:29 +0100988
989 layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
990
991 // register the input connection slots for the layer, connections are made after all layers have been created
992 // only the tensors for the inputs are relevant, exclude the const tensors
993 RegisterInputSlots(layer, {node.input(0)});
994
995 // register the output connection slots for the layer, connections are made after all layers have been created
996 RegisterOutputSlots(layer, {node.output(0)});
997}
998
Kevin Mayef33cb12021-01-29 14:24:57 +0000999void OnnxParserImpl::AddFullyConnected(const onnx::NodeProto& matmulNode, const onnx::NodeProto* addNode)
telsoa01c577f2c2018-08-31 09:22:23 +01001000{
1001
1002 // find matmul inputs
1003 std::string weightName;
1004 std::string inputName;
1005 CHECK_VALID_SIZE(static_cast<size_t>(matmulNode.input_size()), 2);
1006 CHECK_VALID_SIZE(static_cast<size_t>(matmulNode.output_size()), 1);
1007 VALID_INPUTS(matmulNode, STR_LIST(onnx::TensorProto::FLOAT));
1008
1009 GetInputAndParam(matmulNode, &inputName, &weightName, CHECK_LOCATION());
1010
1011 FullyConnectedDescriptor desc;
1012 desc.m_BiasEnabled = addNode != nullptr;
1013
1014 IConnectableLayer* layer = nullptr;
1015 if(desc.m_BiasEnabled)
1016 {
1017 // find bias const
1018 std::string biasName;
1019 CHECK_VALID_SIZE(static_cast<size_t>(addNode->input_size()), 2);
1020 CHECK_VALID_SIZE(static_cast<size_t>(addNode->output_size()), 1);
1021 VALID_INPUTS(*addNode, STR_LIST(onnx::TensorProto::FLOAT));
1022
1023 GetInputAndParam(*addNode, nullptr, &biasName, CHECK_LOCATION());
1024
1025 //Output shape is [1, weights[1]] and 1d vec in ONNX can be [1,X] so we convert biases to "armnn" 1D
1026 To1DTensor(biasName, CHECK_LOCATION());
1027 TensorInfo weightInfo = *m_TensorsInfo[weightName].m_info;
1028 TensorInfo biasInfo = *m_TensorsInfo[biasName].m_info;
1029
1030 if (weightInfo.GetShape()[1] != biasInfo.GetShape()[0])
1031 {
James Ward58dec6b2020-09-11 17:32:44 +01001032 throw ParseException(
1033 fmt::format("Shape of weights '{}' and bias of following Add node '{}' do not match : {}"
1034 " and {} ( /!\\ bias should be a 1D tensor) {}",
1035 weightName,
1036 addNode->name(),
1037 TensorInfoAsString(*m_TensorsInfo[weightName].m_info, weightName,
1038 m_TensorsInfo[weightName].m_dtype),
1039 TensorInfoAsString(*m_TensorsInfo[biasName].m_info, biasName,
1040 m_TensorsInfo[biasName].m_dtype ),
1041 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001042 }
Matthew Sloyan81beae32021-07-13 19:46:11 +01001043
1044 // Just add a FullyConnected layer, weights and biases are handled as inputs now.
1045 layer = m_Network->AddFullyConnectedLayer(desc, matmulNode.name().c_str());
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001046 ARMNN_ASSERT(layer != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +01001047
1048 auto outputInfo = ComputeOutputInfo({addNode->output(0)}, layer,
1049 {m_TensorsInfo[inputName].m_info->GetShape(),
1050 m_TensorsInfo[weightName].m_info->GetShape()});
telsoa01c577f2c2018-08-31 09:22:23 +01001051 layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
1052
Matthew Sloyan81beae32021-07-13 19:46:11 +01001053 // Add constant layer to store weights/biases and connect to FullyConnected layer..
1054 if(m_TensorsInfo[weightName].isConstant())
1055 {
1056 IConnectableLayer* weightsLayer = m_Network->AddConstantLayer(CreateConstTensor(weightName).first);
1057
1058 weightInfo.SetConstant();
1059 weightsLayer->GetOutputSlot(0).SetTensorInfo(weightInfo);
1060 weightsLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1u));
1061 }
1062
1063 if(m_TensorsInfo[biasName].isConstant())
1064 {
1065 IConnectableLayer* biasLayer = m_Network->AddConstantLayer(CreateConstTensor(biasName).first);
1066
1067 biasInfo.SetConstant();
1068 biasLayer->GetOutputSlot(0).SetTensorInfo(biasInfo);
1069 biasLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(2u));
1070 }
1071
1072 RegisterInputSlots(layer, {inputName, weightName, biasName});
telsoa01c577f2c2018-08-31 09:22:23 +01001073 RegisterOutputSlots(layer, {addNode->output(0)});
1074 }
1075 else
1076 {
Matthew Sloyan81beae32021-07-13 19:46:11 +01001077 layer = m_Network->AddFullyConnectedLayer(desc, matmulNode.name().c_str());
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001078 ARMNN_ASSERT(layer != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +01001079
1080 auto outputInfo = ComputeOutputInfo({matmulNode.output(0)}, layer,
1081 {m_TensorsInfo[inputName].m_info->GetShape(),
1082 m_TensorsInfo[weightName].m_info->GetShape()});
1083 layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
1084
Matthew Sloyan81beae32021-07-13 19:46:11 +01001085 // Add constant layer to store weights and connect to FullyConnected layer.
1086 if(m_TensorsInfo[weightName].isConstant())
1087 {
1088 TensorInfo weightInfo = *m_TensorsInfo[weightName].m_info;
1089 IConnectableLayer* weightsLayer = m_Network->AddConstantLayer(CreateConstTensor(weightName).first);
1090
1091 weightInfo.SetConstant();
1092 weightsLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1u));
1093 weightsLayer->GetOutputSlot(0).SetTensorInfo(weightInfo);
1094 }
1095
1096 RegisterInputSlots(layer, {inputName, weightName});
telsoa01c577f2c2018-08-31 09:22:23 +01001097 RegisterOutputSlots(layer, {matmulNode.output(0)});
1098 }
1099}
1100
Kevin Mayef33cb12021-01-29 14:24:57 +00001101void OnnxParserImpl::AddPoolingLayer(const onnx::NodeProto& node, Pooling2dDescriptor& desc)
telsoa01c577f2c2018-08-31 09:22:23 +01001102{
1103
1104 CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 1);
1105 CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
1106
1107 VALID_INPUTS(node, STR_LIST(onnx::TensorProto::FLOAT));
1108
1109 std::vector<uint32_t> kernel_shape = ReadMandatoryNodeUint32ListAttribute(node, "kernel_shape"); //size of pool win
1110 std::vector<uint32_t> strides = ReadOptionalNodeUint32ListAttribute(node, "strides");
1111 std::vector<uint32_t> pads = ReadOptionalNodeUint32ListAttribute(node, "pads");
1112
1113 desc.m_OutputShapeRounding = OutputShapeRounding::Floor;
1114 desc.m_PoolWidth = kernel_shape[1];
1115 desc.m_PoolHeight = kernel_shape[0];
1116
1117 if(strides.empty())
1118 {
1119 desc.m_StrideX = 1;
1120 desc.m_StrideY = 1;
1121 }
1122 else
1123 {
1124 desc.m_StrideX = strides[1];
1125 desc.m_StrideY = strides[0];
1126 }
1127
1128 //Check new padding version first
1129 if(pads.empty())
1130 {
1131 //Check deprecated version
1132 std::string paddingString = ReadOptionalNodeStringAttribute(node, "auto_pad");
1133 if(paddingString != "VALID" && paddingString != "" && paddingString != "NOTSET")
1134 {
1135 bool isUpper;
1136 if( paddingString == "SAME_LOWER")
1137 {
1138 isUpper = false;
1139 }
1140 else if (paddingString == "SAME_UPPER")
1141 {
1142 isUpper = true;
1143 }
1144 else
1145 {
James Ward58dec6b2020-09-11 17:32:44 +01001146 throw ParseException(fmt::format("Invalid auto_pad attribute for node {}. "
1147 "Only SAME_UPPER, SAME_LOWER or VALID supported and found {} {}",
1148 node.name(),
1149 paddingString,
1150 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001151 }
1152 auto inputInfo = *m_TensorsInfo[node.input(0)].m_info;
1153 uint32_t inputHeight = inputInfo.GetShape()[2];
1154 uint32_t inputWidth = inputInfo.GetShape()[3];
Sadik Armagan60bb9d82021-01-11 15:15:01 +00001155 CalcPadding(inputHeight,
1156 desc.m_PoolHeight,
1157 desc.m_StrideY,
1158 1u,
1159 &desc.m_PadTop,
1160 &desc.m_PadBottom,
1161 isUpper);
1162 CalcPadding(inputWidth,
1163 desc.m_PoolWidth,
1164 desc.m_StrideX,
1165 1u,
1166 &desc.m_PadLeft,
1167 &desc.m_PadRight,
1168 isUpper);
telsoa01c577f2c2018-08-31 09:22:23 +01001169 }
1170 }
1171 else
1172 {
1173 desc.m_PadTop = pads[0];
1174 desc.m_PadLeft = pads[1];
1175 desc.m_PadBottom = pads[2];
1176 desc.m_PadRight = pads[3];
1177 }
1178
1179 IConnectableLayer* layer = m_Network->AddPooling2dLayer(desc, node.name().c_str());
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001180 ARMNN_ASSERT(layer != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +01001181
1182 auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, {m_TensorsInfo[node.input(0)].m_info->GetShape()});
1183 layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
1184
1185 // register the input connection slots for the layer, connections are made after all layers have been created
1186 // only the tensors for the inputs are relevant, exclude the const tensors
1187 RegisterInputSlots(layer, {node.input(0)});
1188
1189 // register the output connection slots for the layer, connections are made after all layers have been created
1190 RegisterOutputSlots(layer, {node.output(0)});
1191}
1192
Kevin Mayef33cb12021-01-29 14:24:57 +00001193std::pair<std::string, std::string> OnnxParserImpl::AddPrepareBroadcast(const std::string& input0,
1194 const std::string& input1)
Ryan OSheaed27ee72020-04-22 16:37:29 +01001195{
1196 std::pair<std::string, std::string> inputs = std::make_pair(input0, input1);
1197
1198 TensorShape input0Shape = m_TensorsInfo[input0].m_info->GetShape();
1199 TensorShape input1Shape = m_TensorsInfo[input1].m_info->GetShape();
1200
1201 if(input1Shape.GetNumDimensions() < input0Shape.GetNumDimensions())
1202 {
James Ward58dec6b2020-09-11 17:32:44 +01001203 auto outputName = fmt::format("reshape_output_{}", input1);
Ryan OSheaed27ee72020-04-22 16:37:29 +01001204 PrependForBroadcast(outputName, input1, input0);
1205 inputs.second = outputName;
1206 }
1207 else if(input0Shape.GetNumDimensions() < input1Shape.GetNumDimensions())
1208 {
James Ward58dec6b2020-09-11 17:32:44 +01001209 auto outputName = fmt::format("reshape_output_{}", input0);
Ryan OSheaed27ee72020-04-22 16:37:29 +01001210 PrependForBroadcast(outputName, input0, input1);
1211 inputs.first = outputName;
1212 }
1213 return inputs;
1214}
1215
Kevin Mayef33cb12021-01-29 14:24:57 +00001216void OnnxParserImpl::CreateConstantLayer(const std::string& tensorName, const std::string& layerName)
Ryan OSheaed27ee72020-04-22 16:37:29 +01001217{
1218 auto armnnTensor = CreateConstTensor(tensorName);
Narumol Prangnawaratf10b15a2021-09-17 21:08:57 +01001219 IConnectableLayer* layer = m_Network->AddConstantLayer(armnnTensor.first, layerName.c_str());
1220 layer->GetOutputSlot(0).SetTensorInfo(armnnTensor.first.GetInfo());
1221 RegisterOutputSlots(layer, {tensorName});
1222}
Ryan OSheaed27ee72020-04-22 16:37:29 +01001223
Narumol Prangnawaratf10b15a2021-09-17 21:08:57 +01001224void OnnxParserImpl::CreateInt64ConstantLayer(const std::string& tensorName, const std::string& layerName)
1225{
1226 auto armnnTensor = CreateInt64ConstTensor(tensorName);
Ryan OSheaed27ee72020-04-22 16:37:29 +01001227 IConnectableLayer* layer = m_Network->AddConstantLayer(armnnTensor.first, layerName.c_str());
1228 layer->GetOutputSlot(0).SetTensorInfo(armnnTensor.first.GetInfo());
1229 RegisterOutputSlots(layer, {tensorName});
1230}
1231
Kevin Mayef33cb12021-01-29 14:24:57 +00001232void OnnxParserImpl::CreateReshapeLayer(const std::string& inputName,
1233 const std::string& outputName,
1234 const std::string& layerName)
telsoa01c577f2c2018-08-31 09:22:23 +01001235{
1236 const TensorInfo outputTensorInfo = *m_TensorsInfo[outputName].m_info;
1237 ReshapeDescriptor reshapeDesc;
1238 reshapeDesc.m_TargetShape = outputTensorInfo.GetShape();
1239
1240 IConnectableLayer* layer = m_Network->AddReshapeLayer(reshapeDesc, layerName.c_str());
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001241 ARMNN_ASSERT(layer != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +01001242 layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
1243
1244 // register the input connection slots for the layer, connections are made after all layers have been created
1245 // only the tensors for the inputs are relevant, exclude the const tensors
1246 RegisterInputSlots(layer, {inputName});
1247
1248 // register the output connection slots for the layer, connections are made after all layers have been created
1249 RegisterOutputSlots(layer, {outputName});
1250}
1251
Kevin Mayef33cb12021-01-29 14:24:57 +00001252void OnnxParserImpl::ParseActivation(const onnx::NodeProto& node, const armnn::ActivationFunction func)
telsoa01c577f2c2018-08-31 09:22:23 +01001253{
Finn Williams7ee5d2c2020-03-27 11:11:50 +00001254 CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 1, 3);
telsoa01c577f2c2018-08-31 09:22:23 +01001255 CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
1256
1257 VALID_INPUTS(node, STR_LIST(onnx::TensorProto::FLOAT));
1258
1259 ActivationDescriptor desc;
Tee Jung7ff9a602019-11-01 07:04:42 +00001260 desc.m_Function = func;
telsoa01c577f2c2018-08-31 09:22:23 +01001261
Finn Williams7ee5d2c2020-03-27 11:11:50 +00001262 if (func == ActivationFunction::BoundedReLu)
1263 {
Narumol Prangnawaratf106ab72021-09-15 17:30:37 +01001264 if (node.input_size() == 1 && node.attribute_size() > 0)
1265 {
1266 desc.m_A = ReadOptionalNodeFloatAttribute(node, "max", std::numeric_limits<float>::max());
1267 desc.m_B = ReadOptionalNodeFloatAttribute(node, "min", std::numeric_limits<float>::lowest());
1268 }
1269 else
1270 {
1271 desc.m_A = node.input(2).empty() ? std::numeric_limits<float>::max() : std::stof(node.input(2));
1272 desc.m_B = node.input(1).empty() ? std::numeric_limits<float>::lowest() : std::stof(node.input(1));
1273 }
Finn Williams7ee5d2c2020-03-27 11:11:50 +00001274 }
1275
telsoa01c577f2c2018-08-31 09:22:23 +01001276 IConnectableLayer* const layer = m_Network->AddActivationLayer(desc, node.name().c_str());
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001277 ARMNN_ASSERT(layer != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +01001278
1279 auto outputInfo = ComputeOutputInfo({ node.output(0)}, layer, {m_TensorsInfo[node.input(0)].m_info->GetShape()});
1280 layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
1281
1282 // register the input connection slots for the layer, connections are made after all layers have been created
1283 // only the tensors for the inputs are relevant, exclude the const tensors
1284 RegisterInputSlots(layer, {node.input(0)});
1285
1286 // register the output connection slots for the layer, connections are made after all layers have been created
1287 RegisterOutputSlots(layer, {node.output(0)});
1288}
1289
Kevin Mayef33cb12021-01-29 14:24:57 +00001290void OnnxParserImpl::ParseClip(const onnx::NodeProto& node)
Finn Williams7ee5d2c2020-03-27 11:11:50 +00001291{
1292 ParseActivation(node, ActivationFunction::BoundedReLu);
1293}
1294
Kevin Mayef33cb12021-01-29 14:24:57 +00001295void OnnxParserImpl::ParseSigmoid(const onnx::NodeProto& node)
Tee Jung7ff9a602019-11-01 07:04:42 +00001296{
1297 ParseActivation(node, ActivationFunction::Sigmoid);
1298}
1299
Kevin Mayef33cb12021-01-29 14:24:57 +00001300void OnnxParserImpl::ParseTanh(const onnx::NodeProto& node)
Tee Jung7ff9a602019-11-01 07:04:42 +00001301{
1302 ParseActivation(node, ActivationFunction::TanH);
1303}
1304
Kevin Mayef33cb12021-01-29 14:24:57 +00001305void OnnxParserImpl::ParseRelu(const onnx::NodeProto& node)
Tee Jung7ff9a602019-11-01 07:04:42 +00001306{
1307 ParseActivation(node, ActivationFunction::ReLu);
1308}
1309
Kevin Mayef33cb12021-01-29 14:24:57 +00001310void OnnxParserImpl::ParseLeakyRelu(const onnx::NodeProto& node)
Tee Jung7ff9a602019-11-01 07:04:42 +00001311{
1312 ParseActivation(node, ActivationFunction::LeakyReLu);
1313}
telsoa01c577f2c2018-08-31 09:22:23 +01001314
Kevin Mayef33cb12021-01-29 14:24:57 +00001315void OnnxParserImpl::ParseAdd(const onnx::NodeProto& node)
telsoa01c577f2c2018-08-31 09:22:23 +01001316{
Ryan OSheaed27ee72020-04-22 16:37:29 +01001317 CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 2);
1318 CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
telsoa01c577f2c2018-08-31 09:22:23 +01001319
Ryan OSheaed27ee72020-04-22 16:37:29 +01001320 VALID_INPUTS(node, STR_LIST(onnx::TensorProto::FLOAT));
telsoa01c577f2c2018-08-31 09:22:23 +01001321
Ryan OSheaed27ee72020-04-22 16:37:29 +01001322 // TODO: unify broadcast validation code across layers
1323 // tracked by: IVGCVSW-1576
telsoa01c577f2c2018-08-31 09:22:23 +01001324
Ryan OSheaed27ee72020-04-22 16:37:29 +01001325 // Checking broadcast compatibility : only scalar or 1D tensors
1326 auto inputs = AddPrepareBroadcast(node.input(0), node.input(1));
1327 auto input0 = *m_TensorsInfo[inputs.first].m_info;
1328 auto input1 = *m_TensorsInfo[inputs.second].m_info;
1329 ARMNN_ASSERT(input0.GetNumDimensions() == input1.GetNumDimensions());
1330
1331 unsigned int numDims = input0.GetNumDimensions();
1332 for (unsigned int i = 0; i < numDims; i++)
telsoa01c577f2c2018-08-31 09:22:23 +01001333 {
Ryan OSheaed27ee72020-04-22 16:37:29 +01001334 unsigned int dim0 = input0.GetShape()[i];
1335 unsigned int dim1 = input1.GetShape()[i];
1336 if (dim0 != dim1 && dim0 != 1 && dim1 != 1)
telsoa01c577f2c2018-08-31 09:22:23 +01001337 {
James Ward58dec6b2020-09-11 17:32:44 +01001338 throw ParseException(
1339 fmt::format("Broadcast is only supported for scalar or 1D tensors in Add node '{}'. "
1340 "Input dimensions should either match or one should be of size 1 and here, "
1341 "{} and {} {}",
1342 node.name(),
1343 TensorInfoAsString(*m_TensorsInfo[inputs.first].m_info, inputs.first,
1344 m_TensorsInfo[inputs.first].m_dtype),
1345 TensorInfoAsString(*m_TensorsInfo[inputs.second].m_info, inputs.second,
1346 m_TensorsInfo[inputs.second].m_dtype),
1347 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001348 }
telsoa01c577f2c2018-08-31 09:22:23 +01001349 }
Ryan OSheaed27ee72020-04-22 16:37:29 +01001350
1351
1352 IConnectableLayer* layer = m_Network->AddAdditionLayer(node.name().c_str());
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001353 ARMNN_ASSERT(layer != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +01001354
1355 auto outputInfo = ComputeOutputInfo({ node.output(0) }, layer,
Ryan OSheaed27ee72020-04-22 16:37:29 +01001356 { m_TensorsInfo[inputs.first].m_info->GetShape(),
1357 m_TensorsInfo[inputs.second].m_info->GetShape() });
telsoa01c577f2c2018-08-31 09:22:23 +01001358 layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
1359
Ryan OSheaed27ee72020-04-22 16:37:29 +01001360 // register the input connection -> for constant inputs, we need to make a newDim constant layer
1361 if(m_TensorsInfo[inputs.first].isConstant()) {
James Ward58dec6b2020-09-11 17:32:44 +01001362 CreateConstantLayer(inputs.first, fmt::format("Add:constant_of_{}", node.input(0)));
Ryan OSheaed27ee72020-04-22 16:37:29 +01001363 }
1364 if(m_TensorsInfo[inputs.second].isConstant()) {
James Ward58dec6b2020-09-11 17:32:44 +01001365 CreateConstantLayer(inputs.second, fmt::format("Add:constant_of_{}", node.input(1)));
Ryan OSheaed27ee72020-04-22 16:37:29 +01001366 }
1367 RegisterInputSlots(layer, {inputs.first, inputs.second});
telsoa01c577f2c2018-08-31 09:22:23 +01001368
Ryan OSheaed27ee72020-04-22 16:37:29 +01001369 // register the output connection
telsoa01c577f2c2018-08-31 09:22:23 +01001370 RegisterOutputSlots(layer, {node.output(0)});
1371}
1372
Kevin Mayef33cb12021-01-29 14:24:57 +00001373void OnnxParserImpl::ParseAveragePool(const onnx::NodeProto& node)
Ryan OSheaed27ee72020-04-22 16:37:29 +01001374{
1375 Pooling2dDescriptor desc;
1376 desc.m_PoolType = PoolingAlgorithm::Average;
1377
1378 uint32_t count_include_pad = 0;
1379 count_include_pad = ReadOptionalNodeUint32Attribute(node, "count_include_pad");
1380 if(count_include_pad) {
1381 desc.m_PaddingMethod = PaddingMethod::IgnoreValue;
1382 }
1383 AddPoolingLayer(node, desc);
1384}
1385
Kevin Mayef33cb12021-01-29 14:24:57 +00001386void OnnxParserImpl::ParseBatchNormalization(const onnx::NodeProto& node)
Ryan OSheaed27ee72020-04-22 16:37:29 +01001387{
1388 //IGNORE momentum parameter and spatial parameters
1389
1390 CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 5);
1391 CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
1392
1393 VALID_INPUTS(node, STR_LIST(onnx::TensorProto::FLOAT));
1394 for(int ind = 1; ind < node.input_size(); ++ind)
1395 {
1396 auto tensor = node.input(ind);
1397 if(! m_TensorsInfo[tensor].isConstant())
1398 {
James Ward58dec6b2020-09-11 17:32:44 +01001399 throw ParseException(
1400 fmt::format("Input tensor '{}' should be constant in BatchNormalization node '{}' {}",
1401 tensor,
1402 node.name(),
1403 CHECK_LOCATION().AsString()));
Ryan OSheaed27ee72020-04-22 16:37:29 +01001404 }
1405 }
1406
1407 float epsilon = ReadOptionalNodeFloatAttribute(node, "epsilon", 1e-5f);
1408 BatchNormalizationDescriptor desc;
1409 desc.m_Eps = epsilon;
1410
1411 auto scaleTensor = CreateConstTensor(node.input(1));
1412 auto biasTensor = CreateConstTensor(node.input(2));
1413 auto meanTensor = CreateConstTensor(node.input(3));
1414 auto varTensor = CreateConstTensor(node.input(4));
1415
1416 IConnectableLayer* layer = m_Network->AddBatchNormalizationLayer(desc,
1417 meanTensor.first,
1418 varTensor.first,
1419 biasTensor.first,
1420 scaleTensor.first,
1421 node.name().c_str());
1422 ARMNN_ASSERT(layer != nullptr);
1423
1424 auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, {m_TensorsInfo[node.input(0)].m_info->GetShape()});
1425 layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
1426
1427 RegisterInputSlots(layer, {node.input(0)}); //don't register constant inputs
1428
1429 // register the output connection
1430 RegisterOutputSlots(layer, {node.output(0)});
1431}
1432
Kevin Mayef33cb12021-01-29 14:24:57 +00001433void OnnxParserImpl::ParseConstant(const onnx::NodeProto& node)
Ryan OSheaed27ee72020-04-22 16:37:29 +01001434{
1435 CHECK_VALID_SIZE(static_cast<size_t>(node.attribute_size()), 1);
1436 if (!node.attribute(0).has_t())
1437 {
James Ward58dec6b2020-09-11 17:32:44 +01001438 throw ParseException(fmt::format("Value not found for Constant node '{}' {}",
1439 node.name(),
1440 CHECK_LOCATION().AsString()));
Ryan OSheaed27ee72020-04-22 16:37:29 +01001441 }
1442 const onnx::TensorProto& onnxTensor = node.attribute(0).t();
1443
Ryan OSheaed27ee72020-04-22 16:37:29 +01001444 //Register this as a m_ConstParam so we know we can use it as a constant param in future layers.
1445 m_TensorsInfo[node.output(0)].m_tensor = std::make_unique<const onnx::TensorProto>(onnxTensor);
1446 m_TensorsInfo[node.output(0)].m_info = std::make_unique<TensorInfo>(ToTensorInfo(onnxTensor));
1447 m_TensorsInfo[node.output(0)].m_dtype = static_cast<onnx::TensorProto::DataType>(onnxTensor.data_type());
1448
Narumol Prangnawaratf10b15a2021-09-17 21:08:57 +01001449 if (m_TensorsInfo[node.output(0)].m_dtype == onnx::TensorProto_DataType_FLOAT)
1450 {
1451 CreateConstantLayer(node.output(0), node.name());
1452 }
1453 else if (m_TensorsInfo[node.output(0)].m_dtype == onnx::TensorProto_DataType_INT64)
1454 {
1455 CreateInt64ConstantLayer(node.output(0), node.name());
1456 }
1457 else
1458 {
1459 throw ParseException(fmt::format("Data type not support for Constant node '{}' {}",
1460 node.name(),
1461 CHECK_LOCATION().AsString()));
1462 }
Ryan OSheaed27ee72020-04-22 16:37:29 +01001463}
1464
Kevin Mayef33cb12021-01-29 14:24:57 +00001465void OnnxParserImpl::ParseConv(const onnx::NodeProto& node)
telsoa01c577f2c2018-08-31 09:22:23 +01001466{
1467 CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 2, 3); //input, weight, (bias)
1468 CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
1469
1470 VALID_INPUTS(node, STR_LIST(onnx::TensorProto::FLOAT));
1471
1472 if(m_TensorsInfo[node.input(0)].m_info->GetNumDimensions() != 4)
1473 {
James Ward58dec6b2020-09-11 17:32:44 +01001474 throw ParseException(
1475 fmt::format("ArmNN only supports 2D convolution and Conv layer '{}' input {} {}",
1476 node.name(),
1477 TensorInfoAsString(*m_TensorsInfo[node.input(0)].m_info, node.input(0),
1478 m_TensorsInfo[node.input(0)].m_dtype),
1479 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001480 }
1481
1482 if(!m_TensorsInfo[node.input(1)].isConstant())
1483 {
James Ward58dec6b2020-09-11 17:32:44 +01001484 throw ParseException(
1485 fmt::format("Weights '{}' should be constant in Conv layer '{}' {}",
1486 node.input(1),
1487 node.name(),
1488 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001489 }
1490
1491 auto inputInfo = *m_TensorsInfo[node.input(0)].m_info;
1492
telsoa01c577f2c2018-08-31 09:22:23 +01001493 Convolution2dDescriptor desc;
1494 desc.m_BiasEnabled = false;
1495
1496 std::vector<uint32_t> strides = ReadOptionalNodeUint32ListAttribute(node, "strides");
1497 if(strides.empty())
1498 {
1499 desc.m_StrideX = 1;
1500 desc.m_StrideY = 1;
1501 }
1502 else
1503 {
1504 desc.m_StrideX = strides[1];
1505 desc.m_StrideY = strides[0];
1506 }
1507
Sadik Armagan60bb9d82021-01-11 15:15:01 +00001508 std::vector<uint32_t> dilations = ReadOptionalNodeUint32ListAttribute(node, "dilations");
1509 if(!dilations.empty())
1510 {
1511 desc.m_DilationX = dilations[1];
1512 desc.m_DilationY = dilations[0];
1513 }
1514
telsoa01c577f2c2018-08-31 09:22:23 +01001515 std::vector<uint32_t> pads = ReadOptionalNodeUint32ListAttribute(node, "pads");
1516 //Check new padding version first
1517 if(pads.empty())
1518 {
1519 //Check deprecated version
1520 std::string paddingString = ReadOptionalNodeStringAttribute(node, "auto_pad");
1521 if(paddingString != "VALID" && paddingString != "" && paddingString != "NOTSET")
1522 {
1523 bool isUpper;
1524 if( paddingString == "SAME_LOWER")
1525 {
1526 isUpper = false;
1527 }
1528 else if (paddingString == "SAME_UPPER")
1529 {
1530 isUpper = true;
1531 }
1532 else
1533 {
James Ward58dec6b2020-09-11 17:32:44 +01001534 throw ParseException(
1535 fmt::format("Invalid auto_pad attribute for node {}. Only SAME_UPPER, SAME_LOWER or VALID "
1536 "supported and found {} {}",
1537 node.name(),
1538 paddingString,
1539 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001540 }
1541 uint32_t inputHeight = inputInfo.GetShape()[2];
1542 uint32_t inputWidth = inputInfo.GetShape()[3];
1543
1544 uint32_t weightHeight;
1545 uint32_t weightWidth;
1546 std::vector<uint32_t> kernel_shape = ReadOptionalNodeUint32ListAttribute(node, "kernel_shape");
1547 if (kernel_shape.empty())
1548 {
1549 const TensorInfo weightTensorInfo = *m_TensorsInfo[node.input(1)].m_info;
1550 weightHeight = weightTensorInfo.GetShape()[2];
1551 weightWidth = weightTensorInfo.GetShape()[3];
1552 }
1553 else
1554 {
1555 weightHeight = kernel_shape[0];
1556 weightWidth = kernel_shape[1];
1557 }
Sadik Armagan60bb9d82021-01-11 15:15:01 +00001558 CalcPadding(inputHeight,
1559 weightHeight,
1560 desc.m_StrideY,
1561 desc.m_DilationY,
1562 &desc.m_PadTop,
1563 &desc.m_PadBottom,
1564 isUpper);
1565 CalcPadding(inputWidth,
1566 weightWidth,
1567 desc.m_StrideX,
1568 desc.m_DilationX,
1569 &desc.m_PadLeft,
1570 &desc.m_PadRight,
1571 isUpper);
telsoa01c577f2c2018-08-31 09:22:23 +01001572 }
1573 }
1574 else
1575 {
1576 desc.m_PadTop = pads[0];
1577 desc.m_PadLeft = pads[1];
1578 desc.m_PadBottom = pads[2];
1579 desc.m_PadRight = pads[3];
1580 }
1581
1582 uint32_t group = ReadOptionalNodeUint32Attribute(node, "group", 1);
1583 if(group > 1)
1584 {
1585 if (group > inputInfo.GetShape()[1])
1586 {
1587 throw ParseException(
James Ward58dec6b2020-09-11 17:32:44 +01001588 fmt::format("Error parsing Convolution node: {}. "
1589 "The 'group'={} parameter cannot be larger than the "
1590 "channel of the input shape={} (in NCHW format). {}",
1591 node.name(),
1592 group,
1593 inputInfo.GetShape()[1],
1594 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001595 }
1596 else if (group == inputInfo.GetShape()[1])
1597 {
1598 // we use a depthwise convolution here, because the number of groups equals to the
1599 // input channels
1600 AddConvLayerWithDepthwiseConv(node, desc);
1601 return;
1602 }
1603 else
1604 {
1605 // TODO: split the input by channels into channels/groups separate convolutions
Jim Flynne242f2d2019-05-22 14:24:13 +01001606 // and concatenate the results afterwards
James Ward58dec6b2020-09-11 17:32:44 +01001607 throw ParseException(fmt::format("Error parsing Convolution node: {}. "
1608 "The 'group'={} parameter should be 1 or be equal to the "
1609 "channel of the input shape={} (in NCHW format). {}",
1610 node.name(),
1611 group,
1612 inputInfo.GetShape()[1],
1613 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001614 }
1615 }
1616
1617 armnn::IConnectableLayer* layer;
1618 auto weightTensor = CreateConstTensor(node.input(1));
1619
1620 if (node.input_size() == 3)
1621 {
1622 if(!m_TensorsInfo[node.input(2)].isConstant())
1623 {
James Ward58dec6b2020-09-11 17:32:44 +01001624 throw ParseException(fmt::format("Bias '{}' should be constant in Conv layer '{}' {}",
1625 node.input(2),
1626 node.name(),
1627 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001628 }
1629 desc.m_BiasEnabled = true;
1630 auto biasTensor = CreateConstTensor(node.input(2));
1631 layer = m_Network->AddConvolution2dLayer(desc,
1632 weightTensor.first,
Matteo Martincighfc598e12019-05-14 10:36:13 +01001633 Optional<ConstTensor>(biasTensor.first),
telsoa01c577f2c2018-08-31 09:22:23 +01001634 node.name().c_str());
1635 }
1636 else
1637 {
1638 layer = m_Network->AddConvolution2dLayer(desc,
1639 weightTensor.first,
Matteo Martincighfc598e12019-05-14 10:36:13 +01001640 EmptyOptional(),
telsoa01c577f2c2018-08-31 09:22:23 +01001641 node.name().c_str());
1642 }
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001643 ARMNN_ASSERT(layer != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +01001644
1645 auto outputInfo = ComputeOutputInfo({ node.output(0) }, layer,
1646 { m_TensorsInfo[node.input(0)].m_info->GetShape(),
1647 m_TensorsInfo[node.input(1)].m_info->GetShape() });
1648 layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
1649
1650 // register the input connection slots for the layer, connections are made after all layers have been created
1651 // only the tensors for the inputs are relevant, exclude the const tensors
1652 RegisterInputSlots(layer, {node.input(0)});
1653
1654 // register the output connection slots for the layer, connections are made after all layers have been created
1655 RegisterOutputSlots(layer, {node.output(0)});
1656}
1657
Kevin Mayef33cb12021-01-29 14:24:57 +00001658void OnnxParserImpl::ParseFlatten(const onnx::NodeProto& node)
Ryan OSheaed27ee72020-04-22 16:37:29 +01001659{
1660 CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 1);
1661 CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
1662
1663 CHECK_VALID_DATATYPE(node.name(), node.input(0),
1664 m_TensorsInfo[node.input(0)].m_dtype,
1665 onnx::TensorProto::FLOAT);
1666
1667 int64_t axis = ReadOptionalNodeInt64Attribute(node, "axis", 1);
1668 TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape();
1669
1670 /// Negative axis conversion
1671 if (axis < 0)
1672 {
1673 axis += inputShape.GetNumDimensions();
1674 }
1675
1676 /// Check Axis is within dimensions
1677 if (axis < 0 || axis >= inputShape.GetNumDimensions())
1678 {
James Ward58dec6b2020-09-11 17:32:44 +01001679 throw ParseException(fmt::format("Axis '{}' invalid. Tensor has '{}' dimensions in FlattenLayer '{}'",
1680 axis, inputShape.GetNumDimensions(), node.name()));
Ryan OSheaed27ee72020-04-22 16:37:29 +01001681 }
1682
1683 /// If axis chosen is 0 dimension1 will always be 1 in output , default dimension2 to 1 because 0 is invalid
1684 uint dimension1{1};
1685 uint dimension2{1};
1686 uint i{0};
1687
1688 /// dimension1 = (d_0 * d_1 ... d_(axis-1))
1689 for (i = 0; i < axis; i++){
1690 dimension1 *= inputShape[i];
1691 }
1692
1693 /// dimension2 = (d_axis * d_(axis+1) ... d_n)
1694 for (i = static_cast<uint>(axis); i < inputShape.GetNumDimensions(); i++){
1695 dimension2 *= inputShape[i];
1696 }
1697
1698 TensorShape outputShape{dimension1, dimension2};
1699
1700 auto outInfo = ComputeReshapeInfo(outputShape, inputShape, node.output(0));
1701 m_TensorsInfo[node.output(0)].m_info = std::make_unique<TensorInfo>(outInfo);
1702 CreateReshapeLayer(node.input(0), node.output(0), node.name());
1703}
1704
Narumol Prangnawaratf10b15a2021-09-17 21:08:57 +01001705void OnnxParserImpl::ParseGather(const onnx::NodeProto& node)
1706{
1707 CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 2);
1708 CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
1709
1710 armnn::GatherDescriptor gatherDescriptor;
1711 gatherDescriptor.m_Axis = static_cast<int>(ReadOptionalNodeInt64Attribute(node, "axis", 0));
1712
1713 IConnectableLayer* layer = m_Network->AddGatherLayer(gatherDescriptor, node.name().c_str());
1714 ARMNN_ASSERT(layer != nullptr);
1715
1716 TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape();
1717 TensorShape indicesShape = m_TensorsInfo[node.input(1)].m_info->GetShape();
1718 auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, { inputShape, indicesShape });
1719 layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
1720
1721 // register the input connection slots for the layer, connections are made after all layers have been created
1722 RegisterInputSlots(layer, { node.input(0), node.input(1) });
1723
1724 // register the output connection slots for the layer, connections are made after all layers have been created
1725 RegisterOutputSlots(layer, { node.output(0) });
1726}
1727
Kevin Mayef33cb12021-01-29 14:24:57 +00001728void OnnxParserImpl::ParseGlobalAveragePool(const onnx::NodeProto& node)
Ryan OSheaed27ee72020-04-22 16:37:29 +01001729{
1730 Pooling2dDescriptor desc = Pooling2dDescriptor();
1731 desc.m_PoolType = PoolingAlgorithm::Average;
1732
1733 //kernel size is the same as input
1734 TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape();
1735 desc.m_PoolWidth = inputShape[3];
1736 desc.m_PoolHeight = inputShape[2];
1737
1738 IConnectableLayer* layer = m_Network->AddPooling2dLayer(desc, node.name().c_str());
1739 ARMNN_ASSERT(layer != nullptr);
1740
1741 auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, {inputShape});
1742 layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
1743
1744 // register the input connection slots for the layer, connections are made after all layers have been created
1745 // only the tensors for the inputs are relevant, exclude the const tensors
1746 RegisterInputSlots(layer, {node.input(0)});
1747
1748 // register the output connection slots for the layer, connections are made after all layers have been created
1749 RegisterOutputSlots(layer, {node.output(0)});
1750}
1751
Kevin Mayef33cb12021-01-29 14:24:57 +00001752void OnnxParserImpl::ParseMaxPool(const onnx::NodeProto& node)
Ryan OSheaed27ee72020-04-22 16:37:29 +01001753{
1754 Pooling2dDescriptor desc;
1755 desc.m_PoolType = PoolingAlgorithm::Max;
1756 desc.m_PaddingMethod = PaddingMethod::Exclude;
1757 AddPoolingLayer(node, desc);
1758}
1759
Narumol Prangnawaratcdc495e2021-09-16 18:13:39 +01001760void OnnxParserImpl::ParseShape(const onnx::NodeProto& node)
1761{
1762 CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 1);
1763 CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
1764
1765 // Output must be INT64
1766 CHECK_VALID_DATATYPE(node.name(), node.output(0),
1767 m_TensorsInfo[node.output(0)].m_dtype,
1768 onnx::TensorProto::INT64);
1769
1770 IConnectableLayer* layer = m_Network->AddShapeLayer(node.name().c_str());
1771 ARMNN_ASSERT(layer != nullptr);
1772
1773 TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape();
1774 auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, {inputShape});
1775 layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
1776
1777 // register the input connection slots for the layer, connections are made after all layers have been created
1778 RegisterInputSlots(layer, {node.input(0)});
1779
1780 // register the output connection slots for the layer, connections are made after all layers have been created
1781 RegisterOutputSlots(layer, {node.output(0)});
1782}
1783
Kevin Mayef33cb12021-01-29 14:24:57 +00001784void OnnxParserImpl::ParseReshape(const onnx::NodeProto& node)
Ryan OSheaed27ee72020-04-22 16:37:29 +01001785{
1786 CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 2);
1787 CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
1788
1789 CHECK_VALID_DATATYPE(node.name(), node.input(0),
1790 m_TensorsInfo[node.input(0)].m_dtype,
1791 onnx::TensorProto::FLOAT); //input
1792 CHECK_VALID_DATATYPE(node.name(), node.input(1),
1793 m_TensorsInfo[node.input(1)].m_dtype,
1794 onnx::TensorProto::INT64); //shape
1795
1796 if(!m_TensorsInfo[node.input(1)].isConstant())
1797 {
James Ward58dec6b2020-09-11 17:32:44 +01001798 throw ParseException(fmt::format("Shape '{}' should be constant in Reshape layer '{}' {}",
1799 node.input(1),
1800 node.name(),
1801 CHECK_LOCATION().AsString()));
Ryan OSheaed27ee72020-04-22 16:37:29 +01001802 }
1803
1804 if(m_TensorsInfo[node.input(0)].isConstant())
1805 {
1806 //make a new cst tensor -> move the data to the output tensor (the shape is already good in the output tensor)
1807 if(m_TensorsInfo.count(node.output(0)) == 0)
1808 {
1809 m_TensorsInfo[node.output(0)] = OnnxTensor();
1810 }
1811 m_TensorsInfo[node.output(0)].m_tensor =
1812 std::make_unique<onnx::TensorProto>(*m_TensorsInfo[node.input(0)].m_tensor);
1813 }
1814 else
1815 {
1816 TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape();
1817
1818 if(m_TensorsInfo.count(node.output(0)) == 0 || m_TensorsInfo[node.output(0)].m_info == nullptr)
1819 {
1820 uint64_t dims = static_cast<uint64_t>(m_TensorsInfo[node.input(1)].m_tensor->int64_data_size());
1821 TensorShape targetShape{static_cast<unsigned int>(dims), 1};
1822
1823 for(uint i = 0; i < dims; i++)
1824 {
1825 int val = CHECKED_INT32(m_TensorsInfo[node.input(1)].m_tensor->int64_data(static_cast<int>(i)));
1826 targetShape[i]= static_cast<unsigned int>(val);
1827 }
1828
1829 auto outInfo = ComputeReshapeInfo(targetShape, inputShape, node.output(0));
1830 m_TensorsInfo[node.output(0)].m_info = std::make_unique<TensorInfo>(outInfo);
1831 }
1832
1833 CreateReshapeLayer(node.input(0), node.output(0), node.name());
1834 }
1835}
1836
Kevin Mayef33cb12021-01-29 14:24:57 +00001837void OnnxParserImpl::PrependForBroadcast(const std::string& outputName,
1838 const std::string& input0,
1839 const std::string& input1)
telsoa01c577f2c2018-08-31 09:22:23 +01001840{
1841 //input0 should be reshaped to have same number of dim as input1
1842 TensorInfo outputTensorInfo = TensorInfo(*m_TensorsInfo[input0].m_info);
1843
1844 TensorShape input0Shape = m_TensorsInfo[input0].m_info->GetShape();
1845 TensorShape input1Shape = m_TensorsInfo[input1].m_info->GetShape();
1846
1847 uint32_t diff = input1Shape.GetNumDimensions() - input0Shape.GetNumDimensions();
1848 std::vector<uint32_t> newShape;
1849 while(diff > 0)
1850 {
1851 newShape.push_back(1);
1852 diff--;
1853 }
1854 for (uint dim = 0; dim < input0Shape.GetNumDimensions(); ++dim)
1855 {
1856 newShape.push_back(input0Shape[dim]);
1857 }
1858 outputTensorInfo.SetShape(TensorShape(static_cast<unsigned int>(newShape.size()), newShape.data()));
1859
1860 //add the new tensor to m_TensorsInfo
1861 m_TensorsInfo[outputName] = OnnxTensor();
1862 m_TensorsInfo[outputName].m_info = std::make_unique<TensorInfo>(outputTensorInfo);
1863
1864 //add reshape layer if the parent was not constant...
1865 if( ! m_TensorsInfo[input0].isConstant())
1866 {
James Ward58dec6b2020-09-11 17:32:44 +01001867 CreateReshapeLayer(input0, outputName, fmt::format("Add:reshapeOf{}", input0));
telsoa01c577f2c2018-08-31 09:22:23 +01001868 }
1869 else //make it constant and it will be create in Add
1870 {
1871 m_TensorsInfo[outputName].m_tensor = std::make_unique<onnx::TensorProto>(*m_TensorsInfo[input0].m_tensor);
1872
1873 }
1874}
1875
Kevin Mayef33cb12021-01-29 14:24:57 +00001876void OnnxParserImpl::SetupInputLayers()
telsoa01c577f2c2018-08-31 09:22:23 +01001877{
1878 //Find user input and add their layers
1879 for(int inputIndex = 0; inputIndex < m_Graph->input_size(); ++inputIndex)
1880 {
1881 auto input = m_Graph->input(inputIndex);
1882 if (! m_TensorsInfo[input.name()].isConstant())
1883 {
1884 IConnectableLayer* layer =
1885 m_Network->AddInputLayer(static_cast<armnn::LayerBindingId>(inputIndex), input.name().c_str());
1886 auto tensorInfo = ToTensorInfo(input);
1887 layer->GetOutputSlot(0).SetTensorInfo(tensorInfo);
1888
1889 RegisterOutputSlots(layer,{ input.name() });
1890 }
1891 }
1892}
1893
Kevin Mayef33cb12021-01-29 14:24:57 +00001894void OnnxParserImpl::SetupOutputLayers()
telsoa01c577f2c2018-08-31 09:22:23 +01001895{
1896 if(m_Graph->output_size() == 0)
1897 {
James Ward58dec6b2020-09-11 17:32:44 +01001898 throw ParseException(fmt::format("The given model does not have any outputs {}", CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001899 }
1900
1901 for(int outputIndex = 0; outputIndex < m_Graph->output_size(); ++outputIndex)
1902 {
1903 IConnectableLayer* layer =
1904 m_Network->AddOutputLayer(static_cast<armnn::LayerBindingId>(outputIndex),
1905 m_Graph->output(outputIndex).name().c_str());
1906
1907 RegisterInputSlots(layer, { m_Graph->output(outputIndex).name() });
1908 }
1909}
1910
Kevin Mayef33cb12021-01-29 14:24:57 +00001911void OnnxParserImpl::RegisterInputSlots(IConnectableLayer* layer, const std::vector<std::string>& tensorIds)
telsoa01c577f2c2018-08-31 09:22:23 +01001912{
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001913 ARMNN_ASSERT(layer != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +01001914 if (tensorIds.size() != layer->GetNumInputSlots())
1915 {
1916 throw ParseException(
James Ward58dec6b2020-09-11 17:32:44 +01001917 fmt::format("The number of tensor inputs ({}) does not match the number expected ({}) {}",
1918 tensorIds.size(),
1919 layer->GetNumInputSlots(),
1920 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001921 }
Matthew Sloyan81beae32021-07-13 19:46:11 +01001922
telsoa01c577f2c2018-08-31 09:22:23 +01001923 for (unsigned int slotIndex = 0; slotIndex < layer->GetNumInputSlots(); ++slotIndex)
1924 {
1925 std::string tensorId = tensorIds[slotIndex];
1926 armnn::IInputSlot* slot = &(layer->GetInputSlot(slotIndex));
1927
1928 auto it = m_TensorConnections.find(tensorId);
1929
1930 if (it == m_TensorConnections.end())
1931 {
1932 //First time seing this tensor, we need to map it
1933 m_TensorConnections[tensorId] = TensorSlots();
1934 }
1935 m_TensorConnections[tensorId].inputSlots.push_back(slot);
1936 }
1937}
1938
Kevin Mayef33cb12021-01-29 14:24:57 +00001939void OnnxParserImpl::RegisterOutputSlots(IConnectableLayer* layer, const std::vector<std::string>& tensorIds)
telsoa01c577f2c2018-08-31 09:22:23 +01001940{
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001941 ARMNN_ASSERT(layer != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +01001942 if (tensorIds.size() != layer->GetNumOutputSlots())
1943 {
1944 throw ParseException(
James Ward58dec6b2020-09-11 17:32:44 +01001945 fmt::format("The number of tensor outputs ({}) does not match the number expected ({}) {} ",
1946 tensorIds.size(),
1947 layer->GetNumOutputSlots(),
1948 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001949 }
1950
1951 for (unsigned int slotIndex = 0; slotIndex < layer->GetNumOutputSlots(); ++slotIndex)
1952 {
1953 std::string tensorId = tensorIds[slotIndex];
1954 armnn::IOutputSlot* slot = &(layer->GetOutputSlot(slotIndex));
1955
1956 auto it = m_TensorConnections.find(tensorId);
1957
1958 if (it == m_TensorConnections.end())
1959 {
1960 //First time seing this tensor, we need to map it
1961 m_TensorConnections[tensorId] = TensorSlots();
1962 }
1963
Ryan OShea337c17f2020-02-21 12:33:17 +00001964 TensorSlots& tensorSlots = m_TensorConnections[tensorId];
telsoa01c577f2c2018-08-31 09:22:23 +01001965
1966 // assuming there is only one producer for that tensor
1967 if (tensorSlots.outputSlot != nullptr)
1968 {
James Ward58dec6b2020-09-11 17:32:44 +01001969 throw ParseException(fmt::format("Another layer has already registered itself as the producer of "
1970 "tensor:{} {}",
1971 tensorId,
1972 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001973 }
1974 tensorSlots.outputSlot = slot;
1975 }
1976}
1977
Kevin Mayef33cb12021-01-29 14:24:57 +00001978BindingPointInfo OnnxParserImpl::GetNetworkInputBindingInfo(const std::string& name) const
telsoa01c577f2c2018-08-31 09:22:23 +01001979{
1980 for(int i = 0; i < m_Graph->input_size(); ++i)
1981 {
1982 auto input = m_Graph->input(i);
1983 if(input.name() == name)
1984 {
1985 return std::make_pair(static_cast<armnn::LayerBindingId>(i), ToTensorInfo(input));
1986 }
1987 }
James Ward58dec6b2020-09-11 17:32:44 +01001988 throw InvalidArgumentException(fmt::format("The input layer '{}' does not exist {}",
1989 name, CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001990}
1991
Kevin Mayef33cb12021-01-29 14:24:57 +00001992BindingPointInfo OnnxParserImpl::GetNetworkOutputBindingInfo(const std::string& name) const
telsoa01c577f2c2018-08-31 09:22:23 +01001993{
1994 for(int i = 0; i < m_Graph->output_size(); ++i)
1995 {
1996 auto output = m_Graph->output(i);
1997 if(output.name() == name)
1998 {
1999 return std::make_pair(static_cast<armnn::LayerBindingId>(i), ToTensorInfo(output));
2000 }
2001 }
James Ward58dec6b2020-09-11 17:32:44 +01002002 throw InvalidArgumentException(fmt::format("The output layer '{}' does not exist {}",
2003 name, CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01002004}
2005
Kevin Mayef33cb12021-01-29 14:24:57 +00002006std::vector<std::string> OnnxParserImpl::GetInputs(ModelPtr& model)
telsoa01c577f2c2018-08-31 09:22:23 +01002007{
2008 if(model == nullptr) {
James Ward58dec6b2020-09-11 17:32:44 +01002009 throw InvalidArgumentException(fmt::format("The given model cannot be null {}",
2010 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01002011 }
2012
2013 std::vector<std::string> inputNames;
2014 std::map<std::string, bool> isConstant;
2015 for(auto tensor : model->graph().initializer())
2016 {
2017 isConstant[tensor.name()] = true;
2018 }
2019 for(auto input : model->graph().input())
2020 {
2021 auto it = isConstant.find(input.name());
2022 if(it == isConstant.end())
2023 {
2024 inputNames.push_back(input.name());
2025 }
2026 }
2027 return inputNames;
2028}
2029
Kevin Mayef33cb12021-01-29 14:24:57 +00002030std::vector<std::string> OnnxParserImpl::GetOutputs(ModelPtr& model)
telsoa01c577f2c2018-08-31 09:22:23 +01002031{
2032 if(model == nullptr) {
James Ward58dec6b2020-09-11 17:32:44 +01002033 throw InvalidArgumentException(fmt::format("The given model cannot be null {}",
2034 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01002035 }
2036
2037 std::vector<std::string> outputNames;
2038 for(auto output : model->graph().output())
2039 {
2040 outputNames.push_back(output.name());
2041 }
2042 return outputNames;
2043}
2044
Matthew Sloyanac001ee2021-02-03 10:43:04 +00002045const std::string OnnxParserImpl::GetVersion()
2046{
2047 return ONNX_PARSER_VERSION;
2048}
2049
telsoa01c577f2c2018-08-31 09:22:23 +01002050} // namespace armnnOnnxParser