IVGCVSW-2994 Add Reshape layer to ParseUnpack in TfLite parser
to remove the unpacked dimension of each output from Splitter
and correct ReshapeFixtureWithReshapeDimsFlatten test output shape

Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com>
Change-Id: I517d315475612ac8b773930f9b58cac316fa8553
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index b7258b3..44b3614 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -1888,6 +1888,19 @@
     CHECK_VALID_SIZE(inputs.size(), 1);
 
     armnn::TensorInfo inputTensorInfo  = ToTensorInfo(inputs[0]);
+
+    if (unpackAxis >= inputTensorInfo.GetNumDimensions())
+    {
+        throw ParseException(
+                boost::str(
+                        boost::format(
+                                "The unpack axis: %1% cannot be greater than or equal to "
+                                "the number of input dimension %2% %3%")
+                        % unpackAxis
+                        % inputTensorInfo.GetNumDimensions()
+                        % CHECK_LOCATION().AsString()));
+    }
+
     unsigned int unpackNum = CHECKED_NON_NEGATIVE(options->num);
     // If num is not defined, automatically infer from the length of the dimension axis.
     if(unpackNum == 0)
@@ -1935,20 +1948,46 @@
     auto layerName = boost::str(boost::format("Unpack:%1%:%2%") % subgraphIndex % operatorIndex);
     IConnectableLayer* layer = m_Network->AddSplitterLayer(splitDesc, layerName.c_str());
 
+    TensorShape splitOutShape = TensorShape(static_cast<unsigned int>(unpackDimSizes.size()),
+                                            unpackDimSizes.data());
+
     auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
     RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]});
 
-    TensorShape outShape = TensorShape(static_cast<unsigned int>(unpackDimSizes.size()),
-        unpackDimSizes.data());
+    // Reshape to remove unpacked dimension
+    unsigned int reshapedNumDimensions = inputDimSize - 1;
+    std::vector<unsigned int> reshapedDimensions(reshapedNumDimensions);
 
-    for (unsigned int k = 0; k < layer->GetNumOutputSlots(); ++k)
+    unsigned int reshapeIndex = 0;
+    for (unsigned int i = 0; i < inputDimSize; ++i)
     {
-        layer->GetOutputSlot(k).SetTensorInfo(armnn::TensorInfo(outShape,
-            inputTensorInfo.GetDataType()));
+        if (i == unpackAxis)
+        {
+            continue;
+        }
+        reshapedDimensions[reshapeIndex++] = unpackDimSizes[i];
     }
 
-    auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
-    RegisterOutputSlots(subgraphIndex, operatorIndex, layer, outputTensorIndexes);
+    // Create reshape to remove the unpacked dimension for unpack operator of each output from Splitter.
+    for (unsigned int k = 0; k < layer->GetNumOutputSlots(); ++k)
+    {
+        armnn::TensorInfo reshapedTensorInfo = inputTensorInfo;
+        reshapedTensorInfo.SetShape(armnn::TensorShape{ reshapedNumDimensions, reshapedDimensions.data() });
+
+        std::string reshapeLayerName = boost::str(boost::format("Reshape_for:%1%") % layer->GetName());
+        armnn::ReshapeDescriptor desc;
+        desc.m_TargetShape = reshapedTensorInfo.GetShape();
+        armnn::IConnectableLayer* reshapeLayer = m_Network->AddReshapeLayer(desc, layerName.c_str());
+
+        layer->GetOutputSlot(k).SetTensorInfo(armnn::TensorInfo(splitOutShape, inputTensorInfo.GetDataType()));
+        layer->GetOutputSlot(k).Connect(reshapeLayer->GetInputSlot(0));
+
+        reshapeLayer->GetOutputSlot(0).SetTensorInfo(reshapedTensorInfo);
+
+        uint32_t reshapedOutputId = CHECKED_NON_NEGATIVE(operatorPtr->outputs[k]);
+        armnn::IOutputSlot* slot = &(reshapeLayer->GetOutputSlot(0));
+        RegisterProducerOfTensor(subgraphIndex, reshapedOutputId, slot);
+    }
 }
 
 void TfLiteParser::ParseSplit(size_t subgraphIndex, size_t operatorIndex)
diff --git a/src/armnnTfLiteParser/test/Reshape.cpp b/src/armnnTfLiteParser/test/Reshape.cpp
index ef4b761..62fbad6 100644
--- a/src/armnnTfLiteParser/test/Reshape.cpp
+++ b/src/armnnTfLiteParser/test/Reshape.cpp
@@ -95,17 +95,17 @@
 
 struct ReshapeFixtureWithReshapeDimsFlatten : ReshapeFixture
 {
-    ReshapeFixtureWithReshapeDimsFlatten() : ReshapeFixture("[ 3, 3 ]", "[ 1, 9 ]", "[ -1 ]") {}
+    ReshapeFixtureWithReshapeDimsFlatten() : ReshapeFixture("[ 3, 3 ]", "[ 9 ]", "[ -1 ]") {}
 };
 
 BOOST_FIXTURE_TEST_CASE(ParseReshapeWithReshapeDimsFlatten, ReshapeFixtureWithReshapeDimsFlatten)
 {
     SetupSingleInputSingleOutput("inputTensor", "outputTensor");
-    RunTest<2, armnn::DataType::QuantisedAsymm8>(0,
+    RunTest<1, armnn::DataType::QuantisedAsymm8>(0,
                                                  { 1, 2, 3, 4, 5, 6, 7, 8, 9 },
                                                  { 1, 2, 3, 4, 5, 6, 7, 8, 9 });
     BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
-                == armnn::TensorShape({1,9})));
+                == armnn::TensorShape({9})));
 }
 
 struct ReshapeFixtureWithReshapeDimsFlattenTwoDims : ReshapeFixture