IVGCVSW-2563 Fix bug in TfLiteParser::ParseConcatenation

Change-Id: I8fbf27b383a821e062f72809cc2e269fcd18851c
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index 359695b..8b2a818 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -33,48 +33,6 @@
 {
 namespace
 {
-const PermutationVector NHWCToArmNN = { 0, 2, 3, 1 };
-const PermutationVector ArmNNToNHWC = { 0, 3, 1, 2 };
-
-IConnectableLayer* SwizzleIn(INetwork& network,
-                             IConnectableLayer* layer,
-                             unsigned int inputSlotIndex,
-                             const TensorInfo & inputInfo)
-{
-    BOOST_ASSERT(layer != nullptr);
-    // Add swizzle layer
-    std::stringstream name;
-    name << "swizzle_for-" << layer->GetName() << ":in" << inputSlotIndex;
-    IConnectableLayer* const swizzleLayer = network.AddPermuteLayer(NHWCToArmNN, name.str().c_str());
-    // Set swizzled output shape
-    const TensorInfo swizzleOutInfo = armnnUtils::Permuted(inputInfo, NHWCToArmNN);
-    swizzleLayer->GetOutputSlot(0).SetTensorInfo(swizzleOutInfo);
-    // Connect the swizzle layer to the actual layer
-    swizzleLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(inputSlotIndex));
-
-    return swizzleLayer;
-}
-
-IConnectableLayer* DeswizzleOut(INetwork& network,
-                                IConnectableLayer* layer,
-                                unsigned int outputSlotIndex,
-                                const TensorInfo & outputInfo)
-{
-    BOOST_ASSERT(layer != nullptr);
-    // Add deswizzle layer
-    std::stringstream name;
-    name << "deswizzle_for-" << layer->GetName() << ":out" << outputSlotIndex;
-    IConnectableLayer* const deswizzleLayer = network.AddPermuteLayer(ArmNNToNHWC, name.str().c_str());
-    // Set deswizzled output shape
-    deswizzleLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
-    // Set original layer output shape
-    const TensorInfo deswizzleOutInfo = armnnUtils::Permuted(outputInfo, NHWCToArmNN);
-    layer->GetOutputSlot(outputSlotIndex).SetTensorInfo(deswizzleOutInfo);
-    // Connect the actual layer to the deswizzle layer
-    layer->GetOutputSlot(outputSlotIndex).Connect(deswizzleLayer->GetInputSlot(0));
-
-    return deswizzleLayer;
-}
 
 const uint32_t VIRTUAL_OPERATOR_ID = std::numeric_limits<uint32_t>::max();
 
@@ -1383,39 +1341,24 @@
     auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
     CHECK_VALID_SIZE(outputs.size(), 1);
 
-    unsigned int numInputs = static_cast<unsigned int>(inputs.size());
-    unsigned int numConcatView = numInputs;
+    unsigned int numConcatView = static_cast<unsigned int>(inputs.size());
+    uint32_t inputRank = ToTensorInfo(inputs[0]).GetNumDimensions();
 
-    OriginsDescriptor concatDescriptor(static_cast<uint32_t>(numConcatView), MaxNumOfTensorDimensions);
-    std::vector<unsigned int>mergeDimSizes(MaxNumOfTensorDimensions, 0u);
+    const unsigned int concatDimInput = static_cast<unsigned int>(
+        (static_cast<int>(inputRank) + options->axis) % static_cast<int>(inputRank));
 
-    unsigned int mergeDim = 0;
+    OriginsDescriptor concatDescriptor(static_cast<uint32_t>(numConcatView), inputRank);
+    concatDescriptor.SetConcatAxis(concatDimInput);
 
-    // This concatDim indicates the data format: 3 is the NHWC, 1 is the NCHW.
-    // axis could also be negative numbers. Negative axis are interpreted as counting from the end of the rank,
-    // i.e., axis + rank(values)-th dimension.
-    int32_t inputRank = static_cast<int32_t>(ToTensorInfo(inputs[0]).GetNumDimensions());
-    const unsigned int concatDimInput = static_cast<unsigned int>((inputRank + options->axis) % inputRank);
-
-    // ArmNN supports concatenation along the channel dimension for data formats NHWC and NCHW.
-    if (concatDimInput == 0 || concatDimInput == 2)
-    {
-        throw ParseException(
-            boost::str(
-                boost::format(
-                    "Dimension %1% for concatenation is not supported by Armnn. "
-                    "Node %2%")
-                % concatDimInput
-                % CHECK_LOCATION().AsString()));
-    }
+    unsigned int mergeDimOrigin = 0;
 
     for (unsigned int viewIndex = 0; viewIndex < numConcatView; ++viewIndex)
     {
         TensorInfo inputTensorInfo = ToTensorInfo(inputs[viewIndex]);
 
-        // process the input tensor info
-        armnnUtils::ProcessConcatInputTensorInfo(inputTensorInfo, concatDescriptor,
-                                                 concatDimInput, viewIndex, mergeDimSizes, mergeDim);
+        // This set up concatDescriptor view origin
+        armnnUtils::ProcessConcatInputTensorInfo(
+            inputTensorInfo, concatDescriptor, concatDimInput, viewIndex, mergeDimOrigin);
     }
 
     auto layerName = boost::str(boost::format("Concatenation:%1%:%2%") % subgraphIndex % operatorIndex);
@@ -1425,39 +1368,14 @@
 
     armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
     auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
-    if (concatDimInput == 3)
-    {
-        // Adding Fused Activation Layer after this moment....
-        for (unsigned int viewIndex = 0; viewIndex < numConcatView; ++viewIndex)
-        {
-            // add permute layers to swizzle the inputs
-            armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputs[viewIndex]);
-            IConnectableLayer* const swizzleLayer = SwizzleIn(*m_Network, layer, viewIndex, inputTensorInfo);
 
-            BOOST_ASSERT(swizzleLayer != nullptr);
+    layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
 
-            // register the input connection slots for the layer
-            // only the tensors for the inputs are relevant, exclude the const tensors
-            RegisterInputSlots(subgraphIndex, operatorIndex, swizzleLayer, {inputTensorIndexes[viewIndex]});
-        }
+    RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes});
 
-        // add permute layer to deswizzle the output
-        IConnectableLayer* const deswizzleLayer = DeswizzleOut(*m_Network, layer, 0, outputTensorInfo);
+    // add fused activation layer
+    layer = AddFusedActivationLayer(layer, 0, options->fused_activation_function);
 
-        // add fused activation layer after the trailing swizzle layer
-        layer = AddFusedActivationLayer(deswizzleLayer, 0, options->fused_activation_function);
-    }
-    else
-    {
-        // set the layer output tensor info
-        layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
-
-        // register the input connection slots for the layer, connections are made after all layers have been created
-        // only the tensors for the inputs are relevant, exclude the const tensors
-        RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes});
-    }
-
-    // register the output connection slots for the layer, connections are made after all layers have been created
     auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
     RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]});
 }
diff --git a/src/armnnTfLiteParser/test/Concatenation.cpp b/src/armnnTfLiteParser/test/Concatenation.cpp
index bb5aebf..d3d571f 100644
--- a/src/armnnTfLiteParser/test/Concatenation.cpp
+++ b/src/armnnTfLiteParser/test/Concatenation.cpp
@@ -189,4 +189,55 @@
                                70, 71, 72, 73 } } });
 }
 
+struct ConcatenationFixture3DDim0 : ConcatenationFixture
+{
+    ConcatenationFixture3DDim0() : ConcatenationFixture("[ 1, 2, 3]", "[ 2, 2, 3]", "[ 3, 2, 3]", "0" ) {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseConcatenation3DDim0, ConcatenationFixture3DDim0)
+{
+    RunTest<3, armnn::DataType::QuantisedAsymm8>(
+        0,
+        { { "inputTensor1", { 0,  1,  2,  3,  4,  5 } },
+          { "inputTensor2", { 6,  7,  8,  9, 10, 11,
+                             12, 13, 14, 15, 16, 17 } } },
+        { { "outputTensor", { 0,  1,  2,  3,  4,  5,
+                              6,  7,  8,  9, 10, 11,
+                             12, 13, 14, 15, 16, 17 } } });
+}
+
+struct ConcatenationFixture3DDim1 : ConcatenationFixture
+{
+    ConcatenationFixture3DDim1() : ConcatenationFixture("[ 1, 2, 3]", "[ 1, 4, 3]", "[ 1, 6, 3]", "1" ) {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseConcatenation3DDim1, ConcatenationFixture3DDim1)
+{
+    RunTest<3, armnn::DataType::QuantisedAsymm8>(
+        0,
+        { { "inputTensor1", { 0,  1,  2,  3,  4,  5 } },
+          { "inputTensor2", { 6,  7,  8,  9, 10, 11,
+                             12, 13, 14, 15, 16, 17 } } },
+        { { "outputTensor", { 0,  1,  2,  3,  4,  5,
+                              6,  7,  8,  9, 10, 11,
+                             12, 13, 14, 15, 16, 17 } } });
+}
+
+struct ConcatenationFixture3DDim2 : ConcatenationFixture
+{
+    ConcatenationFixture3DDim2() : ConcatenationFixture("[ 1, 2, 3]", "[ 1, 2, 6]", "[ 1, 2, 9]", "2" ) {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseConcatenation3DDim2, ConcatenationFixture3DDim2)
+{
+    RunTest<3, armnn::DataType::QuantisedAsymm8>(
+        0,
+        { { "inputTensor1", { 0,  1,  2,
+                              3,  4,  5 } },
+          { "inputTensor2", { 6,  7,  8,  9, 10, 11,
+                             12, 13, 14, 15, 16, 17 } } },
+        { { "outputTensor", { 0,  1,  2,  6,  7,  8,  9, 10, 11,
+                              3,  4,  5, 12, 13, 14, 15, 16, 17 } } });
+}
+
 BOOST_AUTO_TEST_SUITE_END()
diff --git a/src/armnnUtils/ParserHelper.cpp b/src/armnnUtils/ParserHelper.cpp
index 9d633cf..2286f8b 100644
--- a/src/armnnUtils/ParserHelper.cpp
+++ b/src/armnnUtils/ParserHelper.cpp
@@ -16,12 +16,16 @@
 const armnn::PermutationVector NHWCToArmNN = { 0, 2, 3, 1 };
 const armnn::PermutationVector ArmNNToNHWC = { 0, 3, 1, 2 };
 
-void ProcessConcatInputTensorInfo(armnn::TensorInfo& inputTensorInfo, armnn::OriginsDescriptor& concatDescriptor,
-                                  const unsigned int& concatAxis, unsigned int inputIndex,
-                                  std::vector<unsigned int>& mergeDimSizes, unsigned int& mergeDim)
+void ProcessConcatInputTensorInfo(armnn::TensorInfo& inputTensorInfo,
+                                  armnn::OriginsDescriptor& concatDescriptor,
+                                  const unsigned int& concatAxis,
+                                  unsigned int inputIndex,
+                                  unsigned int& mergeDimOrigin)
 {
+    const uint32_t inputRank = concatDescriptor.GetNumDimensions();
+
     // double check dimensions of the tensors
-    if (inputTensorInfo.GetNumDimensions() != armnn::MaxNumOfTensorDimensions)
+    if (inputTensorInfo.GetNumDimensions() != inputRank)
     {
         throw armnn::ParseException(
             boost::str(
@@ -29,33 +33,19 @@
                     "The number of dimensions: %1% for input tensors of the "
                     "concatenation op should be %2% %3%")
                 % inputTensorInfo.GetNumDimensions()
-                % armnn::MaxNumOfTensorDimensions
+                % inputRank
                 % CHECK_LOCATION().AsString()));
     }
 
-    // if concatenation axis is 3 then need to be permuted
-    if (concatAxis == 3)
-    {
-        inputTensorInfo = armnnUtils::Permuted(inputTensorInfo, NHWCToArmNN);
-    }
-
-    for (unsigned int dim = 0; dim < armnn::MaxNumOfTensorDimensions; ++dim)
-    {
-        mergeDimSizes[dim] = inputTensorInfo.GetShape()[dim];
-    }
-
-    // Concatenation dimension 1 is the only dimension supported in ArmNN
-    const unsigned int concatenationDim = 1;
-
-    for (unsigned int j = 0; j < concatenationDim; ++j)
+    for (unsigned int j = 0; j < concatAxis; ++j)
     {
         concatDescriptor.SetViewOriginCoord(inputIndex, j, 0);
     }
 
-    concatDescriptor.SetViewOriginCoord(inputIndex, concatenationDim, mergeDim);
-    mergeDim += mergeDimSizes[concatenationDim];
+    concatDescriptor.SetViewOriginCoord(inputIndex, concatAxis, mergeDimOrigin);
+    mergeDimOrigin += inputTensorInfo.GetShape()[concatAxis];
 
-    for (unsigned int j = concatenationDim + 1; j < armnn::MaxNumOfTensorDimensions; ++j)
+    for (unsigned int j = concatAxis + 1; j < inputRank; ++j)
     {
         concatDescriptor.SetViewOriginCoord(inputIndex, j, 0);
     }
diff --git a/src/armnnUtils/ParserHelper.hpp b/src/armnnUtils/ParserHelper.hpp
index 24369dc..bcc1e5b 100644
--- a/src/armnnUtils/ParserHelper.hpp
+++ b/src/armnnUtils/ParserHelper.hpp
@@ -10,9 +10,11 @@
 namespace armnnUtils
 {
 
-void ProcessConcatInputTensorInfo(armnn::TensorInfo& inputTensorInfo, armnn::OriginsDescriptor& concatDescriptor,
-                                  const unsigned int& concatAxis, unsigned int inputIndex,
-                                  std::vector<unsigned int>& mergeDimSizes, unsigned int& mergeDim);
+void ProcessConcatInputTensorInfo(armnn::TensorInfo& inputTensorInfo,
+                                  armnn::OriginsDescriptor& concatDescriptor,
+                                  const unsigned int& concatAxis,
+                                  unsigned int inputIndex,
+                                  unsigned int& mergeDimOrigin);
 
 /// Creates a tensor info after reducing the dimensions mentioned in axisData.
 void CalculateReducedOutputTensoInfo(const armnn::TensorInfo& inputTensorInfo, const armnn::TensorInfo& axisTensorInfo,