IVGCVSW-2971 Support QSymm16 for DetectionPostProcess workloads

Signed-off-by: Aron Virginas-Tar <Aron.Virginas-Tar@arm.com>
Change-Id: I8af45afe851a9ccbf8bce54727147fcd52ac9a1f
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index a373f55..d0aaf1d 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -1459,53 +1459,63 @@
 
 void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
 {
-    ValidateNumInputs(workloadInfo, "DetectionPostProcessQueueDescriptor", 2);
+    const std::string& descriptorName = " DetectionPostProcessQueueDescriptor";
+    ValidateNumInputs(workloadInfo, descriptorName, 2);
 
     if (workloadInfo.m_OutputTensorInfos.size() != 4)
     {
-        throw InvalidArgumentException("DetectionPostProcessQueueDescriptor: Requires exactly four outputs. " +
+        throw InvalidArgumentException(descriptorName + ": Requires exactly four outputs. " +
                                        to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
     }
 
     if (m_Anchors == nullptr)
     {
-        throw InvalidArgumentException("DetectionPostProcessQueueDescriptor: Anchors tensor descriptor is missing.");
+        throw InvalidArgumentException(descriptorName + ": Anchors tensor descriptor is missing.");
     }
 
     const TensorInfo& boxEncodingsInfo =  workloadInfo.m_InputTensorInfos[0];
-    const TensorInfo& scoresInfo =  workloadInfo.m_InputTensorInfos[1];
-    const TensorInfo& anchorsInfo = m_Anchors->GetTensorInfo();
-    const TensorInfo& detectionBoxesInfo = workloadInfo.m_OutputTensorInfos[0];
+    const TensorInfo& scoresInfo       =  workloadInfo.m_InputTensorInfos[1];
+    const TensorInfo& anchorsInfo      = m_Anchors->GetTensorInfo();
+
+    const TensorInfo& detectionBoxesInfo   = workloadInfo.m_OutputTensorInfos[0];
     const TensorInfo& detectionClassesInfo = workloadInfo.m_OutputTensorInfos[1];
-    const TensorInfo& detectionScoresInfo = workloadInfo.m_OutputTensorInfos[2];
-    const TensorInfo& numDetectionsInfo = workloadInfo.m_OutputTensorInfos[3];
+    const TensorInfo& detectionScoresInfo  = workloadInfo.m_OutputTensorInfos[2];
+    const TensorInfo& numDetectionsInfo    = workloadInfo.m_OutputTensorInfos[3];
 
-    ValidateTensorNumDimensions(boxEncodingsInfo, "DetectionPostProcessQueueDescriptor", 3, "box encodings");
-    ValidateTensorNumDimensions(scoresInfo, "DetectionPostProcessQueueDescriptor", 3, "scores");
-    ValidateTensorNumDimensions(anchorsInfo, "DetectionPostProcessQueueDescriptor", 2, "anchors");
+    ValidateTensorNumDimensions(boxEncodingsInfo, descriptorName, 3, "box encodings");
+    ValidateTensorNumDimensions(scoresInfo, descriptorName, 3, "scores");
+    ValidateTensorNumDimensions(anchorsInfo, descriptorName, 2, "anchors");
 
-    ValidateTensorNumDimensions(detectionBoxesInfo, "DetectionPostProcessQueueDescriptor", 3, "detection boxes");
-    ValidateTensorNumDimensions(detectionScoresInfo, "DetectionPostProcessQueueDescriptor", 2, "detection scores");
-    ValidateTensorNumDimensions(detectionClassesInfo, "DetectionPostProcessQueueDescriptor", 2, "detection classes");
-    ValidateTensorNumDimensions(numDetectionsInfo, "DetectionPostProcessQueueDescriptor", 1, "num detections");
+    const std::vector<DataType> supportedInputTypes =
+    {
+        DataType::Float32,
+        DataType::QuantisedAsymm8,
+        DataType::QuantisedSymm16
+    };
 
-    ValidateTensorDataType(detectionBoxesInfo, DataType::Float32,
-                          "DetectionPostProcessQueueDescriptor", "detection boxes");
-    ValidateTensorDataType(detectionScoresInfo, DataType::Float32,
-                          "DetectionPostProcessQueueDescriptor", "detection scores");
-    ValidateTensorDataType(detectionClassesInfo, DataType::Float32,
-                          "DetectionPostProcessQueueDescriptor", "detection classes");
-    ValidateTensorDataType(numDetectionsInfo, DataType::Float32,
-                          "DetectionPostProcessQueueDescriptor", "num detections");
+    ValidateDataTypes(boxEncodingsInfo, supportedInputTypes, descriptorName);
+    ValidateDataTypes(scoresInfo, supportedInputTypes, descriptorName);
+    ValidateDataTypes(anchorsInfo, supportedInputTypes, descriptorName);
+
+    ValidateTensorNumDimensions(detectionBoxesInfo, descriptorName, 3, "detection boxes");
+    ValidateTensorNumDimensions(detectionScoresInfo, descriptorName, 2, "detection scores");
+    ValidateTensorNumDimensions(detectionClassesInfo, descriptorName, 2, "detection classes");
+    ValidateTensorNumDimensions(numDetectionsInfo, descriptorName, 1, "num detections");
+
+    // NOTE: Output is always Float32 regardless of input type
+    ValidateTensorDataType(detectionBoxesInfo, DataType::Float32, descriptorName, "detection boxes");
+    ValidateTensorDataType(detectionScoresInfo, DataType::Float32, descriptorName, "detection scores");
+    ValidateTensorDataType(detectionClassesInfo, DataType::Float32, descriptorName, "detection classes");
+    ValidateTensorDataType(numDetectionsInfo, DataType::Float32, descriptorName, "num detections");
 
     if (m_Parameters.m_NmsIouThreshold <= 0.0f || m_Parameters.m_NmsIouThreshold > 1.0f)
     {
-        throw InvalidArgumentException("DetectionPostProcessQueueDescriptor: Intersection over union threshold "
+        throw InvalidArgumentException(descriptorName + ": Intersection over union threshold "
                                        "must be positive and less than or equal to 1.");
     }
     if (scoresInfo.GetShape()[2] != m_Parameters.m_NumClasses + 1)
     {
-        throw InvalidArgumentException("DetectionPostProcessQueueDescriptor: Number of classes with background "
+        throw InvalidArgumentException(descriptorName + ": Number of classes with background "
                                        "should be equal to number of classes + 1.");
     }
 }
diff --git a/src/backends/backendsCommon/test/DetectionPostProcessLayerTestImpl.hpp b/src/backends/backendsCommon/test/DetectionPostProcessLayerTestImpl.hpp
index 092ce26..2726fde 100644
--- a/src/backends/backendsCommon/test/DetectionPostProcessLayerTestImpl.hpp
+++ b/src/backends/backendsCommon/test/DetectionPostProcessLayerTestImpl.hpp
@@ -15,7 +15,124 @@
 #include <backendsCommon/test/WorkloadFactoryHelper.hpp>
 #include <test/TensorHelpers.hpp>
 
-template <typename FactoryType, armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
+namespace
+{
+
+using FloatData = std::vector<float>;
+using QuantData = std::pair<float, int32_t>;
+
+struct TestData
+{
+    static const armnn::TensorShape s_BoxEncodingsShape;
+    static const armnn::TensorShape s_ScoresShape;
+    static const armnn::TensorShape s_AnchorsShape;
+
+    static const QuantData s_BoxEncodingsQuantData;
+    static const QuantData s_ScoresQuantData;
+    static const QuantData s_AnchorsQuantData;
+
+    static const FloatData s_BoxEncodings;
+    static const FloatData s_Scores;
+    static const FloatData s_Anchors;
+};
+
+struct RegularNmsExpectedResults
+{
+    static const FloatData s_DetectionBoxes;
+    static const FloatData s_DetectionScores;
+    static const FloatData s_DetectionClasses;
+    static const FloatData s_NumDetections;
+};
+
+struct FastNmsExpectedResults
+{
+    static const FloatData s_DetectionBoxes;
+    static const FloatData s_DetectionScores;
+    static const FloatData s_DetectionClasses;
+    static const FloatData s_NumDetections;
+};
+
+const armnn::TensorShape TestData::s_BoxEncodingsShape = { 1, 6, 4 };
+const armnn::TensorShape TestData::s_ScoresShape       = { 1, 6, 3 };
+const armnn::TensorShape TestData::s_AnchorsShape      = { 6, 4 };
+
+const QuantData TestData::s_BoxEncodingsQuantData = { 1.00f, 1 };
+const QuantData TestData::s_ScoresQuantData       = { 0.01f, 0 };
+const QuantData TestData::s_AnchorsQuantData      = { 0.50f, 0 };
+
+const FloatData TestData::s_BoxEncodings =
+{
+    0.0f,  0.0f, 0.0f, 0.0f,
+    0.0f,  1.0f, 0.0f, 0.0f,
+    0.0f, -1.0f, 0.0f, 0.0f,
+    0.0f,  0.0f, 0.0f, 0.0f,
+    0.0f,  1.0f, 0.0f, 0.0f,
+    0.0f,  0.0f, 0.0f, 0.0f
+};
+
+const FloatData TestData::s_Scores =
+{
+    0.0f, 0.90f, 0.80f,
+    0.0f, 0.75f, 0.72f,
+    0.0f, 0.60f, 0.50f,
+    0.0f, 0.93f, 0.95f,
+    0.0f, 0.50f, 0.40f,
+    0.0f, 0.30f, 0.20f
+};
+
+const FloatData TestData::s_Anchors =
+{
+    0.5f,   0.5f, 1.0f, 1.0f,
+    0.5f,   0.5f, 1.0f, 1.0f,
+    0.5f,   0.5f, 1.0f, 1.0f,
+    0.5f,  10.5f, 1.0f, 1.0f,
+    0.5f,  10.5f, 1.0f, 1.0f,
+    0.5f, 100.5f, 1.0f, 1.0f
+};
+
+const FloatData RegularNmsExpectedResults::s_DetectionBoxes =
+{
+    0.0f, 10.0f, 1.0f, 11.0f,
+    0.0f, 10.0f, 1.0f, 11.0f,
+    0.0f,  0.0f, 0.0f,  0.0f
+};
+
+const FloatData RegularNmsExpectedResults::s_DetectionScores =
+{
+    0.95f, 0.93f, 0.0f
+};
+
+const FloatData RegularNmsExpectedResults::s_DetectionClasses =
+{
+    1.0f, 0.0f, 0.0f
+};
+
+const FloatData RegularNmsExpectedResults::s_NumDetections = { 2.0f };
+
+const FloatData FastNmsExpectedResults::s_DetectionBoxes =
+{
+    0.0f,  10.0f, 1.0f,  11.0f,
+    0.0f,   0.0f, 1.0f,   1.0f,
+    0.0f, 100.0f, 1.0f, 101.0f
+};
+
+const FloatData FastNmsExpectedResults::s_DetectionScores =
+{
+    0.95f, 0.9f, 0.3f
+};
+
+const FloatData FastNmsExpectedResults::s_DetectionClasses =
+{
+    1.0f, 0.0f, 0.0f
+};
+
+const FloatData FastNmsExpectedResults::s_NumDetections = { 3.0f };
+
+} // anonymous namespace
+
+template<typename FactoryType,
+         armnn::DataType ArmnnType,
+         typename T = armnn::ResolveType<ArmnnType>>
 void DetectionPostProcessImpl(const armnn::TensorInfo& boxEncodingsInfo,
                               const armnn::TensorInfo& scoresInfo,
                               const armnn::TensorInfo& anchorsInfo,
@@ -110,254 +227,140 @@
     BOOST_TEST(CompareTensors(numDetectionsResult.output, numDetectionsResult.outputExpected));
 }
 
-inline void QuantizeData(uint8_t* quant, const float* dequant, const armnn::TensorInfo& info)
+template<armnn::DataType QuantizedType, typename RawType = armnn::ResolveType<QuantizedType>>
+void QuantizeData(RawType* quant, const float* dequant, const armnn::TensorInfo& info)
 {
     for (size_t i = 0; i < info.GetNumElements(); i++)
     {
-        quant[i] = armnn::Quantize<uint8_t>(dequant[i], info.GetQuantizationScale(), info.GetQuantizationOffset());
+        quant[i] = armnn::Quantize<RawType>(
+            dequant[i], info.GetQuantizationScale(), info.GetQuantizationOffset());
     }
 }
 
-template <typename FactoryType>
+template<typename FactoryType>
 void DetectionPostProcessRegularNmsFloatTest()
 {
-    armnn::TensorInfo boxEncodingsInfo({ 1, 6, 4 }, armnn::DataType::Float32);
-    armnn::TensorInfo scoresInfo({ 1, 6, 3}, armnn::DataType::Float32);
-    armnn::TensorInfo anchorsInfo({ 6, 4 }, armnn::DataType::Float32);
-
-    std::vector<float> boxEncodingsData({
-        0.0f, 0.0f, 0.0f, 0.0f,
-        0.0f, 1.0f, 0.0f, 0.0f,
-        0.0f, -1.0f, 0.0f, 0.0f,
-        0.0f, 0.0f, 0.0f, 0.0f,
-        0.0f, 1.0f, 0.0f, 0.0f,
-        0.0f, 0.0f, 0.0f, 0.0f
-    });
-    std::vector<float> scoresData({
-        0.0f, 0.9f, 0.8f,
-        0.0f, 0.75f, 0.72f,
-        0.0f, 0.6f, 0.5f,
-        0.0f, 0.93f, 0.95f,
-        0.0f, 0.5f, 0.4f,
-        0.0f, 0.3f, 0.2f
-    });
-    std::vector<float> anchorsData({
-        0.5f, 0.5f, 1.0f, 1.0f,
-        0.5f, 0.5f, 1.0f, 1.0f,
-        0.5f, 0.5f, 1.0f, 1.0f,
-        0.5f, 10.5f, 1.0f, 1.0f,
-        0.5f, 10.5f, 1.0f, 1.0f,
-        0.5f, 100.5f, 1.0f, 1.0f
-    });
-
-    std::vector<float> expectedDetectionBoxes({
-        0.0f, 10.0f, 1.0f, 11.0f,
-        0.0f, 10.0f, 1.0f, 11.0f,
-        0.0f, 0.0f, 0.0f, 0.0f
-    });
-    std::vector<float> expectedDetectionScores({ 0.95f, 0.93f, 0.0f });
-    std::vector<float> expectedDetectionClasses({ 1.0f, 0.0f, 0.0f });
-    std::vector<float> expectedNumDetections({ 2.0f });
-
-    return DetectionPostProcessImpl<FactoryType, armnn::DataType::Float32>(boxEncodingsInfo,
-                                                                           scoresInfo,
-                                                                           anchorsInfo,
-                                                                           boxEncodingsData,
-                                                                           scoresData,
-                                                                           anchorsData,
-                                                                           expectedDetectionBoxes,
-                                                                           expectedDetectionClasses,
-                                                                           expectedDetectionScores,
-                                                                           expectedNumDetections,
-                                                                           true);
+    return DetectionPostProcessImpl<FactoryType, armnn::DataType::Float32>(
+        armnn::TensorInfo(TestData::s_BoxEncodingsShape, armnn::DataType::Float32),
+        armnn::TensorInfo(TestData::s_ScoresShape, armnn::DataType::Float32),
+        armnn::TensorInfo(TestData::s_AnchorsShape, armnn::DataType::Float32),
+        TestData::s_BoxEncodings,
+        TestData::s_Scores,
+        TestData::s_Anchors,
+        RegularNmsExpectedResults::s_DetectionBoxes,
+        RegularNmsExpectedResults::s_DetectionClasses,
+        RegularNmsExpectedResults::s_DetectionScores,
+        RegularNmsExpectedResults::s_NumDetections,
+        true);
 }
 
-template <typename FactoryType>
-void DetectionPostProcessRegularNmsUint8Test()
+template<typename FactoryType,
+         armnn::DataType QuantizedType,
+         typename RawType = armnn::ResolveType<QuantizedType>>
+void DetectionPostProcessRegularNmsQuantizedTest()
 {
-    armnn::TensorInfo boxEncodingsInfo({ 1, 6, 4 }, armnn::DataType::QuantisedAsymm8);
-    armnn::TensorInfo scoresInfo({ 1, 6, 3 }, armnn::DataType::QuantisedAsymm8);
-    armnn::TensorInfo anchorsInfo({ 6, 4 }, armnn::DataType::QuantisedAsymm8);
+    armnn::TensorInfo boxEncodingsInfo(TestData::s_BoxEncodingsShape, QuantizedType);
+    armnn::TensorInfo scoresInfo(TestData::s_ScoresShape, QuantizedType);
+    armnn::TensorInfo anchorsInfo(TestData::s_AnchorsShape, QuantizedType);
 
-    boxEncodingsInfo.SetQuantizationScale(1.0f);
-    boxEncodingsInfo.SetQuantizationOffset(1);
-    scoresInfo.SetQuantizationScale(0.01f);
-    scoresInfo.SetQuantizationOffset(0);
-    anchorsInfo.SetQuantizationScale(0.5f);
-    anchorsInfo.SetQuantizationOffset(0);
+    boxEncodingsInfo.SetQuantizationScale(TestData::s_BoxEncodingsQuantData.first);
+    boxEncodingsInfo.SetQuantizationOffset(TestData::s_BoxEncodingsQuantData.second);
 
-    std::vector<float> boxEncodings({
-        0.0f, 0.0f, 0.0f, 0.0f,
-        0.0f, 1.0f, 0.0f, 0.0f,
-        0.0f, -1.0f, 0.0f, 0.0f,
-        0.0f, 0.0f, 0.0f, 0.0f,
-        0.0f, 1.0f, 0.0f, 0.0f,
-        0.0f, 0.0f, 0.0f, 0.0f
-    });
-    std::vector<float> scores({
-        0.0f, 0.9f, 0.8f,
-        0.0f, 0.75f, 0.72f,
-        0.0f, 0.6f, 0.5f,
-        0.0f, 0.93f, 0.95f,
-        0.0f, 0.5f, 0.4f,
-        0.0f, 0.3f, 0.2f
-    });
-    std::vector<float> anchors({
-        0.5f, 0.5f, 1.0f, 1.0f,
-        0.5f, 0.5f, 1.0f, 1.0f,
-        0.5f, 0.5f, 1.0f, 1.0f,
-        0.5f, 10.5f, 1.0f, 1.0f,
-        0.5f, 10.5f, 1.0f, 1.0f,
-        0.5f, 100.5f, 1.0f, 1.0f
-    });
+    scoresInfo.SetQuantizationScale(TestData::s_ScoresQuantData.first);
+    scoresInfo.SetQuantizationOffset(TestData::s_ScoresQuantData.second);
 
-    std::vector<uint8_t> boxEncodingsData(boxEncodings.size(), 0);
-    std::vector<uint8_t> scoresData(scores.size(), 0);
-    std::vector<uint8_t> anchorsData(anchors.size(), 0);
-    QuantizeData(boxEncodingsData.data(), boxEncodings.data(), boxEncodingsInfo);
-    QuantizeData(scoresData.data(), scores.data(), scoresInfo);
-    QuantizeData(anchorsData.data(), anchors.data(), anchorsInfo);
+    anchorsInfo.SetQuantizationScale(TestData::s_AnchorsQuantData.first);
+    anchorsInfo.SetQuantizationOffset(TestData::s_BoxEncodingsQuantData.second);
 
-    std::vector<float> expectedDetectionBoxes({
-        0.0f, 10.0f, 1.0f, 11.0f,
-        0.0f, 10.0f, 1.0f, 11.0f,
-        0.0f, 0.0f, 0.0f, 0.0f
-    });
-    std::vector<float> expectedDetectionScores({ 0.95f, 0.93f, 0.0f });
-    std::vector<float> expectedDetectionClasses({ 1.0f, 0.0f, 0.0f });
-    std::vector<float> expectedNumDetections({ 2.0f });
+    std::vector<RawType> boxEncodingsData(TestData::s_BoxEncodingsShape.GetNumElements());
+    QuantizeData<QuantizedType>(boxEncodingsData.data(),
+                                TestData::s_BoxEncodings.data(),
+                                boxEncodingsInfo);
 
-    return DetectionPostProcessImpl<FactoryType, armnn::DataType::QuantisedAsymm8>(boxEncodingsInfo,
-                                                                                   scoresInfo,
-                                                                                   anchorsInfo,
-                                                                                   boxEncodingsData,
-                                                                                   scoresData,
-                                                                                   anchorsData,
-                                                                                   expectedDetectionBoxes,
-                                                                                   expectedDetectionClasses,
-                                                                                   expectedDetectionScores,
-                                                                                   expectedNumDetections,
-                                                                                   true);
+    std::vector<RawType> scoresData(TestData::s_ScoresShape.GetNumElements());
+    QuantizeData<QuantizedType>(scoresData.data(),
+                                TestData::s_Scores.data(),
+                                scoresInfo);
+
+    std::vector<RawType> anchorsData(TestData::s_AnchorsShape.GetNumElements());
+    QuantizeData<QuantizedType>(anchorsData.data(),
+                                TestData::s_Anchors.data(),
+                                anchorsInfo);
+
+    return DetectionPostProcessImpl<FactoryType, QuantizedType>(
+        boxEncodingsInfo,
+        scoresInfo,
+        anchorsInfo,
+        boxEncodingsData,
+        scoresData,
+        anchorsData,
+        RegularNmsExpectedResults::s_DetectionBoxes,
+        RegularNmsExpectedResults::s_DetectionClasses,
+        RegularNmsExpectedResults::s_DetectionScores,
+        RegularNmsExpectedResults::s_NumDetections,
+        true);
 }
 
-template <typename FactoryType>
+template<typename FactoryType>
 void DetectionPostProcessFastNmsFloatTest()
 {
-    armnn::TensorInfo boxEncodingsInfo({ 1, 6, 4 }, armnn::DataType::Float32);
-    armnn::TensorInfo scoresInfo({ 1, 6, 3}, armnn::DataType::Float32);
-    armnn::TensorInfo anchorsInfo({ 6, 4 }, armnn::DataType::Float32);
-
-    std::vector<float> boxEncodingsData({
-        0.0f, 0.0f, 0.0f, 0.0f,
-        0.0f, 1.0f, 0.0f, 0.0f,
-        0.0f, -1.0f, 0.0f, 0.0f,
-        0.0f, 0.0f, 0.0f, 0.0f,
-        0.0f, 1.0f, 0.0f, 0.0f,
-        0.0f, 0.0f, 0.0f, 0.0f
-    });
-    std::vector<float> scoresData({
-        0.0f, 0.9f, 0.8f,
-        0.0f, 0.75f, 0.72f,
-        0.0f, 0.6f, 0.5f,
-        0.0f, 0.93f, 0.95f,
-        0.0f, 0.5f, 0.4f,
-        0.0f, 0.3f, 0.2f
-    });
-    std::vector<float> anchorsData({
-        0.5f, 0.5f, 1.0f, 1.0f,
-        0.5f, 0.5f, 1.0f, 1.0f,
-        0.5f, 0.5f, 1.0f, 1.0f,
-        0.5f, 10.5f, 1.0f, 1.0f,
-        0.5f, 10.5f, 1.0f, 1.0f,
-        0.5f, 100.5f, 1.0f, 1.0f
-    });
-
-    std::vector<float> expectedDetectionBoxes({
-        0.0f, 10.0f, 1.0f, 11.0f,
-        0.0f, 0.0f, 1.0f, 1.0f,
-        0.0f, 100.0f, 1.0f, 101.0f
-    });
-    std::vector<float> expectedDetectionScores({ 0.95f, 0.9f, 0.3f });
-    std::vector<float> expectedDetectionClasses({ 1.0f, 0.0f, 0.0f });
-    std::vector<float> expectedNumDetections({ 3.0f });
-
-    return DetectionPostProcessImpl<FactoryType, armnn::DataType::Float32>(boxEncodingsInfo,
-                                                                           scoresInfo,
-                                                                           anchorsInfo,
-                                                                           boxEncodingsData,
-                                                                           scoresData,
-                                                                           anchorsData,
-                                                                           expectedDetectionBoxes,
-                                                                           expectedDetectionClasses,
-                                                                           expectedDetectionScores,
-                                                                           expectedNumDetections,
-                                                                           false);
+    return DetectionPostProcessImpl<FactoryType, armnn::DataType::Float32>(
+        armnn::TensorInfo(TestData::s_BoxEncodingsShape, armnn::DataType::Float32),
+        armnn::TensorInfo(TestData::s_ScoresShape, armnn::DataType::Float32),
+        armnn::TensorInfo(TestData::s_AnchorsShape, armnn::DataType::Float32),
+        TestData::s_BoxEncodings,
+        TestData::s_Scores,
+        TestData::s_Anchors,
+        FastNmsExpectedResults::s_DetectionBoxes,
+        FastNmsExpectedResults::s_DetectionClasses,
+        FastNmsExpectedResults::s_DetectionScores,
+        FastNmsExpectedResults::s_NumDetections,
+        false);
 }
 
-template <typename FactoryType>
-void DetectionPostProcessFastNmsUint8Test()
+template<typename FactoryType,
+         armnn::DataType QuantizedType,
+         typename RawType = armnn::ResolveType<QuantizedType>>
+void DetectionPostProcessFastNmsQuantizedTest()
 {
-    armnn::TensorInfo boxEncodingsInfo({ 1, 6, 4 }, armnn::DataType::QuantisedAsymm8);
-    armnn::TensorInfo scoresInfo({ 1, 6, 3 }, armnn::DataType::QuantisedAsymm8);
-    armnn::TensorInfo anchorsInfo({ 6, 4 }, armnn::DataType::QuantisedAsymm8);
+    armnn::TensorInfo boxEncodingsInfo(TestData::s_BoxEncodingsShape, QuantizedType);
+    armnn::TensorInfo scoresInfo(TestData::s_ScoresShape, QuantizedType);
+    armnn::TensorInfo anchorsInfo(TestData::s_AnchorsShape, QuantizedType);
 
-    boxEncodingsInfo.SetQuantizationScale(1.0f);
-    boxEncodingsInfo.SetQuantizationOffset(1);
-    scoresInfo.SetQuantizationScale(0.01f);
-    scoresInfo.SetQuantizationOffset(0);
-    anchorsInfo.SetQuantizationScale(0.5f);
-    anchorsInfo.SetQuantizationOffset(0);
+    boxEncodingsInfo.SetQuantizationScale(TestData::s_BoxEncodingsQuantData.first);
+    boxEncodingsInfo.SetQuantizationOffset(TestData::s_BoxEncodingsQuantData.second);
 
-    std::vector<float> boxEncodings({
-        0.0f, 0.0f, 0.0f, 0.0f,
-        0.0f, 1.0f, 0.0f, 0.0f,
-        0.0f, -1.0f, 0.0f, 0.0f,
-        0.0f, 0.0f, 0.0f, 0.0f,
-        0.0f, 1.0f, 0.0f, 0.0f,
-        0.0f, 0.0f, 0.0f, 0.0f
-    });
-    std::vector<float> scores({
-        0.0f, 0.9f, 0.8f,
-        0.0f, 0.75f, 0.72f,
-        0.0f, 0.6f, 0.5f,
-        0.0f, 0.93f, 0.95f,
-        0.0f, 0.5f, 0.4f,
-        0.0f, 0.3f, 0.2f
-    });
-    std::vector<float> anchors({
-        0.5f, 0.5f, 1.0f, 1.0f,
-        0.5f, 0.5f, 1.0f, 1.0f,
-        0.5f, 0.5f, 1.0f, 1.0f,
-        0.5f, 10.5f, 1.0f, 1.0f,
-        0.5f, 10.5f, 1.0f, 1.0f,
-        0.5f, 100.5f, 1.0f, 1.0f
-    });
+    scoresInfo.SetQuantizationScale(TestData::s_ScoresQuantData.first);
+    scoresInfo.SetQuantizationOffset(TestData::s_ScoresQuantData.second);
 
-    std::vector<uint8_t> boxEncodingsData(boxEncodings.size(), 0);
-    std::vector<uint8_t> scoresData(scores.size(), 0);
-    std::vector<uint8_t> anchorsData(anchors.size(), 0);
-    QuantizeData(boxEncodingsData.data(), boxEncodings.data(), boxEncodingsInfo);
-    QuantizeData(scoresData.data(), scores.data(), scoresInfo);
-    QuantizeData(anchorsData.data(), anchors.data(), anchorsInfo);
+    anchorsInfo.SetQuantizationScale(TestData::s_AnchorsQuantData.first);
+    anchorsInfo.SetQuantizationOffset(TestData::s_BoxEncodingsQuantData.second);
 
-    std::vector<float> expectedDetectionBoxes({
-        0.0f, 10.0f, 1.0f, 11.0f,
-        0.0f, 0.0f, 1.0f, 1.0f,
-        0.0f, 100.0f, 1.0f, 101.0f
-    });
-    std::vector<float> expectedDetectionScores({ 0.95f, 0.9f, 0.3f });
-    std::vector<float> expectedDetectionClasses({ 1.0f, 0.0f, 0.0f });
-    std::vector<float> expectedNumDetections({ 3.0f });
+    std::vector<RawType> boxEncodingsData(TestData::s_BoxEncodingsShape.GetNumElements());
+    QuantizeData<QuantizedType>(boxEncodingsData.data(),
+                                TestData::s_BoxEncodings.data(),
+                                boxEncodingsInfo);
 
-    return DetectionPostProcessImpl<FactoryType, armnn::DataType::QuantisedAsymm8>(boxEncodingsInfo,
-                                                                                   scoresInfo,
-                                                                                   anchorsInfo,
-                                                                                   boxEncodingsData,
-                                                                                   scoresData,
-                                                                                   anchorsData,
-                                                                                   expectedDetectionBoxes,
-                                                                                   expectedDetectionClasses,
-                                                                                   expectedDetectionScores,
-                                                                                   expectedNumDetections,
-                                                                                   false);
-}
+    std::vector<RawType> scoresData(TestData::s_ScoresShape.GetNumElements());
+    QuantizeData<QuantizedType>(scoresData.data(),
+                                TestData::s_Scores.data(),
+                                scoresInfo);
+
+    std::vector<RawType> anchorsData(TestData::s_AnchorsShape.GetNumElements());
+    QuantizeData<QuantizedType>(anchorsData.data(),
+                                TestData::s_Anchors.data(),
+                                anchorsInfo);
+
+    return DetectionPostProcessImpl<FactoryType, QuantizedType>(
+        boxEncodingsInfo,
+        scoresInfo,
+        anchorsInfo,
+        boxEncodingsData,
+        scoresData,
+        anchorsData,
+        FastNmsExpectedResults::s_DetectionBoxes,
+        FastNmsExpectedResults::s_DetectionClasses,
+        FastNmsExpectedResults::s_DetectionScores,
+        FastNmsExpectedResults::s_NumDetections,
+        false);
+}
\ No newline at end of file