COMPMID-2060: Support different qinfo in PoolingLayer

CL and Neon back ends now support different qinfos

Change-Id: I638d5f258ab2f99b40659601b4c5398d2c34c43b
Signed-off-by: Pablo Tello <pablo.tello@arm.com>
Reviewed-on: https://review.mlplatform.org/c/927
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michalis Spyrou <michalis.spyrou@arm.com>
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
diff --git a/src/core/NEON/kernels/NEPoolingLayerKernel.cpp b/src/core/NEON/kernels/NEPoolingLayerKernel.cpp
index d00a4af..308fad5 100644
--- a/src/core/NEON/kernels/NEPoolingLayerKernel.cpp
+++ b/src/core/NEON/kernels/NEPoolingLayerKernel.cpp
@@ -138,7 +138,6 @@
     {
         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(input, output);
-        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(input, output);
         ARM_COMPUTE_RETURN_ERROR_ON((output->dimension(get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::WIDTH)) != pooled_w)
                                     || (output->dimension(get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::HEIGHT)) != pooled_h));
     }
@@ -640,6 +639,15 @@
             }
         }
 
+        const QuantizationInfo &input_qinfo  = _input->info()->quantization_info();
+        const QuantizationInfo &output_qinfo = _output->info()->quantization_info();
+        if(input_qinfo != output_qinfo)
+        {
+            const auto requantized_output = vquantize(vdequantize(vcombine_u8(lower_res, upper_res), input_qinfo), output_qinfo);
+            lower_res                     = vget_low_u8(requantized_output);
+            upper_res                     = vget_high_u8(requantized_output);
+        }
+
         // Store result
         if(pool_stride_x == 1)
         {
@@ -1641,6 +1649,11 @@
         }
 
         // Store result
+        const QuantizationInfo &input_qinfo  = _input->info()->quantization_info();
+        const QuantizationInfo &output_qinfo = _output->info()->quantization_info();
+        res                                  = (input_qinfo != output_qinfo) ? sqcvt_qasymm8_f32(scvt_f32_qasymm8(res, input_qinfo.scale, input_qinfo.offset), output_qinfo.scale,
+                                                                                                 output_qinfo.offset) :
+                                               res;
         *(reinterpret_cast<uint8_t *>(output.ptr())) = res;
     },
     input, output);
@@ -1663,7 +1676,9 @@
     const int upper_bound_w = _input->info()->dimension(1) + (exclude_padding ? 0 : pool_pad_right);
     const int upper_bound_h = _input->info()->dimension(2) + (exclude_padding ? 0 : pool_pad_bottom);
 
-    const float32x4_t half_scale_v = vdupq_n_f32(0.5f);
+    const float32x4_t       half_scale_v = vdupq_n_f32(0.5f);
+    const QuantizationInfo &input_qinfo  = _input->info()->quantization_info();
+    const QuantizationInfo &output_qinfo = _output->info()->quantization_info();
 
     execute_window_loop(window, [&](const Coordinates & id)
     {
@@ -1713,6 +1728,12 @@
 
             uint8x8_t res1 = vmovn_u16(vcombine_u16(vmovn_u32(vres1), vmovn_u32(vres2)));
             uint8x8_t res2 = vmovn_u16(vcombine_u16(vmovn_u32(vres3), vmovn_u32(vres4)));
+            if(input_qinfo != output_qinfo)
+            {
+                const auto requantized_output = vquantize(vdequantize(vcombine_u8(res1, res2), input_qinfo), output_qinfo);
+                res1                          = vget_low_u8(requantized_output);
+                res2                          = vget_high_u8(requantized_output);
+            }
 
             // Store result
             vst1_u8(output.ptr(), res1);
@@ -1733,7 +1754,7 @@
             }
 
             // Store result
-            vst1q_u8(output.ptr(), vres);
+            vst1q_u8(output.ptr(), (input_qinfo != output_qinfo) ? vquantize(vdequantize(vres, input_qinfo), output_qinfo) : vres);
         }
     },
     input, output);