COMPMID-2640: Fix performance regression for Resnet101 Int8 on NEON

Change-Id: I32c8b67c5ce0918cc5603807bad80952ea2fd097
Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Reviewed-on: https://review.mlplatform.org/c/1848
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
diff --git a/arm_compute/core/NEON/kernels/NEPixelWiseMultiplicationKernel.h b/arm_compute/core/NEON/kernels/NEPixelWiseMultiplicationKernel.h
index e2ea90a..a199a11 100644
--- a/arm_compute/core/NEON/kernels/NEPixelWiseMultiplicationKernel.h
+++ b/arm_compute/core/NEON/kernels/NEPixelWiseMultiplicationKernel.h
@@ -127,6 +127,7 @@
     ITensor       *_output;
     float          _scale;
     int            _scale_exponent;
+    bool           _run_optimized_qasymm8;
 };
 
 /** Interface for the complex pixelwise multiplication kernel. */
diff --git a/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp b/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp
index 711bde3..1dab5d9 100644
--- a/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp
+++ b/src/core/NEON/kernels/NEPixelWiseMultiplicationKernel.cpp
@@ -191,8 +191,8 @@
     return vreinterpretq_u16_s16(vcombine_s16(vmovn_s32(tmp_s2), vmovn_s32(tmp_s1)));
 }
 
-void mul_saturate_QASYMM8_QASYMM8_QASYMM8_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, float scale,
-                                            const UniformQuantizationInfo &input1_qua_info, const UniformQuantizationInfo &input2_qua_info, const UniformQuantizationInfo &output_qua_info)
+inline void mul_saturate_QASYMM8_QASYMM8_QASYMM8_n_opt(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, float scale,
+                                                       float32x4_t input1_vscale, int32x4_t input1_voffset, float32x4_t input2_vscale, int32x4_t input2_voffset, float32x4_t output_voffset, float32x4_t vinvscale)
 {
     const auto input1 = static_cast<const qasymm8_t *__restrict>(input1_ptr);
     const auto input2 = static_cast<const qasymm8_t *__restrict>(input2_ptr);
@@ -202,21 +202,40 @@
     const qasymm8x16_t input2_q = vld1q_u8(input2);
 
     // Dequantitize inputs
-    const float32x4x4_t in1_f32x4x4 = vdequantize(input1_q, input1_qua_info);
-    const float32x4x4_t in2_f32x4x4 = vdequantize(input2_q, input2_qua_info);
+    float32x4x4_t in1_f32x4x4;
+    float32x4x4_t in2_f32x4x4;
+    in1_f32x4x4.val[0] = vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_low_u8(input1_q))))), input1_voffset)), input1_vscale);
+    in1_f32x4x4.val[1] = vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(vmovl_u8(vget_low_u8(input1_q))))), input1_voffset)), input1_vscale);
+    in1_f32x4x4.val[2] = vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_high_u8(input1_q))))), input1_voffset)), input1_vscale);
+    in1_f32x4x4.val[3] = vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(vmovl_u8(vget_high_u8(input1_q))))), input1_voffset)), input1_vscale);
 
-    const UniformQuantizationInfo tmp_qua_info = { output_qua_info.scale / scale, output_qua_info.offset };
+    in2_f32x4x4.val[0] = vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_low_u8(input2_q))))), input2_voffset)), input2_vscale);
+    in2_f32x4x4.val[1] = vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(vmovl_u8(vget_low_u8(input2_q))))), input2_voffset)), input2_vscale);
+    in2_f32x4x4.val[2] = vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_high_u8(input2_q))))), input2_voffset)), input2_vscale);
+    in2_f32x4x4.val[3] = vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(vmovl_u8(vget_high_u8(input2_q))))), input2_voffset)), input2_vscale);
 
-    const float32x4x4_t out_f32x4x4 =
-    {
-        vmulq_f32(in1_f32x4x4.val[0], in2_f32x4x4.val[0]),
-        vmulq_f32(in1_f32x4x4.val[1], in2_f32x4x4.val[1]),
-        vmulq_f32(in1_f32x4x4.val[2], in2_f32x4x4.val[2]),
-        vmulq_f32(in1_f32x4x4.val[3], in2_f32x4x4.val[3])
-    };
+    float32x4x4_t out_f32x4x4;
+    out_f32x4x4.val[0] = vmulq_f32(in1_f32x4x4.val[0], in2_f32x4x4.val[0]);
+    out_f32x4x4.val[1] = vmulq_f32(in1_f32x4x4.val[1], in2_f32x4x4.val[1]);
+    out_f32x4x4.val[2] = vmulq_f32(in1_f32x4x4.val[2], in2_f32x4x4.val[2]);
+    out_f32x4x4.val[3] = vmulq_f32(in1_f32x4x4.val[3], in2_f32x4x4.val[3]);
 
-    const uint8x16_t result = vquantize(out_f32x4x4, tmp_qua_info);
-    vst1q_u8(output, result);
+    int32x4x4_t rf;
+#ifdef __aarch64__
+    rf.val[0] = vcvtnq_s32_f32(vmlaq_f32(output_voffset, out_f32x4x4.val[0], vinvscale));
+    rf.val[1] = vcvtnq_s32_f32(vmlaq_f32(output_voffset, out_f32x4x4.val[1], vinvscale));
+    rf.val[2] = vcvtnq_s32_f32(vmlaq_f32(output_voffset, out_f32x4x4.val[2], vinvscale));
+    rf.val[3] = vcvtnq_s32_f32(vmlaq_f32(output_voffset, out_f32x4x4.val[3], vinvscale));
+#else  //__aarch64__
+    rf.val[0] = vcvtq_s32_f32(vmlaq_f32(output_voffset, out_f32x4x4.val[0], vinvscale));
+    rf.val[1] = vcvtq_s32_f32(vmlaq_f32(output_voffset, out_f32x4x4.val[1], vinvscale));
+    rf.val[2] = vcvtq_s32_f32(vmlaq_f32(output_voffset, out_f32x4x4.val[2], vinvscale));
+    rf.val[3] = vcvtq_s32_f32(vmlaq_f32(output_voffset, out_f32x4x4.val[3], vinvscale));
+#endif //__aarch64__
+    const uint8x8_t pa = vqmovun_s16(vcombine_s16(vqmovn_s32(rf.val[0]), vqmovn_s32(rf.val[1])));
+    const uint8x8_t pb = vqmovun_s16(vcombine_s16(vqmovn_s32(rf.val[2]), vqmovn_s32(rf.val[3])));
+
+    vst1q_u8(output, vcombine_u8(pa, pb));
 }
 
 void mul_saturate_QSYMM16_QSYMM16_QSYMM16_n(const void *__restrict input1_ptr, const void *__restrict input2_ptr, void *__restrict output_ptr, float scale,
@@ -534,7 +553,7 @@
 } // namespace
 
 NEPixelWiseMultiplicationKernel::NEPixelWiseMultiplicationKernel()
-    : _func_float(nullptr), _func_int(nullptr), _func_quantized(nullptr), _input1(nullptr), _input2(nullptr), _output(nullptr), _scale{ 0 }, _scale_exponent{ 0 }
+    : _func_float(nullptr), _func_int(nullptr), _func_quantized(nullptr), _input1(nullptr), _input2(nullptr), _output(nullptr), _scale{ 0 }, _scale_exponent{ 0 }, _run_optimized_qasymm8(false)
 {
 }
 
@@ -549,14 +568,15 @@
     auto win_config = validate_and_configure_window(input1->info(), input2->info(), output->info());
     ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
 
-    _input1         = input1;
-    _input2         = input2;
-    _output         = output;
-    _scale          = scale;
-    _scale_exponent = 0;
-    _func_quantized = nullptr;
-    _func_int       = nullptr;
-    _func_float     = nullptr;
+    _input1                = input1;
+    _input2                = input2;
+    _output                = output;
+    _scale                 = scale;
+    _scale_exponent        = 0;
+    _func_quantized        = nullptr;
+    _func_int              = nullptr;
+    _func_float            = nullptr;
+    _run_optimized_qasymm8 = false;
 
     bool is_scale_255 = false;
     // Check and validate scaling factor
@@ -582,7 +602,7 @@
 
     if(dt_input1 == DataType::QASYMM8 && dt_input2 == DataType::QASYMM8)
     {
-        _func_quantized = &mul_saturate_QASYMM8_QASYMM8_QASYMM8_n;
+        _run_optimized_qasymm8 = true;
     }
     else if(dt_input1 == DataType::QSYMM16 && dt_input2 == DataType::QSYMM16)
     {
@@ -707,14 +727,36 @@
 
     if(is_data_type_quantized(_input1->info()->data_type()))
     {
-        execute_window_loop(collapsed, [&](const Coordinates &)
+        if(_run_optimized_qasymm8)
         {
-            (*_func_quantized)(input1.ptr(), input2.ptr(), output.ptr(), _scale,
-                               _input1->info()->quantization_info().uniform(), _input2->info()->quantization_info().uniform(), _output->info()->quantization_info().uniform());
-            ARM_COMPUTE_UNUSED(collapsed.slide_window_slice_3D(slice_input1));
-            ARM_COMPUTE_UNUSED(collapsed.slide_window_slice_3D(slice_input2));
-        },
-        input1, input2, output);
+            const int32x4_t   input1_voffset = vdupq_n_s32(_input1->info()->quantization_info().uniform().offset);
+            const float32x4_t input1_vscale  = vdupq_n_f32(_input1->info()->quantization_info().uniform().scale);
+            const int32x4_t   input2_voffset = vdupq_n_s32(_input2->info()->quantization_info().uniform().offset);
+            const float32x4_t input2_vscale  = vdupq_n_f32(_input2->info()->quantization_info().uniform().scale);
+            const float32x4_t output_voffset = vdupq_n_f32(static_cast<float>(_output->info()->quantization_info().uniform().offset));
+            const float       output_scale   = _output->info()->quantization_info().uniform().scale;
+            const float32x4_t vinvscale      = vdupq_n_f32(1.f / (output_scale / _scale));
+
+            execute_window_loop(collapsed, [&](const Coordinates &)
+            {
+                mul_saturate_QASYMM8_QASYMM8_QASYMM8_n_opt(input1.ptr(), input2.ptr(), output.ptr(), _scale,
+                                                           input1_vscale, input1_voffset, input2_vscale, input2_voffset, output_voffset, vinvscale);
+                ARM_COMPUTE_UNUSED(collapsed.slide_window_slice_3D(slice_input1));
+                ARM_COMPUTE_UNUSED(collapsed.slide_window_slice_3D(slice_input2));
+            },
+            input1, input2, output);
+        }
+        else
+        {
+            execute_window_loop(collapsed, [&](const Coordinates &)
+            {
+                (*_func_quantized)(input1.ptr(), input2.ptr(), output.ptr(), _scale,
+                                   _input1->info()->quantization_info().uniform(), _input2->info()->quantization_info().uniform(), _output->info()->quantization_info().uniform());
+                ARM_COMPUTE_UNUSED(collapsed.slide_window_slice_3D(slice_input1));
+                ARM_COMPUTE_UNUSED(collapsed.slide_window_slice_3D(slice_input2));
+            },
+            input1, input2, output);
+        }
     }
     else if(_func_int != nullptr)
     {