IVGCVSW-6382 Add Gather operator support to ONNX parser

 * Add ParseGather to support Gather operator on ONNX
 * Add Support of int64 converted to int32 for constant
 * Add OnnxParserTestUtils
 * Refactor ValidateTensorShapesFromInputs of GatherLayer
 * Unit tests

Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com>
Change-Id: Ie9dff640240e14a062fef38f7faf0ccc212de5f7
diff --git a/src/armnnOnnxParser/OnnxParser.cpp b/src/armnnOnnxParser/OnnxParser.cpp
index 889c35f..e70eb64 100644
--- a/src/armnnOnnxParser/OnnxParser.cpp
+++ b/src/armnnOnnxParser/OnnxParser.cpp
@@ -427,7 +427,8 @@
     { "Conv",                  &OnnxParserImpl::ParseConv },
     { "Add",                   &OnnxParserImpl::ParseAdd },
     { "Flatten",               &OnnxParserImpl::ParseFlatten },
-    { "Shape",                 &OnnxParserImpl::ParseShape }
+    { "Shape",                 &OnnxParserImpl::ParseShape },
+    { "Gather",                &OnnxParserImpl::ParseGather },
 };
 
 template<typename TypePair, typename Location>
@@ -533,6 +534,10 @@
     TensorInfo tensorInfo = *m_TensorsInfo[name].m_info;
     onnx::TensorProto onnxTensor = *m_TensorsInfo[name].m_tensor;
 
+    //ONNX can have Float16 and double constant nodes but ArmNN only supports float32
+    CHECK_VALID_DATATYPE(name, onnxTensor.name(),
+                         static_cast<onnx::TensorProto::DataType>(onnxTensor.data_type()), onnx::TensorProto::FLOAT);
+
     // Makes sure IsConstant flag is set.
     tensorInfo.SetConstant();
 
@@ -568,6 +573,65 @@
     }
 }
 
+std::pair<ConstTensor, std::unique_ptr<int32_t[]>>
+OnnxParserImpl::CreateInt64ConstTensor(const std::string name,
+                                       armnn::Optional<armnn::PermutationVector&> permutationVector)
+{
+    TensorInfo tensorInfo = *m_TensorsInfo[name].m_info;
+    onnx::TensorProto onnxTensor = *m_TensorsInfo[name].m_tensor;
+
+    CHECK_VALID_DATATYPE(name, onnxTensor.name(),
+                         static_cast<onnx::TensorProto::DataType>(onnxTensor.data_type()), onnx::TensorProto::INT64);
+
+    // Makes sure IsConstant flag is set.
+    tensorInfo.SetConstant();
+    uint numElements = tensorInfo.GetNumElements();
+
+    // Const tensors requires at least a list of values
+    if (numElements == 0)
+    {
+        throw ParseException(fmt::format("No tensor data found for Const tensor '{}' {}",
+                                         name,
+                                         CHECK_LOCATION().AsString()));
+    }
+
+    // Copy the value list entries into the destination
+    if (!onnxTensor.has_raw_data())
+    {
+        auto srcData = onnxTensor.int64_data().data();
+        if(numElements != static_cast<uint>(onnxTensor.int64_data_size()))
+        {
+            throw ParseException(
+                fmt::format("The number of data provided ({}) does not match the tensor '{}' number of "
+                            "elements ({}) {}",
+                            onnxTensor.int64_data_size(),
+                            name,
+                            tensorInfo.GetNumElements(),
+                            CHECK_LOCATION().AsString()));
+        }
+
+        std::vector<int32_t> int32Data;
+        for(uint i = 0; i < numElements; i++)
+        {
+            int32_t int32Value = CHECKED_INT32(srcData[i]);
+            int32Data.push_back(int32Value);
+        }
+
+        return CreateConstTensorImpl<int32_t>(int32Data.data(), tensorInfo, permutationVector);
+    }
+    else
+    {
+        auto srcData = reinterpret_cast<const int64_t*>(onnxTensor.raw_data().c_str());
+        std::vector<int32_t> int32Data;
+        for(uint i = 0; i < numElements; i++)
+        {
+            int32_t int32Value = CHECKED_INT32(srcData[i]);
+            int32Data.push_back(int32Value);
+        }
+        return CreateConstTensorImpl<int32_t>(int32Data.data(), tensorInfo, permutationVector);
+    }
+}
+
 ModelPtr OnnxParserImpl::LoadModelFromTextFile(const char* graphFile)
 {
     FILE* fd = fopen(graphFile, "r");
@@ -1152,7 +1216,14 @@
 void OnnxParserImpl::CreateConstantLayer(const std::string& tensorName, const std::string& layerName)
 {
     auto armnnTensor = CreateConstTensor(tensorName);
+    IConnectableLayer* layer = m_Network->AddConstantLayer(armnnTensor.first, layerName.c_str());
+    layer->GetOutputSlot(0).SetTensorInfo(armnnTensor.first.GetInfo());
+    RegisterOutputSlots(layer, {tensorName});
+}
 
+void OnnxParserImpl::CreateInt64ConstantLayer(const std::string& tensorName, const std::string& layerName)
+{
+    auto armnnTensor = CreateInt64ConstTensor(tensorName);
     IConnectableLayer* layer = m_Network->AddConstantLayer(armnnTensor.first, layerName.c_str());
     layer->GetOutputSlot(0).SetTensorInfo(armnnTensor.first.GetInfo());
     RegisterOutputSlots(layer, {tensorName});
@@ -1370,16 +1441,25 @@
     }
     const onnx::TensorProto& onnxTensor = node.attribute(0).t();
 
-    //ONNX can have Float16 and double constant nodes but ArmNN only supports float32
-    CHECK_VALID_DATATYPE(node.name(), onnxTensor.name(),
-                         static_cast<onnx::TensorProto::DataType>(onnxTensor.data_type()), onnx::TensorProto::FLOAT);
-
     //Register this as a m_ConstParam so we know we can use it as a constant param in future layers.
     m_TensorsInfo[node.output(0)].m_tensor = std::make_unique<const onnx::TensorProto>(onnxTensor);
     m_TensorsInfo[node.output(0)].m_info = std::make_unique<TensorInfo>(ToTensorInfo(onnxTensor));
     m_TensorsInfo[node.output(0)].m_dtype = static_cast<onnx::TensorProto::DataType>(onnxTensor.data_type());
 
-    CreateConstantLayer(node.output(0), node.name());
+    if (m_TensorsInfo[node.output(0)].m_dtype == onnx::TensorProto_DataType_FLOAT)
+    {
+        CreateConstantLayer(node.output(0), node.name());
+    }
+    else if (m_TensorsInfo[node.output(0)].m_dtype == onnx::TensorProto_DataType_INT64)
+    {
+        CreateInt64ConstantLayer(node.output(0), node.name());
+    }
+    else
+    {
+        throw ParseException(fmt::format("Data type not support for Constant node '{}' {}",
+                                         node.name(),
+                                         CHECK_LOCATION().AsString()));
+    }
 }
 
 void OnnxParserImpl::ParseConv(const onnx::NodeProto& node)
@@ -1622,6 +1702,29 @@
     CreateReshapeLayer(node.input(0), node.output(0), node.name());
 }
 
+void OnnxParserImpl::ParseGather(const onnx::NodeProto& node)
+{
+    CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 2);
+    CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1);
+
+    armnn::GatherDescriptor gatherDescriptor;
+    gatherDescriptor.m_Axis = static_cast<int>(ReadOptionalNodeInt64Attribute(node, "axis", 0));
+
+    IConnectableLayer* layer = m_Network->AddGatherLayer(gatherDescriptor, node.name().c_str());
+    ARMNN_ASSERT(layer != nullptr);
+
+    TensorShape inputShape = m_TensorsInfo[node.input(0)].m_info->GetShape();
+    TensorShape indicesShape = m_TensorsInfo[node.input(1)].m_info->GetShape();
+    auto outputInfo = ComputeOutputInfo({node.output(0)}, layer, { inputShape, indicesShape });
+    layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
+
+    // register the input connection slots for the layer, connections are made after all layers have been created
+    RegisterInputSlots(layer, { node.input(0), node.input(1) });
+
+    // register the output connection slots for the layer, connections are made after all layers have been created
+    RegisterOutputSlots(layer, { node.output(0) });
+}
+
 void OnnxParserImpl::ParseGlobalAveragePool(const onnx::NodeProto& node)
 {
     Pooling2dDescriptor desc = Pooling2dDescriptor();