IVGCVSW-4511 Add BFloat16 to RefLayerSupport and unit tests
Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com>
Change-Id: Ifaae4d5aac468ba927b2c6a4bf31b8c8522aeb2e
diff --git a/src/backends/backendsCommon/MakeWorkloadHelper.hpp b/src/backends/backendsCommon/MakeWorkloadHelper.hpp
index 7ef140e..8abc8a6 100644
--- a/src/backends/backendsCommon/MakeWorkloadHelper.hpp
+++ b/src/backends/backendsCommon/MakeWorkloadHelper.hpp
@@ -52,6 +52,7 @@
switch (dataType)
{
+
case DataType::Float16:
return MakeWorkloadForType<Float16Workload>::Func(descriptor, info, std::forward<Args>(args)...);
case DataType::Float32:
@@ -65,6 +66,7 @@
return MakeWorkloadForType<Int32Workload>::Func(descriptor, info, std::forward<Args>(args)...);
case DataType::Boolean:
return MakeWorkloadForType<BooleanWorkload>::Func(descriptor, info, std::forward<Args>(args)...);
+ case DataType::BFloat16:
case DataType::QSymmS16:
return nullptr;
default:
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index bb0c21f..b501b3d 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -26,6 +26,8 @@
{
switch (inputDataType)
{
+ case DataType::BFloat16:
+ return DataType::BFloat16;
case DataType::Float16:
return DataType::Float16;
case DataType::Float32:
@@ -599,6 +601,7 @@
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float16,
DataType::Float32,
DataType::QAsymmS8,
@@ -628,6 +631,7 @@
std::vector<DataType> supportedInputTypes =
{
+ DataType::BFloat16,
DataType::Float16,
DataType::Float32,
DataType::QAsymmU8,
@@ -685,6 +689,7 @@
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float16,
DataType::Float32,
DataType::QAsymmS8,
@@ -706,6 +711,7 @@
// Check the supported data types
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::Boolean,
@@ -842,6 +848,7 @@
// Check the supported data types
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::Boolean,
@@ -929,6 +936,7 @@
// Check the supported data types
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::Boolean,
@@ -992,6 +1000,7 @@
// Check the supported data types
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QAsymmS8,
@@ -1016,6 +1025,7 @@
// Check the supported data types
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float16,
DataType::Float32,
DataType::QAsymmU8,
@@ -1042,6 +1052,7 @@
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QAsymmS8,
@@ -1077,6 +1088,7 @@
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::QAsymmU8,
DataType::QAsymmS8,
@@ -1111,6 +1123,7 @@
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float16,
DataType::Float32,
DataType::QAsymmU8,
@@ -1183,6 +1196,7 @@
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::QAsymmS8,
DataType::QAsymmU8,
@@ -1258,6 +1272,7 @@
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::QAsymmU8,
DataType::QAsymmS8,
@@ -1313,6 +1328,7 @@
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QAsymmS8,
@@ -1339,6 +1355,7 @@
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float16,
DataType::Float32,
DataType::QAsymmU8,
@@ -1386,6 +1403,7 @@
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float16,
DataType::Float32,
DataType::QAsymmS8,
@@ -1460,6 +1478,7 @@
// Check the supported data types
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16
};
@@ -1488,6 +1507,7 @@
// Check the supported data types
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QAsymmU8,
@@ -1512,6 +1532,7 @@
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
};
@@ -1538,6 +1559,7 @@
// Check the supported data types
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::Signed32,
@@ -1565,6 +1587,7 @@
// Check the supported data types
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::Signed32,
@@ -1632,10 +1655,11 @@
std::vector<DataType> supportedTypes =
{
- DataType::Float16,
- DataType::Float32,
- DataType::QAsymmU8,
- DataType::QSymmS16
+ DataType::BFloat16,
+ DataType::Float16,
+ DataType::Float32,
+ DataType::QAsymmU8,
+ DataType::QSymmS16
};
ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
@@ -1657,6 +1681,7 @@
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QAsymmU8,
@@ -1705,6 +1730,7 @@
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QSymmS16
@@ -1736,6 +1762,7 @@
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float16,
DataType::Float32,
DataType::QSymmS16
@@ -2051,7 +2078,8 @@
DataType::Float32,
DataType::QAsymmU8,
DataType::QSymmS16,
- DataType::Float16
+ DataType::Float16,
+ DataType::BFloat16
};
ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
@@ -2082,7 +2110,8 @@
DataType::Float32,
DataType::QAsymmU8,
DataType::QSymmS16,
- DataType::Float16
+ DataType::Float16,
+ DataType::BFloat16
};
ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
@@ -2110,6 +2139,7 @@
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float16,
DataType::Float32,
DataType::Signed32,
@@ -2142,6 +2172,7 @@
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QAsymmU8,
@@ -2206,6 +2237,7 @@
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QSymmS8,
@@ -2234,6 +2266,7 @@
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QAsymmU8,
@@ -2256,6 +2289,7 @@
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float16,
DataType::Float32,
DataType::QAsymmU8,
@@ -2312,6 +2346,7 @@
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float16,
DataType::Float32,
DataType::Signed32,
@@ -2401,6 +2436,7 @@
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float16,
DataType::Float32,
DataType::QAsymmU8,
@@ -2429,6 +2465,7 @@
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float16,
DataType::Float32,
DataType::QAsymmU8,
@@ -2475,6 +2512,7 @@
const std::vector<DataType> supportedInputTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QAsymmU8,
@@ -2526,6 +2564,7 @@
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16
};
@@ -2566,6 +2605,7 @@
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::QAsymmU8,
DataType::QSymmS16
@@ -2608,6 +2648,7 @@
std::vector<DataType> supportedTypes
{
+ DataType::BFloat16,
DataType::Float16,
DataType::Float32,
DataType::QAsymmU8,
@@ -2670,6 +2711,7 @@
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QAsymmU8,
@@ -2894,6 +2936,7 @@
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float16,
DataType::Float32,
DataType::QAsymmU8,
@@ -2974,6 +3017,7 @@
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float32,
DataType::Float16,
DataType::QAsymmU8,
@@ -3048,6 +3092,7 @@
std::vector<DataType> supportedTypes =
{
+ DataType::BFloat16,
DataType::Float16,
DataType::Float32,
DataType::QAsymmU8,
diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp
index 6ac76ec..2e1ce0a 100644
--- a/src/backends/backendsCommon/WorkloadFactory.cpp
+++ b/src/backends/backendsCommon/WorkloadFactory.cpp
@@ -325,6 +325,7 @@
TensorInfo biasInfo;
const TensorInfo * biasInfoPtr = nullptr;
+ static const TensorInfo dummyBFloat16Bias(TensorShape({1,1,1,1}), DataType::BFloat16);
static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16);
static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32);
static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32);
@@ -341,6 +342,11 @@
// If biases are not enabled pass a dummy tensorinfo for the validation
switch(input.GetDataType())
{
+ case DataType::BFloat16:
+ {
+ biasInfoPtr = &dummyBFloat16Bias;
+ break;
+ }
case DataType::Float16:
{
biasInfoPtr = &dummyFloat16Bias;
diff --git a/src/backends/backendsCommon/test/WorkloadTestUtils.hpp b/src/backends/backendsCommon/test/WorkloadTestUtils.hpp
index 5168333..df001b7 100644
--- a/src/backends/backendsCommon/test/WorkloadTestUtils.hpp
+++ b/src/backends/backendsCommon/test/WorkloadTestUtils.hpp
@@ -95,6 +95,7 @@
switch(weightsType.value())
{
+ case armnn::DataType::BFloat16:
case armnn::DataType::Float16:
case armnn::DataType::Float32:
return weightsType;
diff --git a/src/backends/backendsCommon/test/layerTests/ConcatTestImpl.cpp b/src/backends/backendsCommon/test/layerTests/ConcatTestImpl.cpp
index f6f4b09..1e40b42 100644
--- a/src/backends/backendsCommon/test/layerTests/ConcatTestImpl.cpp
+++ b/src/backends/backendsCommon/test/layerTests/ConcatTestImpl.cpp
@@ -2342,6 +2342,13 @@
return Concat3dDim1TestImpl<DataType::Float16>(workloadFactory, memoryManager, 0.0f, 0);
}
+LayerTestResult<BFloat16, 3> ConcatBFloat16Test(
+ IWorkloadFactory& workloadFactory,
+ const IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
+{
+ return Concat3dDim1TestImpl<DataType::BFloat16>(workloadFactory, memoryManager, 0.0f, 0);
+}
+
LayerTestResult<uint8_t, 3> ConcatUint8DifferentQParamsTest(
IWorkloadFactory& workloadFactory,
const IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
diff --git a/src/backends/backendsCommon/test/layerTests/ConcatTestImpl.hpp b/src/backends/backendsCommon/test/layerTests/ConcatTestImpl.hpp
index 4ce9d29..167a547 100644
--- a/src/backends/backendsCommon/test/layerTests/ConcatTestImpl.hpp
+++ b/src/backends/backendsCommon/test/layerTests/ConcatTestImpl.hpp
@@ -7,8 +7,9 @@
#include "LayerTestResult.hpp"
-#include <ResolveType.hpp>
+#include <BFloat16.hpp>
#include <Half.hpp>
+#include <ResolveType.hpp>
#include <armnn/backends/IBackendInternal.hpp>
#include <backendsCommon/WorkloadFactory.hpp>
@@ -23,6 +24,10 @@
armnn::IWorkloadFactory& workloadFactory,
const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
+LayerTestResult<armnn::BFloat16, 3> ConcatBFloat16Test(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
+
LayerTestResult<armnn::Half, 3> ConcatFloat16Test(
armnn::IWorkloadFactory& workloadFactory,
const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
diff --git a/src/backends/backendsCommon/test/layerTests/Conv2dTestImpl.cpp b/src/backends/backendsCommon/test/layerTests/Conv2dTestImpl.cpp
index 89cdd96..e1babd3 100644
--- a/src/backends/backendsCommon/test/layerTests/Conv2dTestImpl.cpp
+++ b/src/backends/backendsCommon/test/layerTests/Conv2dTestImpl.cpp
@@ -2791,6 +2791,12 @@
//
// Explicit template specializations
//
+template LayerTestResult<armnn::ResolveType<armnn::DataType::BFloat16>, 4>
+Convolution2d3x3Dilation3x3Test<armnn::DataType::BFloat16, armnn::DataType::BFloat16>(
+ armnn::IWorkloadFactory&,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr&,
+ bool,
+ armnn::DataLayout);
template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 4>
Convolution2d3x3Dilation3x3Test<armnn::DataType::Float32, armnn::DataType::Float32>(
@@ -2820,6 +2826,13 @@
bool,
armnn::DataLayout);
+template LayerTestResult<armnn::ResolveType<armnn::DataType::BFloat16>, 4>
+Convolution2d2x3x3Dilation3x3Test<armnn::DataType::BFloat16, armnn::DataType::BFloat16>(
+ armnn::IWorkloadFactory&,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr&,
+ bool,
+ armnn::DataLayout);
+
template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmU8>, 4>
Convolution2d2x3x3Dilation3x3Test<armnn::DataType::QAsymmU8, armnn::DataType::Signed32>(
armnn::IWorkloadFactory&,
@@ -2834,6 +2847,13 @@
bool,
armnn::DataLayout);
+template LayerTestResult<armnn::ResolveType<armnn::DataType::BFloat16>, 4>
+Convolution2d2x2Dilation2x2Padding2x2Stride3x3Test<armnn::DataType::BFloat16, armnn::DataType::BFloat16>(
+ armnn::IWorkloadFactory &workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr &memoryManager,
+ bool biasEnabled,
+ const armnn::DataLayout layout);
+
template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 4>
Convolution2d2x2Dilation2x2Padding2x2Stride3x3Test<armnn::DataType::Float32, armnn::DataType::Float32>(
armnn::IWorkloadFactory &workloadFactory,
@@ -2855,6 +2875,13 @@
bool biasEnabled,
const armnn::DataLayout layout);
+template LayerTestResult<armnn::ResolveType<armnn::DataType::BFloat16>, 4>
+DepthwiseConvolution2d3x3Dilation3x3Test<armnn::DataType::BFloat16, armnn::DataType::BFloat16>(
+ armnn::IWorkloadFactory&,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr&,
+ bool,
+ armnn::DataLayout);
+
template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 4>
DepthwiseConvolution2d3x3Dilation3x3Test<armnn::DataType::Float32, armnn::DataType::Float32>(
armnn::IWorkloadFactory&,
@@ -2876,6 +2903,13 @@
bool,
armnn::DataLayout);
+template LayerTestResult<armnn::ResolveType<armnn::DataType::BFloat16>, 4>
+DepthwiseConvolution2d2x3x3Dilation3x3Test<armnn::DataType::BFloat16, armnn::DataType::BFloat16>(
+ armnn::IWorkloadFactory&,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr&,
+ bool,
+ armnn::DataLayout);
+
template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 4>
DepthwiseConvolution2d2x3x3Dilation3x3Test<armnn::DataType::Float32, armnn::DataType::Float32>(
armnn::IWorkloadFactory&,
@@ -2897,6 +2931,13 @@
bool,
armnn::DataLayout);
+template LayerTestResult<armnn::ResolveType<armnn::DataType::BFloat16>, 4>
+DepthwiseConvolution2dMult4Test<armnn::DataType::BFloat16, armnn::DataType::BFloat16>(
+ armnn::IWorkloadFactory &workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr &memoryManager,
+ bool biasEnabled,
+ const armnn::DataLayout layout);
+
template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 4>
DepthwiseConvolution2dMult4Test<armnn::DataType::Float32, armnn::DataType::Float32>(
armnn::IWorkloadFactory &workloadFactory,
@@ -2904,6 +2945,13 @@
bool biasEnabled,
const armnn::DataLayout layout);
+template LayerTestResult<armnn::ResolveType<armnn::DataType::BFloat16>, 4>
+DepthwiseConvolution2dMult2Test<armnn::DataType::BFloat16, armnn::DataType::BFloat16>(
+ armnn::IWorkloadFactory &workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr &memoryManager,
+ bool biasEnabled,
+ const armnn::DataLayout layout);
+
template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 4>
DepthwiseConvolution2dMult2Test<armnn::DataType::Float32, armnn::DataType::Float32>(
armnn::IWorkloadFactory &workloadFactory,
diff --git a/src/backends/backendsCommon/test/layerTests/PadTestImpl.cpp b/src/backends/backendsCommon/test/layerTests/PadTestImpl.cpp
index 69c651b..120572c 100644
--- a/src/backends/backendsCommon/test/layerTests/PadTestImpl.cpp
+++ b/src/backends/backendsCommon/test/layerTests/PadTestImpl.cpp
@@ -497,3 +497,31 @@
{
return Pad4dTestCommon<armnn::DataType::Float32>(workloadFactory, memoryManager, 0.0f, 0);
}
+
+LayerTestResult<armnn::BFloat16, 2> PadBFloat162dTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
+{
+ return Pad2dTestCommon<armnn::DataType::BFloat16>(workloadFactory, memoryManager, 0.0f, 0);
+}
+
+LayerTestResult<armnn::BFloat16, 2> PadBFloat162dCustomPaddingTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
+{
+ return Pad2dTestCommon<armnn::DataType::BFloat16>(workloadFactory, memoryManager, 0.0f, 0, 1.0f);
+}
+
+LayerTestResult<armnn::BFloat16, 3> PadBFloat163dTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
+{
+ return Pad3dTestCommon<armnn::DataType::BFloat16>(workloadFactory, memoryManager, 0.0f, 0);
+}
+
+LayerTestResult<armnn::BFloat16, 4> PadBFloat164dTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
+{
+ return Pad4dTestCommon<armnn::DataType::BFloat16>(workloadFactory, memoryManager, 0.0f, 0);
+}
diff --git a/src/backends/backendsCommon/test/layerTests/PadTestImpl.hpp b/src/backends/backendsCommon/test/layerTests/PadTestImpl.hpp
index bc51488..34aa6c6 100644
--- a/src/backends/backendsCommon/test/layerTests/PadTestImpl.hpp
+++ b/src/backends/backendsCommon/test/layerTests/PadTestImpl.hpp
@@ -67,3 +67,19 @@
LayerTestResult<float, 4> PadFloat324dTest(
armnn::IWorkloadFactory& workloadFactory,
const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
+
+LayerTestResult<armnn::BFloat16, 2> PadBFloat162dTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
+
+LayerTestResult<armnn::BFloat16, 2> PadBFloat162dCustomPaddingTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
+
+LayerTestResult<armnn::BFloat16, 3> PadBFloat163dTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
+
+LayerTestResult<armnn::BFloat16, 4> PadBFloat164dTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
diff --git a/src/backends/backendsCommon/test/layerTests/PermuteTestImpl.hpp b/src/backends/backendsCommon/test/layerTests/PermuteTestImpl.hpp
index 71e1533..96d4ec8 100644
--- a/src/backends/backendsCommon/test/layerTests/PermuteTestImpl.hpp
+++ b/src/backends/backendsCommon/test/layerTests/PermuteTestImpl.hpp
@@ -72,27 +72,31 @@
outputTensorInfo = armnn::TensorInfo(4, outputShape, ArmnnType);
// Set quantization parameters if the requested type is a quantized type.
+ float qScale = 0.5f;
+ int32_t qOffset = 5;
if(armnn::IsQuantizedType<T>())
{
- inputTensorInfo.SetQuantizationScale(0.5f);
- inputTensorInfo.SetQuantizationOffset(5);
- outputTensorInfo.SetQuantizationScale(0.5f);
- outputTensorInfo.SetQuantizationOffset(5);
+ inputTensorInfo.SetQuantizationScale(qScale);
+ inputTensorInfo.SetQuantizationOffset(qOffset);
+ outputTensorInfo.SetQuantizationScale(qScale);
+ outputTensorInfo.SetQuantizationOffset(qOffset);
}
- std::vector<T> input = std::vector<T>(
+ std::vector<T> input = armnnUtils::QuantizedVector<T>(
{
1, 2,
3, 4,
5, 6,
7, 8
- });
+ },
+ qScale, qOffset);
- std::vector<T> outputExpected = std::vector<T>(
+ std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>(
{
1, 5, 2, 6,
3, 7, 4, 8
- });
+ },
+ qScale, qOffset);
return SimplePermuteTestImpl<T>(workloadFactory, memoryManager,
descriptor, inputTensorInfo,
@@ -117,28 +121,32 @@
outputTensorInfo = armnn::TensorInfo(4, outputShape, ArmnnType);
// Set quantization parameters if the requested type is a quantized type.
+ float qScale = 0.5f;
+ int32_t qOffset = 5;
if(armnn::IsQuantizedType<T>())
{
- inputTensorInfo.SetQuantizationScale(0.5f);
- inputTensorInfo.SetQuantizationOffset(5);
- outputTensorInfo.SetQuantizationScale(0.5f);
- outputTensorInfo.SetQuantizationOffset(5);
+ inputTensorInfo.SetQuantizationScale(qScale);
+ inputTensorInfo.SetQuantizationOffset(qOffset);
+ outputTensorInfo.SetQuantizationScale(qScale);
+ outputTensorInfo.SetQuantizationOffset(qOffset);
}
- std::vector<T> input = std::vector<T>(
+ std::vector<T> input = armnnUtils::QuantizedVector<T>(
{
1, 2, 3,
11, 12, 13,
21, 22, 23,
31, 32, 33
- });
+ },
+ qScale, qOffset);
- std::vector<T> outputExpected = std::vector<T>(
+ std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>(
{
1, 11, 21, 31,
2, 12, 22, 32,
3, 13, 23, 33
- });
+ },
+ qScale, qOffset);
return SimplePermuteTestImpl<T>(workloadFactory, memoryManager,
descriptor, inputTensorInfo,
@@ -163,28 +171,32 @@
outputTensorInfo = armnn::TensorInfo(4, outputShape, ArmnnType);
// Set quantization parameters if the requested type is a quantized type.
+ float qScale = 0.5f;
+ int32_t qOffset = 5;
if(armnn::IsQuantizedType<T>())
{
- inputTensorInfo.SetQuantizationScale(0.5f);
- inputTensorInfo.SetQuantizationOffset(5);
- outputTensorInfo.SetQuantizationScale(0.5f);
- outputTensorInfo.SetQuantizationOffset(5);
+ inputTensorInfo.SetQuantizationScale(qScale);
+ inputTensorInfo.SetQuantizationOffset(qOffset);
+ outputTensorInfo.SetQuantizationScale(qScale);
+ outputTensorInfo.SetQuantizationOffset(qOffset);
}
- std::vector<T> input = std::vector<T>(
+ std::vector<T> input = armnnUtils::QuantizedVector<T>(
{
1, 11, 21, 31,
2, 12, 22, 32,
3, 13, 23, 33
- });
+ },
+ qScale, qOffset);
- std::vector<T> outputExpected = std::vector<T>(
+ std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>(
{
1, 2, 3,
11, 12, 13,
21, 22, 23,
31, 32, 33,
- });
+ },
+ qScale, qOffset);
return SimplePermuteTestImpl<T>(workloadFactory, memoryManager,
descriptor, inputTensorInfo,
@@ -209,30 +221,34 @@
outputTensorInfo = armnn::TensorInfo(4, outputShape, ArmnnType);
// Set quantization parameters if the requested type is a quantized type.
+ float qScale = 0.5f;
+ int32_t qOffset = 5;
if(armnn::IsQuantizedType<T>())
{
- inputTensorInfo.SetQuantizationScale(0.5f);
- inputTensorInfo.SetQuantizationOffset(5);
- outputTensorInfo.SetQuantizationScale(0.5f);
- outputTensorInfo.SetQuantizationOffset(5);
+ inputTensorInfo.SetQuantizationScale(qScale);
+ inputTensorInfo.SetQuantizationOffset(qOffset);
+ outputTensorInfo.SetQuantizationScale(qScale);
+ outputTensorInfo.SetQuantizationOffset(qOffset);
}
- std::vector<T> input = std::vector<T>(
- {
- 1, 2, 3,
- 11, 12, 13,
- 21, 22, 23,
- 31, 32, 33,
- 41, 42, 43,
- 51, 52, 53
- });
+ std::vector<T> input = armnnUtils::QuantizedVector<T>(
+ {
+ 1, 2, 3,
+ 11, 12, 13,
+ 21, 22, 23,
+ 31, 32, 33,
+ 41, 42, 43,
+ 51, 52, 53
+ },
+ qScale, qOffset);
- std::vector<T> outputExpected = std::vector<T>(
- {
- 1, 11, 21, 31, 41, 51,
- 2, 12, 22, 32, 42, 52,
- 3, 13, 23, 33, 43, 53
- });
+ std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>(
+ {
+ 1, 11, 21, 31, 41, 51,
+ 2, 12, 22, 32, 42, 52,
+ 3, 13, 23, 33, 43, 53
+ },
+ qScale, qOffset);
return SimplePermuteTestImpl<T>(workloadFactory, memoryManager,
descriptor, inputTensorInfo,
diff --git a/src/backends/backendsCommon/test/layerTests/TransposeTestImpl.hpp b/src/backends/backendsCommon/test/layerTests/TransposeTestImpl.hpp
index 0e0f317..5721952 100644
--- a/src/backends/backendsCommon/test/layerTests/TransposeTestImpl.hpp
+++ b/src/backends/backendsCommon/test/layerTests/TransposeTestImpl.hpp
@@ -72,27 +72,31 @@
outputTensorInfo = armnn::TensorInfo(4, outputShape, ArmnnType);
// Set quantization parameters if the requested type is a quantized type.
+ float qScale = 0.5f;
+ int32_t qOffset = 5;
if(armnn::IsQuantizedType<T>())
{
- inputTensorInfo.SetQuantizationScale(0.5f);
- inputTensorInfo.SetQuantizationOffset(5);
- outputTensorInfo.SetQuantizationScale(0.5f);
- outputTensorInfo.SetQuantizationOffset(5);
+ inputTensorInfo.SetQuantizationScale(qScale);
+ inputTensorInfo.SetQuantizationOffset(qOffset);
+ outputTensorInfo.SetQuantizationScale(qScale);
+ outputTensorInfo.SetQuantizationOffset(qOffset);
}
- std::vector<T> input = std::vector<T>(
+ std::vector<T> input = armnnUtils::QuantizedVector<T>(
{
1, 2,
3, 4,
5, 6,
7, 8
- });
+ },
+ qScale, qOffset);
- std::vector<T> outputExpected = std::vector<T>(
+ std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>(
{
1, 5, 2, 6,
3, 7, 4, 8
- });
+ },
+ qScale, qOffset);
return SimpleTransposeTestImpl<T>(workloadFactory, memoryManager,
descriptor, inputTensorInfo,
@@ -117,28 +121,32 @@
outputTensorInfo = armnn::TensorInfo(4, outputShape, ArmnnType);
// Set quantization parameters if the requested type is a quantized type.
+ float qScale = 0.5f;
+ int32_t qOffset = 5;
if(armnn::IsQuantizedType<T>())
{
- inputTensorInfo.SetQuantizationScale(0.5f);
- inputTensorInfo.SetQuantizationOffset(5);
- outputTensorInfo.SetQuantizationScale(0.5f);
- outputTensorInfo.SetQuantizationOffset(5);
+ inputTensorInfo.SetQuantizationScale(qScale);
+ inputTensorInfo.SetQuantizationOffset(qOffset);
+ outputTensorInfo.SetQuantizationScale(qScale);
+ outputTensorInfo.SetQuantizationOffset(qOffset);
}
- std::vector<T> input = std::vector<T>(
+ std::vector<T> input = armnnUtils::QuantizedVector<T>(
{
1, 2, 3,
11, 12, 13,
21, 22, 23,
31, 32, 33
- });
+ },
+ qScale, qOffset);
- std::vector<T> outputExpected = std::vector<T>(
+ std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>(
{
1, 11, 21, 31,
2, 12, 22, 32,
3, 13, 23, 33
- });
+ },
+ qScale, qOffset);
return SimpleTransposeTestImpl<T>(workloadFactory, memoryManager,
descriptor, inputTensorInfo,
@@ -163,28 +171,32 @@
outputTensorInfo = armnn::TensorInfo(4, outputShape, ArmnnType);
// Set quantization parameters if the requested type is a quantized type.
+ float qScale = 0.5f;
+ int32_t qOffset = 5;
if(armnn::IsQuantizedType<T>())
{
- inputTensorInfo.SetQuantizationScale(0.5f);
- inputTensorInfo.SetQuantizationOffset(5);
- outputTensorInfo.SetQuantizationScale(0.5f);
- outputTensorInfo.SetQuantizationOffset(5);
+ inputTensorInfo.SetQuantizationScale(qScale);
+ inputTensorInfo.SetQuantizationOffset(qOffset);
+ outputTensorInfo.SetQuantizationScale(qScale);
+ outputTensorInfo.SetQuantizationOffset(qOffset);
}
- std::vector<T> input = std::vector<T>(
+ std::vector<T> input = armnnUtils::QuantizedVector<T>(
{
1, 11, 21, 31,
2, 12, 22, 32,
3, 13, 23, 33
- });
+ },
+ qScale, qOffset);
- std::vector<T> outputExpected = std::vector<T>(
+ std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>(
{
1, 2, 3,
11, 12, 13,
21, 22, 23,
31, 32, 33,
- });
+ },
+ qScale, qOffset);
return SimpleTransposeTestImpl<T>(workloadFactory, memoryManager,
descriptor, inputTensorInfo,
@@ -209,15 +221,17 @@
outputTensorInfo = armnn::TensorInfo(4, outputShape, ArmnnType);
// Set quantization parameters if the requested type is a quantized type.
+ float qScale = 0.5f;
+ int32_t qOffset = 5;
if(armnn::IsQuantizedType<T>())
{
- inputTensorInfo.SetQuantizationScale(0.5f);
- inputTensorInfo.SetQuantizationOffset(5);
- outputTensorInfo.SetQuantizationScale(0.5f);
- outputTensorInfo.SetQuantizationOffset(5);
+ inputTensorInfo.SetQuantizationScale(qScale);
+ inputTensorInfo.SetQuantizationOffset(qOffset);
+ outputTensorInfo.SetQuantizationScale(qScale);
+ outputTensorInfo.SetQuantizationOffset(qOffset);
}
- std::vector<T> input = std::vector<T>(
+ std::vector<T> input = armnnUtils::QuantizedVector<T>(
{
1, 2, 3,
11, 12, 13,
@@ -225,14 +239,16 @@
31, 32, 33,
41, 42, 43,
51, 52, 53
- });
+ },
+ qScale, qOffset);
- std::vector<T> outputExpected = std::vector<T>(
+ std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>(
{
1, 11, 21, 31, 41, 51,
2, 12, 22, 32, 42, 52,
3, 13, 23, 33, 43, 53
- });
+ },
+ qScale, qOffset);
return SimpleTransposeTestImpl<T>(workloadFactory, memoryManager,
descriptor, inputTensorInfo,