COMPMID-2966 Add support for QASYMM8_SIGNED in NEGEMMLowpQuantizeDownInt32ToUint8ScaleKernel

Signed-off-by: Luca Foschiani <luca.foschiani@arm.com>
Change-Id: Ia8692f8fda16fa3b73f343e4b5b1b55e14403225
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/2750
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/runtime/NEON/functions/NEGEMMLowpOutputStage.cpp b/src/runtime/NEON/functions/NEGEMMLowpOutputStage.cpp
index 42d2ffc..43ca7b3 100644
--- a/src/runtime/NEON/functions/NEGEMMLowpOutputStage.cpp
+++ b/src/runtime/NEON/functions/NEGEMMLowpOutputStage.cpp
@@ -24,10 +24,10 @@
 #include "arm_compute/runtime/NEON/functions/NEGEMMLowpOutputStage.h"
 
 #include "arm_compute/core/ITensor.h"
+#include "arm_compute/core/NEON/kernels/NEGEMMLowpQuantizeDownInt32ScaleKernel.h"
 #include "arm_compute/core/NEON/kernels/NEGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPointKernel.h"
 #include "arm_compute/core/NEON/kernels/NEGEMMLowpQuantizeDownInt32ToInt8ScaleByFixedPointKernel.h"
 #include "arm_compute/core/NEON/kernels/NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel.h"
-#include "arm_compute/core/NEON/kernels/NEGEMMLowpQuantizeDownInt32ToUint8ScaleKernel.h"
 #include "arm_compute/core/Validate.h"
 #include "support/MemorySupport.h"
 
@@ -35,14 +35,25 @@
 {
 void NEGEMMLowpQuantizeDownInt32ToUint8Scale::configure(const ITensor *input, const ITensor *bias, ITensor *output, int result_offset, int result_mult_int, int result_shift, int min, int max)
 {
-    auto k = arm_compute::support::cpp14::make_unique<NEGEMMLowpQuantizeDownInt32ToUint8ScaleKernel>();
-    k->configure(input, bias, output, result_offset, result_mult_int, result_shift, min, max);
+    GEMMLowpOutputStageInfo info = GEMMLowpOutputStageInfo();
+    info.gemmlowp_offset         = result_offset;
+    info.gemmlowp_multiplier     = result_mult_int;
+    info.gemmlowp_shift          = result_shift;
+    info.gemmlowp_min_bound      = min;
+    info.gemmlowp_max_bound      = max;
+
+    auto k = arm_compute::support::cpp14::make_unique<NEGEMMLowpQuantizeDownInt32ScaleKernel>();
+    k->configure(input, bias, output, &info);
     _kernel = std::move(k);
 }
 
 Status NEGEMMLowpQuantizeDownInt32ToUint8Scale::validate(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output, int min, int max)
 {
-    return NEGEMMLowpQuantizeDownInt32ToUint8ScaleKernel::validate(input, bias, output, min, max);
+    GEMMLowpOutputStageInfo info = GEMMLowpOutputStageInfo();
+    info.gemmlowp_min_bound      = min;
+    info.gemmlowp_max_bound      = max;
+
+    return NEGEMMLowpQuantizeDownInt32ScaleKernel::validate(input, bias, output, &info);
 }
 
 void NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint::configure(const ITensor *input, const ITensor *bias, ITensor *output, int result_fixedpoint_multiplier, int result_shift,
@@ -89,53 +100,63 @@
     ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
     ARM_COMPUTE_ERROR_THROW_ON(NEGEMMLowpOutputStage::validate(input->info(), bias != nullptr ? bias->info() : nullptr, output->info(), info));
 
-    if(info.type == GEMMLowpOutputStageType::QUANTIZE_DOWN)
+    switch(info.type)
     {
-        switch(output->info()->data_type())
+        case GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT:
         {
-            case DataType::QASYMM8:
+            switch(info.output_data_type)
             {
-                auto k = arm_compute::support::cpp14::make_unique<NEGEMMLowpQuantizeDownInt32ToUint8ScaleKernel>();
-                k->configure(input, bias, output, info.gemmlowp_multiplier, info.gemmlowp_shift, info.gemmlowp_offset, info.gemmlowp_min_bound, info.gemmlowp_max_bound);
-                _kernel = std::move(k);
-                break;
+                case DataType::QASYMM8:
+                {
+                    auto k = arm_compute::support::cpp14::make_unique<NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel>();
+                    k->configure(input, bias, output, info.gemmlowp_multiplier, info.gemmlowp_shift, info.gemmlowp_offset, info.gemmlowp_min_bound, info.gemmlowp_max_bound);
+                    _kernel = std::move(k);
+                    break;
+                }
+                case DataType::QASYMM8_SIGNED:
+                {
+                    auto k = arm_compute::support::cpp14::make_unique<NEGEMMLowpQuantizeDownInt32ToInt8ScaleByFixedPointKernel>();
+                    k->configure(input, bias, output, info.gemmlowp_multiplier, info.gemmlowp_shift, info.gemmlowp_offset, info.gemmlowp_min_bound, info.gemmlowp_max_bound);
+                    _kernel = std::move(k);
+                    break;
+                }
+                case DataType::QSYMM16:
+                {
+                    auto k = arm_compute::support::cpp14::make_unique<NEGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPointKernel>();
+                    k->configure(input, bias, output, info.gemmlowp_multiplier, info.gemmlowp_shift, info.gemmlowp_min_bound, info.gemmlowp_max_bound);
+                    _kernel = std::move(k);
+                    break;
+                }
+                default:
+                {
+                    ARM_COMPUTE_ERROR("Unsupported output data type.");
+                    break;
+                }
             }
-            default:
-                ARM_COMPUTE_ERROR("Unsupported output data type.");
+            break;
         }
-    }
-    else if(info.type == GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT)
-    {
-        switch(output->info()->data_type())
+        case GEMMLowpOutputStageType::QUANTIZE_DOWN:
         {
-            case DataType::QASYMM8:
+            switch(info.output_data_type)
             {
-                auto k = arm_compute::support::cpp14::make_unique<NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel>();
-                k->configure(input, bias, output, info.gemmlowp_multiplier, info.gemmlowp_shift, info.gemmlowp_offset, info.gemmlowp_min_bound, info.gemmlowp_max_bound);
-                _kernel = std::move(k);
-                break;
+                case DataType::QASYMM8:
+                case DataType::QASYMM8_SIGNED:
+                {
+                    auto k = arm_compute::support::cpp14::make_unique<NEGEMMLowpQuantizeDownInt32ScaleKernel>();
+                    k->configure(input, bias, output, &info);
+                    _kernel = std::move(k);
+                    break;
+                }
+                default:
+                {
+                    ARM_COMPUTE_ERROR("Unsupported output data type.");
+                    break;
+                }
             }
-            case DataType::QASYMM8_SIGNED:
-            {
-                auto k = arm_compute::support::cpp14::make_unique<NEGEMMLowpQuantizeDownInt32ToInt8ScaleByFixedPointKernel>();
-                k->configure(input, bias, output, info.gemmlowp_multiplier, info.gemmlowp_shift, info.gemmlowp_offset, info.gemmlowp_min_bound, info.gemmlowp_max_bound);
-                _kernel = std::move(k);
-                break;
-            }
-            case DataType::QSYMM16:
-            {
-                auto k = arm_compute::support::cpp14::make_unique<NEGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPointKernel>();
-                k->configure(input, bias, output, info.gemmlowp_multiplier, info.gemmlowp_shift, info.gemmlowp_min_bound, info.gemmlowp_max_bound);
-                _kernel = std::move(k);
-                break;
-            }
-            default:
-                ARM_COMPUTE_ERROR("Unsupported output data type.");
+            break;
         }
-    }
-    else
-    {
-        ARM_COMPUTE_ERROR("Unsupported output stage quantization type.");
+        default:
+            ARM_COMPUTE_ERROR("Unsupported GEMMLowpOutputStage type.");
     }
 }
 
@@ -147,29 +168,35 @@
 
     ARM_COMPUTE_RETURN_ERROR_ON((info.type != GEMMLowpOutputStageType::QUANTIZE_DOWN) && (info.type != GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT));
 
-    if(info.type == GEMMLowpOutputStageType::QUANTIZE_DOWN)
+    switch(info.type)
     {
-        switch(output->data_type())
+        case GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT:
         {
-            case DataType::QASYMM8:
-                return NEGEMMLowpQuantizeDownInt32ToUint8ScaleKernel::validate(input, bias, output, info.gemmlowp_min_bound, info.gemmlowp_max_bound);
-            default:
-                return ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Unsupported output data type.");
+            switch(output->data_type())
+            {
+                case DataType::QASYMM8:
+                    return NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel::validate(input, bias, output, info.gemmlowp_min_bound, info.gemmlowp_max_bound);
+                case DataType::QASYMM8_SIGNED:
+                    return NEGEMMLowpQuantizeDownInt32ToInt8ScaleByFixedPointKernel::validate(input, bias, output, info.gemmlowp_min_bound, info.gemmlowp_max_bound);
+                case DataType::QSYMM16:
+                    return NEGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPointKernel::validate(input, bias, output, info.gemmlowp_min_bound, info.gemmlowp_max_bound);
+                default:
+                    return ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Unsupported output data type.");
+            }
         }
-    }
-    else
-    {
-        switch(output->data_type())
+        case GEMMLowpOutputStageType::QUANTIZE_DOWN:
         {
-            case DataType::QASYMM8:
-                return NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel::validate(input, bias, output, info.gemmlowp_min_bound, info.gemmlowp_max_bound);
-            case DataType::QASYMM8_SIGNED:
-                return NEGEMMLowpQuantizeDownInt32ToInt8ScaleByFixedPointKernel::validate(input, bias, output, info.gemmlowp_min_bound, info.gemmlowp_max_bound);
-            case DataType::QSYMM16:
-                return NEGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPointKernel::validate(input, bias, output, info.gemmlowp_min_bound, info.gemmlowp_max_bound);
-            default:
-                return ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Unsupported output data type.");
+            switch(output->data_type())
+            {
+                case DataType::QASYMM8:
+                case DataType::QASYMM8_SIGNED:
+                    return NEGEMMLowpQuantizeDownInt32ScaleKernel::validate(input, bias, output, &info);
+                default:
+                    return ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Unsupported output data type.");
+            }
         }
+        default:
+            return ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Unsupported GEMMLowpOutputStage type.");
     }
 }
 } // namespace arm_compute