IVGCVSW-2127- Update HAL Policy for merger
* Remove permutation when concat axis is inner most
* Add additional parameter to IsMergerSupported as changed in armnn
!armnn:151
Change-Id: Ie214c9573f242d8f04d58fc61621ad3831991d9a
diff --git a/1.0/HalPolicy.cpp b/1.0/HalPolicy.cpp
index d0bd95b..719d1a2 100644
--- a/1.0/HalPolicy.cpp
+++ b/1.0/HalPolicy.cpp
@@ -241,17 +241,22 @@
}
else if (tensorDimensionsAdded == 2)
{
- outputShape = armnn::TensorShape({1, 1, outputShape[0], outputShape[1]});
+ outputShape = armnn::TensorShape({1, 1, outputShape[0]});
}
}
- // Get the pair of permutations required for the concatenation
+ // Check if permutations is required and get the pair of permutations required for the concatenation.
+ // Permutation is required when the concat dimension is 2 for a 4D tensor or 1 for a 3D tensor.
std::pair<armnn::PermutationVector, armnn::PermutationVector> permutationPair =
std::make_pair(IdentityPermutation4D, IdentityPermutation4D);
- CreatePermutationParameters(inputShapes[0].GetNumDimensions(), concatDim, permutationPair);
+ bool needPermute = CreateConcatPermutationParameters(inputShapes[0].GetNumDimensions(), concatDim, permutationPair);
- outputShape = armnnUtils::Permuted(outputShape, permutationPair.first);
+ if (needPermute)
+ {
+ outputShape = armnnUtils::Permuted(outputShape, permutationPair.first);
+ }
+
outputInfo.SetShape(outputShape);
// this is no-op for identity swizzles, otherwise it replaces both
@@ -260,10 +265,11 @@
// Create an armnn merger layer descriptor - this will also perform validation on the input shapes
armnn::OriginsDescriptor mergerDescriptor;
+
try
{
- // The merger descriptor is always created across the only supported concat
- // dimension, which is 0 or 1
+ // The merger descriptor is always created across the only supported concat dimension
+ // which is 0, 1 or 3 for a 4-D tensor, or 0 or 2 for a 3-D tensor.
mergerDescriptor =
armnn::CreateMergerDescriptorForConcatenation(
inputShapes.begin(), inputShapes.end(), concatDim);
@@ -274,7 +280,7 @@
}
// Validate the output shape is correct given the input shapes based on the
- // only valid concat dimension which is 0 or 1
+ // only valid concat dimension which is 0, 1 or 3 for a 4-D tensor, or 0 or 2 for a 3-D tensor.
if (!ValidateConcatOutputShape(inputShapes, outputShape, concatDim))
{
return Fail("%s: Error validating the output shape for concat", __func__);
@@ -287,6 +293,7 @@
armnn::IsMergerSupported,
data.m_Compute,
inputTensorInfos,
+ outputInfo,
mergerDescriptor))
{
return false;
@@ -305,11 +312,14 @@
inputHandles[static_cast<unsigned int>(i)].Connect(layer->GetInputSlot(i));
}
- // Add permutation layer and connect the output to it, the permutation becomes the output layer
- armnn::IConnectableLayer& deswizzleLayer = AddPermuteLayer(*data.m_Network,
- layer->GetOutputSlot(0),
- permutationPair.second);
- layer = &deswizzleLayer;
+ if (needPermute)
+ {
+ // Add permutation layer and connect the output to it, the permutation becomes the output layer
+ armnn::IConnectableLayer& deswizzleLayer = AddPermuteLayer(*data.m_Network,
+ layer->GetOutputSlot(0),
+ permutationPair.second);
+ layer = &deswizzleLayer;
+ }
if (inputsHaveBeenReshaped)
{
@@ -323,8 +333,7 @@
}
else if (tensorDimensionsAdded == 2)
{
- afterConcatInfo.SetShape(armnn::TensorShape({ afterConcatInfo.GetShape()[2],
- afterConcatInfo.GetShape()[3] }));
+ afterConcatInfo.SetShape(armnn::TensorShape({ afterConcatInfo.GetShape()[2] }));
}
layer = &AddReshapeLayer(
diff --git a/ConversionUtils.hpp b/ConversionUtils.hpp
index 68ce09d..c86ad93 100644
--- a/ConversionUtils.hpp
+++ b/ConversionUtils.hpp
@@ -390,50 +390,29 @@
}
}
-void CreatePermutationParameters(const unsigned int numberOfDimensions,
- int32_t & concatDimension,
- std::pair<armnn::PermutationVector, armnn::PermutationVector> & permutationPair)
+bool CreateConcatPermutationParameters(const unsigned int numberOfDimensions,
+ int32_t & concatDimension,
+ std::pair<armnn::PermutationVector, armnn::PermutationVector> & permutationPair)
{
+ bool needPermute = false;
BOOST_ASSERT(numberOfDimensions >= 3);
// ArmNN uses Compute Library subtensors to perform concatenation
- // This only works when concatenating along dimension 0 or 1 for a 4-D tensor,
- // or along dimension 0 for a 3-D tensor.
- if (numberOfDimensions == 4)
+ // This only works when concatenating along dimension 0, 1 or 3 for a 4-D tensor,
+ // or along dimension 0 or 2 for a 3-D tensor.
+ if (numberOfDimensions == 4 && concatDimension == 2)
{
- if (concatDimension == 3)
- {
- concatDimension = 1;
- permutationPair = std::make_pair(NHWCToArmNN, ArmNNToNHWC);
- }
- else if (concatDimension == 2)
- {
- concatDimension = 1;
- permutationPair = std::make_pair(SwapDim1And2, SwapDim1And2);
- }
- else
- {
- permutationPair = std::make_pair(IdentityPermutation4D, IdentityPermutation4D);
- }
-
+ concatDimension = 1;
+ permutationPair = std::make_pair(SwapDim1And2, SwapDim1And2);
+ needPermute = true;
}
- else if (numberOfDimensions == 3)
+ else if (numberOfDimensions == 3 && concatDimension == 1)
{
- if (concatDimension == 2)
- {
- concatDimension = 0;
- permutationPair = std::make_pair(RotateTensorRight, RotateTensorLeft);
- }
- else if (concatDimension == 1)
- {
- concatDimension = 0;
- permutationPair = std::make_pair(RotateTensorLeft, RotateTensorRight);
- }
- else
- {
- permutationPair = std::make_pair(IdentityPermutation3D, IdentityPermutation3D);
- }
+ concatDimension = 0;
+ permutationPair = std::make_pair(RotateTensorLeft, RotateTensorRight);
+ needPermute = true;
}
+ return needPermute;
}
} // anonymous namespace