MLCE-229: Support for negative shifts in asm kernels

Change-Id: I2c5e98aae7698963f106d7423df0e65cd00ee2a9
Signed-off-by: morgolock <pablo.tello@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3710
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Sheri Zhang <sheri.zhang@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/core/NEON/kernels/arm_gemm/quantized.cpp b/src/core/NEON/kernels/arm_gemm/quantized.cpp
index e50dca7..201bd9d 100644
--- a/src/core/NEON/kernels/arm_gemm/quantized.cpp
+++ b/src/core/NEON/kernels/arm_gemm/quantized.cpp
@@ -55,15 +55,16 @@
  * column is set up in any case (and it is hoped that the compiler can elide
  * the needless movs in the per-layer case).
  */
-template<bool do_shift_correction, bool per_channel>
+template<bool do_shift_correction, bool per_channel, bool do_left_shift>
 void requantize_block_32_int(const Requantize32 &qp, unsigned int width, unsigned int height,
                              const int32_t *input, unsigned int in_stride, int8_t *output, unsigned int out_stride,
                              const int32_t *row_bias, const int32_t *col_bias, const unsigned int start_col) {
-    const int32x4_t v_mul      = vdupq_n_s32(qp.per_layer_mul);
-    const int32x4_t v_shift    = vdupq_n_s32(qp.per_layer_shift);
-    const int32x4_t v_minval   = vdupq_n_s32(qp.minval);
-    const int32x4_t v_maxval   = vdupq_n_s32(qp.maxval);
-    const int32x4_t v_c_offset = vdupq_n_s32(qp.c_offset);
+    const int32x4_t v_mul          = vdupq_n_s32(qp.per_layer_mul);
+    const int32x4_t v_right_shift  = vdupq_n_s32(qp.per_layer_right_shift);
+    const int32x4_t v_left_shift   = vdupq_n_s32(qp.per_layer_left_shift);
+    const int32x4_t v_minval       = vdupq_n_s32(qp.minval);
+    const int32x4_t v_maxval       = vdupq_n_s32(qp.maxval);
+    const int32x4_t v_c_offset     = vdupq_n_s32(qp.c_offset);
 
     /* To make sure we have plenty of accumulators, compute two rows at a
      * time.  If the number of rows is odd, compute the bottom row twice to
@@ -77,8 +78,9 @@
         unsigned int odds=(width % 4);
 
         const int32_t *colptr = col_bias;
-        const int32_t *perch_mul_ptr   = qp.per_channel_muls + start_col;
-        const int32_t *perch_shift_ptr = qp.per_channel_shifts + start_col;
+        const int32_t *perch_mul_ptr    = qp.per_channel_muls + start_col;
+        const int32_t *perch_shift_ptr  = qp.per_channel_right_shifts + start_col;
+        const int32_t *perch_shiftl_ptr = qp.per_channel_left_shifts + start_col;
 
         const int32_t *in_ptr = input + (row * in_stride);
         int8_t *out_ptr = output + (row * out_stride);
@@ -112,6 +114,11 @@
             int32x4_t v_shf2;
             int32x4_t v_shf3;
 
+            int32x4_t v_shf0l;
+            int32x4_t v_shf1l;
+            int32x4_t v_shf2l;
+            int32x4_t v_shf3l;
+
             if (per_channel) {
                 v_mul0 = vld1q_s32(perch_mul_ptr);
                 v_mul1 = vld1q_s32(perch_mul_ptr + 4);
@@ -124,9 +131,17 @@
                 v_shf2 = vld1q_s32(perch_shift_ptr + 8);
                 v_shf3 = vld1q_s32(perch_shift_ptr + 12);
                 perch_shift_ptr += 16;
+
+                if (do_left_shift) {
+                    v_shf0l = vld1q_s32(perch_shiftl_ptr);
+                    v_shf1l = vld1q_s32(perch_shiftl_ptr + 4);
+                    v_shf2l = vld1q_s32(perch_shiftl_ptr + 8);
+                    v_shf3l = vld1q_s32(perch_shiftl_ptr + 12);
+                }
             } else {
                 v_mul0=v_mul1=v_mul2=v_mul3=v_mul;
-                v_shf0=v_shf1=v_shf2=v_shf3=v_shift;
+                v_shf0=v_shf1=v_shf2=v_shf3=v_right_shift;
+                v_shf0l=v_shf1l=v_shf2l=v_shf3l=v_left_shift;
             }
 
             // Load column pointers
@@ -171,7 +186,22 @@
             v_in12 = vaddq_s32(v_in12, v_col2);
             v_in13 = vaddq_s32(v_in13, v_col3);
 
-            // Quantize - start with multiply
+            // Quantize
+
+            // If a left shift is needed it needs to happen first.
+            if (do_left_shift) {
+                v_in00 = vrshlq_s32(v_in00, v_shf0l);
+                v_in01 = vrshlq_s32(v_in01, v_shf1l);
+                v_in02 = vrshlq_s32(v_in02, v_shf2l);
+                v_in03 = vrshlq_s32(v_in03, v_shf3l);
+
+                v_in10 = vrshlq_s32(v_in10, v_shf0l);
+                v_in11 = vrshlq_s32(v_in11, v_shf1l);
+                v_in12 = vrshlq_s32(v_in12, v_shf2l);
+                v_in13 = vrshlq_s32(v_in13, v_shf3l);
+            }
+
+            // Multiply
             v_in00 = vqrdmulhq_s32(v_in00, v_mul0);
             v_in01 = vqrdmulhq_s32(v_in01, v_mul1);
             v_in02 = vqrdmulhq_s32(v_in02, v_mul2);
@@ -273,6 +303,7 @@
         while (regs--) {
             int32x4_t v_mul0;
             int32x4_t v_shf0;
+            int32x4_t v_shf0l;
 
             if (per_channel) {
                 v_mul0 = vld1q_s32(perch_mul_ptr);
@@ -280,9 +311,15 @@
 
                 v_shf0 = vld1q_s32(perch_shift_ptr);
                 perch_shift_ptr += 4;
+
+                if (do_left_shift) {
+                    v_shf0l = vld1q_s32(perch_shiftl_ptr);
+                    perch_shiftl_ptr += 4;
+                }
             } else {
                 v_mul0=v_mul;
-                v_shf0=v_shift;
+                v_shf0=v_right_shift;
+                v_shf0l=v_left_shift;
             }
             // Load column pointers
             int32x4_t v_col0 = vld1q_s32(colptr);
@@ -306,7 +343,14 @@
 
             v_in10 = vaddq_s32(v_in10, v_col0);
 
-            // Quantize - start with multiply
+            // Quantize - start with (optional) left shift
+            if (do_left_shift) {
+                v_in00 = vrshlq_s32(v_in00, v_shf0l);
+
+                v_in10 = vrshlq_s32(v_in10, v_shf0l);
+            }
+
+            // Then multiply
             v_in00 = vqrdmulhq_s32(v_in00, v_mul0);
 
             v_in10 = vqrdmulhq_s32(v_in10, v_mul0);
@@ -358,10 +402,12 @@
             int32x4_t v_in10 = vdupq_n_s32(0);
             int32x4_t v_mul0 = vdupq_n_s32(0);
             int32x4_t v_shf0 = vdupq_n_s32(0);
+            int32x4_t v_shf0l = vdupq_n_s32(0);
 
             if (!per_channel) {
                 v_mul0 = v_mul;
-                v_shf0 = v_shift;
+                v_shf0 = v_right_shift;
+                v_shf0l = v_left_shift;
             }
 
             do {
@@ -371,6 +417,9 @@
                 if (per_channel) {
                     v_mul0 = vld1q_lane_s32(perch_mul_ptr, v_mul0, 0);
                     v_shf0 = vld1q_lane_s32(perch_shift_ptr, v_shf0, 0);
+                    if (do_left_shift) {
+                        v_shf0l = vld1q_lane_s32(perch_shiftl_ptr, v_shf0l, 0);
+                    }
                 }
                 if (odds == 1) { break; }
 
@@ -380,6 +429,9 @@
                 if (per_channel) {
                     v_mul0 = vld1q_lane_s32(perch_mul_ptr + 1, v_mul0, 1);
                     v_shf0 = vld1q_lane_s32(perch_shift_ptr + 1, v_shf0, 1);
+                    if (do_left_shift) {
+                        v_shf0l = vld1q_lane_s32(perch_shiftl_ptr + 1, v_shf0l, 1);
+                    }
                 }
                 if (odds == 2) { break; }
 
@@ -389,6 +441,9 @@
                 if (per_channel) {
                     v_mul0 = vld1q_lane_s32(perch_mul_ptr + 2, v_mul0, 2);
                     v_shf0 = vld1q_lane_s32(perch_shift_ptr + 2, v_shf0, 2);
+                    if (do_left_shift) {
+                        v_shf0l = vld1q_lane_s32(perch_shiftl_ptr + 2, v_shf0l, 2);
+                    }
                 }
             } while (0);
 
@@ -402,7 +457,14 @@
 
             v_in10 = vaddq_s32(v_in10, v_col0);
 
-            // Quantize - start with multiply
+            // Quantize - start with (optional) left shift
+            if (do_left_shift) {
+                v_in00 = vrshlq_s32(v_in00, v_shf0l);
+
+                v_in10 = vrshlq_s32(v_in10, v_shf0l);
+            }
+
+            // Then multiply
             v_in00 = vqrdmulhq_s32(v_in00, v_mul0);
 
             v_in10 = vqrdmulhq_s32(v_in10, v_mul0);
@@ -464,19 +526,39 @@
                          const int32_t *row_bias, const int32_t *col_bias, unsigned int start_col) {
     if (qp.per_channel_requant) {
         if (qp.minval >= qp.c_offset) {
-            requantize_block_32_int<false, true>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
-                             reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
+            if (qp.per_channel_left_shifts) {
+                requantize_block_32_int<false, true, true>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
+                                 reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
+            } else {
+                requantize_block_32_int<false, true, false>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
+                                 reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
+            }
         } else {
-            requantize_block_32_int<true, true>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
-                             reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
+            if (qp.per_channel_left_shifts) {
+                requantize_block_32_int<true, true, true>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
+                                 reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
+            } else {
+                requantize_block_32_int<true, true, false>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
+                                 reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
+            }
         }
     } else {
         if (qp.minval >= qp.c_offset) {
-            requantize_block_32_int<false, false>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
-                             reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
+            if (qp.per_layer_left_shift > 0) {
+                requantize_block_32_int<false, false, true>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
+                                 reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
+            } else {
+                requantize_block_32_int<false, false, false>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
+                                 reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
+            }
         } else {
-            requantize_block_32_int<true, false>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
-                             reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
+            if (qp.per_layer_left_shift > 0) {
+                requantize_block_32_int<true, false, true>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
+                                 reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
+            } else {
+                requantize_block_32_int<true, false, false>(qp, width, height, reinterpret_cast<const int32_t *>(input), in_stride,
+                                 reinterpret_cast<int8_t *>(output), out_stride, row_bias, col_bias, start_col);
+            }
         }
     }
 }