Implement Quantized MatMul kernel using MMUL extension

Resolves: COMPMID-6475
Change-Id: Ic867cdfff5d4391cb749a04bf7cc35cda63d3b71
Signed-off-by: Gunes Bayir <gunes.bayir@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/10311
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/core/CL/cl_kernels/common/mat_mul_quantized_mmul.cl b/src/core/CL/cl_kernels/common/mat_mul_quantized_mmul.cl
index 56e278c..9123e5b 100644
--- a/src/core/CL/cl_kernels/common/mat_mul_quantized_mmul.cl
+++ b/src/core/CL/cl_kernels/common/mat_mul_quantized_mmul.cl
@@ -40,10 +40,33 @@
 }
 #endif // defined(BIAS)
 
+#define MMUL_BLOCK_SIZE (MMUL_M0 * MMUL_N0) // MMUL block size for the output matrix
+
 #if defined(MAT_MUL_NATIVE_QUANTIZED_MMUL_NT_NT)
 /** This OpenCL kernel performs the batch matrix multiplication (BatchMatMul): LHS non-transposed, RHS non-transposed - buffer only
  *
- * TODO: report build configuration
+ * @note the "batch" here expresses the number of matrix multiplications to run in parallel. However, it
+ *       should NOT be confused with the batch size of the model. For NHWC the "batch" is the "H" dimension
+ * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=uchar)
+ * @note The block's dimensions used for the LHS and RHS matrices (M0, N0 and K0) must be passed at
+ *       compile time using -DN0, -DM0 and -DK0 (e.g. -DN0=8, -DM0=4, -DK0=4).
+ * @note The number of leftover outputs rows/columns must be passed using -DN0_LEFTOVER and -DM0_LEFTOVER
+ *       (e.g. -DN0_LEFTOVER=2, -DM0_LEFTOVER=3)
+ * @note The dimensions M, N, K must be passed at compile time using -DK (e.g. -DM=5, -DN=8, -DK=6).
+ *       K must be a multiple of 16.
+ * @note MMUL block sizes must be passed at compile time using -DMMUL_K0, -DMMUL_M0, -DMMUL_N0
+ *       (e.g. -DMMUL_K0=16, -DMMUL_M0=4, -DMMUL_N0=4)
+ * @note If there is bias -DBIAS option must be passed at compile time
+ * @note Quantization offsets of lhs, rhs and dst tensors must be passed at compile time using -DLHS_OFFSET,
+ *       -DRHS_OFFSET, -DDST_OFFSET (e.g. -DLHS_OFFSET=10, -DRHS_OFFSET=0, -DDST_OFFSET=-6)
+ * @note Effective quantization multiplier and shift for the destination tensor must be passed at compile time using
+ *       -DDST_MULTIPLIER and -DDST_SHIFT (e.g. -DDST_MULTIPLIER=2091, -DST_SHIFT=8)
+ * @note The kernel name in uppercase must be passed at compile time (e.g. -DMAT_MUL_NATIVE_QUANTIZED_MMUL_NT_NT)
+ * @note Only the following configurations of M0, N0 and K0 are currently supported:
+ *  - M0 > 0
+ *  - N0 = 1, 2, 3, 4, 8, 16
+ *  - K0 = 4
+ * @note For a generic view on how the MMUL works, see mat_mul_mmul.cl
  *
  * @param[in]  lhs_ptr                            Pointer to the lhs matrix. Supported data types: QASYMM8_SIGNED/QASYMM8
  * @param[in]  lhs_stride_y                       Stride of the lhs matrix in Y (2nd) dimension (in bytes)
@@ -59,7 +82,7 @@
  * @param[in]  rhs_h                              The height of the rhs tensor
  * @param[in]  rhs_n                              Number of the matrices (buffers) in the batch
  * @param[in]  rhs_offset_first_element_in_bytes  The offset of the first element in the rhs matrix
- * @param[in]  bias_ptr                           (Optional) Pointer to the bias tensor. Supported data type: same as @p lhs_ptr
+ * @param[in]  bias_ptr                           (Optional) Pointer to the bias tensor. Supported data type: S32
  * @param[in]  bias_stride_y                      (Optional) Stride of the bias tensor in Y dimension (in bytes)
  * @param[in]  bias_stride_z                      (Optional) Stride of the bias tensor in Z dimension (in bytes)
  * @param[in]  bias_w                             (Optional) The size of the width dimension of the bias tensor
@@ -82,6 +105,252 @@
 #endif // defined(BIAS)
     TENSOR3D_T(dst, BUFFER))
 {
+    // The explanation of how this kernel works is very similar to the explanation given in
+    // mat_mul_mmul.cl. The MMUL logic, and terminology is the same. The only difference is
+    // in quantization multiplication, the MMUL block sizes are (4 x 16) for Lhs matrix and
+    // (16 x 4) for Rhs matrix, resulting in (4 x 4) MMUL block size for the destination.
+    //
+    // Figures 1, 2 and 3 in the previous explanation works the same. Since the Lhs and Rhs
+    // MMUL block sizes are different in quantized extension, the thread access pattern is
+    // slightly different. We can redraw Figure 4 (Thread access pattern) as follows:
+    //
+    //                                 (Modified Figure 4 from mat_mul_mmul.cl)
+    //                               Thread Access Layouts in LHS & RHS matrices
+    //
+    //                                    LHS matrix
+    //                 4 times      4 times          4 times        4 times
+    //           _______________________________________________________________
+    //          |T0_|T0_|T0_|T0_|T1_|T1_|T1_|T1_|T2_|T2_|T2_|T2_|T3_|T3_|T3_|T3_|
+    //          |T0_| ...                                                       |
+    //    M0    |   .    .                                                      |
+    //   Times  |   .       .                                                   |
+    //          |   .           .                                               |
+    //          |T0_|T0_|T0_|T0_|T1_|T1_|T1_|T1_|T2_|T2_|T2_|T2_|T3_|T3_|T3_|T3_|
+    //          |T4_|T4_|T4_|T4_|T5_|T5_|T5_|T5_|T6_|T6_|T6_|T6_|T7_|T7_|T7_|T7_|
+    //          |T4_|T4_|T4_|T4_|T5_|T5_|T5_|T5_|T6_|T6_|T6_|T6_|T7_|T7_|T7_|T7_|
+    //    M0    |   .    .                                                      |
+    //   Times  |   .       .                                                   |
+    //          |   .           .                                               |
+    //          |T4_|T4_|T4_|T4_|T5_|T5_|T5_|T5_|T6_|T6_|T6_|T6_|T7_|T7_|T7_|T7_|
+    //          |T8_|T8_|T8_|T8_|T9_|T9_|T9_|T9_|T10|T10|T10|T10|T11|T11|T11|T11|
+    //    M0    |   .                                                           |
+    //   Times  |   .                                                           |
+    //          |   .                                                           |
+    //          |T8_|T8_|T8_|T8_|T9_|T9_|T9_|T9_|T10|T10|T10|T10|T11|T11|T11|T11|
+    //    M0    |   .                                                           |
+    //   Times  |   .                                                           |
+    //          |   .                                                           |
+    //          |T12|T12|T12|T12|T13|T13|T13|T13|T14|T14|T14|T14|T15|T15|T15|T15|
+    //
+    //
+    //                                                  RHS Matrix
+    //
+    //                   __________N0 times______N0 times____________________N0 times_______
+    //                  |__T0__| ... |__T0__|__T1__| ...  |__T1__| ... |__T3__| ... |__T3__|
+    //       4 times    |__T0__| ... |__T0__|__T1__| ...  |__T1__| ... |__T3__| ... |__T3__|
+    //                  |__T0__| ... |__T0__|__T1__| ...  |__T1__| ... |__T3__| ... |__T3__|
+    //                  |__T0__| ... |__T0__|__T1__| ...  |__T1__| ... |__T3__| ... |__T3__|
+    //                  |__T4__| ... |__T4__|__T5__| ...  |__T5__| ... |__T7__| ... |__T7__|
+    //       4 times    |__T4__| ... |__T4__|__T5__| ...  |__T5__| ... |__T7__| ... |__T7__|
+    //                  |__T4__| ... |__T4__|__T5__| ...  |__T5__| ... |__T7__| ... |__T7__|
+    //           X      |__T4__| ... |__T4__|__T5__| ...  |__T5__| ... |__T7__| ... |__T7__|
+    //                  |__T8__| ... |__T8__|__T9__| ...  |__T9__| ... |__T11_| ... |__T11_|
+    //                  |__T8__| ... |__T8__|__T9__| ...  |__T9__| ... |__T11_| ... |__T11_|
+    //       4 times    |__T8__| ... |__T8__|__T9__| ...  |__T9__| ... |__T11_| ... |__T11_|
+    //                  |__T8__| ... |__T8__|__T9__| ...  |__T9__| ... |__T11_| ... |__T11_|
+    //                  |__T12_| ... |__T12_|__T13_| ...  |__T13_| ... |__T15_| ... |__T15_|
+    //       4 times    |__T12_| ... |__T12_|__T13_| ...  |__T13_| ... |__T15_| ... |__T15_|
+    //                  |__T12_| ... |__T12_|__T13_| ...  |__T13_| ... |__T15_| ... |__T15_|
+    //                  |__T12_|_____|__T12_|__T13_|______|__T13_|_____|__T15_|_____|__T15_|
+    //
+    //
+    // The logic behind this thread access pattern is already descried in the explanation
+    // in mat_mul_mmul.cl. The only change is threads accesses are extended to 4 elements
+    // from 1, in rightward direction in Lhs, and in downward direction in Rhs, because they
+    // are now operating on 4 char/uchar's (again 32-bit data), instead of one 32-bit floating point.
+    //
+    // The mathematical view of the matrix multiplication explained in Figure 5 also holds for this,
+    // except the dimension 4 is 16 instead, but the vector notations do not change, i.e. it's as follows:
+    //
+    //   Settings:
+    //          - a 8 x 16 LHS section
+    //          - 16 x 8 RHS section
+    //          - Each vector variable ai, bj represent a 16x1 vector
+    //          - ^T (superscript T) denotes transpose
+    //          - M0 = N0 = 2
+    //          - MMUL_N0 = MMUL_M0 = 4, MMUL_K0 = 16
+    //
+    //
+    //                                             (Modified Figure 5)
+    //                              Mathematical view of the Matrix Multiplication
+    //
+    //      LHS                           RHS                                           DST
+    //    [  a1^T  ]            [ b1 b2 b3 b4 b5 b6 b7 ]                [ a1^Tb1  a1^Tb2  a1^Tb3 ... a1^Tb7 ]
+    //    [  a2^T  ]                                    16 x 8          [ a2^Tb1  a2^Tb2  a2^Tb3 ... a2^Tb7 ]
+    //    [  a3^T  ]                                                    [                                   ]
+    //    [  a4^T  ]                                                =   [   .       .                       ]
+    //    [  a5^T  ]        X                                           [   .          .                    ]
+    //    [  a6^T  ]                                                    [   .             .                 ]
+    //    [  a7^T  ]                                                    [                                   ]
+    //    [  a8^T  ]                                                    [ a7^Tb1  a7^Tb2  a7^Tb3 ... a7^Tb7 ]
+    //              8 x 16                                                                                     8 x 8
+    //
+    //
+    //  For the first iteration, i.e. (m0, n0) = (0, 0), the arm_matrix_multiply would multiply the following matrices:
+    //
+    //    [  a1^T  ]            [  b1 b3 b5 b7 ]                [ a1^Tb1  a1^Tb3  a1^Tb5  a1^Tb7 ]
+    //    [  a3^T  ]        x                   4 x 4     =     [ a3^Tb1  a1^Tb3  a1^Tb5  a1^Tb7 ]
+    //    [  a5^T  ]                                            [ a5^Tb1  a1^Tb3  a1^Tb5  a1^Tb7 ]
+    //    [  a7^T  ]                                            [ a7^Tb1  a7^Tb3  a7^Tb5  a7^Tb7 ]
+    //              4 x 4                                                                         4 x 4
+    // The elements calculated in the 4x4 output block are the "interleaved" elements in the DST above.
+    // When we follow for each combination of (m0, n0), every element of the DST matrix "section" is filled.
+    //
+    // Please refer to mat_mul_mmul.cl for more details.
+
+    const uint x0 = get_global_id(0); // [0, (N / N0) * MMUL_M0)
+    // The upper limit is a simplified version of (N / N0) / MMUL_N0) * MMUL_BLOCK_SIZE)
+    const uint y0 = get_global_id(1); // [0, (M / M0) / MMUL_M0)
+    const uint z  = get_global_id(2); // Batch
+
+    // Get section coordinates
+    const uint section_x = (x0 / MMUL_BLOCK_SIZE);
+    const uint section_y = y0;
+
+    // Get thread coordinates within an mmul block
+    const uint thread_id = (x0 % MMUL_BLOCK_SIZE);
+    const uint thread_x  = thread_id % MMUL_N0;
+    const uint thread_y  = (thread_id / MMUL_N0);
+
+    // Calculate dst coordinates
+    const uint dst_x_unclamped = thread_x * N0 + section_x * N0 * MMUL_N0;
+    const uint dst_y_unclamped = thread_y * M0 + section_y * M0 * MMUL_M0;
+    const uint dst_x           = min(dst_x_unclamped, (uint)(N - N0));
+    const uint dst_y           = min(dst_y_unclamped, (uint)(M - M0));
+
+    // Starting LHS coordinates
+    const uint lhs_x = K0 * thread_x;
+    const uint lhs_y = dst_y;
+
+    // Starting RHS coordinates
+    const uint rhs_x = dst_x;
+    const uint rhs_y = K0 * thread_y;
+
+    // Compute LHS/RHS/DST matrix address
+    lhs_offset_first_element_in_bytes += lhs_x * sizeof(DATA_TYPE) + lhs_y * lhs_stride_y + z * lhs_stride_z;
+    rhs_offset_first_element_in_bytes += rhs_x * sizeof(DATA_TYPE) + rhs_y * rhs_stride_y + z * rhs_stride_z;
+    dst_offset_first_element_in_bytes += dst_x * sizeof(DATA_TYPE) + dst_y * dst_stride_y + z * dst_stride_z;
+
+    // Initialize the accumulators
+    TILE(int, M0, N0, c);
+    LOOP_UNROLLING(int, i, 0, 1, M0,
+    {
+        c[i].v = K * ((int)LHS_OFFSET) * ((int)RHS_OFFSET);
+    })
+
+    // Calculate row and column sums
+    TILE(int, 1, N0, b_sum);
+    b_sum[0].v = 0;
+
+    TILE(int, 1, M0, a_sum);
+    a_sum[0].v = 0;
+
+    VEC_DATA_TYPE(DATA_TYPE, K0)
+    vec_1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(1, 1, 1, 1);
+
+    for(int k = 0; k < lhs_w; k += MMUL_K0)
+    {
+        // A tile of M0xK0 but K0 must be set to K0
+        TILE(DATA_TYPE, M0, K0, a);
+        // A tile of K0xN0 but K0 must be set to K0
+        TILE(DATA_TYPE, K0, N0, b);
+
+        // Load tile from the lhs/rhs tensors
+        T_LOAD(DATA_TYPE, M0, K0, BUFFER, lhs, 0, 0, 1, lhs_stride_y, a);
+        T_LOAD(DATA_TYPE, K0, N0, BUFFER, rhs, 0, 0, 1, rhs_stride_y, b);
+
+        LOOP_UNROLLING(int, m0, 0, 1, M0,
+        {
+            LOOP_UNROLLING(int, n0, 0, 1, N0,
+            {
+                VEC_DATA_TYPE(DATA_TYPE, K0)
+                vec_b       = (VEC_DATA_TYPE(DATA_TYPE, K0))(b[0].s[n0], b[1].s[n0], b[2].s[n0], b[3].s[n0]);
+                c[m0].s[n0] = arm_matrix_multiply(a[m0].v, vec_b, c[m0].s[n0]);
+            })
+        })
+
+#if RHS_OFFSET != 0
+        // Row Sum of A: Calculate the sum of rows by multiplying A with
+        // a matrix of 1's from Right
+        LOOP_UNROLLING(int, m0, 0, 1, M0,
+        {
+            a_sum[0].s[m0] = arm_matrix_multiply(a[m0].v, vec_1, a_sum[0].s[m0]);
+        })
+#endif // RHS_OFFSET != 0
+
+#if LHS_OFFSET != 0
+        // Column Sum of B: Calculate the sum of columns by multiplying B
+        // with a matrix of 1's from Left
+        LOOP_UNROLLING(int, n0, 0, 1, N0,
+        {
+            VEC_DATA_TYPE(DATA_TYPE, K0)
+            vec_b          = (VEC_DATA_TYPE(DATA_TYPE, K0))(b[0].s[n0], b[1].s[n0], b[2].s[n0], b[3].s[n0]);
+            b_sum[0].s[n0] = arm_matrix_multiply(vec_1, vec_b, b_sum[0].s[n0]);
+        })
+#endif // LHS_OFFSET != 0
+
+        lhs_offset_first_element_in_bytes += MMUL_K0 * sizeof(DATA_TYPE);
+        rhs_offset_first_element_in_bytes += MMUL_K0 * rhs_stride_y;
+    }
+
+    // Do not write if the coordinates are out of bound
+    // But, read has to happen as arm_matrix_multiply() expects certain number of calls
+    if(dst_x_unclamped >= N || dst_y_unclamped >= M)
+    {
+        return;
+    }
+
+#if RHS_OFFSET != 0 || LHS_OFFSET != 0
+    LOOP_UNROLLING(int, i, 0, 1, M0,
+    {
+        const int A = ((int)RHS_OFFSET) * a_sum[0].s[i];
+        LOOP_UNROLLING(int, j, 0, 1, N0,
+        {
+            c[i].s[j] -= A + ((int)(LHS_OFFSET)) * b_sum[0].s[j];
+        })
+    })
+#endif // RHS_OFFSET != 0 || LHS_OFFSET != 0
+
+#ifdef BIAS
+    perform_bias_addition(bias_ptr, bias_offset_first_element_in_bytes, c, dst_x);
+#endif // defined(BIAS)
+
+    // Quantize the tile
+    TILE(DATA_TYPE, M0, N0, cq);
+    T_QUANTIZE8_ASYMMETRIC(int, DATA_TYPE, M0, N0, DST_OFFSET, DST_SHIFT, DST_MULTIPLIER, c, cq);
+
+    if(dst_x + N0 <= N || N0_LEFTOVER == 0)
+    {
+        LOOP_UNROLLING(int, m0, 0, 1, M0,
+        {
+            if(dst_y + m0 < M || M0_LEFTOVER == 0)
+            {
+                VSTORE(N0)
+                (cq[m0].v, 0, (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + m0 * dst_stride_y));
+            }
+        })
+    }
+    else
+    {
+        LOOP_UNROLLING(int, m0, 0, 1, M0,
+        {
+            if(dst_y + m0 < M || M0_LEFTOVER == 0)
+            {
+                VSTORE_PARTIAL(N0, N0_LEFTOVER)
+                (cq[m0].v, 0, (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + m0 * dst_stride_y));
+            }
+        })
+    }
 }
 #endif // defined(MAT_MUL_NATIVE_QUANTIZED_MMUL_NT_NT)
 
diff --git a/src/gpu/cl/kernels/ClMatMulLowpNativeMMULKernel.cpp b/src/gpu/cl/kernels/ClMatMulLowpNativeMMULKernel.cpp
index 4a6a3f3..464212d 100644
--- a/src/gpu/cl/kernels/ClMatMulLowpNativeMMULKernel.cpp
+++ b/src/gpu/cl/kernels/ClMatMulLowpNativeMMULKernel.cpp
@@ -29,15 +29,17 @@
 #include "arm_compute/core/ITensorPack.h"
 #include "arm_compute/core/TensorInfo.h"
 #include "arm_compute/core/utils/StringUtils.h"
+#include "arm_compute/core/utils/helpers/AdjustVecSize.h"
 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
+#include "arm_compute/core/utils/quantization/AsymmHelpers.h"
 
 #include "src/common/utils/Log.h"
 #include "src/core/helpers/AutoConfiguration.h"
-#include "src/core/helpers/WindowHelpers.h"
 #include "src/gpu/cl/ClCompileContext.h"
 #include "src/gpu/cl/kernels/helpers/MatMulKernelHelpers.h"
 #include "support/Cast.h"
 #include "support/StringSupport.h"
+#include "utils/TypePrinter.h"
 
 namespace arm_compute
 {
@@ -54,8 +56,25 @@
 
 Status validate_matmul_kernel_info(const MatMulKernelInfo &matmul_kernel_info)
 {
-    ARM_COMPUTE_UNUSED(matmul_kernel_info);
-    // TODO: Validate MatMulKernelInfo
+    const bool adj_lhs = matmul_kernel_info.adj_lhs;
+    const int  m0      = matmul_kernel_info.m0;
+    const int  n0      = matmul_kernel_info.n0;
+    const int  k0      = matmul_kernel_info.k0;
+
+    // Validate M0
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(m0 < 1, "Only positive integers are supported for M0");
+
+    if(adj_lhs)
+    {
+        ARM_COMPUTE_RETURN_ERROR_ON_MSG((m0 != 1) && (m0 != 2) && (m0 != 3) && (m0 != 4) && (m0 != 8) && (m0 != 16), "Only 1,2,3,4,8,16 are supported for M0 for Lhs transposed");
+    }
+
+    // Validate N0
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG((n0 != 1) && (n0 != 2) && (n0 != 3) && (n0 != 4) && (n0 != 8) && (n0 != 16), "Only 1,2,3,4,8,16 are supported for N0");
+
+    // Validate K0
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG((k0 != 4), "Only 4 is supported for k0");
+
     return Status{};
 }
 } // namespace
@@ -69,17 +88,21 @@
                                               const ActivationLayerInfo &act_info)
 {
     ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lhs, rhs, dst);
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(!arm_matrix_multiply_supported(CLKernelLibrary::get().get_device()),
+                                    "The extension cl_arm_matrix_multiply is not supported on the target platform");
     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lhs, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED);
     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, rhs);
     ARM_COMPUTE_RETURN_ON_ERROR(validate_matmul_kernel_info(matmul_kernel_info));
-    ARM_COMPUTE_RETURN_ON_ERROR(validate_matmul_input_shapes(lhs->tensor_shape(), rhs->tensor_shape(), matmul_kernel_info));
-    // TODO: Check MMUL block sizes against the tensor shapes
-    ARM_COMPUTE_UNUSED(mmul_k0);
 
-    ARM_COMPUTE_RETURN_ERROR_ON_MSG((act_info.activation() != ActivationFunction::IDENTITY && act_info.activation() != ActivationFunction::RELU
-                                     && act_info.activation() != ActivationFunction::LU_BOUNDED_RELU && act_info.activation() != ActivationFunction::BOUNDED_RELU),
+    const TensorShape &lhs_shape = lhs->tensor_shape();
+    ARM_COMPUTE_RETURN_ON_ERROR(validate_matmul_input_shapes(lhs_shape, rhs->tensor_shape(), matmul_kernel_info));
+
+    const size_t lhs_k = matmul_kernel_info.adj_lhs ? lhs_shape.y() : lhs_shape.x();
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG_VAR((lhs_k % mmul_k0) != 0, "K dimension must be a multiple of %d", mmul_k0);
+
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG((act_info.activation() != ActivationFunction::IDENTITY),
                                     "Activation Function specified is unsupported.");
-    const TensorShape expected_output_shape = misc::shape_calculator::compute_matmul_shape(lhs->tensor_shape(), rhs->tensor_shape(), matmul_kernel_info);
+    const TensorShape expected_output_shape = misc::shape_calculator::compute_matmul_shape(lhs_shape, rhs->tensor_shape(), matmul_kernel_info);
 
     if(dst->total_size() != 0)
     {
@@ -99,8 +122,7 @@
 }
 
 void ClMatMulLowpNativeMMULKernel::configure(const ClCompileContext &compile_context, ITensorInfo *lhs, ITensorInfo *rhs, ITensorInfo *bias, ITensorInfo *dst,
-                                             const MatMulKernelInfo    &matmul_kernel_info,
-                                             const ActivationLayerInfo &act_info)
+                                             const MatMulKernelInfo &matmul_kernel_info, const ActivationLayerInfo &act_info)
 {
     ARM_COMPUTE_ERROR_ON_NULLPTR(lhs, rhs, dst);
     ARM_COMPUTE_LOG_PARAMS(lhs, rhs, bias, dst, matmul_kernel_info, act_info);
@@ -110,14 +132,55 @@
     auto_init_if_empty(*dst, lhs->clone()->set_tensor_shape(misc::shape_calculator::compute_matmul_shape(lhs->tensor_shape(), rhs->tensor_shape(), matmul_kernel_info)));
 
     ARM_COMPUTE_UNUSED(compile_context, lhs, rhs, bias, matmul_kernel_info, act_info);
+    CLBuildOptions build_opts;
+
+    const int m = dst->dimension(1);
+    const int n = dst->dimension(0);
+    const int k = matmul_kernel_info.adj_lhs ? lhs->tensor_shape().y() : lhs->tensor_shape().x();
+
+    const int m0 = std::min(matmul_kernel_info.m0, m);
+    const int n0 = adjust_vec_size(matmul_kernel_info.n0, n);
+
+    // Calculate partial (store instead of load) M0 and partial N0 for the partial blocks
+    // at the end of a row/column if any. This is to avoid padding.
+    const unsigned int m0_leftover = m % m0;
+    const unsigned int n0_leftover = n % n0;
 
     // Configure kernel window
     const auto win_config = validate_and_configure_window_for_mmul_kernels(lhs, rhs, dst, matmul_kernel_info, mmul_m0, mmul_n0);
     ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
     IClKernel::configure_internal(win_config.second);
 
-    CLBuildOptions build_opts;
-    // TODO: Build options & configuration
+    build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(lhs->data_type()));
+    build_opts.add_option("-DM=" + support::cpp11::to_string(m));
+    build_opts.add_option("-DN=" + support::cpp11::to_string(n));
+    build_opts.add_option("-DK=" + support::cpp11::to_string(k));
+    build_opts.add_option("-DM0=" + support::cpp11::to_string(m0));
+    build_opts.add_option("-DN0=" + support::cpp11::to_string(n0));
+    build_opts.add_option("-DK0=" + support::cpp11::to_string(matmul_kernel_info.k0));
+    build_opts.add_option("-DM0_LEFTOVER=" + support::cpp11::to_string(m0_leftover));
+    build_opts.add_option("-DN0_LEFTOVER=" + support::cpp11::to_string(n0_leftover));
+    build_opts.add_option("-DMMUL_M0=" + support::cpp11::to_string(mmul_m0));
+    build_opts.add_option("-DMMUL_N0=" + support::cpp11::to_string(mmul_n0));
+    build_opts.add_option("-DMMUL_K0=" + support::cpp11::to_string(mmul_k0));
+    build_opts.add_option_if(bias != nullptr, "-DBIAS");
+
+    const UniformQuantizationInfo lqinfo = lhs->quantization_info().uniform();
+    const UniformQuantizationInfo rqinfo = rhs->quantization_info().uniform();
+    const UniformQuantizationInfo dqinfo = dst->quantization_info().uniform();
+
+    float multiplier        = lqinfo.scale * rqinfo.scale / dqinfo.scale;
+    int   output_multiplier = 0;
+    int   output_shift      = 0;
+    arm_compute::quantization::calculate_quantized_multiplier(multiplier, &output_multiplier, &output_shift);
+
+    build_opts.add_option("-DDST_MULTIPLIER=" + support::cpp11::to_string(output_multiplier));
+    build_opts.add_option("-DDST_SHIFT=" + support::cpp11::to_string(output_shift));
+
+    // Note : Offset is not negated, unlike gemmlowp kernels
+    build_opts.add_option("-DLHS_OFFSET=" + support::cpp11::to_string(lqinfo.offset));
+    build_opts.add_option("-DRHS_OFFSET=" + support::cpp11::to_string(rqinfo.offset));
+    build_opts.add_option("-DDST_OFFSET=" + support::cpp11::to_string(dqinfo.offset));
 
     std::string kernel_name("mat_mul_native_quantized_mmul");
     kernel_name += matmul_kernel_info.adj_lhs ? "_t" : "_nt";
@@ -129,7 +192,22 @@
     // Create kernel
     _kernel = create_kernel(compile_context, kernel_name, build_opts.options());
 
-    // TODO: Tuner configuration
+    // Set config_id for enabling LWS tuning
+    _config_id = kernel_name;
+    _config_id += "_";
+    _config_id += lower_string(string_from_data_type(lhs->data_type()));
+    _config_id += "_";
+    _config_id += support::cpp11::to_string(m);
+    _config_id += "_";
+    _config_id += support::cpp11::to_string(n);
+    _config_id += "_";
+    _config_id += support::cpp11::to_string(k);
+    _config_id += "_";
+    _config_id += support::cpp11::to_string(dst->dimension(2));
+    _config_id += "_";
+    _config_id += support::cpp11::to_string(m0);
+    _config_id += "_";
+    _config_id += support::cpp11::to_string(n0);
 }
 
 void ClMatMulLowpNativeMMULKernel::run_op(ITensorPack &tensors, const Window &window, cl::CommandQueue &queue)
diff --git a/src/gpu/cl/kernels/ClMatMulNativeMMULKernel.cpp b/src/gpu/cl/kernels/ClMatMulNativeMMULKernel.cpp
index 2420ad6..432270e 100644
--- a/src/gpu/cl/kernels/ClMatMulNativeMMULKernel.cpp
+++ b/src/gpu/cl/kernels/ClMatMulNativeMMULKernel.cpp
@@ -34,7 +34,6 @@
 
 #include "src/common/utils/Log.h"
 #include "src/core/helpers/AutoConfiguration.h"
-#include "src/core/helpers/WindowHelpers.h"
 #include "src/gpu/cl/kernels/helpers/MatMulKernelHelpers.h"
 
 #include "support/Cast.h"
@@ -136,9 +135,7 @@
     const int n0 = adjust_vec_size(matmul_kernel_info.n0, n);
 
     // Configure kernel window
-    const auto win_config = validate_and_configure_window_for_mmul_kernels(lhs, rhs, dst, matmul_kernel_info, mmul_m0,
-                                                                           mmul_n0);
-
+    const auto win_config = validate_and_configure_window_for_mmul_kernels(lhs, rhs, dst, matmul_kernel_info, mmul_m0, mmul_n0);
     ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
     IClKernel::configure_internal(win_config.second);
 
diff --git a/tests/datasets/MatMulLowpMMULDataset.h b/tests/datasets/MatMulLowpMMULDataset.h
new file mode 100644
index 0000000..1b22e10
--- /dev/null
+++ b/tests/datasets/MatMulLowpMMULDataset.h
@@ -0,0 +1,97 @@
+/*
+ * Copyright (c) 2023 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifndef ACL_TESTS_DATASETS_MATMULLOWPMMULDATASET_H
+#define ACL_TESTS_DATASETS_MATMULLOWPMMULDATASET_H
+
+#include "arm_compute/core/TensorShape.h"
+#include "arm_compute/core/Types.h"
+#include "tests/datasets/MatMulDataset.h"
+
+namespace arm_compute
+{
+namespace test
+{
+namespace datasets
+{
+/** MatMulLowp MMUL shapes are similar to MatMul MMUL shapes except that K has to be a
+ * multiple of MMUL_K0 which is 16 (e.g. see src/gpu/cl/kernels/ClMatMulLowpNativeMMULKernel.cpp for the definition)
+ */
+class SmallMatMulLowpMMULDataset final : public MatMulDataset
+{
+public:
+    SmallMatMulLowpMMULDataset()
+    {
+        add_config(TensorShape(16U, 4U), TensorShape(4U, 16U), TensorShape(4U, 4U)); // same as mmul block
+        add_config(TensorShape(96U, 1U), TensorShape(1U, 96U), TensorShape(1U, 1U)); // vector x vector
+        add_config(TensorShape(32U, 4U, 2U), TensorShape(16U, 32U, 2U), TensorShape(16U, 4U, 2U));
+        add_config(TensorShape(48U, 2U), TensorShape(17U, 48U), TensorShape(17U, 2U));
+        add_config(TensorShape(32U, 6U), TensorShape(7U, 32U), TensorShape(7U, 6U));
+    }
+};
+
+// This dataset is for smaller number of tests that will still use small shapes
+// e.g. not repeating everything for QASYMM8 while we're already testing for QASYMM8_SIGNED
+class SmallMatMulLowpMMULDatasetSubset final : public MatMulDataset
+{
+public:
+    SmallMatMulLowpMMULDatasetSubset()
+    {
+        add_config(TensorShape(32U, 4U, 2U), TensorShape(16U, 32U, 2U), TensorShape(16U, 4U, 2U));
+        add_config(TensorShape(32U, 6U), TensorShape(7U, 32U), TensorShape(7U, 6U));
+    }
+};
+
+class SmallMatMulLowpMMULWithBiasDataset final : public MatMulDataset
+{
+public:
+    SmallMatMulLowpMMULWithBiasDataset()
+    {
+        add_config(TensorShape(32U, 4U, 2U, 2U), TensorShape(16U, 32U, 2U, 2U), TensorShape(16U, 4U, 2U, 2U));
+    }
+};
+
+class LargeMatMulLowpMMULDataset final : public MatMulDataset
+{
+public:
+    LargeMatMulLowpMMULDataset()
+    {
+        add_config(TensorShape(192U, 38U, 3U, 2U), TensorShape(21U, 192U, 3U, 2U), TensorShape(21U, 38U, 3U, 2U));
+    }
+};
+
+class HighDimensionalMatMulLowpMMULDataset final : public MatMulDataset
+{
+public:
+    HighDimensionalMatMulLowpMMULDataset()
+    {
+        add_config(TensorShape(16U, 5U, 2U, 2U, 2U, 2U), TensorShape(5U, 16U, 2U, 2U, 2U, 2U), TensorShape(5U, 5U, 2U, 2U, 2U, 2U)); // 6D tensor
+    }
+};
+
+} // namespace datasets
+} // namespace test
+} // namespace arm_compute
+
+#endif // ACL_TESTS_DATASETS_MATMULLOWPMMULDATASET_H
diff --git a/tests/validation/CL/MatMulLowpNativeMMULKernel.cpp b/tests/validation/CL/MatMulLowpNativeMMULKernel.cpp
index 10d893e..a361a5a 100644
--- a/tests/validation/CL/MatMulLowpNativeMMULKernel.cpp
+++ b/tests/validation/CL/MatMulLowpNativeMMULKernel.cpp
@@ -26,8 +26,7 @@
 
 #include "src/gpu/cl/kernels/ClMatMulLowpNativeMMULKernel.h"
 
-#include "tests/datasets/LargeMatMulDataset.h"
-#include "tests/datasets/SmallMatMulDataset.h"
+#include "tests/datasets/MatMulLowpMMULDataset.h"
 #include "tests/framework/Macros.h"
 #include "tests/framework/datasets/Datasets.h"
 #include "tests/validation/Validation.h"
@@ -44,14 +43,27 @@
 {
 namespace
 {
-// TODO: enable
-// constexpr AbsoluteTolerance<float> tolerance_quant(1); /**< Tolerance value for comparing reference's output against implementation's output for quantized data types */
+constexpr AbsoluteTolerance<float> tolerance_quant(1); /**< Tolerance value for comparing reference's output against implementation's output for quantized data types */
 }
-template <typename T>
-using CLMatMulLowpNativeMMULKernelFixture = MatMulKernelValidationFixture<T, ClMatMulLowpNativeMMULKernel>;
+using framework::dataset::make;
 
 template <typename T>
-using CLMatMulLowpKernelWithBiasFixture = MatMulKernelWithBiasValidation<T, ClMatMulLowpNativeMMULKernel>;
+using CLMatMulLowpNativeMMULKernelFixture = MatMulKernelValidationFixture<T, ClMatMulLowpNativeMMULKernel, true /* use_mmul */>;
+
+template <typename T>
+using CLMatMulLowpNativeMMULKernelWithBiasFixture = MatMulKernelWithBiasValidation<T, ClMatMulLowpNativeMMULKernel, true /* use_mmul */>;
+
+/** M0 values to test --precommit*/
+const auto m0_values_precommit = framework::dataset::make("M0", { 1, 3 });
+
+/** N0 values to test --precommit*/
+const auto n0_values_precommit = framework::dataset::make("N0", { 2, 4 });
+
+/** M0 values to test --nightly*/
+const auto m0_values_nightly_lhs_nt = framework::dataset::make("M0", { 2, 4, 5, 8 });
+
+/** N0 values to test --nightly*/
+const auto n0_values_nightly_rhs_nt = framework::dataset::make("N0", { 1, 3, 8, 16 });
 
 TEST_SUITE(CL)
 TEST_SUITE(MatMulLowpNativeMMULKernel)
@@ -77,9 +89,10 @@
     for(auto &pair : supported_block_sizes)
     {
         TensorInfo output_info;
-        Status     status = ClMatMulLowpNativeMMULKernel::validate(&lhs_info, &rhs_info, nullptr, &output_info, pair.first);
+        Status     status   = ClMatMulLowpNativeMMULKernel::validate(&lhs_info, &rhs_info, nullptr, &output_info, pair.first);
+        const bool expected = (pair.second && arm_matrix_multiply_supported(CLKernelLibrary::get().get_device()));
 
-        ARM_COMPUTE_EXPECT(bool(status) == pair.second, framework::LogLevel::ERRORS);
+        ARM_COMPUTE_EXPECT(bool(status) == expected, framework::LogLevel::ERRORS);
     }
 }
 
@@ -89,21 +102,22 @@
     using ShapeConfigurationTuple = std::tuple<TensorShape, TensorShape, TensorShape, bool>;
     const std::vector<ShapeConfigurationTuple> shape_configurations =
     {
-        { TensorShape(5U, 1U), TensorShape(3U, 5U), TensorShape(3U), true },
-        { TensorShape(10U, 12U), TensorShape(3U, 10U), TensorShape(3U), true },
-        { TensorShape(8U, 4U), TensorShape(2U, 8U), TensorShape(2U), true },
-        { TensorShape(8U, 4U), TensorShape(2U, 5U), TensorShape(2U), false }, // Mismatch in the K dimension
-        { TensorShape(5U, 0U), TensorShape(2U, 5U), TensorShape(2U), false }, // Invalid dimension
-        { TensorShape(5U, 4U, 3U, 4U, 5U, 6U), TensorShape(2U, 5U, 3U, 4U, 5U, 6U), TensorShape(2U), true },
-        { TensorShape(5U, 4U, 3U, 4U, 5U, 1U), TensorShape(2U, 5U, 3U, 4U, 5U, 6U), TensorShape(2U), false }, // no batch broadcasting
-        { TensorShape(5U, 4U, 3U, 4U, 9U, 6U), TensorShape(2U, 5U, 3U, 4U, 5U, 6U), TensorShape(2U), false }, // mismatch in batch dimension
-        { TensorShape(5U, 1U), TensorShape(3U, 5U), TensorShape(1U), false },                                 // invalid broadcast of bias
-        { TensorShape(5U, 1U), TensorShape(3U, 5U), TensorShape(3U, 3U), false },                             // 2d bias is invalid
+        { TensorShape(32U, 1U), TensorShape(3U, 32U), TensorShape(3U), true },
+        { TensorShape(16U, 12U), TensorShape(3U, 16U), TensorShape(3U), true },
+        { TensorShape(64U, 4U), TensorShape(2U, 64U), TensorShape(2U), true },
+        { TensorShape(16U, 4U), TensorShape(2U, 32U), TensorShape(2U), false }, // Mismatch in the K dimension
+        { TensorShape(16U, 0U), TensorShape(2U, 16U), TensorShape(2U), false }, // Invalid dimension
+        { TensorShape(32U, 4U, 3U, 4U, 5U, 6U), TensorShape(2U, 32U, 3U, 4U, 5U, 6U), TensorShape(2U), true },
+        { TensorShape(32U, 4U, 3U, 4U, 5U, 1U), TensorShape(2U, 32U, 3U, 4U, 5U, 6U), TensorShape(2U), false }, // no batch broadcasting
+        { TensorShape(32U, 4U, 3U, 4U, 9U, 6U), TensorShape(2U, 32U, 3U, 4U, 5U, 6U), TensorShape(2U), false }, // mismatch in batch dimension
+        { TensorShape(32U, 1U), TensorShape(3U, 32U), TensorShape(1U), false },                                 // invalid broadcast of bias
+        { TensorShape(32U, 1U), TensorShape(3U, 32U), TensorShape(3U, 3U), false },                             // 2d bias is invalid
+        { TensorShape(12U, 12U), TensorShape(3U, 12U), TensorShape(3U), false },                                // K must be multiple of 16
     };
 
     for(auto &tuple : shape_configurations)
     {
-        const bool expected = std::get<3>(tuple);
+        const bool expected = (std::get<3>(tuple) && arm_matrix_multiply_supported(CLKernelLibrary::get().get_device()));
 
         for(bool adj_lhs :
             {
@@ -134,7 +148,7 @@
                 const TensorInfo bia_info = TensorInfo(bia_shape, 1, DataType::S32);
                 TensorInfo       output_info;
 
-                MatMulKernelInfo matmul_kernel_info{ adj_lhs, adj_rhs, 1, 1, 1, false /* export_rhs_to_cl_image */ };
+                MatMulKernelInfo matmul_kernel_info{ adj_lhs, adj_rhs, 1, 1, 4, false /* export_rhs_to_cl_image */ };
 
                 Status status = ClMatMulLowpNativeMMULKernel::validate(&lhs_info, &rhs_info, &bia_info, &output_info, matmul_kernel_info);
                 ARM_COMPUTE_EXPECT(bool(status) == expected, framework::LogLevel::ERRORS);
@@ -172,10 +186,10 @@
     // It's enough to test a single shape and block size configuration while checking data types
     const TensorShape      shape     = TensorShape(48U, 48U);
     const TensorShape      bia_shape = TensorShape(48U);
-    const MatMulKernelInfo matmul_kernel_info{ false, false, 1, 1, 1, false };
+    const MatMulKernelInfo matmul_kernel_info{ false, false, 1, 1, 4, false };
     for(auto &tuple : data_type_configurations)
     {
-        const bool expected = std::get<4>(tuple);
+        const bool expected = (std::get<4>(tuple) && arm_matrix_multiply_supported(CLKernelLibrary::get().get_device()));
 
         const TensorInfo lhs_info(shape, 1, std::get<0>(tuple));
         const TensorInfo rhs_info(shape, 1, std::get<1>(tuple));
@@ -183,6 +197,7 @@
         TensorInfo       output_info(shape, 1, std::get<3>(tuple));
 
         Status status = ClMatMulLowpNativeMMULKernel::validate(&lhs_info, &rhs_info, &bia_info, &output_info, matmul_kernel_info);
+
         ARM_COMPUTE_EXPECT(bool(status) == expected, framework::LogLevel::ERRORS);
     }
 }
@@ -192,12 +207,137 @@
 TEST_SUITE(Quantized)
 TEST_SUITE(QASYMM8_SIGNED)
 
-// TODO: tests
+FIXTURE_DATA_TEST_CASE(RunSmall, CLMatMulLowpNativeMMULKernelFixture<int8_t>,
+                       framework::DatasetMode::ALL,
+                       combine(datasets::SmallMatMulLowpMMULDataset(),
+                               make("TransposeA", { false }),
+                               make("TransposeB", { false }),
+                               m0_values_precommit,
+                               n0_values_precommit,
+                               make("K0", { 4 }),
+                               make("ExportRhsToCLImage", { false }),
+                               make("DataType", DataType::QASYMM8_SIGNED)))
+{
+    if(_device_supports_mmul)
+    {
+        // Validate output
+        validate(CLAccessor(_target), _reference, tolerance_quant);
+    }
+}
+
+FIXTURE_DATA_TEST_CASE(RunWithBias, CLMatMulLowpNativeMMULKernelWithBiasFixture<int8_t>,
+                       framework::DatasetMode::ALL,
+                       combine(datasets::SmallMatMulLowpMMULWithBiasDataset(),
+                               make("TransposeA", { false }),
+                               make("TransposeB", { false }),
+                               m0_values_precommit,
+                               n0_values_precommit,
+                               make("K0", { 4 }),
+                               make("ExportRhsToCLImage", { false }),
+                               make("DataType", DataType::QASYMM8_SIGNED)))
+{
+    if(_device_supports_mmul)
+    {
+        // Validate output
+        validate(CLAccessor(_target), _reference, tolerance_quant);
+    }
+}
+
+FIXTURE_DATA_TEST_CASE(RunLargeNoTranspose, CLMatMulLowpNativeMMULKernelFixture<int8_t>,
+                       framework::DatasetMode::NIGHTLY,
+                       combine(datasets::LargeMatMulLowpMMULDataset(),
+                               make("TransposeA", { false }),
+                               make("TransposeB", { false }),
+                               m0_values_nightly_lhs_nt,
+                               n0_values_nightly_rhs_nt,
+                               make("K0", { 4 }),
+                               make("ExportRhsToCLImage", { false }),
+                               make("DataType", DataType::QASYMM8_SIGNED)))
+{
+    if(_device_supports_mmul)
+    {
+        // Validate output
+        validate(CLAccessor(_target), _reference, tolerance_quant);
+    }
+}
+
+// Running High Dimensional test is enough for qasymm8_signed, because we're stressing the number of dimensions, not data type or M0/N0/K0
+// It's a good idea to test for each Lhs/Rhs T/NT combinations because they're different CL kernels
+FIXTURE_DATA_TEST_CASE(RunHighDimensional, CLMatMulLowpNativeMMULKernelFixture<int8_t>,
+                       framework::DatasetMode::ALL,
+                       combine(datasets::HighDimensionalMatMulLowpMMULDataset(),
+                               make("TransposeA", { false }),
+                               make("TransposeB", { false }),
+                               make("M0", { 2 }),
+                               make("N0", { 2 }),
+                               make("K0", { 4 }),
+                               make("ExportRhsToCLImage", { false }),
+                               make("DataType", DataType::QASYMM8_SIGNED)))
+{
+    if(_device_supports_mmul)
+    {
+        // Validate output
+        validate(CLAccessor(_target), _reference, tolerance_quant);
+    }
+}
 
 TEST_SUITE_END() // QASYMM8_SIGNED
+
 TEST_SUITE(QASYMM8)
 
-// TODO: tests
+FIXTURE_DATA_TEST_CASE(RunSmall, CLMatMulLowpNativeMMULKernelFixture<uint8_t>,
+                       framework::DatasetMode::ALL,
+                       combine(datasets::SmallMatMulLowpMMULDatasetSubset(),
+                               make("TransposeA", { false }),
+                               make("TransposeB", { false }),
+                               m0_values_precommit,
+                               n0_values_precommit,
+                               make("K0", { 4 }),
+                               make("ExportRhsToCLImage", { false }),
+                               make("DataType", DataType::QASYMM8)))
+{
+    if(_device_supports_mmul)
+    {
+        // Validate output
+        validate(CLAccessor(_target), _reference, tolerance_quant);
+    }
+}
+
+FIXTURE_DATA_TEST_CASE(RunWithBias, CLMatMulLowpNativeMMULKernelWithBiasFixture<uint8_t>,
+                       framework::DatasetMode::ALL,
+                       combine(datasets::SmallMatMulLowpMMULWithBiasDataset(),
+                               make("TransposeA", { false }),
+                               make("TransposeB", { false }),
+                               m0_values_precommit,
+                               n0_values_precommit,
+                               make("K0", { 4 }),
+                               make("ExportRhsToCLImage", { false }),
+                               make("DataType", DataType::QASYMM8)))
+{
+    if(_device_supports_mmul)
+    {
+        // Validate output
+        validate(CLAccessor(_target), _reference, tolerance_quant);
+    }
+}
+
+FIXTURE_DATA_TEST_CASE(RunLargeNoTranspose, CLMatMulLowpNativeMMULKernelFixture<uint8_t>,
+                       framework::DatasetMode::NIGHTLY,
+                       combine(datasets::LargeMatMulLowpMMULDataset(),
+                               make("TransposeA", { false }),
+                               make("TransposeB", { false }),
+                               m0_values_nightly_lhs_nt,
+                               n0_values_nightly_rhs_nt,
+                               make("K0", { 4 }),
+                               make("ExportRhsToCLImage", { false }),
+                               make("DataType", DataType::QASYMM8)))
+{
+    if(_device_supports_mmul)
+    {
+        // Validate output
+        validate(CLAccessor(_target), _reference, tolerance_quant);
+    }
+}
 
 TEST_SUITE_END() // QASYMM8
 TEST_SUITE_END() // Quantized