blob: 81d9e3d240390d42aadcdb3bc8f0304059ad5e4e [file] [log] [blame]
telsoa01c577f2c2018-08-31 09:22:23 +01001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
David Beckecb56cd2018-09-05 12:52:57 +01003// SPDX-License-Identifier: MIT
telsoa01c577f2c2018-08-31 09:22:23 +01004//
5#include "OnnxParser.hpp"
6
Matthew Sloyanac001ee2021-02-03 10:43:04 +00007#include "armnnOnnxParser/Version.hpp"
8
Matthew Bentham39ef3e52020-01-20 10:09:09 +00009#include <armnn/Descriptors.hpp>
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +010010#include <armnn/utility/Assert.hpp>
Matthew Sloyan589e3e82020-09-11 16:17:48 +010011#include <armnn/utility/NumericCast.hpp>
telsoa01c577f2c2018-08-31 09:22:23 +010012#include <VerificationHelpers.hpp>
13
James Ward58dec6b2020-09-11 17:32:44 +010014#include <fmt/format.h>
Aron Virginas-Tard4f0fea2019-04-09 14:08:06 +010015
telsoa01c577f2c2018-08-31 09:22:23 +010016#include <google/protobuf/text_format.h>
17#include <google/protobuf/io/zero_copy_stream_impl.h>
18
Matthew Sloyanac001ee2021-02-03 10:43:04 +000019#include <iostream>
telsoa01c577f2c2018-08-31 09:22:23 +010020#include <numeric>
21
22using namespace armnn;
23
24namespace armnnOnnxParser
25{
Kevin Mayef33cb12021-01-29 14:24:57 +000026
27IOnnxParser::IOnnxParser() : pOnnxParserImpl(new OnnxParserImpl()) {}
28
29IOnnxParser::~IOnnxParser() = default;
30
31IOnnxParser* IOnnxParser::CreateRaw()
32{
33 return new IOnnxParser();
34}
35
36IOnnxParserPtr IOnnxParser::Create()
37{
38 return IOnnxParserPtr(CreateRaw(), &IOnnxParser::Destroy);
39}
40
41void IOnnxParser::Destroy(IOnnxParser* parser)
42{
43 delete parser;
44}
45
46armnn::INetworkPtr IOnnxParser::CreateNetworkFromBinaryFile(const char* graphFile)
47{
48 return pOnnxParserImpl->CreateNetworkFromBinaryFile(graphFile);
49}
50
51armnn::INetworkPtr IOnnxParser::CreateNetworkFromTextFile(const char* graphFile)
52{
53 return pOnnxParserImpl->CreateNetworkFromTextFile(graphFile);
54}
55
56armnn::INetworkPtr IOnnxParser::CreateNetworkFromString(const std::string& protoText)
57{
58 return pOnnxParserImpl->CreateNetworkFromString(protoText);
59}
60
61BindingPointInfo IOnnxParser::GetNetworkInputBindingInfo(const std::string& name) const
62{
63 return pOnnxParserImpl->GetNetworkInputBindingInfo(name);
64}
65
66BindingPointInfo IOnnxParser::GetNetworkOutputBindingInfo(const std::string& name) const
67{
68 return pOnnxParserImpl->GetNetworkOutputBindingInfo(name);
69}
70
telsoa01c577f2c2018-08-31 09:22:23 +010071namespace
72{
73void CheckValidDataType(std::initializer_list<onnx::TensorProto::DataType> validInputTypes,
74 const onnx::TensorProto::DataType actualValue,
75 const char* validExpr,
76 std::string nodeName,
77 std::string tensorName,
78 const armnn::CheckLocation& location)
79{
80 bool isValid = std::any_of(validInputTypes.begin(),
81 validInputTypes.end(),
82 [&actualValue](onnx::TensorProto::DataType x) { return x == actualValue; } );
83 if (!isValid)
84 {
85 throw ParseException(
James Ward58dec6b2020-09-11 17:32:44 +010086 fmt::format("Datatype {} is not valid for tensor '{}' of node '{}', not in {{{}}}. {}",
87 onnx::TensorProto::DataType_Name(actualValue),
88 tensorName,
89 nodeName,
90 validExpr,
91 location.AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +010092 }
93}
94
95#define CHECK_VALID_DATATYPE(NODE, TENSOR, ACTUAL, ...) \
96CheckValidDataType({__VA_ARGS__}, ACTUAL, #__VA_ARGS__, NODE, TENSOR, CHECK_LOCATION())
97
98using StrTypeListPair = std::pair<const char*, std::initializer_list<onnx::TensorProto::DataType>>;
99#define STR_LIST(...) StrTypeListPair(#__VA_ARGS__, {__VA_ARGS__})
100
101template <typename Callable>
102void ReadMandatoryNodeAttributeImpl(const onnx::NodeProto& node,
103 const std::string& attribName,
104 onnx::AttributeProto::AttributeType expectedType,
105 Callable callable)
106{
107 auto attribs = node.attribute();
108 int attriNum = 0;
109 while (attriNum < node.attribute_size())
110 {
111 if (attribs.Get(attriNum).name() == attribName)
112 {
113 if (attribs.Get(attriNum).type() == expectedType)
114 {
115 callable(attribs.Get(attriNum));
116 }
117 else
118 {
James Ward58dec6b2020-09-11 17:32:44 +0100119 throw ParseException(fmt::format("Attribute {} of node {} expected to have {} as "
120 "onnx::AttributeProto::AttributeType, but found {} instead {}",
121 attribName,
122 node.name(),
123 onnx::AttributeProto::AttributeType_Name(expectedType),
124 onnx::AttributeProto::AttributeType_Name(attribs.Get(attriNum).type()),
125 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100126 }
127 break;
128 }
129 ++attriNum;
130 }
131 if (attriNum == node.attribute_size())
132 {
James Ward58dec6b2020-09-11 17:32:44 +0100133 throw ParseException(fmt::format("Could not find required attribute {} in node {} {}",
134 attribName, node.name(), CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100135 }
136}
137
138template <typename Callable>
139void ReadOptionalNodeAttributeImpl(const onnx::NodeProto& node,
140 const std::string& attribName,
141 onnx::AttributeProto::AttributeType expectedType,
142 Callable callable)
143{
144 auto attribs = node.attribute();
145 for (int attriNum = 0; attriNum < node.attribute_size(); ++attriNum)
146 {
147 if (attribs.Get(attriNum).name() == attribName)
148 {
149 if (attribs.Get(attriNum).type() == expectedType)
150 {
151 callable(attribs.Get(attriNum));
152 }
153 else
154 {
James Ward58dec6b2020-09-11 17:32:44 +0100155 throw ParseException(
156 fmt::format("Attribute {} of node {} expected to have {} as onnx::AttributeProto::AttributeType, "
157 "but found {} instead {}",
158 attribName,
159 node.name(),
160 onnx::AttributeProto::AttributeType_Name(expectedType),
161 onnx::AttributeProto::AttributeType_Name(attribs.Get(attriNum).type()),
162 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100163 }
164 }
165 }
166}
167
Ryan OSheaed27ee72020-04-22 16:37:29 +0100168int64_t ReadOptionalNodeInt64Attribute(const onnx::NodeProto& node,
169 const std::string& name,
170 const int64_t defaultValue = 0)
171{
172 int64_t attribValue = defaultValue;
173 ReadOptionalNodeAttributeImpl(node, name, onnx::AttributeProto::INT,
174 [&attribValue](const onnx::AttributeProto& attrValue)
175 {
176 attribValue = attrValue.i();
177 });
178 return attribValue;
179}
180
telsoa01c577f2c2018-08-31 09:22:23 +0100181std::vector<uint32_t> ReadMandatoryNodeUint32ListAttribute(const onnx::NodeProto& node,
182 const std::string& name)
183{
184 std::vector<uint32_t> attriList;
185 ReadMandatoryNodeAttributeImpl(node, name, onnx::AttributeProto::INTS,
186 [&attriList](const onnx::AttributeProto& attrValue)
187 {
188 for (int attriNum = 0; attriNum < attrValue.ints_size(); ++attriNum)
189 {
190 attriList.push_back(CHECKED_NON_NEGATIVE(CHECKED_INT32(attrValue.ints().Get(attriNum))));
191 }
192 });
193 return attriList;
194}
195
196uint32_t ReadOptionalNodeUint32Attribute(const onnx::NodeProto& node,
197 const std::string& name,
198 const uint32_t defaultVal = 0u)
199{
200 uint32_t attribValue = defaultVal;
201 ReadOptionalNodeAttributeImpl(node, name, onnx::AttributeProto::INT,
202 [&attribValue](const onnx::AttributeProto& attrValue)
203 {
204 attribValue = CHECKED_NON_NEGATIVE(CHECKED_INT32((attrValue.i())));
205 });
206 return attribValue;
207}
208
209std::vector<uint32_t> ReadOptionalNodeUint32ListAttribute(const onnx::NodeProto& node,
210 const std::string& name)
211{
212 std::vector<uint32_t> attriList;
213 ReadOptionalNodeAttributeImpl(node, name, onnx::AttributeProto::INTS,
214 [&attriList](const onnx::AttributeProto& attrValue)
215 {
216 for (int attriNum = 0; attriNum < attrValue.ints_size(); ++attriNum)
217 {
218 attriList.push_back(CHECKED_NON_NEGATIVE(CHECKED_INT32(attrValue.ints().Get(attriNum))));
219 }
220 });
221
222 return attriList;
223}
224
225float ReadOptionalNodeFloatAttribute(const onnx::NodeProto& node,
226 const std::string& name,
227 const float defaultValue = 0.0f)
228{
229 float attribValue = defaultValue;
230 ReadOptionalNodeAttributeImpl(node, name, onnx::AttributeProto::FLOAT,
231 [&attribValue](const onnx::AttributeProto& attrValue)
232 {
233 attribValue = attrValue.f();
234 });
235 return attribValue;
236}
237
238std::string ReadOptionalNodeStringAttribute(const onnx::NodeProto& node, const std::string& name)
239{
240 std::string attribValue = "";
241 ReadOptionalNodeAttributeImpl(node, name, onnx::AttributeProto::STRING,
242 [&attribValue](const onnx::AttributeProto& attrValue)
243 {
244 attribValue = attrValue.s();
245 });
246 return attribValue;
247}
248
Tee Jungfcf6fd52019-11-01 05:27:28 +0000249armnn::TensorInfo ToTensorInfo(const std::string& name, std::vector<unsigned int>& shape, int data_type)
telsoa01c577f2c2018-08-31 09:22:23 +0100250{
telsoa01c577f2c2018-08-31 09:22:23 +0100251 DataType type;
Tee Jungfcf6fd52019-11-01 05:27:28 +0000252 switch(data_type)
telsoa01c577f2c2018-08-31 09:22:23 +0100253 {
254 case onnx::TensorProto::FLOAT:
255 {
256 type = DataType::Float32;
257 break;
258 }
259 case onnx::TensorProto::INT32:
260 case onnx::TensorProto::INT64:
261 {
262 type = DataType::Signed32;
263 break;
264 }
265 default:
266 {
267 throw ParseException(
James Ward58dec6b2020-09-11 17:32:44 +0100268 fmt::format("'{}' is not a currently supported datatype for tensor {}."
269 " Supported dataTypes are FLOAT, INT32 and INT64. {}",
270 onnx::TensorProto::DataType_Name(static_cast<onnx::TensorProto::DataType>(data_type)),
271 name,
272 CHECK_LOCATION().AsString() ));
telsoa01c577f2c2018-08-31 09:22:23 +0100273 }
telsoa01c577f2c2018-08-31 09:22:23 +0100274 }
Tee Jungcaf2bdd2019-11-13 07:23:14 +0000275
276 // To avoid crashes by trivial tensors
277 if (shape.empty())
278 {
279 return TensorInfo(TensorShape(), type);
280 }
281
Tee Jungfcf6fd52019-11-01 05:27:28 +0000282 return TensorInfo(TensorShape(static_cast<unsigned int>(shape.size()), shape.data()), type);
283}
284
285armnn::TensorInfo ToTensorInfo(const onnx::ValueInfoProto& info)
286{
287 const onnx::TensorShapeProto onnxShape = info.type().tensor_type().shape();
288 std::vector<unsigned int> shapeDims;
289 for (int i = 0; i < onnxShape.dim_size(); ++i)
290 {
291 shapeDims.push_back(CHECKED_NON_NEGATIVE(CHECKED_INT32(onnxShape.dim(i).dim_value())));
292 }
293
Ryan OShea337c17f2020-02-21 12:33:17 +0000294 if (shapeDims.empty())
295 {
296 shapeDims.push_back(1);
297 }
298
Tee Jungfcf6fd52019-11-01 05:27:28 +0000299 return ToTensorInfo(info.name(), shapeDims, info.type().tensor_type().elem_type());
300}
301
302armnn::TensorInfo ToTensorInfo(const onnx::TensorProto& tensor)
303{
304 std::vector<unsigned int> shapeDims;
Ryan OShea337c17f2020-02-21 12:33:17 +0000305
Tee Jungfcf6fd52019-11-01 05:27:28 +0000306 for (auto dim: tensor.dims())
307 {
308 shapeDims.push_back(CHECKED_NON_NEGATIVE(CHECKED_INT32(dim)));
309 }
310
Ryan OShea337c17f2020-02-21 12:33:17 +0000311 if (shapeDims.empty())
312 {
313 shapeDims.push_back(1);
314 }
315
Tee Jungfcf6fd52019-11-01 05:27:28 +0000316 return ToTensorInfo(tensor.name(), shapeDims, tensor.data_type());
telsoa01c577f2c2018-08-31 09:22:23 +0100317}
318
319std::string TensorInfoAsString(const TensorInfo& info,
320 const std::string& name,
321 const onnx::TensorProto::DataType& type)
322{
323 const TensorShape shape = info.GetShape();
324 std::stringstream ss;
325 ss << "tensor '" << name << "' contains "
326 << onnx::TensorProto::DataType_Name(type)
327 << " and has shape [";
328
329 for (uint32_t i = 0; i < shape.GetNumDimensions() - 1; ++i)
330 {
331 ss << shape[i] << ", ";
332 }
333 ss << shape[shape.GetNumDimensions() - 1] << "]";
334 return ss.str();
335}
336
Sadik Armagan60bb9d82021-01-11 15:15:01 +0000337void CalcPadding(uint32_t inputSize,
338 uint32_t filterSize,
339 uint32_t stride,
340 uint32_t dilation,
341 uint32_t* paddingFront,
342 uint32_t* paddingBack,
343 bool isUpper)
telsoa01c577f2c2018-08-31 09:22:23 +0100344{
345 uint32_t outputSize = (inputSize + stride - 1) / stride;
Sadik Armagan60bb9d82021-01-11 15:15:01 +0000346 uint32_t dilatedSize = filterSize + (dilation - 1) * (filterSize - 1);
347 uint32_t temp = (outputSize - 1) * stride + dilatedSize;
telsoa01c577f2c2018-08-31 09:22:23 +0100348 *paddingFront = (temp - inputSize) / 2;
349 *paddingBack = *paddingFront;
350 if((temp - inputSize) % 2 == 1)
351 {
352 if (isUpper)
353 {
Sadik Armagan60bb9d82021-01-11 15:15:01 +0000354 *paddingBack += 1;
telsoa01c577f2c2018-08-31 09:22:23 +0100355 }
356 else
357 {
Sadik Armagan60bb9d82021-01-11 15:15:01 +0000358 *paddingFront += 1;
telsoa01c577f2c2018-08-31 09:22:23 +0100359 }
360 }
361}
362
Ryan OSheaed27ee72020-04-22 16:37:29 +0100363TensorInfo ComputeReshapeInfo(const TensorShape& targetShapeTensor,
telsoa01c577f2c2018-08-31 09:22:23 +0100364 const TensorShape& inShape,
365 const std::string& outName)
366{
367 std::vector<int> targetDims;
Ryan OSheaed27ee72020-04-22 16:37:29 +0100368 for(uint i = 0; i < targetShapeTensor.GetNumDimensions(); ++i)
telsoa01c577f2c2018-08-31 09:22:23 +0100369 {
Ryan OSheaed27ee72020-04-22 16:37:29 +0100370 int val = CHECKED_INT32(targetShapeTensor[i]);
telsoa01c577f2c2018-08-31 09:22:23 +0100371 if(val == 0)
372 {
373 targetDims.push_back(static_cast<int>(inShape[static_cast<uint>(i)]));
374 }
375 else
376 {
377 targetDims.push_back(val);
378 }
379 }
380
381 std::vector<unsigned int> outDims(targetDims.begin(), targetDims.end());
382 const auto stretchDim = std::find(targetDims.begin(), targetDims.end(), -1);
383 if (stretchDim != targetDims.end())
384 {
385 if (std::find(std::next(stretchDim), targetDims.end(), -1) != targetDims.end())
386 {
387 std::stringstream ss;
388 ss << "[ ";
389 for(uint i = 0; i < targetDims.size() - 1; ++i)
390 {
391 ss << targetDims[i] << ", ";
392 }
393 ss << targetDims[targetDims.size() - 1] << " ]";
394
James Ward58dec6b2020-09-11 17:32:44 +0100395 throw ParseException(
396 fmt::format("Error during creation of reshaped tensor '{}'. At most one component of shape can be "
397 " -1 and here, shape is {} {}",
398 outName,
399 ss.str(),
400 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100401 }
402
Matthew Sloyan589e3e82020-09-11 16:17:48 +0100403 auto targetNumElements = armnn::numeric_cast<unsigned int>(std::accumulate(targetDims.begin(), targetDims.end(),
telsoa01c577f2c2018-08-31 09:22:23 +0100404 -1, std::multiplies<int32_t>()));
405 auto stretchIndex = static_cast<size_t>(std::distance(targetDims.begin(), stretchDim));
406 outDims[stretchIndex] = inShape.GetNumElements() / targetNumElements;
407 }
408 TensorShape outShape = TensorShape{static_cast<unsigned int>(outDims.size()), outDims.data()};
409 return TensorInfo(outShape, DataType::Float32);
410}
411
412} //namespace
413
Kevin Mayef33cb12021-01-29 14:24:57 +0000414const std::map<std::string, OnnxParserImpl::OperationParsingFunction> OnnxParserImpl::m_ParserFunctions = {
415 { "BatchNormalization", &OnnxParserImpl::ParseBatchNormalization},
416 { "GlobalAveragePool", &OnnxParserImpl::ParseGlobalAveragePool},
417 { "AveragePool", &OnnxParserImpl::ParseAveragePool },
418 { "Clip", &OnnxParserImpl::ParseClip },
419 { "Constant", &OnnxParserImpl::ParseConstant },
420 { "MaxPool", &OnnxParserImpl::ParseMaxPool },
421 { "Reshape", &OnnxParserImpl::ParseReshape },
422 { "Sigmoid", &OnnxParserImpl::ParseSigmoid },
423 { "Tanh", &OnnxParserImpl::ParseTanh },
424 { "Relu", &OnnxParserImpl::ParseRelu },
425 { "LeakyRelu", &OnnxParserImpl::ParseLeakyRelu },
426 { "Conv", &OnnxParserImpl::ParseConv },
427 { "Add", &OnnxParserImpl::ParseAdd },
428 { "Flatten", &OnnxParserImpl::ParseFlatten},
telsoa01c577f2c2018-08-31 09:22:23 +0100429};
430
431template<typename TypePair, typename Location>
Kevin Mayef33cb12021-01-29 14:24:57 +0000432void OnnxParserImpl::ValidateInputs(const onnx::NodeProto& node,
telsoa01c577f2c2018-08-31 09:22:23 +0100433 TypePair validInputs,
434 const Location& location)
435{
436 for(auto input : node.input())
437 {
438 CheckValidDataType(validInputs.second,
439 m_TensorsInfo[input].m_dtype,
440 validInputs.first,
441 node.name(),
442 input,
443 location);
444 }
445}
446
447#define VALID_INPUTS(NODE, VALID_INPUTS) \
Kevin Mayef33cb12021-01-29 14:24:57 +0000448 OnnxParserImpl::ValidateInputs(NODE, \
telsoa01c577f2c2018-08-31 09:22:23 +0100449 VALID_INPUTS, \
450 CHECK_LOCATION())
451
Kevin Mayef33cb12021-01-29 14:24:57 +0000452std::vector<TensorInfo> OnnxParserImpl::ComputeOutputInfo(std::vector<std::string> outNames,
453 const IConnectableLayer* layer,
454 std::vector<TensorShape> inputShapes)
telsoa01c577f2c2018-08-31 09:22:23 +0100455{
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100456 ARMNN_ASSERT(! outNames.empty());
telsoa01c577f2c2018-08-31 09:22:23 +0100457 bool needCompute = std::any_of(outNames.begin(),
458 outNames.end(),
459 [this](std::string name)
460 {
461 return (m_TensorsInfo.count(name) == 0 || m_TensorsInfo[name].m_info == nullptr);
462 });
463 std::vector<TensorInfo> outInfo;
464 //if the output info(s) are not here, we need to compute them
465 std::vector<TensorShape> inferredShapes;
466 if(needCompute)
467 {
468 inferredShapes = layer->InferOutputShapes(inputShapes);
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100469 ARMNN_ASSERT(inferredShapes.size() == outNames.size());
telsoa01c577f2c2018-08-31 09:22:23 +0100470 }
471 for (uint i = 0; i < outNames.size(); ++i)
472 {
473 if(needCompute)
474 {
475 m_TensorsInfo[outNames[i]] = OnnxTensor();
476 m_TensorsInfo[outNames[i]].m_info = std::make_unique<TensorInfo>(
477 TensorInfo(inferredShapes[i], DataType::Float32));
478 }
479 outInfo.push_back(*m_TensorsInfo[outNames[i]].m_info);
480 }
481 return outInfo;
482}
483
Kevin Mayef33cb12021-01-29 14:24:57 +0000484OnnxParserImpl::OnnxParserImpl()
telsoa01c577f2c2018-08-31 09:22:23 +0100485 : m_Network(nullptr, nullptr)
486{
487}
488
Kevin Mayef33cb12021-01-29 14:24:57 +0000489void OnnxParserImpl::ResetParser()
telsoa01c577f2c2018-08-31 09:22:23 +0100490{
491 m_Network = armnn::INetworkPtr(nullptr, nullptr);
492 m_Graph = nullptr;
493}
494
Kevin Mayef33cb12021-01-29 14:24:57 +0000495void OnnxParserImpl::Cleanup()
telsoa01c577f2c2018-08-31 09:22:23 +0100496{
497 m_TensorConnections.clear();
498 m_TensorsInfo.clear();
499 m_OutputsMap.clear();
500 m_OutputsFusedAndUsed.clear();
501}
502
Kevin Mayef33cb12021-01-29 14:24:57 +0000503std::pair<ConstTensor, std::unique_ptr<float[]>> OnnxParserImpl::CreateConstTensor(const std::string name)
telsoa01c577f2c2018-08-31 09:22:23 +0100504{
505 const TensorInfo tensorInfo = *m_TensorsInfo[name].m_info;
506 onnx::TensorProto onnxTensor = *m_TensorsInfo[name].m_tensor;
507
508 auto srcData = onnxTensor.float_data().data();
Pablo Tello3dcc1c62019-04-24 14:20:21 +0100509 std::unique_ptr<float[]> tensorData(new float[tensorInfo.GetNumElements()]);
510 const size_t tensorSizeInBytes = tensorInfo.GetNumBytes();
511 // Copy the value list entries into the destination
512 if (!onnxTensor.has_raw_data())
telsoa01c577f2c2018-08-31 09:22:23 +0100513 {
Pablo Tello3dcc1c62019-04-24 14:20:21 +0100514 if(tensorInfo.GetNumElements() != static_cast<uint>(onnxTensor.float_data_size()))
515 {
James Ward58dec6b2020-09-11 17:32:44 +0100516 throw ParseException(
517 fmt::format("The number of data provided ({}) does not match the tensor '{}' number of "
518 "elements ({}) {}",
519 onnxTensor.float_data_size(),
520 name,
521 tensorInfo.GetNumElements(),
522 CHECK_LOCATION().AsString()));
Pablo Tello3dcc1c62019-04-24 14:20:21 +0100523 }
524 ::memcpy(tensorData.get(), srcData, tensorSizeInBytes);
telsoa01c577f2c2018-08-31 09:22:23 +0100525 }
Pablo Tello3dcc1c62019-04-24 14:20:21 +0100526 else
527 {
528 ::memcpy(tensorData.get(), onnxTensor.raw_data().c_str(), tensorSizeInBytes);
529 }
telsoa01c577f2c2018-08-31 09:22:23 +0100530
531 // Const tensors requires at least a list of values
532 if (tensorInfo.GetNumElements() == 0)
533 {
James Ward58dec6b2020-09-11 17:32:44 +0100534 throw ParseException(fmt::format("No tensor data found for Const tensor '{}' {}",
535 name,
536 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100537 }
538 return std::make_pair(ConstTensor(tensorInfo, tensorData.get()), std::move(tensorData));
539}
540
Kevin Mayef33cb12021-01-29 14:24:57 +0000541ModelPtr OnnxParserImpl::LoadModelFromTextFile(const char* graphFile)
telsoa01c577f2c2018-08-31 09:22:23 +0100542{
543 FILE* fd = fopen(graphFile, "r");
544
545 if (fd == nullptr)
546 {
James Ward58dec6b2020-09-11 17:32:44 +0100547 throw FileNotFoundException(fmt::format("Invalid (null) filename {}", CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100548 }
549
550 // Parse the file into a message
551 ModelPtr modelProto = std::make_unique<onnx::ModelProto>();
552 using google::protobuf::io::FileInputStream;
553 std::unique_ptr<FileInputStream> input = std::make_unique<FileInputStream>(fileno(fd));
554 bool success = google::protobuf::TextFormat::Parse(input.get(), modelProto.get());
555 fclose(fd);
556
557 if (!success)
558 {
559 std::stringstream error;
560 error << "Failed to parse graph file";
James Ward58dec6b2020-09-11 17:32:44 +0100561 throw ParseException(fmt::format("{} {}", error.str(), CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100562 }
563 return modelProto;
564}
565
Kevin Mayef33cb12021-01-29 14:24:57 +0000566INetworkPtr OnnxParserImpl::CreateNetworkFromTextFile(const char* graphFile)
telsoa01c577f2c2018-08-31 09:22:23 +0100567{
568 ResetParser();
569 ModelPtr modelProto = LoadModelFromTextFile(graphFile);
570 return CreateNetworkFromModel(*modelProto);
571}
572
573
Kevin Mayef33cb12021-01-29 14:24:57 +0000574ModelPtr OnnxParserImpl::LoadModelFromBinaryFile(const char* graphFile)
telsoa01c577f2c2018-08-31 09:22:23 +0100575{
576 FILE* fd = fopen(graphFile, "rb");
577
578 if (fd == nullptr)
579 {
James Ward58dec6b2020-09-11 17:32:44 +0100580 throw FileNotFoundException(fmt::format("Invalid (null) filename {}", CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100581 }
582
583 // Parse the file into a message
584 ModelPtr modelProto = std::make_unique<onnx::ModelProto>();
585
586 google::protobuf::io::FileInputStream inStream(fileno(fd));
587 google::protobuf::io::CodedInputStream codedStream(&inStream);
Nikhil Raje5181532020-10-09 14:52:25 +0100588 codedStream.SetTotalBytesLimit(INT_MAX);
telsoa01c577f2c2018-08-31 09:22:23 +0100589 bool success = modelProto.get()->ParseFromCodedStream(&codedStream);
590 fclose(fd);
591
592 if (!success)
593 {
594 std::stringstream error;
595 error << "Failed to parse graph file";
James Ward58dec6b2020-09-11 17:32:44 +0100596 throw ParseException(fmt::format("{} {}", error.str(), CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100597 }
598 return modelProto;
599
600}
601
Kevin Mayef33cb12021-01-29 14:24:57 +0000602INetworkPtr OnnxParserImpl::CreateNetworkFromBinaryFile(const char* graphFile)
telsoa01c577f2c2018-08-31 09:22:23 +0100603{
604 ResetParser();
605 ModelPtr modelProto = LoadModelFromBinaryFile(graphFile);
606 return CreateNetworkFromModel(*modelProto);
607}
608
Kevin Mayef33cb12021-01-29 14:24:57 +0000609ModelPtr OnnxParserImpl::LoadModelFromString(const std::string& protoText)
telsoa01c577f2c2018-08-31 09:22:23 +0100610{
611 if (protoText == "")
612 {
James Ward58dec6b2020-09-11 17:32:44 +0100613 throw InvalidArgumentException(fmt::format("Invalid (empty) string for model parameter {}",
614 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100615 }
616 // Parse the string into a message
617 ModelPtr modelProto = std::make_unique<onnx::ModelProto>();
618 bool success = google::protobuf::TextFormat::ParseFromString(protoText, modelProto.get());
619 if (!success)
620 {
621 std::stringstream error;
622 error << "Failed to parse graph file";
James Ward58dec6b2020-09-11 17:32:44 +0100623 throw ParseException(fmt::format("{} {}", error.str(), CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100624 }
625 return modelProto;
626}
627
Kevin Mayef33cb12021-01-29 14:24:57 +0000628INetworkPtr OnnxParserImpl::CreateNetworkFromString(const std::string& protoText)
telsoa01c577f2c2018-08-31 09:22:23 +0100629{
630 ResetParser();
631 ModelPtr modelProto = LoadModelFromString(protoText);
632 return CreateNetworkFromModel(*modelProto);
633}
634
Kevin Mayef33cb12021-01-29 14:24:57 +0000635INetworkPtr OnnxParserImpl::CreateNetworkFromModel(onnx::ModelProto& model)
telsoa01c577f2c2018-08-31 09:22:23 +0100636{
637 m_Network = INetwork::Create();
638 try
639 {
640 m_Graph = std::make_unique<onnx::GraphProto>(*model.mutable_graph());
641 LoadGraph();
642 }
643 catch (const ParseException& e)
644 {
645 Cleanup();
646 throw e;
647 }
648 Cleanup();
649 return std::move(m_Network);
650}
651
Kevin Mayef33cb12021-01-29 14:24:57 +0000652void OnnxParserImpl::LoadGraph()
telsoa01c577f2c2018-08-31 09:22:23 +0100653{
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100654 ARMNN_ASSERT(m_Graph.get() != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +0100655
656 //Fill m_TensorsInfo with the shapes and value of every tensor
657 SetupInfo(m_Graph->mutable_output());
658 SetupInfo(m_Graph->mutable_input());
659 SetupInfo(m_Graph->mutable_value_info());
660
661 for (auto tensor : m_Graph->initializer())
662 {
663 m_TensorsInfo[tensor.name()].m_tensor = std::make_unique<const onnx::TensorProto>(tensor);
Tee Jungfcf6fd52019-11-01 05:27:28 +0000664 m_TensorsInfo[tensor.name()].m_info = std::make_unique<TensorInfo>(ToTensorInfo(tensor));
665 m_TensorsInfo[tensor.name()].m_dtype =
666 static_cast<onnx::TensorProto::DataType>(tensor.data_type());
telsoa01c577f2c2018-08-31 09:22:23 +0100667 }
668
669 SetupInputLayers();
670 SetupOutputLayers();
671
672 //Detect FullyConnected layers with bias and update the FusedAndUsed map acccordingly
673 DetectFullyConnected();
674
675 //Parsing the graph
676 for(size_t nodeIndex = 0; nodeIndex < static_cast<size_t>(m_Graph->node_size()); nodeIndex++)
677 {
678 auto node = m_Graph->node(static_cast<int>(nodeIndex));
679 const std::string& operation = node.op_type();
680
681 // check which layers we handled already (add and matmul fused as FC)
Ryan OShea337c17f2020-02-21 12:33:17 +0000682 if (operation == "MatMul" )
telsoa01c577f2c2018-08-31 09:22:23 +0100683 {
684 if(m_OutputsFusedAndUsed[nodeIndex].inputForNodes != m_OutputsFusedAndUsed[nodeIndex].fusedWithNodes.size())
685 {
686 //Node which can not be fused as a FullyConnected layer (used in layers as a simple matmul output)
687 AddFullyConnected(node);
688 }
689 }
690 else if (!(m_OutputsFusedAndUsed[nodeIndex].fusedWithNodes.empty()) && operation == "Add")
691 {
692 int matmulIndex = static_cast<int> (m_OutputsFusedAndUsed[nodeIndex].fusedWithNodes[0]);
693 AddFullyConnected(m_Graph->node(matmulIndex), &node);
694 }
695 else if (m_OutputsFusedAndUsed[nodeIndex].fusedWithNodes.empty()) //node is not part of a fused layer
696 {
697 auto it = m_ParserFunctions.find(operation);
698 if (it != m_ParserFunctions.end())
699 {
700 auto func = it->second;
701 (this->*func)(node);
702 }
703 else
704 {
James Ward58dec6b2020-09-11 17:32:44 +0100705 throw ParseException(fmt::format("Unsupported operation {} for node '{}' {}",
706 operation,
707 node.name(),
708 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100709 }
710 }
711 }
712
713 //Making the connections between outputs and inputs of each layers
714 for (const auto& tensorCon : m_TensorConnections)
715 {
716 if (tensorCon.second.outputSlot != nullptr)
717 {
718 for (size_t inputSlotIdx = 0; inputSlotIdx < tensorCon.second.inputSlots.size(); ++inputSlotIdx)
719 {
720 tensorCon.second.outputSlot->Connect(*(tensorCon.second.inputSlots[inputSlotIdx]));
721 }
722 }
723 }
724}
725
Kevin Mayef33cb12021-01-29 14:24:57 +0000726void OnnxParserImpl::SetupInfo(const google::protobuf::RepeatedPtrField<onnx::ValueInfoProto >* list)
telsoa01c577f2c2018-08-31 09:22:23 +0100727{
728 for (auto tensor : *list)
729 {
730 m_TensorsInfo[tensor.name()] = OnnxTensor();
731 m_TensorsInfo[tensor.name()].m_info = std::make_unique<TensorInfo>(ToTensorInfo(tensor));
Matteo Martincighe355dc22018-12-10 13:45:27 +0000732 m_TensorsInfo[tensor.name()].m_dtype =
733 static_cast<onnx::TensorProto::DataType>(tensor.type().tensor_type().elem_type());
telsoa01c577f2c2018-08-31 09:22:23 +0100734 }
735}
736
Kevin Mayef33cb12021-01-29 14:24:57 +0000737void OnnxParserImpl::DetectFullyConnected()
telsoa01c577f2c2018-08-31 09:22:23 +0100738{
739 m_OutputsFusedAndUsed = std::vector<UsageSummary> (static_cast<size_t>(m_Graph->node_size()), UsageSummary());
740 auto matmulAndConstant = [&](const std::string& constInput,
741 const std::string& matmulInput,
742 int& nodeIndex)
743 {
744 auto matmulIt = m_OutputsMap.find(matmulInput);
745 if(matmulIt != m_OutputsMap.end() && matmulIt->second.first->op_type() == "MatMul"
746 && m_TensorsInfo[constInput].isConstant())
747 {
748 nodeIndex = matmulIt->second.second;
749 return true;
750 }
751 return false;
752 };
753
754 for(int nodeIndex = 0; nodeIndex < m_Graph->node_size(); nodeIndex++)
755 {
756 const onnx::NodeProto* node = &m_Graph->node(nodeIndex);
757 for (const std::string& output : node->output())
758 {
759 m_OutputsMap[output] = std::make_pair(node, nodeIndex);
760 }
761
762 for (const std::string& input : node->input()) //count how many time a node is used as input
763 {
764 auto matmulIt = m_OutputsMap.find(input);
765 if(matmulIt != m_OutputsMap.end()){
766 ++m_OutputsFusedAndUsed[static_cast<size_t>(matmulIt->second.second)].inputForNodes; //node used
767 }
768 }
769
770 if (node->op_type() == "Add")
771 {
772 int matmulIndex = 0;
773 if (matmulAndConstant(node->input(0), node->input(1), matmulIndex) ||
774 matmulAndConstant(node->input(1), node->input(0), matmulIndex))
775 {
776 //matmul and add were fused
777 m_OutputsFusedAndUsed[static_cast<size_t>(matmulIndex)].fusedWithNodes
778 .push_back(static_cast<size_t>(nodeIndex));
779
780 m_OutputsFusedAndUsed[static_cast<size_t>(nodeIndex)].fusedWithNodes
781 .push_back(static_cast<size_t>(matmulIndex));
782 }
783 }
784 }
785
786 for (auto output: m_Graph->output()) { //Add usages as output of the graph in count of usages
787 auto matmulIt = m_OutputsMap.find(output.name());
788 if(matmulIt != m_OutputsMap.end()){
789 ++m_OutputsFusedAndUsed[static_cast<size_t>(matmulIt->second.second)].inputForNodes;
790 }
791 }
792}
793
794template<typename Location>
Kevin Mayef33cb12021-01-29 14:24:57 +0000795void OnnxParserImpl::GetInputAndParam(const onnx::NodeProto& node,
796 std::string* inputName,
797 std::string* constName,
798 const Location& location)
telsoa01c577f2c2018-08-31 09:22:23 +0100799{
800 int cstIndex;
801 if (m_TensorsInfo[node.input(0)].isConstant())
802 {
803 cstIndex = 0;
804 }
805 else if (m_TensorsInfo[node.input(1)].isConstant())
806 {
807 cstIndex = 1;
808 }
809 else
810 {
James Ward58dec6b2020-09-11 17:32:44 +0100811 throw ParseException(fmt::format("One of the input tensors ('{}' or '{}') should be constant in node '{}' {}",
812 node.input(0),
813 node.input(1),
814 node.name(),
815 location.AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100816 }
817 if(constName)
818 {
819 *constName = node.input(cstIndex);
820 }
821 if(inputName)
822 {
823 *inputName = node.input(!cstIndex);
824 }
825}
826
827template<typename Location>
Kevin Mayef33cb12021-01-29 14:24:57 +0000828void OnnxParserImpl::To1DTensor(const std::string& name, const Location& location)
telsoa01c577f2c2018-08-31 09:22:23 +0100829{
830 TensorShape shape = m_TensorsInfo[name].m_info->GetShape();
831 std::vector<uint32_t> newShape;
832 for(uint i = 0; i < shape.GetNumDimensions() - 1; ++i)
833 {
834 if(shape[i] != 1)
835 {
James Ward58dec6b2020-09-11 17:32:44 +0100836 throw ParseException(
837 fmt::format("Only tensors with shape [1, ..., 1, X] can be converted to 1D and {} {}",
838 TensorInfoAsString(*m_TensorsInfo[name].m_info, name, m_TensorsInfo[name].m_dtype),
839 location.AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100840 }
841 }
842 newShape.push_back(shape[shape.GetNumDimensions() - 1]);
843
844 m_TensorsInfo[name].m_info->SetShape(TensorShape(static_cast<unsigned int>(newShape.size()), newShape.data()));
845}
846
Kevin Mayef33cb12021-01-29 14:24:57 +0000847void OnnxParserImpl::AddConvLayerWithDepthwiseConv(const onnx::NodeProto& node, const Convolution2dDescriptor& convDesc)
Ryan OSheaed27ee72020-04-22 16:37:29 +0100848{
849 ARMNN_ASSERT(node.op_type() == "Conv");
850
851 DepthwiseConvolution2dDescriptor desc;
852 desc.m_PadLeft = convDesc.m_PadLeft;
853 desc.m_PadRight = convDesc.m_PadRight;
854 desc.m_PadTop = convDesc.m_PadTop;
855 desc.m_PadBottom = convDesc.m_PadBottom;
856 desc.m_StrideX = convDesc.m_StrideX;
857 desc.m_StrideY = convDesc.m_StrideY;
858 desc.m_BiasEnabled = convDesc.m_BiasEnabled;
859
860 armnn::IConnectableLayer* layer;
861 auto weightTensor = CreateConstTensor(node.input(1));
862 TensorShape& weightShape = weightTensor.first.GetShape();
863 weightShape[1] = weightShape[0];
864 weightShape[0] = 1;
865 m_TensorsInfo[node.input(1)].m_info->SetShape(weightShape);
866
867 if (node.input_size() == 3)
868 {
869 if(!m_TensorsInfo[node.input(2)].isConstant())
870 {
James Ward58dec6b2020-09-11 17:32:44 +0100871 throw ParseException(fmt::format("Bias '{}' should be constant in Conv layer '{}' {}",
872 node.input(2),
873 node.name(),
874 CHECK_LOCATION().AsString()));
Ryan OSheaed27ee72020-04-22 16:37:29 +0100875 }
876 desc.m_BiasEnabled = true;
877 auto biasTensor = CreateConstTensor(node.input(2));
878 layer = m_Network->AddDepthwiseConvolution2dLayer(desc,
879 weightTensor.first,
880 Optional<ConstTensor>(biasTensor.first),
881 node.name().c_str());
882 }
883 else
884 {
885 layer = m_Network->AddDepthwiseConvolution2dLayer(desc,
886 weightTensor.first,
887 EmptyOptional(),
888 node.name().c_str());
889 }
890 ARMNN_ASSERT(layer != nullptr);
891
892 auto outputInfo = ComputeOutputInfo({ node.output(0) }, layer,
893 { m_TensorsInfo[node.input(0)].m_info->GetShape(),
894 m_TensorsInfo[node.input(1)].m_info->GetShape() });
895
896 layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
897
898 // register the input connection slots for the layer, connections are made after all layers have been created
899 // only the tensors for the inputs are relevant, exclude the const tensors
900 RegisterInputSlots(layer, {node.input(0)});
901
902 // register the output connection slots for the layer, connections are made after all layers have been created
903 RegisterOutputSlots(layer, {node.output(0)});
904}
905
Kevin Mayef33cb12021-01-29 14:24:57 +0000906void OnnxParserImpl::AddFullyConnected(const onnx::NodeProto& matmulNode, const onnx::NodeProto* addNode)
telsoa01c577f2c2018-08-31 09:22:23 +0100907{
908
909 // find matmul inputs
910 std::string weightName;
911 std::string inputName;
912 CHECK_VALID_SIZE(static_cast<size_t>(matmulNode.input_size()), 2);
913 CHECK_VALID_SIZE(static_cast<size_t>(matmulNode.output_size()), 1);
914 VALID_INPUTS(matmulNode, STR_LIST(onnx::TensorProto::FLOAT));
915
916 GetInputAndParam(matmulNode, &inputName, &weightName, CHECK_LOCATION());
917
918 FullyConnectedDescriptor desc;
919 desc.m_BiasEnabled = addNode != nullptr;
920
921 IConnectableLayer* layer = nullptr;
922 if(desc.m_BiasEnabled)
923 {
924 // find bias const
925 std::string biasName;
926 CHECK_VALID_SIZE(static_cast<size_t>(addNode->input_size()), 2);
927 CHECK_VALID_SIZE(static_cast<size_t>(addNode->output_size()), 1);
928 VALID_INPUTS(*addNode, STR_LIST(onnx::TensorProto::FLOAT));
929
930 GetInputAndParam(*addNode, nullptr, &biasName, CHECK_LOCATION());
931
932 //Output shape is [1, weights[1]] and 1d vec in ONNX can be [1,X] so we convert biases to "armnn" 1D
933 To1DTensor(biasName, CHECK_LOCATION());
934 TensorInfo weightInfo = *m_TensorsInfo[weightName].m_info;
935 TensorInfo biasInfo = *m_TensorsInfo[biasName].m_info;
936
937 if (weightInfo.GetShape()[1] != biasInfo.GetShape()[0])
938 {
James Ward58dec6b2020-09-11 17:32:44 +0100939 throw ParseException(
940 fmt::format("Shape of weights '{}' and bias of following Add node '{}' do not match : {}"
941 " and {} ( /!\\ bias should be a 1D tensor) {}",
942 weightName,
943 addNode->name(),
944 TensorInfoAsString(*m_TensorsInfo[weightName].m_info, weightName,
945 m_TensorsInfo[weightName].m_dtype),
946 TensorInfoAsString(*m_TensorsInfo[biasName].m_info, biasName,
947 m_TensorsInfo[biasName].m_dtype ),
948 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +0100949 }
950 layer = m_Network->AddFullyConnectedLayer(desc,
951 CreateConstTensor(weightName).first,
Matteo Martincighfc598e12019-05-14 10:36:13 +0100952 Optional<ConstTensor>(CreateConstTensor(biasName).first),
telsoa01c577f2c2018-08-31 09:22:23 +0100953 matmulNode.name().c_str());
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100954 ARMNN_ASSERT(layer != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +0100955
956 auto outputInfo = ComputeOutputInfo({addNode->output(0)}, layer,
957 {m_TensorsInfo[inputName].m_info->GetShape(),
958 m_TensorsInfo[weightName].m_info->GetShape()});
959
960 layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
961
962 RegisterInputSlots(layer, {inputName});
963 RegisterOutputSlots(layer, {addNode->output(0)});
964 }
965 else
966 {
Matteo Martincighfc598e12019-05-14 10:36:13 +0100967 layer = m_Network->AddFullyConnectedLayer(desc,
968 CreateConstTensor(weightName).first,
969 EmptyOptional(),
970 matmulNode.name().c_str());
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100971 ARMNN_ASSERT(layer != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +0100972
973 auto outputInfo = ComputeOutputInfo({matmulNode.output(0)}, layer,
974 {m_TensorsInfo[inputName].m_info->GetShape(),
975 m_TensorsInfo[weightName].m_info->GetShape()});
976 layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
977
978 RegisterInputSlots(layer, {inputName});
979 RegisterOutputSlots(layer, {matmulNode.output(0)});
980 }
981}
982
Kevin Mayef33cb12021-01-29 14:24:57 +0000983void OnnxParserImpl::AddPoolingLayer(const onnx::NodeProto& node, Pooling2dDescriptor& desc)
telsoa01c577f2c2018-08-31 09:22:23 +0100984{
985
986 CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 1);
987 CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
988
989 VALID_INPUTS(node, STR_LIST(onnx::TensorProto::FLOAT));
990
991 std::vector<uint32_t> kernel_shape = ReadMandatoryNodeUint32ListAttribute(node, "kernel_shape"); //size of pool win
992 std::vector<uint32_t> strides = ReadOptionalNodeUint32ListAttribute(node, "strides");
993 std::vector<uint32_t> pads = ReadOptionalNodeUint32ListAttribute(node, "pads");
994
995 desc.m_OutputShapeRounding = OutputShapeRounding::Floor;
996 desc.m_PoolWidth = kernel_shape[1];
997 desc.m_PoolHeight = kernel_shape[0];
998
999 if(strides.empty())
1000 {
1001 desc.m_StrideX = 1;
1002 desc.m_StrideY = 1;
1003 }
1004 else
1005 {
1006 desc.m_StrideX = strides[1];
1007 desc.m_StrideY = strides[0];
1008 }
1009
1010 //Check new padding version first
1011 if(pads.empty())
1012 {
1013 //Check deprecated version
1014 std::string paddingString = ReadOptionalNodeStringAttribute(node, "auto_pad");
1015 if(paddingString != "VALID" && paddingString != "" && paddingString != "NOTSET")
1016 {
1017 bool isUpper;
1018 if( paddingString == "SAME_LOWER")
1019 {
1020 isUpper = false;
1021 }
1022 else if (paddingString == "SAME_UPPER")
1023 {
1024 isUpper = true;
1025 }
1026 else
1027 {
James Ward58dec6b2020-09-11 17:32:44 +01001028 throw ParseException(fmt::format("Invalid auto_pad attribute for node {}. "
1029 "Only SAME_UPPER, SAME_LOWER or VALID supported and found {} {}",
1030 node.name(),
1031 paddingString,
1032 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001033 }
1034 auto inputInfo = *m_TensorsInfo[node.input(0)].m_info;
1035 uint32_t inputHeight = inputInfo.GetShape()[2];
1036 uint32_t inputWidth = inputInfo.GetShape()[3];
Sadik Armagan60bb9d82021-01-11 15:15:01 +00001037 CalcPadding(inputHeight,
1038 desc.m_PoolHeight,
1039 desc.m_StrideY,
1040 1u,
1041 &desc.m_PadTop,
1042 &desc.m_PadBottom,
1043 isUpper);
1044 CalcPadding(inputWidth,
1045 desc.m_PoolWidth,
1046 desc.m_StrideX,
1047 1u,
1048 &desc.m_PadLeft,
1049 &desc.m_PadRight,
1050 isUpper);
telsoa01c577f2c2018-08-31 09:22:23 +01001051 }
1052 }
1053 else
1054 {
1055 desc.m_PadTop = pads[0];
1056 desc.m_PadLeft = pads[1];
1057 desc.m_PadBottom = pads[2];
1058 desc.m_PadRight = pads[3];
1059 }
1060
1061 IConnectableLayer* layer = m_Network->AddPooling2dLayer(desc, node.name().c_str());
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001062 ARMNN_ASSERT(layer != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +01001063
1064 auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, {m_TensorsInfo[node.input(0)].m_info->GetShape()});
1065 layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
1066
1067 // register the input connection slots for the layer, connections are made after all layers have been created
1068 // only the tensors for the inputs are relevant, exclude the const tensors
1069 RegisterInputSlots(layer, {node.input(0)});
1070
1071 // register the output connection slots for the layer, connections are made after all layers have been created
1072 RegisterOutputSlots(layer, {node.output(0)});
1073}
1074
Kevin Mayef33cb12021-01-29 14:24:57 +00001075std::pair<std::string, std::string> OnnxParserImpl::AddPrepareBroadcast(const std::string& input0,
1076 const std::string& input1)
Ryan OSheaed27ee72020-04-22 16:37:29 +01001077{
1078 std::pair<std::string, std::string> inputs = std::make_pair(input0, input1);
1079
1080 TensorShape input0Shape = m_TensorsInfo[input0].m_info->GetShape();
1081 TensorShape input1Shape = m_TensorsInfo[input1].m_info->GetShape();
1082
1083 if(input1Shape.GetNumDimensions() < input0Shape.GetNumDimensions())
1084 {
James Ward58dec6b2020-09-11 17:32:44 +01001085 auto outputName = fmt::format("reshape_output_{}", input1);
Ryan OSheaed27ee72020-04-22 16:37:29 +01001086 PrependForBroadcast(outputName, input1, input0);
1087 inputs.second = outputName;
1088 }
1089 else if(input0Shape.GetNumDimensions() < input1Shape.GetNumDimensions())
1090 {
James Ward58dec6b2020-09-11 17:32:44 +01001091 auto outputName = fmt::format("reshape_output_{}", input0);
Ryan OSheaed27ee72020-04-22 16:37:29 +01001092 PrependForBroadcast(outputName, input0, input1);
1093 inputs.first = outputName;
1094 }
1095 return inputs;
1096}
1097
Kevin Mayef33cb12021-01-29 14:24:57 +00001098void OnnxParserImpl::CreateConstantLayer(const std::string& tensorName, const std::string& layerName)
Ryan OSheaed27ee72020-04-22 16:37:29 +01001099{
1100 auto armnnTensor = CreateConstTensor(tensorName);
1101
1102 IConnectableLayer* layer = m_Network->AddConstantLayer(armnnTensor.first, layerName.c_str());
1103 layer->GetOutputSlot(0).SetTensorInfo(armnnTensor.first.GetInfo());
1104 RegisterOutputSlots(layer, {tensorName});
1105}
1106
Kevin Mayef33cb12021-01-29 14:24:57 +00001107void OnnxParserImpl::CreateReshapeLayer(const std::string& inputName,
1108 const std::string& outputName,
1109 const std::string& layerName)
telsoa01c577f2c2018-08-31 09:22:23 +01001110{
1111 const TensorInfo outputTensorInfo = *m_TensorsInfo[outputName].m_info;
1112 ReshapeDescriptor reshapeDesc;
1113 reshapeDesc.m_TargetShape = outputTensorInfo.GetShape();
1114
1115 IConnectableLayer* layer = m_Network->AddReshapeLayer(reshapeDesc, layerName.c_str());
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001116 ARMNN_ASSERT(layer != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +01001117 layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
1118
1119 // register the input connection slots for the layer, connections are made after all layers have been created
1120 // only the tensors for the inputs are relevant, exclude the const tensors
1121 RegisterInputSlots(layer, {inputName});
1122
1123 // register the output connection slots for the layer, connections are made after all layers have been created
1124 RegisterOutputSlots(layer, {outputName});
1125}
1126
Kevin Mayef33cb12021-01-29 14:24:57 +00001127void OnnxParserImpl::ParseActivation(const onnx::NodeProto& node, const armnn::ActivationFunction func)
telsoa01c577f2c2018-08-31 09:22:23 +01001128{
Finn Williams7ee5d2c2020-03-27 11:11:50 +00001129 CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 1, 3);
telsoa01c577f2c2018-08-31 09:22:23 +01001130 CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
1131
1132 VALID_INPUTS(node, STR_LIST(onnx::TensorProto::FLOAT));
1133
1134 ActivationDescriptor desc;
Tee Jung7ff9a602019-11-01 07:04:42 +00001135 desc.m_Function = func;
telsoa01c577f2c2018-08-31 09:22:23 +01001136
Finn Williams7ee5d2c2020-03-27 11:11:50 +00001137 if (func == ActivationFunction::BoundedReLu)
1138 {
1139 desc.m_A = node.input(2).empty() ? std::numeric_limits<float>::max() : std::stof(node.input(2));
1140 desc.m_B = node.input(1).empty() ? std::numeric_limits<float>::lowest() : std::stof(node.input(1));
1141 }
1142
telsoa01c577f2c2018-08-31 09:22:23 +01001143 IConnectableLayer* const layer = m_Network->AddActivationLayer(desc, node.name().c_str());
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001144 ARMNN_ASSERT(layer != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +01001145
1146 auto outputInfo = ComputeOutputInfo({ node.output(0)}, layer, {m_TensorsInfo[node.input(0)].m_info->GetShape()});
1147 layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
1148
1149 // register the input connection slots for the layer, connections are made after all layers have been created
1150 // only the tensors for the inputs are relevant, exclude the const tensors
1151 RegisterInputSlots(layer, {node.input(0)});
1152
1153 // register the output connection slots for the layer, connections are made after all layers have been created
1154 RegisterOutputSlots(layer, {node.output(0)});
1155}
1156
Kevin Mayef33cb12021-01-29 14:24:57 +00001157void OnnxParserImpl::ParseClip(const onnx::NodeProto& node)
Finn Williams7ee5d2c2020-03-27 11:11:50 +00001158{
1159 ParseActivation(node, ActivationFunction::BoundedReLu);
1160}
1161
Kevin Mayef33cb12021-01-29 14:24:57 +00001162void OnnxParserImpl::ParseSigmoid(const onnx::NodeProto& node)
Tee Jung7ff9a602019-11-01 07:04:42 +00001163{
1164 ParseActivation(node, ActivationFunction::Sigmoid);
1165}
1166
Kevin Mayef33cb12021-01-29 14:24:57 +00001167void OnnxParserImpl::ParseTanh(const onnx::NodeProto& node)
Tee Jung7ff9a602019-11-01 07:04:42 +00001168{
1169 ParseActivation(node, ActivationFunction::TanH);
1170}
1171
Kevin Mayef33cb12021-01-29 14:24:57 +00001172void OnnxParserImpl::ParseRelu(const onnx::NodeProto& node)
Tee Jung7ff9a602019-11-01 07:04:42 +00001173{
1174 ParseActivation(node, ActivationFunction::ReLu);
1175}
1176
Kevin Mayef33cb12021-01-29 14:24:57 +00001177void OnnxParserImpl::ParseLeakyRelu(const onnx::NodeProto& node)
Tee Jung7ff9a602019-11-01 07:04:42 +00001178{
1179 ParseActivation(node, ActivationFunction::LeakyReLu);
1180}
telsoa01c577f2c2018-08-31 09:22:23 +01001181
Kevin Mayef33cb12021-01-29 14:24:57 +00001182void OnnxParserImpl::ParseAdd(const onnx::NodeProto& node)
telsoa01c577f2c2018-08-31 09:22:23 +01001183{
Ryan OSheaed27ee72020-04-22 16:37:29 +01001184 CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 2);
1185 CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
telsoa01c577f2c2018-08-31 09:22:23 +01001186
Ryan OSheaed27ee72020-04-22 16:37:29 +01001187 VALID_INPUTS(node, STR_LIST(onnx::TensorProto::FLOAT));
telsoa01c577f2c2018-08-31 09:22:23 +01001188
Ryan OSheaed27ee72020-04-22 16:37:29 +01001189 // TODO: unify broadcast validation code across layers
1190 // tracked by: IVGCVSW-1576
telsoa01c577f2c2018-08-31 09:22:23 +01001191
Ryan OSheaed27ee72020-04-22 16:37:29 +01001192 // Checking broadcast compatibility : only scalar or 1D tensors
1193 auto inputs = AddPrepareBroadcast(node.input(0), node.input(1));
1194 auto input0 = *m_TensorsInfo[inputs.first].m_info;
1195 auto input1 = *m_TensorsInfo[inputs.second].m_info;
1196 ARMNN_ASSERT(input0.GetNumDimensions() == input1.GetNumDimensions());
1197
1198 unsigned int numDims = input0.GetNumDimensions();
1199 for (unsigned int i = 0; i < numDims; i++)
telsoa01c577f2c2018-08-31 09:22:23 +01001200 {
Ryan OSheaed27ee72020-04-22 16:37:29 +01001201 unsigned int dim0 = input0.GetShape()[i];
1202 unsigned int dim1 = input1.GetShape()[i];
1203 if (dim0 != dim1 && dim0 != 1 && dim1 != 1)
telsoa01c577f2c2018-08-31 09:22:23 +01001204 {
James Ward58dec6b2020-09-11 17:32:44 +01001205 throw ParseException(
1206 fmt::format("Broadcast is only supported for scalar or 1D tensors in Add node '{}'. "
1207 "Input dimensions should either match or one should be of size 1 and here, "
1208 "{} and {} {}",
1209 node.name(),
1210 TensorInfoAsString(*m_TensorsInfo[inputs.first].m_info, inputs.first,
1211 m_TensorsInfo[inputs.first].m_dtype),
1212 TensorInfoAsString(*m_TensorsInfo[inputs.second].m_info, inputs.second,
1213 m_TensorsInfo[inputs.second].m_dtype),
1214 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001215 }
telsoa01c577f2c2018-08-31 09:22:23 +01001216 }
Ryan OSheaed27ee72020-04-22 16:37:29 +01001217
1218
1219 IConnectableLayer* layer = m_Network->AddAdditionLayer(node.name().c_str());
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001220 ARMNN_ASSERT(layer != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +01001221
1222 auto outputInfo = ComputeOutputInfo({ node.output(0) }, layer,
Ryan OSheaed27ee72020-04-22 16:37:29 +01001223 { m_TensorsInfo[inputs.first].m_info->GetShape(),
1224 m_TensorsInfo[inputs.second].m_info->GetShape() });
telsoa01c577f2c2018-08-31 09:22:23 +01001225 layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
1226
Ryan OSheaed27ee72020-04-22 16:37:29 +01001227 // register the input connection -> for constant inputs, we need to make a newDim constant layer
1228 if(m_TensorsInfo[inputs.first].isConstant()) {
James Ward58dec6b2020-09-11 17:32:44 +01001229 CreateConstantLayer(inputs.first, fmt::format("Add:constant_of_{}", node.input(0)));
Ryan OSheaed27ee72020-04-22 16:37:29 +01001230 }
1231 if(m_TensorsInfo[inputs.second].isConstant()) {
James Ward58dec6b2020-09-11 17:32:44 +01001232 CreateConstantLayer(inputs.second, fmt::format("Add:constant_of_{}", node.input(1)));
Ryan OSheaed27ee72020-04-22 16:37:29 +01001233 }
1234 RegisterInputSlots(layer, {inputs.first, inputs.second});
telsoa01c577f2c2018-08-31 09:22:23 +01001235
Ryan OSheaed27ee72020-04-22 16:37:29 +01001236 // register the output connection
telsoa01c577f2c2018-08-31 09:22:23 +01001237 RegisterOutputSlots(layer, {node.output(0)});
1238}
1239
Kevin Mayef33cb12021-01-29 14:24:57 +00001240void OnnxParserImpl::ParseAveragePool(const onnx::NodeProto& node)
Ryan OSheaed27ee72020-04-22 16:37:29 +01001241{
1242 Pooling2dDescriptor desc;
1243 desc.m_PoolType = PoolingAlgorithm::Average;
1244
1245 uint32_t count_include_pad = 0;
1246 count_include_pad = ReadOptionalNodeUint32Attribute(node, "count_include_pad");
1247 if(count_include_pad) {
1248 desc.m_PaddingMethod = PaddingMethod::IgnoreValue;
1249 }
1250 AddPoolingLayer(node, desc);
1251}
1252
Kevin Mayef33cb12021-01-29 14:24:57 +00001253void OnnxParserImpl::ParseBatchNormalization(const onnx::NodeProto& node)
Ryan OSheaed27ee72020-04-22 16:37:29 +01001254{
1255 //IGNORE momentum parameter and spatial parameters
1256
1257 CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 5);
1258 CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
1259
1260 VALID_INPUTS(node, STR_LIST(onnx::TensorProto::FLOAT));
1261 for(int ind = 1; ind < node.input_size(); ++ind)
1262 {
1263 auto tensor = node.input(ind);
1264 if(! m_TensorsInfo[tensor].isConstant())
1265 {
James Ward58dec6b2020-09-11 17:32:44 +01001266 throw ParseException(
1267 fmt::format("Input tensor '{}' should be constant in BatchNormalization node '{}' {}",
1268 tensor,
1269 node.name(),
1270 CHECK_LOCATION().AsString()));
Ryan OSheaed27ee72020-04-22 16:37:29 +01001271 }
1272 }
1273
1274 float epsilon = ReadOptionalNodeFloatAttribute(node, "epsilon", 1e-5f);
1275 BatchNormalizationDescriptor desc;
1276 desc.m_Eps = epsilon;
1277
1278 auto scaleTensor = CreateConstTensor(node.input(1));
1279 auto biasTensor = CreateConstTensor(node.input(2));
1280 auto meanTensor = CreateConstTensor(node.input(3));
1281 auto varTensor = CreateConstTensor(node.input(4));
1282
1283 IConnectableLayer* layer = m_Network->AddBatchNormalizationLayer(desc,
1284 meanTensor.first,
1285 varTensor.first,
1286 biasTensor.first,
1287 scaleTensor.first,
1288 node.name().c_str());
1289 ARMNN_ASSERT(layer != nullptr);
1290
1291 auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, {m_TensorsInfo[node.input(0)].m_info->GetShape()});
1292 layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
1293
1294 RegisterInputSlots(layer, {node.input(0)}); //don't register constant inputs
1295
1296 // register the output connection
1297 RegisterOutputSlots(layer, {node.output(0)});
1298}
1299
Kevin Mayef33cb12021-01-29 14:24:57 +00001300void OnnxParserImpl::ParseConstant(const onnx::NodeProto& node)
Ryan OSheaed27ee72020-04-22 16:37:29 +01001301{
1302 CHECK_VALID_SIZE(static_cast<size_t>(node.attribute_size()), 1);
1303 if (!node.attribute(0).has_t())
1304 {
James Ward58dec6b2020-09-11 17:32:44 +01001305 throw ParseException(fmt::format("Value not found for Constant node '{}' {}",
1306 node.name(),
1307 CHECK_LOCATION().AsString()));
Ryan OSheaed27ee72020-04-22 16:37:29 +01001308 }
1309 const onnx::TensorProto& onnxTensor = node.attribute(0).t();
1310
1311 //ONNX can have Float16 and double constant nodes but ArmNN only supports float32
1312 CHECK_VALID_DATATYPE(node.name(), onnxTensor.name(),
1313 static_cast<onnx::TensorProto::DataType>(onnxTensor.data_type()), onnx::TensorProto::FLOAT);
1314
1315 //Register this as a m_ConstParam so we know we can use it as a constant param in future layers.
1316 m_TensorsInfo[node.output(0)].m_tensor = std::make_unique<const onnx::TensorProto>(onnxTensor);
1317 m_TensorsInfo[node.output(0)].m_info = std::make_unique<TensorInfo>(ToTensorInfo(onnxTensor));
1318 m_TensorsInfo[node.output(0)].m_dtype = static_cast<onnx::TensorProto::DataType>(onnxTensor.data_type());
1319
1320 CreateConstantLayer(node.output(0), node.name());
1321}
1322
Kevin Mayef33cb12021-01-29 14:24:57 +00001323void OnnxParserImpl::ParseConv(const onnx::NodeProto& node)
telsoa01c577f2c2018-08-31 09:22:23 +01001324{
1325 CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 2, 3); //input, weight, (bias)
1326 CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
1327
1328 VALID_INPUTS(node, STR_LIST(onnx::TensorProto::FLOAT));
1329
1330 if(m_TensorsInfo[node.input(0)].m_info->GetNumDimensions() != 4)
1331 {
James Ward58dec6b2020-09-11 17:32:44 +01001332 throw ParseException(
1333 fmt::format("ArmNN only supports 2D convolution and Conv layer '{}' input {} {}",
1334 node.name(),
1335 TensorInfoAsString(*m_TensorsInfo[node.input(0)].m_info, node.input(0),
1336 m_TensorsInfo[node.input(0)].m_dtype),
1337 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001338 }
1339
1340 if(!m_TensorsInfo[node.input(1)].isConstant())
1341 {
James Ward58dec6b2020-09-11 17:32:44 +01001342 throw ParseException(
1343 fmt::format("Weights '{}' should be constant in Conv layer '{}' {}",
1344 node.input(1),
1345 node.name(),
1346 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001347 }
1348
1349 auto inputInfo = *m_TensorsInfo[node.input(0)].m_info;
1350
telsoa01c577f2c2018-08-31 09:22:23 +01001351 Convolution2dDescriptor desc;
1352 desc.m_BiasEnabled = false;
1353
1354 std::vector<uint32_t> strides = ReadOptionalNodeUint32ListAttribute(node, "strides");
1355 if(strides.empty())
1356 {
1357 desc.m_StrideX = 1;
1358 desc.m_StrideY = 1;
1359 }
1360 else
1361 {
1362 desc.m_StrideX = strides[1];
1363 desc.m_StrideY = strides[0];
1364 }
1365
Sadik Armagan60bb9d82021-01-11 15:15:01 +00001366 std::vector<uint32_t> dilations = ReadOptionalNodeUint32ListAttribute(node, "dilations");
1367 if(!dilations.empty())
1368 {
1369 desc.m_DilationX = dilations[1];
1370 desc.m_DilationY = dilations[0];
1371 }
1372
telsoa01c577f2c2018-08-31 09:22:23 +01001373 std::vector<uint32_t> pads = ReadOptionalNodeUint32ListAttribute(node, "pads");
1374 //Check new padding version first
1375 if(pads.empty())
1376 {
1377 //Check deprecated version
1378 std::string paddingString = ReadOptionalNodeStringAttribute(node, "auto_pad");
1379 if(paddingString != "VALID" && paddingString != "" && paddingString != "NOTSET")
1380 {
1381 bool isUpper;
1382 if( paddingString == "SAME_LOWER")
1383 {
1384 isUpper = false;
1385 }
1386 else if (paddingString == "SAME_UPPER")
1387 {
1388 isUpper = true;
1389 }
1390 else
1391 {
James Ward58dec6b2020-09-11 17:32:44 +01001392 throw ParseException(
1393 fmt::format("Invalid auto_pad attribute for node {}. Only SAME_UPPER, SAME_LOWER or VALID "
1394 "supported and found {} {}",
1395 node.name(),
1396 paddingString,
1397 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001398 }
1399 uint32_t inputHeight = inputInfo.GetShape()[2];
1400 uint32_t inputWidth = inputInfo.GetShape()[3];
1401
1402 uint32_t weightHeight;
1403 uint32_t weightWidth;
1404 std::vector<uint32_t> kernel_shape = ReadOptionalNodeUint32ListAttribute(node, "kernel_shape");
1405 if (kernel_shape.empty())
1406 {
1407 const TensorInfo weightTensorInfo = *m_TensorsInfo[node.input(1)].m_info;
1408 weightHeight = weightTensorInfo.GetShape()[2];
1409 weightWidth = weightTensorInfo.GetShape()[3];
1410 }
1411 else
1412 {
1413 weightHeight = kernel_shape[0];
1414 weightWidth = kernel_shape[1];
1415 }
Sadik Armagan60bb9d82021-01-11 15:15:01 +00001416 CalcPadding(inputHeight,
1417 weightHeight,
1418 desc.m_StrideY,
1419 desc.m_DilationY,
1420 &desc.m_PadTop,
1421 &desc.m_PadBottom,
1422 isUpper);
1423 CalcPadding(inputWidth,
1424 weightWidth,
1425 desc.m_StrideX,
1426 desc.m_DilationX,
1427 &desc.m_PadLeft,
1428 &desc.m_PadRight,
1429 isUpper);
telsoa01c577f2c2018-08-31 09:22:23 +01001430 }
1431 }
1432 else
1433 {
1434 desc.m_PadTop = pads[0];
1435 desc.m_PadLeft = pads[1];
1436 desc.m_PadBottom = pads[2];
1437 desc.m_PadRight = pads[3];
1438 }
1439
1440 uint32_t group = ReadOptionalNodeUint32Attribute(node, "group", 1);
1441 if(group > 1)
1442 {
1443 if (group > inputInfo.GetShape()[1])
1444 {
1445 throw ParseException(
James Ward58dec6b2020-09-11 17:32:44 +01001446 fmt::format("Error parsing Convolution node: {}. "
1447 "The 'group'={} parameter cannot be larger than the "
1448 "channel of the input shape={} (in NCHW format). {}",
1449 node.name(),
1450 group,
1451 inputInfo.GetShape()[1],
1452 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001453 }
1454 else if (group == inputInfo.GetShape()[1])
1455 {
1456 // we use a depthwise convolution here, because the number of groups equals to the
1457 // input channels
1458 AddConvLayerWithDepthwiseConv(node, desc);
1459 return;
1460 }
1461 else
1462 {
1463 // TODO: split the input by channels into channels/groups separate convolutions
Jim Flynne242f2d2019-05-22 14:24:13 +01001464 // and concatenate the results afterwards
James Ward58dec6b2020-09-11 17:32:44 +01001465 throw ParseException(fmt::format("Error parsing Convolution node: {}. "
1466 "The 'group'={} parameter should be 1 or be equal to the "
1467 "channel of the input shape={} (in NCHW format). {}",
1468 node.name(),
1469 group,
1470 inputInfo.GetShape()[1],
1471 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001472 }
1473 }
1474
1475 armnn::IConnectableLayer* layer;
1476 auto weightTensor = CreateConstTensor(node.input(1));
1477
1478 if (node.input_size() == 3)
1479 {
1480 if(!m_TensorsInfo[node.input(2)].isConstant())
1481 {
James Ward58dec6b2020-09-11 17:32:44 +01001482 throw ParseException(fmt::format("Bias '{}' should be constant in Conv layer '{}' {}",
1483 node.input(2),
1484 node.name(),
1485 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001486 }
1487 desc.m_BiasEnabled = true;
1488 auto biasTensor = CreateConstTensor(node.input(2));
1489 layer = m_Network->AddConvolution2dLayer(desc,
1490 weightTensor.first,
Matteo Martincighfc598e12019-05-14 10:36:13 +01001491 Optional<ConstTensor>(biasTensor.first),
telsoa01c577f2c2018-08-31 09:22:23 +01001492 node.name().c_str());
1493 }
1494 else
1495 {
1496 layer = m_Network->AddConvolution2dLayer(desc,
1497 weightTensor.first,
Matteo Martincighfc598e12019-05-14 10:36:13 +01001498 EmptyOptional(),
telsoa01c577f2c2018-08-31 09:22:23 +01001499 node.name().c_str());
1500 }
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001501 ARMNN_ASSERT(layer != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +01001502
1503 auto outputInfo = ComputeOutputInfo({ node.output(0) }, layer,
1504 { m_TensorsInfo[node.input(0)].m_info->GetShape(),
1505 m_TensorsInfo[node.input(1)].m_info->GetShape() });
1506 layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
1507
1508 // register the input connection slots for the layer, connections are made after all layers have been created
1509 // only the tensors for the inputs are relevant, exclude the const tensors
1510 RegisterInputSlots(layer, {node.input(0)});
1511
1512 // register the output connection slots for the layer, connections are made after all layers have been created
1513 RegisterOutputSlots(layer, {node.output(0)});
1514}
1515
Kevin Mayef33cb12021-01-29 14:24:57 +00001516void OnnxParserImpl::ParseFlatten(const onnx::NodeProto& node)
Ryan OSheaed27ee72020-04-22 16:37:29 +01001517{
1518 CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 1);
1519 CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
1520
1521 CHECK_VALID_DATATYPE(node.name(), node.input(0),
1522 m_TensorsInfo[node.input(0)].m_dtype,
1523 onnx::TensorProto::FLOAT);
1524
1525 int64_t axis = ReadOptionalNodeInt64Attribute(node, "axis", 1);
1526 TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape();
1527
1528 /// Negative axis conversion
1529 if (axis < 0)
1530 {
1531 axis += inputShape.GetNumDimensions();
1532 }
1533
1534 /// Check Axis is within dimensions
1535 if (axis < 0 || axis >= inputShape.GetNumDimensions())
1536 {
James Ward58dec6b2020-09-11 17:32:44 +01001537 throw ParseException(fmt::format("Axis '{}' invalid. Tensor has '{}' dimensions in FlattenLayer '{}'",
1538 axis, inputShape.GetNumDimensions(), node.name()));
Ryan OSheaed27ee72020-04-22 16:37:29 +01001539 }
1540
1541 /// If axis chosen is 0 dimension1 will always be 1 in output , default dimension2 to 1 because 0 is invalid
1542 uint dimension1{1};
1543 uint dimension2{1};
1544 uint i{0};
1545
1546 /// dimension1 = (d_0 * d_1 ... d_(axis-1))
1547 for (i = 0; i < axis; i++){
1548 dimension1 *= inputShape[i];
1549 }
1550
1551 /// dimension2 = (d_axis * d_(axis+1) ... d_n)
1552 for (i = static_cast<uint>(axis); i < inputShape.GetNumDimensions(); i++){
1553 dimension2 *= inputShape[i];
1554 }
1555
1556 TensorShape outputShape{dimension1, dimension2};
1557
1558 auto outInfo = ComputeReshapeInfo(outputShape, inputShape, node.output(0));
1559 m_TensorsInfo[node.output(0)].m_info = std::make_unique<TensorInfo>(outInfo);
1560 CreateReshapeLayer(node.input(0), node.output(0), node.name());
1561}
1562
Kevin Mayef33cb12021-01-29 14:24:57 +00001563void OnnxParserImpl::ParseGlobalAveragePool(const onnx::NodeProto& node)
Ryan OSheaed27ee72020-04-22 16:37:29 +01001564{
1565 Pooling2dDescriptor desc = Pooling2dDescriptor();
1566 desc.m_PoolType = PoolingAlgorithm::Average;
1567
1568 //kernel size is the same as input
1569 TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape();
1570 desc.m_PoolWidth = inputShape[3];
1571 desc.m_PoolHeight = inputShape[2];
1572
1573 IConnectableLayer* layer = m_Network->AddPooling2dLayer(desc, node.name().c_str());
1574 ARMNN_ASSERT(layer != nullptr);
1575
1576 auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, {inputShape});
1577 layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
1578
1579 // register the input connection slots for the layer, connections are made after all layers have been created
1580 // only the tensors for the inputs are relevant, exclude the const tensors
1581 RegisterInputSlots(layer, {node.input(0)});
1582
1583 // register the output connection slots for the layer, connections are made after all layers have been created
1584 RegisterOutputSlots(layer, {node.output(0)});
1585}
1586
Kevin Mayef33cb12021-01-29 14:24:57 +00001587void OnnxParserImpl::ParseMaxPool(const onnx::NodeProto& node)
Ryan OSheaed27ee72020-04-22 16:37:29 +01001588{
1589 Pooling2dDescriptor desc;
1590 desc.m_PoolType = PoolingAlgorithm::Max;
1591 desc.m_PaddingMethod = PaddingMethod::Exclude;
1592 AddPoolingLayer(node, desc);
1593}
1594
Kevin Mayef33cb12021-01-29 14:24:57 +00001595void OnnxParserImpl::ParseReshape(const onnx::NodeProto& node)
Ryan OSheaed27ee72020-04-22 16:37:29 +01001596{
1597 CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 2);
1598 CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
1599
1600 CHECK_VALID_DATATYPE(node.name(), node.input(0),
1601 m_TensorsInfo[node.input(0)].m_dtype,
1602 onnx::TensorProto::FLOAT); //input
1603 CHECK_VALID_DATATYPE(node.name(), node.input(1),
1604 m_TensorsInfo[node.input(1)].m_dtype,
1605 onnx::TensorProto::INT64); //shape
1606
1607 if(!m_TensorsInfo[node.input(1)].isConstant())
1608 {
James Ward58dec6b2020-09-11 17:32:44 +01001609 throw ParseException(fmt::format("Shape '{}' should be constant in Reshape layer '{}' {}",
1610 node.input(1),
1611 node.name(),
1612 CHECK_LOCATION().AsString()));
Ryan OSheaed27ee72020-04-22 16:37:29 +01001613 }
1614
1615 if(m_TensorsInfo[node.input(0)].isConstant())
1616 {
1617 //make a new cst tensor -> move the data to the output tensor (the shape is already good in the output tensor)
1618 if(m_TensorsInfo.count(node.output(0)) == 0)
1619 {
1620 m_TensorsInfo[node.output(0)] = OnnxTensor();
1621 }
1622 m_TensorsInfo[node.output(0)].m_tensor =
1623 std::make_unique<onnx::TensorProto>(*m_TensorsInfo[node.input(0)].m_tensor);
1624 }
1625 else
1626 {
1627 TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape();
1628
1629 if(m_TensorsInfo.count(node.output(0)) == 0 || m_TensorsInfo[node.output(0)].m_info == nullptr)
1630 {
1631 uint64_t dims = static_cast<uint64_t>(m_TensorsInfo[node.input(1)].m_tensor->int64_data_size());
1632 TensorShape targetShape{static_cast<unsigned int>(dims), 1};
1633
1634 for(uint i = 0; i < dims; i++)
1635 {
1636 int val = CHECKED_INT32(m_TensorsInfo[node.input(1)].m_tensor->int64_data(static_cast<int>(i)));
1637 targetShape[i]= static_cast<unsigned int>(val);
1638 }
1639
1640 auto outInfo = ComputeReshapeInfo(targetShape, inputShape, node.output(0));
1641 m_TensorsInfo[node.output(0)].m_info = std::make_unique<TensorInfo>(outInfo);
1642 }
1643
1644 CreateReshapeLayer(node.input(0), node.output(0), node.name());
1645 }
1646}
1647
Kevin Mayef33cb12021-01-29 14:24:57 +00001648void OnnxParserImpl::PrependForBroadcast(const std::string& outputName,
1649 const std::string& input0,
1650 const std::string& input1)
telsoa01c577f2c2018-08-31 09:22:23 +01001651{
1652 //input0 should be reshaped to have same number of dim as input1
1653 TensorInfo outputTensorInfo = TensorInfo(*m_TensorsInfo[input0].m_info);
1654
1655 TensorShape input0Shape = m_TensorsInfo[input0].m_info->GetShape();
1656 TensorShape input1Shape = m_TensorsInfo[input1].m_info->GetShape();
1657
1658 uint32_t diff = input1Shape.GetNumDimensions() - input0Shape.GetNumDimensions();
1659 std::vector<uint32_t> newShape;
1660 while(diff > 0)
1661 {
1662 newShape.push_back(1);
1663 diff--;
1664 }
1665 for (uint dim = 0; dim < input0Shape.GetNumDimensions(); ++dim)
1666 {
1667 newShape.push_back(input0Shape[dim]);
1668 }
1669 outputTensorInfo.SetShape(TensorShape(static_cast<unsigned int>(newShape.size()), newShape.data()));
1670
1671 //add the new tensor to m_TensorsInfo
1672 m_TensorsInfo[outputName] = OnnxTensor();
1673 m_TensorsInfo[outputName].m_info = std::make_unique<TensorInfo>(outputTensorInfo);
1674
1675 //add reshape layer if the parent was not constant...
1676 if( ! m_TensorsInfo[input0].isConstant())
1677 {
James Ward58dec6b2020-09-11 17:32:44 +01001678 CreateReshapeLayer(input0, outputName, fmt::format("Add:reshapeOf{}", input0));
telsoa01c577f2c2018-08-31 09:22:23 +01001679 }
1680 else //make it constant and it will be create in Add
1681 {
1682 m_TensorsInfo[outputName].m_tensor = std::make_unique<onnx::TensorProto>(*m_TensorsInfo[input0].m_tensor);
1683
1684 }
1685}
1686
Kevin Mayef33cb12021-01-29 14:24:57 +00001687void OnnxParserImpl::SetupInputLayers()
telsoa01c577f2c2018-08-31 09:22:23 +01001688{
1689 //Find user input and add their layers
1690 for(int inputIndex = 0; inputIndex < m_Graph->input_size(); ++inputIndex)
1691 {
1692 auto input = m_Graph->input(inputIndex);
1693 if (! m_TensorsInfo[input.name()].isConstant())
1694 {
1695 IConnectableLayer* layer =
1696 m_Network->AddInputLayer(static_cast<armnn::LayerBindingId>(inputIndex), input.name().c_str());
1697 auto tensorInfo = ToTensorInfo(input);
1698 layer->GetOutputSlot(0).SetTensorInfo(tensorInfo);
1699
1700 RegisterOutputSlots(layer,{ input.name() });
1701 }
1702 }
1703}
1704
Kevin Mayef33cb12021-01-29 14:24:57 +00001705void OnnxParserImpl::SetupOutputLayers()
telsoa01c577f2c2018-08-31 09:22:23 +01001706{
1707 if(m_Graph->output_size() == 0)
1708 {
James Ward58dec6b2020-09-11 17:32:44 +01001709 throw ParseException(fmt::format("The given model does not have any outputs {}", CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001710 }
1711
1712 for(int outputIndex = 0; outputIndex < m_Graph->output_size(); ++outputIndex)
1713 {
1714 IConnectableLayer* layer =
1715 m_Network->AddOutputLayer(static_cast<armnn::LayerBindingId>(outputIndex),
1716 m_Graph->output(outputIndex).name().c_str());
1717
1718 RegisterInputSlots(layer, { m_Graph->output(outputIndex).name() });
1719 }
1720}
1721
Kevin Mayef33cb12021-01-29 14:24:57 +00001722void OnnxParserImpl::RegisterInputSlots(IConnectableLayer* layer, const std::vector<std::string>& tensorIds)
telsoa01c577f2c2018-08-31 09:22:23 +01001723{
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001724 ARMNN_ASSERT(layer != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +01001725 if (tensorIds.size() != layer->GetNumInputSlots())
1726 {
1727 throw ParseException(
James Ward58dec6b2020-09-11 17:32:44 +01001728 fmt::format("The number of tensor inputs ({}) does not match the number expected ({}) {}",
1729 tensorIds.size(),
1730 layer->GetNumInputSlots(),
1731 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001732 }
1733 for (unsigned int slotIndex = 0; slotIndex < layer->GetNumInputSlots(); ++slotIndex)
1734 {
1735 std::string tensorId = tensorIds[slotIndex];
1736 armnn::IInputSlot* slot = &(layer->GetInputSlot(slotIndex));
1737
1738 auto it = m_TensorConnections.find(tensorId);
1739
1740 if (it == m_TensorConnections.end())
1741 {
1742 //First time seing this tensor, we need to map it
1743 m_TensorConnections[tensorId] = TensorSlots();
1744 }
1745 m_TensorConnections[tensorId].inputSlots.push_back(slot);
1746 }
1747}
1748
Kevin Mayef33cb12021-01-29 14:24:57 +00001749void OnnxParserImpl::RegisterOutputSlots(IConnectableLayer* layer, const std::vector<std::string>& tensorIds)
telsoa01c577f2c2018-08-31 09:22:23 +01001750{
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01001751 ARMNN_ASSERT(layer != nullptr);
telsoa01c577f2c2018-08-31 09:22:23 +01001752 if (tensorIds.size() != layer->GetNumOutputSlots())
1753 {
1754 throw ParseException(
James Ward58dec6b2020-09-11 17:32:44 +01001755 fmt::format("The number of tensor outputs ({}) does not match the number expected ({}) {} ",
1756 tensorIds.size(),
1757 layer->GetNumOutputSlots(),
1758 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001759 }
1760
1761 for (unsigned int slotIndex = 0; slotIndex < layer->GetNumOutputSlots(); ++slotIndex)
1762 {
1763 std::string tensorId = tensorIds[slotIndex];
1764 armnn::IOutputSlot* slot = &(layer->GetOutputSlot(slotIndex));
1765
1766 auto it = m_TensorConnections.find(tensorId);
1767
1768 if (it == m_TensorConnections.end())
1769 {
1770 //First time seing this tensor, we need to map it
1771 m_TensorConnections[tensorId] = TensorSlots();
1772 }
1773
Ryan OShea337c17f2020-02-21 12:33:17 +00001774 TensorSlots& tensorSlots = m_TensorConnections[tensorId];
telsoa01c577f2c2018-08-31 09:22:23 +01001775
1776 // assuming there is only one producer for that tensor
1777 if (tensorSlots.outputSlot != nullptr)
1778 {
James Ward58dec6b2020-09-11 17:32:44 +01001779 throw ParseException(fmt::format("Another layer has already registered itself as the producer of "
1780 "tensor:{} {}",
1781 tensorId,
1782 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001783 }
1784 tensorSlots.outputSlot = slot;
1785 }
1786}
1787
Kevin Mayef33cb12021-01-29 14:24:57 +00001788BindingPointInfo OnnxParserImpl::GetNetworkInputBindingInfo(const std::string& name) const
telsoa01c577f2c2018-08-31 09:22:23 +01001789{
1790 for(int i = 0; i < m_Graph->input_size(); ++i)
1791 {
1792 auto input = m_Graph->input(i);
1793 if(input.name() == name)
1794 {
1795 return std::make_pair(static_cast<armnn::LayerBindingId>(i), ToTensorInfo(input));
1796 }
1797 }
James Ward58dec6b2020-09-11 17:32:44 +01001798 throw InvalidArgumentException(fmt::format("The input layer '{}' does not exist {}",
1799 name, CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001800}
1801
Kevin Mayef33cb12021-01-29 14:24:57 +00001802BindingPointInfo OnnxParserImpl::GetNetworkOutputBindingInfo(const std::string& name) const
telsoa01c577f2c2018-08-31 09:22:23 +01001803{
1804 for(int i = 0; i < m_Graph->output_size(); ++i)
1805 {
1806 auto output = m_Graph->output(i);
1807 if(output.name() == name)
1808 {
1809 return std::make_pair(static_cast<armnn::LayerBindingId>(i), ToTensorInfo(output));
1810 }
1811 }
James Ward58dec6b2020-09-11 17:32:44 +01001812 throw InvalidArgumentException(fmt::format("The output layer '{}' does not exist {}",
1813 name, CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001814}
1815
Kevin Mayef33cb12021-01-29 14:24:57 +00001816std::vector<std::string> OnnxParserImpl::GetInputs(ModelPtr& model)
telsoa01c577f2c2018-08-31 09:22:23 +01001817{
1818 if(model == nullptr) {
James Ward58dec6b2020-09-11 17:32:44 +01001819 throw InvalidArgumentException(fmt::format("The given model cannot be null {}",
1820 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001821 }
1822
1823 std::vector<std::string> inputNames;
1824 std::map<std::string, bool> isConstant;
1825 for(auto tensor : model->graph().initializer())
1826 {
1827 isConstant[tensor.name()] = true;
1828 }
1829 for(auto input : model->graph().input())
1830 {
1831 auto it = isConstant.find(input.name());
1832 if(it == isConstant.end())
1833 {
1834 inputNames.push_back(input.name());
1835 }
1836 }
1837 return inputNames;
1838}
1839
Kevin Mayef33cb12021-01-29 14:24:57 +00001840std::vector<std::string> OnnxParserImpl::GetOutputs(ModelPtr& model)
telsoa01c577f2c2018-08-31 09:22:23 +01001841{
1842 if(model == nullptr) {
James Ward58dec6b2020-09-11 17:32:44 +01001843 throw InvalidArgumentException(fmt::format("The given model cannot be null {}",
1844 CHECK_LOCATION().AsString()));
telsoa01c577f2c2018-08-31 09:22:23 +01001845 }
1846
1847 std::vector<std::string> outputNames;
1848 for(auto output : model->graph().output())
1849 {
1850 outputNames.push_back(output.name());
1851 }
1852 return outputNames;
1853}
1854
Matthew Sloyanac001ee2021-02-03 10:43:04 +00001855const std::string OnnxParserImpl::GetVersion()
1856{
1857 return ONNX_PARSER_VERSION;
1858}
1859
telsoa01c577f2c2018-08-31 09:22:23 +01001860} // namespace armnnOnnxParser