IVGCVSW-6382 Add Shape operator support to ONNX parser
Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com>
Change-Id: I3547effcbebf1ebc02d3b20f5db394a26991424d
diff --git a/src/armnnOnnxParser/OnnxParser.cpp b/src/armnnOnnxParser/OnnxParser.cpp
index 49f0271..889c35f 100644
--- a/src/armnnOnnxParser/OnnxParser.cpp
+++ b/src/armnnOnnxParser/OnnxParser.cpp
@@ -426,7 +426,8 @@
{ "LeakyRelu", &OnnxParserImpl::ParseLeakyRelu },
{ "Conv", &OnnxParserImpl::ParseConv },
{ "Add", &OnnxParserImpl::ParseAdd },
- { "Flatten", &OnnxParserImpl::ParseFlatten},
+ { "Flatten", &OnnxParserImpl::ParseFlatten },
+ { "Shape", &OnnxParserImpl::ParseShape }
};
template<typename TypePair, typename Location>
@@ -1653,6 +1654,30 @@
AddPoolingLayer(node, desc);
}
+void OnnxParserImpl::ParseShape(const onnx::NodeProto& node)
+{
+ CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 1);
+ CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
+
+ // Output must be INT64
+ CHECK_VALID_DATATYPE(node.name(), node.output(0),
+ m_TensorsInfo[node.output(0)].m_dtype,
+ onnx::TensorProto::INT64);
+
+ IConnectableLayer* layer = m_Network->AddShapeLayer(node.name().c_str());
+ ARMNN_ASSERT(layer != nullptr);
+
+ TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape();
+ auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, {inputShape});
+ layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
+
+ // register the input connection slots for the layer, connections are made after all layers have been created
+ RegisterInputSlots(layer, {node.input(0)});
+
+ // register the output connection slots for the layer, connections are made after all layers have been created
+ RegisterOutputSlots(layer, {node.output(0)});
+}
+
void OnnxParserImpl::ParseReshape(const onnx::NodeProto& node)
{
CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 2);