IVGCVSW-3843 Add support of per-axis quantization to BuildArmComputeTensorInfo
Signed-off-by: Aron Virginas-Tar <Aron.Virginas-Tar@arm.com>
Change-Id: I0bb0e9da306eee3e19dc9967a6c8bb01da998deb
diff --git a/src/backends/aclCommon/ArmComputeTensorUtils.cpp b/src/backends/aclCommon/ArmComputeTensorUtils.cpp
index c7d250a..b2955b9 100644
--- a/src/backends/aclCommon/ArmComputeTensorUtils.cpp
+++ b/src/backends/aclCommon/ArmComputeTensorUtils.cpp
@@ -106,9 +106,11 @@
arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tensorInfo)
{
const arm_compute::TensorShape aclTensorShape = BuildArmComputeTensorShape(tensorInfo.GetShape());
- const arm_compute::DataType aclDataType = GetArmComputeDataType(tensorInfo.GetDataType());
- const arm_compute::QuantizationInfo aclQuantizationInfo(tensorInfo.GetQuantizationScale(),
- tensorInfo.GetQuantizationOffset());
+ const arm_compute::DataType aclDataType = GetArmComputeDataType(tensorInfo.GetDataType());
+
+ const arm_compute::QuantizationInfo aclQuantizationInfo = tensorInfo.HasMultipleQuantizationScales() ?
+ arm_compute::QuantizationInfo(tensorInfo.GetQuantizationScales()) :
+ arm_compute::QuantizationInfo(tensorInfo.GetQuantizationScale(), tensorInfo.GetQuantizationOffset());
return arm_compute::TensorInfo(aclTensorShape, 1, aclDataType, aclQuantizationInfo);
}
@@ -116,15 +118,10 @@
arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tensorInfo,
armnn::DataLayout dataLayout)
{
- const arm_compute::TensorShape aclTensorShape = BuildArmComputeTensorShape(tensorInfo.GetShape());
- const arm_compute::DataType aclDataType = GetArmComputeDataType(tensorInfo.GetDataType());
- const arm_compute::QuantizationInfo aclQuantizationInfo(tensorInfo.GetQuantizationScale(),
- tensorInfo.GetQuantizationOffset());
+ arm_compute::TensorInfo aclTensorInfo = BuildArmComputeTensorInfo(tensorInfo);
+ aclTensorInfo.set_data_layout(ConvertDataLayout(dataLayout));
- arm_compute::TensorInfo clTensorInfo(aclTensorShape, 1, aclDataType, aclQuantizationInfo);
- clTensorInfo.set_data_layout(ConvertDataLayout(dataLayout));
-
- return clTensorInfo;
+ return aclTensorInfo;
}
arm_compute::DataLayout ConvertDataLayout(armnn::DataLayout dataLayout)