IVGCVSW-5270 Update ConvertConcatenation function to use ShapeInferenceMethod

Signed-off-by: Keith Davis <keith.davis@arm.com>
Change-Id: I13e16d271ba55217b98a439aa82931f809fdeeb8
diff --git a/ConversionUtils.hpp b/ConversionUtils.hpp
index fe8e026..fa67f79 100644
--- a/ConversionUtils.hpp
+++ b/ConversionUtils.hpp
@@ -1968,7 +1968,7 @@
          typename HalModel     = typename HalPolicy::Model>
 bool ConvertConcatenation(const HalOperation& operation, const HalModel& model, ConversionData& data)
 {
-    using HalOperand     = typename HalPolicy::Operand;
+    using HalOperand = typename HalPolicy::Operand;
     using HalOperandType = typename HalPolicy::OperandType;
 
     // The first N (0..N-1) inputs are tensors. The Nth input is the concatenation axis.
@@ -1992,9 +1992,9 @@
         return Fail("%s: Operation has no outputs", __func__);
     }
 
-    armnn::TensorInfo  outputInfo  = GetTensorInfoForOperand(*outputOperand);
-    armnn::TensorShape outputShape = outputInfo.GetShape();
-
+    armnn::TensorInfo  outputInfo      = GetTensorInfoForOperand(*outputOperand);
+    armnn::TensorShape outputShape     = outputInfo.GetShape();
+    const bool         isDynamicTensor = IsDynamicTensor(outputInfo);
     //
     // handle negative concat dims along the lines of tensorflow as described here:
     //    https://www.tensorflow.org/api_docs/python/tf/concat
@@ -2016,9 +2016,8 @@
     inputHandles.reserve(numInputTensors);
     inputShapes.reserve(numInputTensors);
 
-    bool inputsHaveBeenReshaped        = false;
-    unsigned int tensorDimensionsAdded = 0;
-
+    bool          inputsHaveBeenReshaped = false;
+    unsigned int  tensorDimensionsAdded  = 0;
     for (uint32_t i = 0; i < numInputTensors; ++i)
     {
         const HalOperand* operand = GetInputOperand<HalPolicy>(operation, i, model);
@@ -2033,7 +2032,7 @@
             return Fail("%s: Operation has invalid inputs", __func__);
         }
 
-        armnn::TensorShape operandShape     = GetTensorShapeForOperand(*operand);
+        armnn::TensorShape operandShape = GetTensorShapeForOperand(*operand);
         if (operandShape.GetNumDimensions() == 0)
         {
             return Fail("%s: Operands with rank 0 are not supported", __func__);
@@ -2068,19 +2067,15 @@
                                        operandInputHandle.GetTensorInfo(),
                                        reshapeInfo,
                                        reshapeDescriptor);
+
             if (!isSupported)
             {
                 return false;
             }
-
-            armnn::IConnectableLayer& newReshape = AddReshapeLayer(
-                    *data.m_Network,
-                    operandInputHandle,
-                    reshapeInfo
-            );
+            armnn::IConnectableLayer& newReshape = AddReshapeLayer(*data.m_Network, operandInputHandle, reshapeInfo);
 
             // Point to the reshape operation rather then the input operation
-            operandShape = reshapeInfo.GetShape();
+            operandShape       = reshapeInfo.GetShape();
             operandInputHandle = LayerInputHandle(true, &newReshape.GetOutputSlot(0), reshapeInfo);
         }
 
@@ -2103,29 +2098,47 @@
         // Add extra dimensions to the output shape to reflect the addition of the reshape layers
         if (tensorDimensionsAdded == 1)
         {
-            outputShape = armnn::TensorShape({1, outputShape[0], outputShape[1]});
+            if (IsDynamicTensor(outputInfo))
+            {
+                outputShape = armnn::TensorShape({1, 0, 0}, {true, false, false});
+            }
+            else
+            {
+                outputShape = armnn::TensorShape({1, outputShape[0], outputShape[1]});
+            }
         }
         else if (tensorDimensionsAdded == 2)
         {
-            outputShape = armnn::TensorShape({1, 1, outputShape[0]});
+            if (IsDynamicTensor(outputInfo))
+            {
+                outputShape = armnn::TensorShape({1, 1, 0}, {true, true, false});
+            }
+            else
+            {
+                outputShape = armnn::TensorShape({1, 1, outputShape[0]});
+            }
         }
     }
 
     // Check if permutations is required and get the pair of permutations required for the concatenation.
     // Permutation is required when the concat dimension is 2 for a 4D tensor or 1 for a 3D tensor.
     std::pair<armnn::PermutationVector, armnn::PermutationVector> permutationPair =
-            std::make_pair(IdentityPermutation4D, IdentityPermutation4D);
+        std::make_pair(IdentityPermutation4D, IdentityPermutation4D);
 
-    bool needPermute =
-            CreateConcatPermutationParameters(inputShapes[0].GetNumDimensions(), concatDim, permutationPair);
+    bool needPermute = CreateConcatPermutationParameters(inputShapes[0].GetNumDimensions(),
+                                                         concatDim,
+                                                         permutationPair);
 
-    if (needPermute)
+    // Only relevant to static tensors as dynamic output tensors will be transposed as a result of inferring from input
+    if (!isDynamicTensor)
     {
-        outputShape = armnnUtils::TransposeTensorShape(outputShape, permutationPair.first);
+        if (needPermute)
+        {
+            outputShape = armnnUtils::TransposeTensorShape(outputShape, permutationPair.first);
+        }
+
+        outputInfo.SetShape(outputShape);
     }
-
-    outputInfo.SetShape(outputShape);
-
     // this is no-op for identity swizzles, otherwise it replaces both
     // the handles and shapes with the swizzled layer output handles and shapes
     if (!TransposeInputTensors(data, inputHandles, inputShapes, permutationPair.first))
@@ -2140,33 +2153,43 @@
     {
         // The concat descriptor is always created across the only supported concat dimension
         // which is 0, 1 or 3 for a 4-D tensor, or 0 or 2 for a 3-D tensor.
-        concatDescriptor =
-                armnn::CreateDescriptorForConcatenation(inputShapes.begin(), inputShapes.end(), concatDim);
-    }
-    catch (std::exception& error)
+        concatDescriptor = armnn::CreateDescriptorForConcatenation(inputShapes.begin(),
+                                                                   inputShapes.end(),
+                                                                   concatDim);
+    } catch (std::exception& error)
     {
         return Fail("%s: Error preparing concat descriptor. %s", __func__, error.what());
     }
 
     // Validate the output shape is correct given the input shapes based on the
     // only valid concat dimension which is 0, 1 or 3 for a 4-D tensor, or 0 or 2 for a 3-D tensor.
-    if (!ValidateConcatOutputShape(inputShapes, outputShape, concatDim))
+    if (!isDynamicTensor)
     {
-        return Fail("%s: Error validating the output shape for concat", __func__);
+        if (!ValidateConcatOutputShape(inputShapes, outputShape, concatDim))
+        {
+            return Fail("%s: Error validating the output shape for concat", __func__);
+        }
     }
 
     std::vector<const armnn::TensorInfo*> inputTensorInfos;
     std::transform(inputHandles.begin(), inputHandles.end(), std::back_inserter(inputTensorInfos),
-                   [](const LayerInputHandle& h) -> const armnn::TensorInfo*{ return &h.GetTensorInfo(); });
+                   [](const LayerInputHandle& h)->const armnn::TensorInfo*{ return &h.GetTensorInfo(); });
 
-    bool isSupported = false;
-    FORWARD_LAYER_SUPPORT_FUNC(__func__,
-                               IsConcatSupported,
-                               data.m_Backends,
-                               isSupported,
-                               inputTensorInfos,
-                               outputInfo,
-                               concatDescriptor);
+    bool isSupported  = false;
+    auto validateFunc = [&](const armnn::TensorInfo& outputInfo, bool& isSupported){
+        FORWARD_LAYER_SUPPORT_FUNC(__func__, IsConcatSupported, data.m_Backends, isSupported, inputTensorInfos,
+                                   outputInfo, concatDescriptor);
+    };
+
+    if (!isDynamicTensor)
+    {
+        validateFunc(outputInfo, isSupported);
+    }
+    else
+    {
+        isSupported = AreDynamicTensorsSupported();
+    }
+
     if (!isSupported)
     {
         return false;
@@ -2175,7 +2198,6 @@
     armnn::IConnectableLayer* layer = data.m_Network->AddConcatLayer(concatDescriptor);
     assert(layer != nullptr);
     layer->GetOutputSlot(0).SetTensorInfo(outputInfo);
-
     // Connect inputs to the layer
     const int numInputSlots = layer->GetNumInputSlots();
     assert(static_cast<std::size_t>(numInputSlots) == inputHandles.size());
@@ -2185,15 +2207,14 @@
         inputHandles[static_cast<unsigned int>(i)].Connect(layer->GetInputSlot(i));
     }
 
-    if (needPermute)
-    {
+    // Transpose the output shape
+    auto transposeOutputShape = [&](){
         armnn::TransposeDescriptor transposeDesc;
         transposeDesc.m_DimMappings = permutationPair.second;
         armnn::TensorInfo inputTransposeInfo  = layer->GetOutputSlot(0).GetTensorInfo();
         armnn::TensorInfo outputTransposeInfo = armnnUtils::TransposeTensorShape(inputTransposeInfo,
                                                                                  permutationPair.second);
-
-        bool isSupported = false;
+        isSupported = false;
         FORWARD_LAYER_SUPPORT_FUNC(__func__,
                                    IsTransposeSupported,
                                    data.m_Backends,
@@ -2201,56 +2222,92 @@
                                    inputTransposeInfo,
                                    outputTransposeInfo,
                                    transposeDesc);
+
         if (!isSupported)
         {
             return false;
         }
         // Add permutation layer and connect the output to it, the permutation becomes the output layer
-        armnn::IConnectableLayer& deswizzleLayer = AddTransposeLayer(*data.m_Network,
-                                                                     layer->GetOutputSlot(0),
+        armnn::IConnectableLayer& deswizzleLayer = AddTransposeLayer(*data.m_Network, layer->GetOutputSlot(0),
                                                                      permutationPair.second);
         layer = &deswizzleLayer;
+
+        return true;
+    };
+
+    if (needPermute && !isDynamicTensor)
+    {
+        transposeOutputShape();
     }
 
     if (inputsHaveBeenReshaped)
     {
+        if (isDynamicTensor)
+        {
+            // Infer the output shapes of concat if outputs are type 1 dynamic
+            layer->GetOutputSlot(0).IsTensorInfoSet();
+            if (!ValidateConcatOutputShape(inputShapes,
+                                           layer->GetOutputSlot(0).GetTensorInfo().GetShape(),
+                                           concatDim))
+            {
+                return Fail("%s: Error validating the output shape for concat", __func__);
+            }
+            transposeOutputShape();
+        }
+
         armnn::TensorInfo afterConcatInfo = layer->GetOutputSlot(0).GetTensorInfo();
 
         // Undo the reshape knowing the amount of dimensions added
         if (tensorDimensionsAdded == 1)
         {
-            afterConcatInfo.SetShape(armnn::TensorShape({ afterConcatInfo.GetShape()[1],
-                                                          afterConcatInfo.GetShape()[2] }));
+            afterConcatInfo.SetShape(
+                armnn::TensorShape({afterConcatInfo.GetShape()[1], afterConcatInfo.GetShape()[2]}));
         }
         else if (tensorDimensionsAdded == 2)
         {
-            afterConcatInfo.SetShape(armnn::TensorShape({ afterConcatInfo.GetShape()[2] }));
+            afterConcatInfo.SetShape(armnn::TensorShape({afterConcatInfo.GetShape()[2]}));
         }
 
         armnn::ReshapeDescriptor reshapeDescriptor;
         reshapeDescriptor.m_TargetShape = afterConcatInfo.GetShape();
+        armnn::TensorInfo concatInfo = layer->GetOutputSlot(0).GetTensorInfo();
 
-        bool isSupported = false;
-        FORWARD_LAYER_SUPPORT_FUNC(__func__,
-                                   IsReshapeSupported,
-                                   data.m_Backends,
-                                   isSupported,
-                                   layer->GetOutputSlot(0).GetTensorInfo(),
-                                   afterConcatInfo,
-                                   reshapeDescriptor);
+        isSupported = false;
+        auto validateReshapeFunc = [&](const armnn::TensorInfo& afterConcatInfo, bool& isSupported){
+            FORWARD_LAYER_SUPPORT_FUNC(__func__,
+                                       IsReshapeSupported,
+                                       data.m_Backends,
+                                       isSupported,
+                                       concatInfo,
+                                       afterConcatInfo,
+                                       reshapeDescriptor);
+        };
+
+        if (!IsDynamicTensor(afterConcatInfo))
+        {
+            validateReshapeFunc(afterConcatInfo, isSupported);
+        }
+        else
+        {
+            isSupported = AreDynamicTensorsSupported();
+        }
+
         if (!isSupported)
         {
             return false;
         }
 
-        layer = &AddReshapeLayer(
-                *data.m_Network,
-                layer->GetOutputSlot(0),
-                afterConcatInfo
-        );
+        layer = &AddReshapeLayer(*data.m_Network, layer->GetOutputSlot(0), afterConcatInfo);
+        return SetupAndTrackLayerOutputSlot<HalPolicy>(operation,
+                                                       0,
+                                                       *layer,
+                                                       model,
+                                                       data,
+                                                       nullptr,
+                                                       validateReshapeFunc);
     }
 
-    return SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, model, data);
+    return SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, model, data, nullptr, validateFunc);
 }
 
 template<typename HalPolicy,