IVGCVSW-2996 Add Reshape layer to ParseFullyConnected in TfLite parser
when input is > 2D to flatten the input to 2D [batch_size, input_size]

Change-Id: Id9d9ff996225c7d0938204ae0ceb330a11e264f5
Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com>
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index 44b3614..5733343 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -1780,17 +1780,57 @@
     }
     BOOST_ASSERT(layer != nullptr);
 
+    armnn::TensorInfo inputTensorInfo  = ToTensorInfo(inputs[0]);
+
+    auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
+
+    if (inputTensorInfo.GetNumDimensions() > 2)
+    {
+        // Add reshape to flatten to 2D [batch_size, input_size],
+        // where "input_size" corresponds to the number of inputs to the layer,
+        // matching the second dimension of weights,
+        // and "batch_size" is calculated by dividing the number of elements by "input_size".
+        std::vector<unsigned int> reshapedDimensions(2);
+        reshapedDimensions[1] = filterTensorInfo.GetShape()[1];
+        reshapedDimensions[0] = inputTensorInfo.GetNumElements() / reshapedDimensions[1];
+
+        if (inputTensorInfo.GetNumElements() % reshapedDimensions[1] != 0)
+        {
+            throw ParseException(
+                    boost::str(
+                            boost::format(
+                                    "Failed to deduce input tensor shape from filter size %1%")
+                            % reshapedDimensions[1]
+                            % CHECK_LOCATION().AsString()));
+        }
+
+        armnn::TensorInfo reshapedTensorInfo = ToTensorInfo(inputs[0]);
+        reshapedTensorInfo.SetShape(armnn::TensorShape{ 2, reshapedDimensions.data() });
+
+        std::string reshapeLayerName = boost::str(boost::format("Reshape_for:%1%") % layer->GetName());
+        armnn::ReshapeDescriptor desc;
+        desc.m_TargetShape = reshapedTensorInfo.GetShape();
+        armnn::IConnectableLayer* reshapeLayer = m_Network->AddReshapeLayer(desc, layerName.c_str());
+
+        reshapeLayer->GetOutputSlot(0).SetTensorInfo(reshapedTensorInfo);
+        reshapeLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(0));
+
+        RegisterInputSlots(subgraphIndex, operatorIndex, reshapeLayer, {inputTensorIndexes[0]});
+    }
+    else
+    {
+        // register the input connection slot for the layer
+        // only the tensors for the inputs are relevant, exclude the const tensors
+        RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]});
+    }
+
     armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
     layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
 
-    // register the input connection slot for the layer
-    // only the tensors for the inputs are relevant, exclude the const tensors
-    auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
-    RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]});
-
     // we need to add the activation layer and fortunately we don't need to care about the data layout
     armnn::IConnectableLayer* fusedActivationLayer = AddFusedActivationLayer(layer, 0,
                                                                              options->fused_activation_function);
+
     // register the output connection slots for the layer, connections are made after all layers have been created
     auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
     RegisterOutputSlots(subgraphIndex, operatorIndex, fusedActivationLayer, {outputTensorIndexes[0]});
diff --git a/src/armnnTfLiteParser/test/FullyConnected.cpp b/src/armnnTfLiteParser/test/FullyConnected.cpp
index 7ee64a4..54d7bcb 100644
--- a/src/armnnTfLiteParser/test/FullyConnected.cpp
+++ b/src/armnnTfLiteParser/test/FullyConnected.cpp
@@ -151,4 +151,24 @@
         { (400+10)/2 });
 }
 
+struct FullyConnectedWithBiasMultipleOutputsFixture : FullyConnectedFixture
+{
+    FullyConnectedWithBiasMultipleOutputsFixture()
+            : FullyConnectedFixture("[ 1, 4, 2, 1 ]",     // inputShape
+                                    "[ 2, 1 ]",           // outputShape
+                                    "[ 1, 4 ]",           // filterShape
+                                    "[ 2, 3, 4, 5 ]",     // filterData
+                                    "[ 1 ]",              // biasShape
+                                    "[ 10, 0, 0, 0 ]" )   // biasData
+    {}
+};
+
+BOOST_FIXTURE_TEST_CASE(FullyConnectedWithBiasMultipleOutputs, FullyConnectedWithBiasMultipleOutputsFixture)
+{
+    RunTest<2, armnn::DataType::QuantisedAsymm8>(
+            0,
+            { 1, 2, 3, 4, 10, 20, 30, 40 },
+            { (40+10)/2, (400+10)/2 });
+}
+
 BOOST_AUTO_TEST_SUITE_END()