GitHub #714: OnnxParser FullyConnectedLayer inferred shape doesn't match

 * Added reshape before and after FullyConnected to support dimensions
   > 2. This is now consistent with the Delegate and TfLiteParser.
 * Refactored AddFullyConnected method to remove duplicate code.

Signed-off-by: Matthew Sloyan <matthew.sloyan@arm.com>
Signed-off-by: Francis Murtagh <francis.murtagh@arm.com>
Change-Id: I04dfeb38dbcac096c5fcd9dcb5e3821d38ce6550
diff --git a/src/armnnOnnxParser/OnnxParser.cpp b/src/armnnOnnxParser/OnnxParser.cpp
index 552d4e4..936216f 100644
--- a/src/armnnOnnxParser/OnnxParser.cpp
+++ b/src/armnnOnnxParser/OnnxParser.cpp
@@ -1,5 +1,5 @@
 //
-// Copyright © 2017,2022 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2017,2022-2023 Arm Ltd and Contributors. All rights reserved.
 // SPDX-License-Identifier: MIT
 //
 #include "OnnxParser.hpp"
@@ -505,9 +505,10 @@
                                    outNames.end(),
                                    [this](std::string name)
                                    {
-                                       return (m_TensorsInfo.count(name) == 0 || m_TensorsInfo[name].m_info == nullptr
-                                       || m_TensorsInfo[name].m_info->GetShape().GetDimensionality() ==
-                                          Dimensionality::NotSpecified);
+                                       return (m_TensorsInfo.count(name) == 0 ||
+                                               m_TensorsInfo[name].m_info == nullptr ||
+                                               m_TensorsInfo[name].m_info->GetShape().GetDimensionality() ==
+                                               Dimensionality::NotSpecified);
                                    });
     std::vector<TensorInfo> outInfo;
     //if the output info(s) are not here, we need to compute them
@@ -1148,16 +1149,23 @@
 
 void OnnxParserImpl::AddFullyConnected(const onnx::NodeProto& matmulNode, const onnx::NodeProto* addNode)
 {
-
     // find matmul inputs
-    std::string weightName;
     std::string inputName;
+    std::string weightName;
+    std::string biasName;
+    std::string outputName;
     CHECK_VALID_SIZE(static_cast<size_t>(matmulNode.input_size()), 2);
     CHECK_VALID_SIZE(static_cast<size_t>(matmulNode.output_size()), 1);
     VALID_INPUTS(matmulNode, STR_LIST(onnx::TensorProto::FLOAT));
 
     GetInputAndParam(matmulNode, &inputName, &weightName, CHECK_LOCATION());
 
+    TensorInfo inputInfo = *m_TensorsInfo[inputName].m_info;
+    TensorInfo weightInfo = *m_TensorsInfo[weightName].m_info;
+    TensorInfo biasInfo;
+
+    std::vector<std::string> inputNames;
+
     FullyConnectedDescriptor desc;
     desc.m_BiasEnabled = addNode != nullptr;
 
@@ -1165,7 +1173,6 @@
     if(desc.m_BiasEnabled)
     {
         // find bias const
-        std::string biasName;
         CHECK_VALID_SIZE(static_cast<size_t>(addNode->input_size()), 2);
         CHECK_VALID_SIZE(static_cast<size_t>(addNode->output_size()), 1);
         VALID_INPUTS(*addNode, STR_LIST(onnx::TensorProto::FLOAT));
@@ -1174,8 +1181,7 @@
 
         //Output shape is [1, weights[1]] and 1d vec in ONNX can be [1,X] so we convert biases to "armnn" 1D
         To1DTensor(biasName, CHECK_LOCATION());
-        TensorInfo weightInfo = *m_TensorsInfo[weightName].m_info;
-        TensorInfo biasInfo = *m_TensorsInfo[biasName].m_info;
+        biasInfo = *m_TensorsInfo[biasName].m_info;
 
         if (weightInfo.GetShape()[1] != biasInfo.GetShape()[0])
         {
@@ -1191,61 +1197,114 @@
                             CHECK_LOCATION().AsString()));
         }
 
-        // Just add a FullyConnected layer, weights and biases are handled as inputs now.
-        layer = m_Network->AddFullyConnectedLayer(desc, matmulNode.name().c_str());
-        ARMNN_ASSERT(layer != nullptr);
-
-        auto outputInfo = ComputeOutputInfo({addNode->output(0)}, layer,
-                                            {m_TensorsInfo[inputName].m_info->GetShape(),
-                                             m_TensorsInfo[weightName].m_info->GetShape()});
-        layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
-
-        // Add constant layer to store weights/biases and connect to FullyConnected layer..
-        if(m_TensorsInfo[weightName].isConstant())
-        {
-            IConnectableLayer* weightsLayer = m_Network->AddConstantLayer(CreateConstTensor(weightName).first);
-
-            weightInfo.SetConstant();
-            weightsLayer->GetOutputSlot(0).SetTensorInfo(weightInfo);
-            weightsLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1u));
-        }
-
-        if(m_TensorsInfo[biasName].isConstant())
-        {
-            IConnectableLayer* biasLayer = m_Network->AddConstantLayer(CreateConstTensor(biasName).first);
-
-            biasInfo.SetConstant();
-            biasLayer->GetOutputSlot(0).SetTensorInfo(biasInfo);
-            biasLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(2u));
-        }
-
-        RegisterInputSlots(layer, {inputName, weightName, biasName});
-        RegisterOutputSlots(layer, {addNode->output(0)});
+        inputNames = { inputName, weightName, biasName };
+        outputName = addNode->output(0);
     }
     else
     {
-        layer = m_Network->AddFullyConnectedLayer(desc, matmulNode.name().c_str());
-        ARMNN_ASSERT(layer != nullptr);
+        inputNames = { inputName, weightName };
+        outputName = matmulNode.output(0);
+    }
 
-        auto outputInfo = ComputeOutputInfo({matmulNode.output(0)}, layer,
-                                            {m_TensorsInfo[inputName].m_info->GetShape(),
-                                             m_TensorsInfo[weightName].m_info->GetShape()});
-        layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
+    // Just add a FullyConnected layer, weights and biases are handled as inputs now.
+    layer = m_Network->AddFullyConnectedLayer(desc, matmulNode.name().c_str());
+    ARMNN_ASSERT(layer != nullptr);
 
-        // Add constant layer to store weights and connect to FullyConnected layer.
-        if(m_TensorsInfo[weightName].isConstant())
+    if (inputInfo.GetNumDimensions() > 2)
+    {
+        // Add reshape to flatten to 2D [batch_size, input_size],
+        // where "input_size" corresponds to the number of inputs to the layer,
+        // matching the second dimension of weights,
+        // and "batch_size" is calculated by dividing the number of elements by "input_size".
+        std::vector<unsigned int> reshapedDimensions(2);
+        reshapedDimensions[1] = weightInfo.GetShape()[0];
+        reshapedDimensions[0] = inputInfo.GetNumElements() / reshapedDimensions[1];
+
+        if (inputInfo.GetNumElements() % reshapedDimensions[1] != 0)
         {
-            TensorInfo weightInfo = *m_TensorsInfo[weightName].m_info;
-            IConnectableLayer* weightsLayer = m_Network->AddConstantLayer(CreateConstTensor(weightName).first);
-
-            weightInfo.SetConstant();
-            weightsLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1u));
-            weightsLayer->GetOutputSlot(0).SetTensorInfo(weightInfo);
+            throw ParseException(
+                    fmt::format("Failed to deduce input tensor shape from filter size {} {}",
+                                reshapedDimensions[1],
+                                CHECK_LOCATION().AsString()));
         }
 
-        RegisterInputSlots(layer, {inputName, weightName});
-        RegisterOutputSlots(layer, {matmulNode.output(0)});
+        TensorInfo reshapedTensorInfo = inputInfo;
+        reshapedTensorInfo.SetShape(armnn::TensorShape{ 2, reshapedDimensions.data() });
+        inputInfo = reshapedTensorInfo;
+
+        ReshapeDescriptor reshapeDescriptor;
+        reshapeDescriptor.m_TargetShape = reshapedTensorInfo.GetShape();
+
+        std::string reshapeLayerName = fmt::format("Reshape_for:{}", layer->GetName());
+        IConnectableLayer* reshapeLayer = m_Network->AddReshapeLayer(reshapeDescriptor, reshapeLayerName.c_str());
+
+        reshapeLayer->GetOutputSlot(0).SetTensorInfo(reshapedTensorInfo);
+        reshapeLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(0));
+
+        RegisterInputSlots(reshapeLayer, {inputName});
+        inputNames[0] = reshapeLayerName;
     }
+
+    auto outputInfo = ComputeOutputInfo({ outputName },
+                                        layer,
+                                        { inputInfo.GetShape(),
+                                          weightInfo.GetShape() });
+    layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
+
+    RegisterInputSlots(layer, inputNames);
+
+    // Add constant layer to store weights/biases and connect to FullyConnected layer..
+    if(m_TensorsInfo[weightName].isConstant())
+    {
+        IConnectableLayer* weightsLayer = m_Network->AddConstantLayer(CreateConstTensor(weightName).first);
+
+        weightInfo.SetConstant();
+        weightsLayer->GetOutputSlot(0).SetTensorInfo(weightInfo);
+        weightsLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1u));
+    }
+
+    if(desc.m_BiasEnabled && m_TensorsInfo[biasName].isConstant())
+    {
+        IConnectableLayer* biasLayer = m_Network->AddConstantLayer(CreateConstTensor(biasName).first);
+
+        biasInfo.SetConstant();
+        biasLayer->GetOutputSlot(0).SetTensorInfo(biasInfo);
+        biasLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(2u));
+    }
+
+    if (outputInfo[0].GetNumDimensions() > 2)
+    {
+        // Calculate reshape to flatten to 2D [batch_size, input_size]
+        std::vector<unsigned int> reshapedDimensions(2);
+        reshapedDimensions[1] = weightInfo.GetShape()[1];
+        reshapedDimensions[0] = outputInfo[0].GetNumElements() / reshapedDimensions[1];
+
+        if (outputInfo[0].GetNumElements() % reshapedDimensions[1] != 0)
+        {
+            throw ParseException(
+                    fmt::format("Failed to deduce output tensor shape from filter size {} {}",
+                                reshapedDimensions[1],
+                                CHECK_LOCATION().AsString()));
+        }
+
+        armnn::TensorInfo reshapedOutputTensorInfo = outputInfo[0];
+        reshapedOutputTensorInfo.SetShape(armnn::TensorShape{ 2, reshapedDimensions.data() });
+        layer->GetOutputSlot(0).SetTensorInfo(reshapedOutputTensorInfo);
+
+        ReshapeDescriptor desc;
+        desc.m_TargetShape = outputInfo[0].GetShape();
+
+        std::string reshapeLayerName = fmt::format("ExpandDims_for:{}", layer->GetName());
+        IConnectableLayer* reshapeLayer = m_Network->AddReshapeLayer(desc, reshapeLayerName.c_str());
+
+        layer->GetOutputSlot(0).Connect(reshapeLayer->GetInputSlot(0));
+        reshapeLayer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);
+
+        RegisterInputSlots(reshapeLayer, {layer->GetName()});
+        layer = reshapeLayer;
+    }
+
+    RegisterOutputSlots(layer, { outputName });
 }
 
 void OnnxParserImpl::AddPoolingLayer(const onnx::NodeProto& node, Pooling2dDescriptor& desc)