COMPMID-2808: Add support for QASYMM8_SIGNED in NEROIAlignLayer

Signed-off-by: Sheri Zhang <sheri.zhang@arm.com>
Change-Id: Id4f4c96e1823a4b27886fee9baf70847172e619c
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4335
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/core/NEON/kernels/NEROIAlignLayerKernel.cpp b/src/core/NEON/kernels/NEROIAlignLayerKernel.cpp
index c48cda8..e937dad 100644
--- a/src/core/NEON/kernels/NEROIAlignLayerKernel.cpp
+++ b/src/core/NEON/kernels/NEROIAlignLayerKernel.cpp
@@ -47,7 +47,7 @@
     ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, rois, output);
     ARM_COMPUTE_RETURN_ERROR_ON(rois->dimension(0) != 5);
     ARM_COMPUTE_RETURN_ERROR_ON(rois->num_dimensions() > 2);
-    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F32, DataType::F16);
+    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::F32, DataType::F16);
     ARM_COMPUTE_RETURN_ERROR_ON_DATA_LAYOUT_NOT_IN(input, DataLayout::NHWC, DataLayout::NCHW);
     ARM_COMPUTE_RETURN_ERROR_ON((pool_info.pooled_width() == 0) || (pool_info.pooled_height() == 0));
     ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input);
@@ -59,7 +59,7 @@
         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(compute_roi_align_shape(*input, *rois, pool_info), output->tensor_shape());
     }
 
-    if(input->data_type() == DataType::QASYMM8)
+    if(input->data_type() == DataType::QASYMM8 || input->data_type() == DataType::QASYMM8_SIGNED)
     {
         ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(rois, 1, DataType::QASYMM16);
 
@@ -116,7 +116,7 @@
 }
 
 /** Average pooling over an aligned window */
-template <typename input_data_type, DataLayout data_layout>
+template <typename input_data_type>
 inline input_data_type roi_align_1x1(const ITensor *input,
                                      unsigned int   roi_batch,
                                      float          region_start_x,
@@ -135,7 +135,8 @@
     }
     else
     {
-        float avg = 0;
+        const DataLayout data_layout = input->info()->data_layout();
+        float            avg         = 0;
         // Iterate through the aligned pooling region
         for(int iy = 0; iy < grid_size_y; ++iy)
         {
@@ -185,7 +186,7 @@
 }
 
 /** Average pooling over an aligned window */
-template <typename input_data_type, DataLayout data_layout>
+template <typename input_data_type>
 inline input_data_type roi_align_1x1_qasymm8(const ITensor          *input,
                                              unsigned int            roi_batch,
                                              float                   region_start_x,
@@ -205,8 +206,11 @@
     }
     else
     {
-        float                         avg         = 0;
-        const UniformQuantizationInfo input_qinfo = input->info()->quantization_info().uniform();
+        float                         avg              = 0;
+        const UniformQuantizationInfo input_qinfo      = input->info()->quantization_info().uniform();
+        const bool                    is_qasymm_signed = is_data_type_quantized_asymmetric_signed(input->info()->data_type());
+        const DataLayout              data_layout      = input->info()->data_layout();
+
         // Iterate through the aligned pooling region
         for(int iy = 0; iy < grid_size_y; ++iy)
         {
@@ -234,26 +238,57 @@
 
                 if(data_layout == DataLayout::NCHW)
                 {
-                    float data1 = dequantize_qasymm8(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(x_low, y_low, pz, roi_batch))), input_qinfo);
-                    float data2 = dequantize_qasymm8(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(x_high, y_low, pz, roi_batch))), input_qinfo);
-                    float data3 = dequantize_qasymm8(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(x_low, y_high, pz, roi_batch))), input_qinfo);
-                    float data4 = dequantize_qasymm8(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(x_high, y_high, pz, roi_batch))), input_qinfo);
-                    avg += w1 * data1 + w2 * data2 + w3 * data3 + w4 * data4;
+                    if(is_qasymm_signed)
+                    {
+                        float data1 = dequantize_qasymm8_signed(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(x_low, y_low, pz, roi_batch))), input_qinfo);
+                        float data2 = dequantize_qasymm8_signed(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(x_high, y_low, pz, roi_batch))), input_qinfo);
+                        float data3 = dequantize_qasymm8_signed(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(x_low, y_high, pz, roi_batch))), input_qinfo);
+                        float data4 = dequantize_qasymm8_signed(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(x_high, y_high, pz, roi_batch))), input_qinfo);
+                        avg += w1 * data1 + w2 * data2 + w3 * data3 + w4 * data4;
+                    }
+                    else
+                    {
+                        float data1 = dequantize_qasymm8(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(x_low, y_low, pz, roi_batch))), input_qinfo);
+                        float data2 = dequantize_qasymm8(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(x_high, y_low, pz, roi_batch))), input_qinfo);
+                        float data3 = dequantize_qasymm8(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(x_low, y_high, pz, roi_batch))), input_qinfo);
+                        float data4 = dequantize_qasymm8(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(x_high, y_high, pz, roi_batch))), input_qinfo);
+                        avg += w1 * data1 + w2 * data2 + w3 * data3 + w4 * data4;
+                    }
                 }
                 else
                 {
-                    const auto data1 = dequantize_qasymm8(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(pz, x_low, y_low, roi_batch))), input_qinfo);
-                    const auto data2 = dequantize_qasymm8(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(pz, x_high, y_low, roi_batch))), input_qinfo);
-                    const auto data3 = dequantize_qasymm8(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(pz, x_low, y_high, roi_batch))), input_qinfo);
-                    const auto data4 = dequantize_qasymm8(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(pz, x_high, y_high, roi_batch))), input_qinfo);
-                    avg += w1 * data1 + w2 * data2 + w3 * data3 + w4 * data4;
+                    if(is_qasymm_signed)
+                    {
+                        const auto data1 = dequantize_qasymm8_signed(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(pz, x_low, y_low, roi_batch))), input_qinfo);
+                        const auto data2 = dequantize_qasymm8_signed(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(pz, x_high, y_low, roi_batch))), input_qinfo);
+                        const auto data3 = dequantize_qasymm8_signed(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(pz, x_low, y_high, roi_batch))), input_qinfo);
+                        const auto data4 = dequantize_qasymm8_signed(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(pz, x_high, y_high, roi_batch))), input_qinfo);
+                        avg += w1 * data1 + w2 * data2 + w3 * data3 + w4 * data4;
+                    }
+                    else
+                    {
+                        const auto data1 = dequantize_qasymm8(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(pz, x_low, y_low, roi_batch))), input_qinfo);
+                        const auto data2 = dequantize_qasymm8(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(pz, x_high, y_low, roi_batch))), input_qinfo);
+                        const auto data3 = dequantize_qasymm8(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(pz, x_low, y_high, roi_batch))), input_qinfo);
+                        const auto data4 = dequantize_qasymm8(*reinterpret_cast<const input_data_type *>(input->ptr_to_element(Coordinates(pz, x_high, y_high, roi_batch))), input_qinfo);
+                        avg += w1 * data1 + w2 * data2 + w3 * data3 + w4 * data4;
+                    }
                 }
             }
         }
 
         avg /= grid_size_x * grid_size_y;
 
-        return quantize_qasymm8(avg, out_qinfo);
+        input_data_type res = 0;
+        if(is_qasymm_signed)
+        {
+            res = quantize_qasymm8_signed(avg, out_qinfo);
+        }
+        else
+        {
+            res = quantize_qasymm8(avg, out_qinfo);
+        }
+        return res;
     }
 }
 
@@ -265,52 +300,30 @@
 
 void NEROIAlignLayerKernel::run(const Window &window, const ThreadInfo &info)
 {
-    if(_input->info()->data_layout() == DataLayout::NCHW)
+    const DataLayout data_layout = _input->info()->data_layout();
+    if(data_layout == DataLayout::NCHW || data_layout == DataLayout::NHWC)
     {
         switch(_input->info()->data_type())
         {
             case DataType::QASYMM8:
             {
-                NEROIAlignLayerKernel::internal_run<DataLayout::NCHW, uint8_t, uint16_t>(window, info);
+                NEROIAlignLayerKernel::internal_run<uint8_t, uint16_t>(window, info);
+                break;
+            }
+            case DataType::QASYMM8_SIGNED:
+            {
+                NEROIAlignLayerKernel::internal_run<int8_t, uint16_t>(window, info);
                 break;
             }
             case DataType::F32:
             {
-                NEROIAlignLayerKernel::internal_run<DataLayout::NCHW, float>(window, info);
+                NEROIAlignLayerKernel::internal_run<float>(window, info);
                 break;
             }
 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
             case DataType::F16:
             {
-                NEROIAlignLayerKernel::internal_run<DataLayout::NCHW, float16_t>(window, info);
-                break;
-            }
-#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-            default:
-            {
-                ARM_COMPUTE_ERROR("DataType not supported");
-                break;
-            }
-        }
-    }
-    else if(_input->info()->data_layout() == DataLayout::NHWC)
-    {
-        switch(_input->info()->data_type())
-        {
-            case DataType::QASYMM8:
-            {
-                NEROIAlignLayerKernel::internal_run<DataLayout::NHWC, uint8_t, uint16_t>(window, info);
-                break;
-            }
-            case DataType::F32:
-            {
-                NEROIAlignLayerKernel::internal_run<DataLayout::NHWC, float>(window, info);
-                break;
-            }
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-            case DataType::F16:
-            {
-                NEROIAlignLayerKernel::internal_run<DataLayout::NHWC, float16_t>(window, info);
+                NEROIAlignLayerKernel::internal_run<float16_t>(window, info);
                 break;
             }
 #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
@@ -327,21 +340,22 @@
     }
 }
 
-template <DataLayout data_layout, typename input_data_type, typename roi_data_type>
+template <typename input_data_type, typename roi_data_type>
 void NEROIAlignLayerKernel::internal_run(const Window &window, const ThreadInfo &info)
 {
     ARM_COMPUTE_UNUSED(info);
     ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
     ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
 
-    const size_t values_per_roi = _rois->info()->dimension(0);
+    const DataLayout data_layout    = _input->info()->data_layout();
+    const size_t     values_per_roi = _rois->info()->dimension(0);
 
     const int roi_list_start = window.x().start();
     const int roi_list_end   = window.x().end();
 
-    const unsigned int idx_width  = get_data_layout_dimension_index(_input->info()->data_layout(), DataLayoutDimension::WIDTH);
-    const unsigned int idx_height = get_data_layout_dimension_index(_input->info()->data_layout(), DataLayoutDimension::HEIGHT);
-    const unsigned int idx_depth  = get_data_layout_dimension_index(_input->info()->data_layout(), DataLayoutDimension::CHANNEL);
+    const unsigned int idx_width  = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
+    const unsigned int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
+    const unsigned int idx_depth  = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL);
 
     const int input_width   = _input->info()->dimension(idx_width);
     const int input_height  = _input->info()->dimension(idx_height);
@@ -397,14 +411,14 @@
                     input_data_type out_val(0);
                     if(is_qasymm)
                     {
-                        out_val = roi_align_1x1_qasymm8<input_data_type, data_layout>(
+                        out_val = roi_align_1x1_qasymm8<input_data_type>(
                                       _input, roi_batch, region_start_x, bin_size_x,
                                       roi_bin_grid_x, region_end_x, region_start_y, bin_size_y,
                                       roi_bin_grid_y, region_end_y, ch, _output->info()->quantization_info());
                     }
                     else
                     {
-                        out_val = roi_align_1x1<input_data_type, data_layout>(
+                        out_val = roi_align_1x1<input_data_type>(
                                       _input, roi_batch, region_start_x, bin_size_x,
                                       roi_bin_grid_x, region_end_x, region_start_y, bin_size_y,
                                       roi_bin_grid_y, region_end_y, ch);