COMPMID-2608: Enable quantization with multiplier greater than 1 on NEON

Change-Id: Ib2b0c9ac88fc2b645f478c9981f71ee28f2c77fd
Signed-off-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Reviewed-on: https://review.mlplatform.org/c/2425
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
diff --git a/arm_compute/core/NEON/NEAsymm.h b/arm_compute/core/NEON/NEAsymm.h
index 67adcef..c09a7d9 100644
--- a/arm_compute/core/NEON/NEAsymm.h
+++ b/arm_compute/core/NEON/NEAsymm.h
@@ -88,17 +88,32 @@
 {
     const static int32x4_t zero_s32 = vdupq_n_s32(0);
 
-    // Fixed point multiplication with vector saturating rounding doubling multiply high with scalar
-    in_s32.val[0] = vqrdmulhq_n_s32(in_s32.val[0], result_fixedpoint_multiplier);
-    in_s32.val[1] = vqrdmulhq_n_s32(in_s32.val[1], result_fixedpoint_multiplier);
-    in_s32.val[2] = vqrdmulhq_n_s32(in_s32.val[2], result_fixedpoint_multiplier);
-    in_s32.val[3] = vqrdmulhq_n_s32(in_s32.val[3], result_fixedpoint_multiplier);
+    if(result_shift < 0)
+    {
+        in_s32.val[0] = vmulq_n_s32(in_s32.val[0], (1 << (-result_shift)));
+        in_s32.val[1] = vmulq_n_s32(in_s32.val[1], (1 << (-result_shift)));
+        in_s32.val[2] = vmulq_n_s32(in_s32.val[2], (1 << (-result_shift)));
+        in_s32.val[3] = vmulq_n_s32(in_s32.val[3], (1 << (-result_shift)));
 
-    // Round to the nearest division by a power-of-two using result_shift_s32
-    in_s32.val[0] = rounding_divide_by_pow2(in_s32.val[0], result_shift);
-    in_s32.val[1] = rounding_divide_by_pow2(in_s32.val[1], result_shift);
-    in_s32.val[2] = rounding_divide_by_pow2(in_s32.val[2], result_shift);
-    in_s32.val[3] = rounding_divide_by_pow2(in_s32.val[3], result_shift);
+        in_s32.val[0] = vqrdmulhq_n_s32(in_s32.val[0], result_fixedpoint_multiplier);
+        in_s32.val[1] = vqrdmulhq_n_s32(in_s32.val[1], result_fixedpoint_multiplier);
+        in_s32.val[2] = vqrdmulhq_n_s32(in_s32.val[2], result_fixedpoint_multiplier);
+        in_s32.val[3] = vqrdmulhq_n_s32(in_s32.val[3], result_fixedpoint_multiplier);
+    }
+    else
+    {
+        // Fixed point multiplication with vector saturating rounding doubling multiply high with scalar
+        in_s32.val[0] = vqrdmulhq_n_s32(in_s32.val[0], result_fixedpoint_multiplier);
+        in_s32.val[1] = vqrdmulhq_n_s32(in_s32.val[1], result_fixedpoint_multiplier);
+        in_s32.val[2] = vqrdmulhq_n_s32(in_s32.val[2], result_fixedpoint_multiplier);
+        in_s32.val[3] = vqrdmulhq_n_s32(in_s32.val[3], result_fixedpoint_multiplier);
+
+        // Round to the nearest division by a power-of-two using result_shift_s32
+        in_s32.val[0] = rounding_divide_by_pow2(in_s32.val[0], result_shift);
+        in_s32.val[1] = rounding_divide_by_pow2(in_s32.val[1], result_shift);
+        in_s32.val[2] = rounding_divide_by_pow2(in_s32.val[2], result_shift);
+        in_s32.val[3] = rounding_divide_by_pow2(in_s32.val[3], result_shift);
+    }
 
     // Add the offset terms
     in_s32.val[0] = vaddq_s32(in_s32.val[0], result_offset_after_shift_s32);
@@ -154,17 +169,32 @@
                                 int8x16_t    min_s8,
                                 int8x16_t    max_s8)
 {
-    // Fixed point multiplication with vector saturating rounding doubling multiply high with scalar
-    in_s32.val[0] = vqrdmulhq_n_s32(in_s32.val[0], result_fixedpoint_multiplier);
-    in_s32.val[1] = vqrdmulhq_n_s32(in_s32.val[1], result_fixedpoint_multiplier);
-    in_s32.val[2] = vqrdmulhq_n_s32(in_s32.val[2], result_fixedpoint_multiplier);
-    in_s32.val[3] = vqrdmulhq_n_s32(in_s32.val[3], result_fixedpoint_multiplier);
+    if(result_shift < 0)
+    {
+        in_s32.val[0] = vmulq_n_s32(in_s32.val[0], (1 << (-result_shift)));
+        in_s32.val[1] = vmulq_n_s32(in_s32.val[1], (1 << (-result_shift)));
+        in_s32.val[2] = vmulq_n_s32(in_s32.val[2], (1 << (-result_shift)));
+        in_s32.val[3] = vmulq_n_s32(in_s32.val[3], (1 << (-result_shift)));
 
-    // Round to the nearest division by a power-of-two using result_shift_s32
-    in_s32.val[0] = rounding_divide_by_pow2(in_s32.val[0], result_shift);
-    in_s32.val[1] = rounding_divide_by_pow2(in_s32.val[1], result_shift);
-    in_s32.val[2] = rounding_divide_by_pow2(in_s32.val[2], result_shift);
-    in_s32.val[3] = rounding_divide_by_pow2(in_s32.val[3], result_shift);
+        in_s32.val[0] = vqrdmulhq_n_s32(in_s32.val[0], result_fixedpoint_multiplier);
+        in_s32.val[1] = vqrdmulhq_n_s32(in_s32.val[1], result_fixedpoint_multiplier);
+        in_s32.val[2] = vqrdmulhq_n_s32(in_s32.val[2], result_fixedpoint_multiplier);
+        in_s32.val[3] = vqrdmulhq_n_s32(in_s32.val[3], result_fixedpoint_multiplier);
+    }
+    else
+    {
+        // Fixed point multiplication with vector saturating rounding doubling multiply high with scalar
+        in_s32.val[0] = vqrdmulhq_n_s32(in_s32.val[0], result_fixedpoint_multiplier);
+        in_s32.val[1] = vqrdmulhq_n_s32(in_s32.val[1], result_fixedpoint_multiplier);
+        in_s32.val[2] = vqrdmulhq_n_s32(in_s32.val[2], result_fixedpoint_multiplier);
+        in_s32.val[3] = vqrdmulhq_n_s32(in_s32.val[3], result_fixedpoint_multiplier);
+
+        // Round to the nearest division by a power-of-two using result_shift_s32
+        in_s32.val[0] = rounding_divide_by_pow2(in_s32.val[0], result_shift);
+        in_s32.val[1] = rounding_divide_by_pow2(in_s32.val[1], result_shift);
+        in_s32.val[2] = rounding_divide_by_pow2(in_s32.val[2], result_shift);
+        in_s32.val[3] = rounding_divide_by_pow2(in_s32.val[3], result_shift);
+    }
 
     // Add the offset terms
     in_s32.val[0] = vaddq_s32(in_s32.val[0], result_offset_after_shift_s32);
@@ -214,17 +244,54 @@
                                             const int8x16_t   &min_s8,
                                             const int8x16_t   &max_s8)
 {
-    // Fixed point multiplication with vector saturating rounding doubling multiply high with scalar
-    in_s32.val[0] = vqrdmulhq_s32(in_s32.val[0], result_fixedpoint_multiplier.val[0]);
-    in_s32.val[1] = vqrdmulhq_s32(in_s32.val[1], result_fixedpoint_multiplier.val[1]);
-    in_s32.val[2] = vqrdmulhq_s32(in_s32.val[2], result_fixedpoint_multiplier.val[2]);
-    in_s32.val[3] = vqrdmulhq_s32(in_s32.val[3], result_fixedpoint_multiplier.val[3]);
+    const static int32x4_t one_s32 = vdupq_n_s32(1);
 
+    // Fixed point multiplication with vector saturating rounding doubling multiply high with scalar
+    int32x4x4_t res_shift_gt0 =
+    {
+        vqrdmulhq_s32(in_s32.val[0], result_fixedpoint_multiplier.val[0]),
+        vqrdmulhq_s32(in_s32.val[1], result_fixedpoint_multiplier.val[1]),
+        vqrdmulhq_s32(in_s32.val[2], result_fixedpoint_multiplier.val[2]),
+        vqrdmulhq_s32(in_s32.val[3], result_fixedpoint_multiplier.val[3]),
+    };
     // Round to the nearest division by a power-of-two using result_shift_s32
-    in_s32.val[0] = rounding_divide_by_pow2(in_s32.val[0], result_shift.val[0]);
-    in_s32.val[1] = rounding_divide_by_pow2(in_s32.val[1], result_shift.val[1]);
-    in_s32.val[2] = rounding_divide_by_pow2(in_s32.val[2], result_shift.val[2]);
-    in_s32.val[3] = rounding_divide_by_pow2(in_s32.val[3], result_shift.val[3]);
+    res_shift_gt0.val[0] = rounding_divide_by_pow2(res_shift_gt0.val[0], result_shift.val[0]);
+    res_shift_gt0.val[1] = rounding_divide_by_pow2(res_shift_gt0.val[1], result_shift.val[1]);
+    res_shift_gt0.val[2] = rounding_divide_by_pow2(res_shift_gt0.val[2], result_shift.val[2]);
+    res_shift_gt0.val[3] = rounding_divide_by_pow2(res_shift_gt0.val[3], result_shift.val[3]);
+
+    int32x4x4_t res_shift_lt0 =
+    {
+        vmulq_s32(in_s32.val[0], vshlq_s32(one_s32, vnegq_s32(result_shift.val[0]))),
+        vmulq_s32(in_s32.val[1], vshlq_s32(one_s32, vnegq_s32(result_shift.val[1]))),
+        vmulq_s32(in_s32.val[2], vshlq_s32(one_s32, vnegq_s32(result_shift.val[2]))),
+        vmulq_s32(in_s32.val[3], vshlq_s32(one_s32, vnegq_s32(result_shift.val[3]))),
+    };
+    res_shift_lt0.val[0] = vqrdmulhq_s32(res_shift_lt0.val[0], result_fixedpoint_multiplier.val[0]);
+    res_shift_lt0.val[1] = vqrdmulhq_s32(res_shift_lt0.val[1], result_fixedpoint_multiplier.val[1]);
+    res_shift_lt0.val[2] = vqrdmulhq_s32(res_shift_lt0.val[2], result_fixedpoint_multiplier.val[2]);
+    res_shift_lt0.val[3] = vqrdmulhq_s32(res_shift_lt0.val[3], result_fixedpoint_multiplier.val[3]);
+
+    // Select result depending on shift value
+    const uint32x4x4_t mask_lt0 =
+    {
+#ifdef __aarch64__
+        vcltzq_s32(result_shift.val[0]),
+        vcltzq_s32(result_shift.val[1]),
+        vcltzq_s32(result_shift.val[2]),
+        vcltzq_s32(result_shift.val[3]),
+#else  //__aarch64__
+        vcltq_s32(result_shift.val[0], vdupq_n_s32(0)),
+        vcltq_s32(result_shift.val[1], vdupq_n_s32(0)),
+        vcltq_s32(result_shift.val[2], vdupq_n_s32(0)),
+        vcltq_s32(result_shift.val[3], vdupq_n_s32(0)),
+#endif //__aarch64__
+    };
+
+    in_s32.val[0] = vbslq_s32(mask_lt0.val[0], res_shift_lt0.val[0], res_shift_gt0.val[0]);
+    in_s32.val[1] = vbslq_s32(mask_lt0.val[1], res_shift_lt0.val[1], res_shift_gt0.val[1]);
+    in_s32.val[2] = vbslq_s32(mask_lt0.val[2], res_shift_lt0.val[2], res_shift_gt0.val[2]);
+    in_s32.val[3] = vbslq_s32(mask_lt0.val[3], res_shift_lt0.val[3], res_shift_gt0.val[3]);
 
     // Add the offset terms
     in_s32.val[0] = vaddq_s32(in_s32.val[0], result_offset_after_shift_s32);
@@ -273,11 +340,17 @@
 {
     int32x4_t in_s32 = vdupq_n_s32(in_value);
 
-    // Fixed point multiplication with vector saturating rounding doubling multiply high with scalar
-    in_value = vgetq_lane_s32(vqrdmulhq_n_s32(in_s32, result_fixedpoint_multiplier), 0);
-
-    // Shift value by result_shift_s32
-    in_value = rounding_divide_by_pow2(in_value, result_shift);
+    if(result_shift < 0)
+    {
+        in_value = vgetq_lane_s32(vqrdmulhq_n_s32(vmulq_n_s32(in_s32, (1 << (-result_shift))), result_fixedpoint_multiplier), 0);
+    }
+    else
+    {
+        // Fixed point multiplication with vector saturating rounding doubling multiply high with scalar
+        in_value = vgetq_lane_s32(vqrdmulhq_n_s32(in_s32, result_fixedpoint_multiplier), 0);
+        // Shift value by result_shift_s32
+        in_value = rounding_divide_by_pow2(in_value, result_shift);
+    }
 
     // Add the offset term
     in_value += result_offset_after_shift_s32;
@@ -312,11 +385,18 @@
 {
     int32x4_t in_s32 = vdupq_n_s32(in_value);
 
-    // Fixed point multiplication with vector saturating rounding doubling multiply high with scalar
-    in_value = vgetq_lane_s32(vqrdmulhq_n_s32(in_s32, result_fixedpoint_multiplier), 0);
+    if(result_shift < 0)
+    {
+        in_value = vgetq_lane_s32(vqrdmulhq_n_s32(vmulq_n_s32(in_s32, (1 << (-result_shift))), result_fixedpoint_multiplier), 0);
+    }
+    else
+    {
+        // Fixed point multiplication with vector saturating rounding doubling multiply high with scalar
+        in_value = vgetq_lane_s32(vqrdmulhq_n_s32(in_s32, result_fixedpoint_multiplier), 0);
 
-    // Shift value by result_shift_s32
-    in_value = rounding_divide_by_pow2(in_value, result_shift);
+        // Shift value by result_shift_s32
+        in_value = rounding_divide_by_pow2(in_value, result_shift);
+    }
 
     // Add the offset term
     in_value += result_offset_after_shift_s32;
diff --git a/arm_compute/core/utils/quantization/AsymmHelpers.h b/arm_compute/core/utils/quantization/AsymmHelpers.h
index 1bdc995..94876fb 100644
--- a/arm_compute/core/utils/quantization/AsymmHelpers.h
+++ b/arm_compute/core/utils/quantization/AsymmHelpers.h
@@ -60,7 +60,7 @@
  */
 Status calculate_quantized_multiplier_greater_than_one(float multiplier, int32_t *quantized_multiplier, int32_t *left_shift);
 
-/** Calculate quantized representation of per-channel multipliers with value less than one.
+/** Calculate quantized representation of per-channel multipliers
  *
  * @param[in]      iq_info    Input quantization info.
  * @param[in]      wq_info    Weights quantization info.
@@ -69,10 +69,10 @@
  *
  * @return a status
  */
-Status calculate_quantized_multipliers_less_than_one(const QuantizationInfo &iq_info,
-                                                     const QuantizationInfo &wq_info,
-                                                     const QuantizationInfo &oq_info,
-                                                     GEMMLowpOutputStageInfo &stage_info);
+Status calculate_quantized_multipliers(const QuantizationInfo &iq_info,
+                                       const QuantizationInfo &wq_info,
+                                       const QuantizationInfo &oq_info,
+                                       GEMMLowpOutputStageInfo &stage_info);
 
 /** Get minimum and maximum values for the input quantized data type
  *