IVGCVSW-6382 Add Shape operator support to ONNX parser

Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com>
Change-Id: I3547effcbebf1ebc02d3b20f5db394a26991424d
diff --git a/src/armnnOnnxParser/OnnxParser.cpp b/src/armnnOnnxParser/OnnxParser.cpp
index 49f0271..889c35f 100644
--- a/src/armnnOnnxParser/OnnxParser.cpp
+++ b/src/armnnOnnxParser/OnnxParser.cpp
@@ -426,7 +426,8 @@
     { "LeakyRelu",             &OnnxParserImpl::ParseLeakyRelu },
     { "Conv",                  &OnnxParserImpl::ParseConv },
     { "Add",                   &OnnxParserImpl::ParseAdd },
-    { "Flatten",               &OnnxParserImpl::ParseFlatten},
+    { "Flatten",               &OnnxParserImpl::ParseFlatten },
+    { "Shape",                 &OnnxParserImpl::ParseShape }
 };
 
 template<typename TypePair, typename Location>
@@ -1653,6 +1654,30 @@
     AddPoolingLayer(node, desc);
 }
 
+void OnnxParserImpl::ParseShape(const onnx::NodeProto& node)
+{
+    CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 1);
+    CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
+
+    // Output must be INT64
+    CHECK_VALID_DATATYPE(node.name(), node.output(0),
+                         m_TensorsInfo[node.output(0)].m_dtype,
+                         onnx::TensorProto::INT64);
+
+    IConnectableLayer* layer = m_Network->AddShapeLayer(node.name().c_str());
+    ARMNN_ASSERT(layer != nullptr);
+
+    TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape();
+    auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, {inputShape});
+    layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
+
+    // register the input connection slots for the layer, connections are made after all layers have been created
+    RegisterInputSlots(layer, {node.input(0)});
+
+    // register the output connection slots for the layer, connections are made after all layers have been created
+    RegisterOutputSlots(layer, {node.output(0)});
+}
+
 void OnnxParserImpl::ParseReshape(const onnx::NodeProto& node)
 {
     CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 2);