IVGCVSW-2127- Update HAL Policy for merger
* Remove permutation when concat axis is inner most
* Add additional parameter to IsMergerSupported as changed in armnn
!armnn:151
Change-Id: Ie214c9573f242d8f04d58fc61621ad3831991d9a
diff --git a/1.0/HalPolicy.cpp b/1.0/HalPolicy.cpp
index d0bd95b..719d1a2 100644
--- a/1.0/HalPolicy.cpp
+++ b/1.0/HalPolicy.cpp
@@ -241,17 +241,22 @@
}
else if (tensorDimensionsAdded == 2)
{
- outputShape = armnn::TensorShape({1, 1, outputShape[0], outputShape[1]});
+ outputShape = armnn::TensorShape({1, 1, outputShape[0]});
}
}
- // Get the pair of permutations required for the concatenation
+ // Check if permutations is required and get the pair of permutations required for the concatenation.
+ // Permutation is required when the concat dimension is 2 for a 4D tensor or 1 for a 3D tensor.
std::pair<armnn::PermutationVector, armnn::PermutationVector> permutationPair =
std::make_pair(IdentityPermutation4D, IdentityPermutation4D);
- CreatePermutationParameters(inputShapes[0].GetNumDimensions(), concatDim, permutationPair);
+ bool needPermute = CreateConcatPermutationParameters(inputShapes[0].GetNumDimensions(), concatDim, permutationPair);
- outputShape = armnnUtils::Permuted(outputShape, permutationPair.first);
+ if (needPermute)
+ {
+ outputShape = armnnUtils::Permuted(outputShape, permutationPair.first);
+ }
+
outputInfo.SetShape(outputShape);
// this is no-op for identity swizzles, otherwise it replaces both
@@ -260,10 +265,11 @@
// Create an armnn merger layer descriptor - this will also perform validation on the input shapes
armnn::OriginsDescriptor mergerDescriptor;
+
try
{
- // The merger descriptor is always created across the only supported concat
- // dimension, which is 0 or 1
+ // The merger descriptor is always created across the only supported concat dimension
+ // which is 0, 1 or 3 for a 4-D tensor, or 0 or 2 for a 3-D tensor.
mergerDescriptor =
armnn::CreateMergerDescriptorForConcatenation(
inputShapes.begin(), inputShapes.end(), concatDim);
@@ -274,7 +280,7 @@
}
// Validate the output shape is correct given the input shapes based on the
- // only valid concat dimension which is 0 or 1
+ // only valid concat dimension which is 0, 1 or 3 for a 4-D tensor, or 0 or 2 for a 3-D tensor.
if (!ValidateConcatOutputShape(inputShapes, outputShape, concatDim))
{
return Fail("%s: Error validating the output shape for concat", __func__);
@@ -287,6 +293,7 @@
armnn::IsMergerSupported,
data.m_Compute,
inputTensorInfos,
+ outputInfo,
mergerDescriptor))
{
return false;
@@ -305,11 +312,14 @@
inputHandles[static_cast<unsigned int>(i)].Connect(layer->GetInputSlot(i));
}
- // Add permutation layer and connect the output to it, the permutation becomes the output layer
- armnn::IConnectableLayer& deswizzleLayer = AddPermuteLayer(*data.m_Network,
- layer->GetOutputSlot(0),
- permutationPair.second);
- layer = &deswizzleLayer;
+ if (needPermute)
+ {
+ // Add permutation layer and connect the output to it, the permutation becomes the output layer
+ armnn::IConnectableLayer& deswizzleLayer = AddPermuteLayer(*data.m_Network,
+ layer->GetOutputSlot(0),
+ permutationPair.second);
+ layer = &deswizzleLayer;
+ }
if (inputsHaveBeenReshaped)
{
@@ -323,8 +333,7 @@
}
else if (tensorDimensionsAdded == 2)
{
- afterConcatInfo.SetShape(armnn::TensorShape({ afterConcatInfo.GetShape()[2],
- afterConcatInfo.GetShape()[3] }));
+ afterConcatInfo.SetShape(armnn::TensorShape({ afterConcatInfo.GetShape()[2] }));
}
layer = &AddReshapeLayer(