Implement Quantized Matmul T/T and T/Nt kernels using MMUL extension

Resolves: COMPMID-6476, COMPMID-6477
Change-Id: Ied37c269d5a108ff72f70e3ad932cf372bda5562
Signed-off-by: Gunes Bayir <gunes.bayir@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/10346
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Jakub Sujak <jakub.sujak@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/core/CL/cl_kernels/common/mat_mul_quantized_mmul.cl b/src/core/CL/cl_kernels/common/mat_mul_quantized_mmul.cl
index 4ab81d1..fdfb75d 100644
--- a/src/core/CL/cl_kernels/common/mat_mul_quantized_mmul.cl
+++ b/src/core/CL/cl_kernels/common/mat_mul_quantized_mmul.cl
@@ -514,7 +514,9 @@
 /** This OpenCL kernel performs the batch matrix multiplication (BatchMatMul): LHS transposed, RHS non-transposed
  *
  * Supported block configurations:
- *     TODO: Report supported M0, N0, K0
+ *  - M0 = 1, 2, 3, 4, 8, 16
+ *  - N0 = 1, 2, 3, 4, 8, 16
+ *  - K0 = 4
  *
  * Similar to mat_mul_native_quantized_mmul_nt_nt()
  */
@@ -526,6 +528,149 @@
 #endif // defined(BIAS)
     TENSOR3D_T(dst, BUFFER))
 {
+    const uint x0 = get_global_id(0); // [0, (N / N0) * MMUL_M0)
+    // The upper limit is a simplified version of (N / N0) / MMUL_N0) * MMUL_BLOCK_SIZE)
+    const uint y0 = get_global_id(1); // [0, (M / M0) / MMUL_M0)
+    const uint z  = get_global_id(2); // Batch
+
+    // Get section coordinates
+    const uint section_x = (x0 / MMUL_BLOCK_SIZE);
+    const uint section_y = y0;
+
+    // Get thread coordinates within an mmul block
+    const uint thread_id = (x0 % MMUL_BLOCK_SIZE);
+    const uint thread_x  = thread_id % MMUL_N0;
+    const uint thread_y  = (thread_id / MMUL_N0);
+
+    // Calculate dst coordinates
+    const uint dst_x_unclamped = thread_x * N0 + section_x * N0 * MMUL_N0;
+    const uint dst_y_unclamped = thread_y * M0 + section_y * M0 * MMUL_M0;
+    const uint dst_x           = min(dst_x_unclamped, (uint)(N - N0));
+    const uint dst_y           = min(dst_y_unclamped, (uint)(M - M0));
+
+    // Starting LHS coordinates
+    const uint lhs_x = dst_y;
+    const uint lhs_y = K0 * thread_x;
+
+    // Starting RHS coordinates
+    const uint rhs_x = dst_x;
+    const uint rhs_y = K0 * thread_y;
+
+    // Compute LHS/RHS/DST matrix address
+    lhs_offset_first_element_in_bytes += lhs_x * sizeof(DATA_TYPE) + lhs_y * lhs_stride_y + z * lhs_stride_z;
+    rhs_offset_first_element_in_bytes += rhs_x * sizeof(DATA_TYPE) + rhs_y * rhs_stride_y + z * rhs_stride_z;
+    dst_offset_first_element_in_bytes += dst_x * sizeof(DATA_TYPE) + dst_y * dst_stride_y + z * dst_stride_z;
+
+    // Initialize the accumulators
+    TILE(int, M0, N0, c);
+    LOOP_UNROLLING(int, i, 0, 1, M0,
+    {
+        c[i].v = K * ((int)LHS_OFFSET) * ((int)RHS_OFFSET);
+    })
+
+    // Calculate row and column sums
+    TILE(int, 1, N0, b_sum);
+    b_sum[0].v = 0;
+
+    TILE(int, 1, M0, a_sum);
+    a_sum[0].v = 0;
+
+    VEC_DATA_TYPE(DATA_TYPE, K0)
+    vec_1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(1, 1, 1, 1);
+
+    for(int k = 0; k < lhs_h; k += MMUL_K0)
+    {
+        TILE(DATA_TYPE, K0, M0, a);
+        TILE(DATA_TYPE, K0, N0, b);
+
+        // Load tile from the lhs/rhs tensors
+        T_LOAD(DATA_TYPE, K0, M0, BUFFER, lhs, 0, 0, 1, lhs_stride_y, a);
+        T_LOAD(DATA_TYPE, K0, N0, BUFFER, rhs, 0, 0, 1, rhs_stride_y, b);
+
+        LOOP_UNROLLING(int, m0, 0, 1, M0,
+        {
+            VEC_DATA_TYPE(DATA_TYPE, K0)
+            vec_a = (VEC_DATA_TYPE(DATA_TYPE, K0))(a[0].s[m0], a[1].s[m0], a[2].s[m0], a[3].s[m0]);
+
+            LOOP_UNROLLING(int, n0, 0, 1, N0,
+            {
+                VEC_DATA_TYPE(DATA_TYPE, K0)
+                vec_b = (VEC_DATA_TYPE(DATA_TYPE, K0))(b[0].s[n0], b[1].s[n0], b[2].s[n0], b[3].s[n0]);
+
+                c[m0].s[n0] = arm_matrix_multiply(vec_a, vec_b, c[m0].s[n0]);
+            })
+
+#if RHS_OFFSET != 0
+            // Row Sum of A: Calculate the sum of rows by multiplying A with
+            // a matrix of 1's from Right
+            a_sum[0].s[m0] = arm_matrix_multiply(vec_a, vec_1, a_sum[0].s[m0]);
+#endif // RHS_OFFSET != 0
+        })
+
+#if LHS_OFFSET != 0
+        // Column Sum of B: Calculate the sum of columns by multiplying B
+        // with a matrix of 1's from Left
+        LOOP_UNROLLING(int, n0, 0, 1, N0,
+        {
+            VEC_DATA_TYPE(DATA_TYPE, K0)
+            vec_b = (VEC_DATA_TYPE(DATA_TYPE, K0))(b[0].s[n0], b[1].s[n0], b[2].s[n0], b[3].s[n0]);
+
+            b_sum[0].s[n0] = arm_matrix_multiply(vec_1, vec_b, b_sum[0].s[n0]);
+        })
+#endif // LHS_OFFSET != 0
+
+        lhs_offset_first_element_in_bytes += MMUL_K0 * lhs_stride_y;
+        rhs_offset_first_element_in_bytes += MMUL_K0 * rhs_stride_y;
+    }
+
+    // Do not write if the coordinates are out of bound
+    // But, read has to happen as arm_matrix_multiply() expects certain number of calls
+    if(dst_x_unclamped >= N || dst_y_unclamped >= M)
+    {
+        return;
+    }
+
+#if RHS_OFFSET != 0 || LHS_OFFSET != 0
+    LOOP_UNROLLING(int, i, 0, 1, M0,
+    {
+        const int A = ((int)RHS_OFFSET) * a_sum[0].s[i];
+        LOOP_UNROLLING(int, j, 0, 1, N0,
+        {
+            c[i].s[j] -= A + ((int)(LHS_OFFSET)) * b_sum[0].s[j];
+        })
+    })
+#endif // RHS_OFFSET != 0 || LHS_OFFSET != 0
+
+#ifdef BIAS
+    perform_bias_addition(bias_ptr, bias_offset_first_element_in_bytes, c, dst_x);
+#endif // defined(BIAS)
+
+    // Quantize the tile
+    TILE(DATA_TYPE, M0, N0, cq);
+    T_QUANTIZE8_ASYMMETRIC(int, DATA_TYPE, M0, N0, DST_OFFSET, DST_SHIFT, DST_MULTIPLIER, c, cq);
+
+    if(dst_x + N0 <= N || N0_LEFTOVER == 0)
+    {
+        LOOP_UNROLLING(int, m0, 0, 1, M0,
+        {
+            if(dst_y + m0 < M || M0_LEFTOVER == 0)
+            {
+                VSTORE(N0)
+                (cq[m0].v, 0, (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + m0 * dst_stride_y));
+            }
+        })
+    }
+    else
+    {
+        LOOP_UNROLLING(int, m0, 0, 1, M0,
+        {
+            if(dst_y + m0 < M || M0_LEFTOVER == 0)
+            {
+                VSTORE_PARTIAL(N0, N0_LEFTOVER)
+                (cq[m0].v, 0, (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + m0 * dst_stride_y));
+            }
+        })
+    }
 }
 #endif // defined(MAT_MUL_NATIVE_QUANTIZED_MMUL_T_NT)
 
@@ -533,7 +678,9 @@
 /** This OpenCL kernel performs the batch matrix multiplication (BatchMatMul): LHS transposed, RHS transposed
  *
  * Supported block configurations:
- *     TODO: Report supported M0, N0, K0
+ *  - M0 = 1, 2, 3, 4, 8, 16
+ *  - N0 = 1, 2, 3, 4, 8, 16
+ *  - K0 = 4
  *
  * Similar to mat_mul_native_quantized_mmul_nt_nt()
  */
@@ -545,5 +692,141 @@
 #endif // defined(BIAS)
     TENSOR3D_T(dst, BUFFER))
 {
+    const uint x0 = get_global_id(0); // [0, (N / N0) * MMUL_M0)
+    // The upper limit is a simplified version of (N / N0) / MMUL_N0) * MMUL_BLOCK_SIZE)
+    const uint y0 = get_global_id(1); // [0, (M / M0) / MMUL_M0)
+    const uint z  = get_global_id(2); // Batch
+
+    // Get section coordinates
+    const uint section_x = (x0 / MMUL_BLOCK_SIZE);
+    const uint section_y = y0;
+
+    // Get thread coordinates within an mmul block
+    const uint thread_id = (x0 % MMUL_BLOCK_SIZE);
+    const uint thread_x  = thread_id % MMUL_N0;
+    const uint thread_y  = (thread_id / MMUL_N0);
+
+    // Calculate dst coordinates
+    const uint dst_x_unclamped = thread_x * N0 + section_x * N0 * MMUL_N0;
+    const uint dst_y_unclamped = thread_y * M0 + section_y * M0 * MMUL_M0;
+    const uint dst_x           = min(dst_x_unclamped, (uint)(N - N0));
+    const uint dst_y           = min(dst_y_unclamped, (uint)(M - M0));
+
+    // Starting LHS coordinates
+    const uint lhs_x = dst_y;
+    const uint lhs_y = K0 * thread_x;
+
+    // Starting RHS coordinates
+    const uint rhs_x = K0 * thread_y;
+    const uint rhs_y = dst_x;
+
+    // Compute LHS/RHS/DST matrix address
+    lhs_offset_first_element_in_bytes += lhs_x * sizeof(DATA_TYPE) + lhs_y * lhs_stride_y + z * lhs_stride_z;
+    rhs_offset_first_element_in_bytes += rhs_x * sizeof(DATA_TYPE) + rhs_y * rhs_stride_y + z * rhs_stride_z;
+    dst_offset_first_element_in_bytes += dst_x * sizeof(DATA_TYPE) + dst_y * dst_stride_y + z * dst_stride_z;
+
+    // Initialize the accumulators
+    TILE(int, M0, N0, c);
+    LOOP_UNROLLING(int, i, 0, 1, M0,
+    {
+        c[i].v = K * ((int)LHS_OFFSET) * ((int)RHS_OFFSET);
+    })
+
+    // Calculate row and column sums
+    TILE(int, 1, N0, b_sum);
+    b_sum[0].v = 0;
+
+    TILE(int, 1, M0, a_sum);
+    a_sum[0].v = 0;
+
+    VEC_DATA_TYPE(DATA_TYPE, K0)
+    vec_1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(1, 1, 1, 1);
+
+    for(int k = 0; k < lhs_h; k += MMUL_K0)
+    {
+        TILE(DATA_TYPE, K0, M0, a);
+        TILE(DATA_TYPE, N0, K0, b);
+
+        // Load tile from the lhs/rhs tensors
+        T_LOAD(DATA_TYPE, K0, M0, BUFFER, lhs, 0, 0, 1, lhs_stride_y, a);
+        T_LOAD(DATA_TYPE, N0, K0, BUFFER, rhs, 0, 0, 1, rhs_stride_y, b);
+
+        LOOP_UNROLLING(int, m0, 0, 1, M0,
+        {
+            VEC_DATA_TYPE(DATA_TYPE, K0)
+            vec_a = (VEC_DATA_TYPE(DATA_TYPE, K0))(a[0].s[m0], a[1].s[m0], a[2].s[m0], a[3].s[m0]);
+
+            LOOP_UNROLLING(int, n0, 0, 1, N0,
+            {
+                c[m0].s[n0] = arm_matrix_multiply(vec_a, b[n0].v, c[m0].s[n0]);
+            })
+#if RHS_OFFSET != 0
+            // Row Sum of A: Calculate the sum of rows by multiplying A with
+            // a matrix of 1's from Right
+            a_sum[0].s[m0] = arm_matrix_multiply(vec_a, vec_1, a_sum[0].s[m0]);
+#endif // RHS_OFFSET != 0
+        })
+
+#if LHS_OFFSET != 0
+        // Column Sum of B: Calculate the sum of columns by multiplying B
+        // with a matrix of 1's from Left
+        LOOP_UNROLLING(int, n0, 0, 1, N0,
+        {
+            b_sum[0].s[n0] = arm_matrix_multiply(vec_1, b[n0].v, b_sum[0].s[n0]);
+        })
+#endif // LHS_OFFSET != 0
+
+        lhs_offset_first_element_in_bytes += MMUL_K0 * lhs_stride_y;
+        rhs_offset_first_element_in_bytes += MMUL_K0 * sizeof(DATA_TYPE);
+    }
+
+    // Do not write if the coordinates are out of bound
+    // But, read has to happen as arm_matrix_multiply() expects certain number of calls
+    if(dst_x_unclamped >= N || dst_y_unclamped >= M)
+    {
+        return;
+    }
+
+#if RHS_OFFSET != 0 || LHS_OFFSET != 0
+    LOOP_UNROLLING(int, i, 0, 1, M0,
+    {
+        const int A = ((int)RHS_OFFSET) * a_sum[0].s[i];
+        LOOP_UNROLLING(int, j, 0, 1, N0,
+        {
+            c[i].s[j] -= A + ((int)(LHS_OFFSET)) * b_sum[0].s[j];
+        })
+    })
+#endif // RHS_OFFSET != 0 || LHS_OFFSET != 0
+
+#ifdef BIAS
+    perform_bias_addition(bias_ptr, bias_offset_first_element_in_bytes, c, dst_x);
+#endif // defined(BIAS)
+
+    // Quantize the tile
+    TILE(DATA_TYPE, M0, N0, cq);
+    T_QUANTIZE8_ASYMMETRIC(int, DATA_TYPE, M0, N0, DST_OFFSET, DST_SHIFT, DST_MULTIPLIER, c, cq);
+
+    if(dst_x + N0 <= N || N0_LEFTOVER == 0)
+    {
+        LOOP_UNROLLING(int, m0, 0, 1, M0,
+        {
+            if(dst_y + m0 < M || M0_LEFTOVER == 0)
+            {
+                VSTORE(N0)
+                (cq[m0].v, 0, (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + m0 * dst_stride_y));
+            }
+        })
+    }
+    else
+    {
+        LOOP_UNROLLING(int, m0, 0, 1, M0,
+        {
+            if(dst_y + m0 < M || M0_LEFTOVER == 0)
+            {
+                VSTORE_PARTIAL(N0, N0_LEFTOVER)
+                (cq[m0].v, 0, (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + m0 * dst_stride_y));
+            }
+        })
+    }
 }
 #endif // defined(MAT_MUL_NATIVE_QUANTIZED_MMUL_T_T)