IVGCVSW-4378 Fix transpose outputInfo for skipped Concat VTS in CL and Neon
Signed-off-by: Teresa Charlin <teresa.charlinreyes@arm.com>
Change-Id: I7962be3a77cacf15dad594f0a907499c5b39bfeb
diff --git a/ConversionUtils.hpp b/ConversionUtils.hpp
index 3b01b40..ebfc43b 100644
--- a/ConversionUtils.hpp
+++ b/ConversionUtils.hpp
@@ -579,20 +579,21 @@
}
}
-bool CheckReshapeSupported(ConversionData& data,
- std::vector<LayerInputHandle>& inputs,
- std::vector<armnn::TensorShape>& inputShapes,
- const armnn::PermutationVector& mapping,
- const armnn::TensorInfo& outputInfo)
+bool TransposeInputTensors(ConversionData& data,
+ std::vector<LayerInputHandle>& inputs,
+ std::vector<armnn::TensorShape>& inputShapes,
+ const armnn::PermutationVector& mapping)
{
if (!mapping.IsEqual(IdentityPermutation4D))
{
+ armnn::TensorInfo outputTransposeInfo;
size_t nInputs = inputs.size();
for (size_t i=0; i<nInputs; ++i)
{
// check permute layer
armnn::TransposeDescriptor transposeDesc;
transposeDesc.m_DimMappings = mapping;
+ outputTransposeInfo = armnnUtils::TransposeTensorShape(inputs[i].GetTensorInfo(), mapping);
bool isSupported = false;
FORWARD_LAYER_SUPPORT_FUNC(__func__,
@@ -600,7 +601,7 @@
data.m_Backends,
isSupported,
inputs[i].GetTensorInfo(),
- outputInfo,
+ outputTransposeInfo,
transposeDesc);
if (!isSupported)
{
@@ -1985,7 +1986,7 @@
// this is no-op for identity swizzles, otherwise it replaces both
// the handles and shapes with the swizzled layer output handles and shapes
- if (!CheckReshapeSupported(data, inputHandles, inputShapes, permutationPair.first, outputInfo))
+ if (!TransposeInputTensors(data, inputHandles, inputShapes, permutationPair.first))
{
return false;
}
@@ -2046,14 +2047,17 @@
{
armnn::TransposeDescriptor transposeDesc;
transposeDesc.m_DimMappings = permutationPair.second;
+ armnn::TensorInfo inputTransposeInfo = layer->GetOutputSlot(0).GetTensorInfo();
+ armnn::TensorInfo outputTransposeInfo = armnnUtils::TransposeTensorShape(inputTransposeInfo,
+ permutationPair.second);
bool isSupported = false;
FORWARD_LAYER_SUPPORT_FUNC(__func__,
IsTransposeSupported,
data.m_Backends,
isSupported,
- layer->GetOutputSlot(0).GetTensorInfo(),
- outputInfo,
+ inputTransposeInfo,
+ outputTransposeInfo,
transposeDesc);
if (!isSupported)
{