Implement FP32/FP16 MatMul NT/NT kernel using the MMUL extension

Resolves COMPMID-6194

Signed-off-by: SiCong Li <sicong.li@arm.com>
Change-Id: Ie45e2aa9533948b2e5235563cef1d3834494eccf
Signed-off-by: SiCong Li <sicong.li@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9739
Reviewed-by: Gunes Bayir <gunes.bayir@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
diff --git a/Android.bp b/Android.bp
index b634a06..cfddf6e 100644
--- a/Android.bp
+++ b/Android.bp
@@ -51,6 +51,7 @@
         "src/core/CL/cl_kernels/common/instance_normalization.cl",
         "src/core/CL/cl_kernels/common/l2_normalize.cl",
         "src/core/CL/cl_kernels/common/mat_mul.cl",
+        "src/core/CL/cl_kernels/common/mat_mul_mmul.cl",
         "src/core/CL/cl_kernels/common/mat_mul_quantized.cl",
         "src/core/CL/cl_kernels/common/mean_stddev_normalization.cl",
         "src/core/CL/cl_kernels/common/memset.cl",
@@ -698,6 +699,7 @@
         "src/gpu/cl/kernels/ClIndirectConv2dKernel.cpp",
         "src/gpu/cl/kernels/ClMatMulLowpNativeKernel.cpp",
         "src/gpu/cl/kernels/ClMatMulNativeKernel.cpp",
+        "src/gpu/cl/kernels/ClMatMulNativeMMULKernel.cpp",
         "src/gpu/cl/kernels/ClMulKernel.cpp",
         "src/gpu/cl/kernels/ClPermuteKernel.cpp",
         "src/gpu/cl/kernels/ClPool2dKernel.cpp",
diff --git a/SConscript b/SConscript
index 904d5ba..320cb2d 100644
--- a/SConscript
+++ b/SConscript
@@ -395,6 +395,7 @@
                        'src/core/CL/cl_kernels/common/instance_normalization.cl',
                        'src/core/CL/cl_kernels/common/l2_normalize.cl',
                        'src/core/CL/cl_kernels/common/mat_mul.cl',
+                       'src/core/CL/cl_kernels/common/mat_mul_mmul.cl',
                        'src/core/CL/cl_kernels/common/mat_mul_quantized.cl',
                        'src/core/CL/cl_kernels/common/mean_stddev_normalization.cl',
                        'src/core/CL/cl_kernels/common/memset.cl',
diff --git a/filelist.json b/filelist.json
index 6c5b78f..f354e69 100644
--- a/filelist.json
+++ b/filelist.json
@@ -515,6 +515,7 @@
         "common": [
           "src/gpu/cl/kernels/ClMatMulLowpNativeKernel.cpp",
           "src/gpu/cl/kernels/ClMatMulNativeKernel.cpp",
+          "src/gpu/cl/kernels/ClMatMulNativeMMULKernel.cpp",
           "src/gpu/cl/operators/ClMatMul.cpp",
           "src/runtime/CL/functions/CLMatMul.cpp",
           "src/runtime/heuristics/matmul_native/ClMatMulNativeDefaultConfigValhall.cpp",
diff --git a/src/core/CL/cl_kernels/common/mat_mul_mmul.cl b/src/core/CL/cl_kernels/common/mat_mul_mmul.cl
new file mode 100644
index 0000000..1d94767
--- /dev/null
+++ b/src/core/CL/cl_kernels/common/mat_mul_mmul.cl
@@ -0,0 +1,191 @@
+/*
+ * Copyright (c) 2023 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "helpers.h"
+#include "tile_helpers.h"
+
+#if defined(MAT_MUL_NATIVE_MMUL_NT_NT)
+/** This OpenCL kernel performs the batch matrix multiplication (BatchMatMul) using MMUL: LHS non-transposed, RHS non-transposed - buffer only
+ *
+ * @note the "batch" here expresses the number of matrix multiplications to run in parallel. However, it
+ *       should NOT be confused with the batch size of the model. For NHWC the "batch" is the "H" dimension
+ * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
+ * @note The tile's dimensions used for the LHS and RHS matrices (M0, N0 and K0) must be passed at compile time using -DN0, -DM0 and -DK0 (e.g. -DN0=8, -DM0=4, -DK0=1).
+ * @note The number of leftover outputs rows/columns must be passed using -DN0_LEFTOVER and -DM0_LEFTOVER (e.g. -DN0_LEFTOVER=2, -DM0_LEFTOVER=3)
+ * @note The MMUL block dimension (MMUL_M0, MMUL_N0, MMUL_K0) must be passed at compile time using -DMMUL_M0, -DMMUL_N0 and -DMMUL_K0 (e.g. -DMMUL_M0=4, -DMMUL_N0=4, -DMMUL_K0=4).
+ * @note The number of leftover outputs rows/columns must be passed using -DN0_LEFTOVER and -DM0_LEFTOVER (e.g. -DN0_LEFTOVER=2, -DM0_LEFTOVER=3)
+ * @note The dimension K must be passed at compile time using -DK (e.g. -DK=4). K must be a multiple of MMUL_K0
+ * @note The kernel name in uppercase must be passed at compile time (e.g. -DMAT_MUL_NATIVE_MMUL_NT_NT)
+ * @note Only the following configurations of M0, N0 and K0 are currently supported:
+ *  - M0 > 0
+ *  - N0 = 1, 2, 3, 4, 8, 16
+ *  - K0 = 1
+ * @note Values > 8 for M0 are not expected to be efficient
+ *
+ * @param[in]  lhs_ptr                           Pointer to the lhs matrix. Supported data types: F32/F16
+ * @param[in]  lhs_stride_y                      Stride of the lhs matrix in Y (2nd) dimension (in bytes)
+ * @param[in]  lhs_stride_z                      Stride of the lhs tensor in Z (3rd) dimension (in bytes)
+ * @param[in]  lhs_w                             The width of the lhs tensor
+ * @param[in]  lhs_h                             The height of the lhs tensor
+ * @param[in]  lhs_n                             Number of the matrices (buffers) in the batch
+ * @param[in]  lhs_offset_first_element_in_bytes The offset of the first element in the lhs matrix
+ * @param[in]  rhs_ptr                           Pointer to the rhs matrix. Supported data types: same as @p lhs_ptr
+ * @param[in]  rhs_stride_y                      Stride of the rhs matrix in Y (2nd) dimension (in bytes)
+ * @param[in]  rhs_stride_z                      Stride of the rhs tensor in Z (3rd) dimension (in bytes)
+ * @param[in]  rhs_w                             The width of the rhs tensor
+ * @param[in]  rhs_h                             The height of the rhs tensor
+ * @param[in]  rhs_n                             Number of the matrices (buffers) in the batch
+ * @param[in]  rhs_offset_first_element_in_bytes The offset of the first element in the rhs matrix
+ * @param[out] dst_ptr                           Pointer to the dst matrix. Supported data types: same as @p lhs_ptr
+ * @param[in]  dst_stride_y                      Stride of the dst matrix in Y (2nd) dimension (in bytes)
+ * @param[in]  dst_stride_z                      Stride of the dst tensor in Z (3rd) dimension (in bytes)
+ * @param[in]  dst_w                             The width of the dst tensor
+ * @param[in]  dst_h                             The height of the dst tensor
+ * @param[in]  dst_n                             Number of the matrices (buffers) in the batch
+ * @param[in]  dst_offset_first_element_in_bytes The offset of the first element in the dst matrix
+ * @param[in]  M                                 Number of rows in LHS matrix
+ * @param[in]  N                                 Number of columns in RHS matrix
+ */
+__kernel void mat_mul_native_mmul_nt_nt(
+    TENSOR3D_T(lhs, BUFFER),
+    TENSOR3D_T(rhs, BUFFER),
+    TENSOR3D_T(dst, BUFFER),
+    const int M,
+    const int N)
+{
+#define MMUL_BLOCK_SIZE (MMUL_M0 * MMUL_N0)
+
+    const uint x0 = get_global_id(0); // (N / N0) * MMUL_M0
+    const uint y0 = get_global_id(1); // (M / M0) / MMUL_M0
+    const uint z  = get_global_id(2); // Batch
+
+    // Get block coordinates
+    const uint block_x = (x0 / MMUL_BLOCK_SIZE);
+    const uint block_y = y0;
+
+    // Get thread coordinates within a block
+    const uint thread_id = (x0 % MMUL_BLOCK_SIZE);
+    const uint thread_x  = thread_id % MMUL_N0;
+    const uint thread_y  = (thread_id / MMUL_N0);
+
+    // Starting destination coordinates
+    // Note: We need to clamp dst_x and dst_y because we always need to execute a complete MMUL block! Only after the matrix multiplication
+    // part can we exit the kernel if it is out-of-bound. Remember, we have a cooperative matrix multiplication. Therefore, we need a full block to get the correct results
+    // Although we will never write out-of-bound, we still need this clamp to ensure that we do not read out-of-bound either.
+    const uint dst_x_unclamped = thread_x * N0 + block_x * N0 * MMUL_N0;
+    const uint dst_y_unclamped = thread_y * M0 + block_y * M0 * MMUL_M0;
+    const uint dst_x           = min(dst_x_unclamped, (uint)(N - N0));
+    const uint dst_y           = min(dst_y_unclamped, (uint)(M - M0));
+
+    // Starting LHS coordinates
+    const uint lhs_x = thread_x;
+    const uint lhs_y = dst_y;
+
+    // Starting RHS coordinates
+    const uint rhs_x = dst_x;
+    const uint rhs_y = thread_y;
+
+    // Compute LHS/RHS/DST matrix address
+    lhs_offset_first_element_in_bytes += lhs_x * sizeof(DATA_TYPE) + lhs_y * lhs_stride_y + z * lhs_stride_z;
+    rhs_offset_first_element_in_bytes += rhs_x * sizeof(DATA_TYPE) + rhs_y * rhs_stride_y + z * rhs_stride_z;
+    dst_offset_first_element_in_bytes += dst_x * sizeof(DATA_TYPE) + dst_y * dst_stride_y + z * dst_stride_z;
+
+    // Initialize the accumulators
+    // MMUL extension accumulate the result in F32 for both F32 and F16
+    TILE(float, M0, N0, c_f32);
+
+    LOOP_UNROLLING(int, i, 0, 1, M0,
+    {
+        c_f32[i].v = 0;
+    })
+
+    for(int k = 0; k < K; k += MMUL_K0)
+    {
+        // A tile of M0xK0 but K0 must be set to 1
+        TILE(DATA_TYPE, M0, 1, a);
+        // A tile of K0xN0 but K0 must be set to 1
+        TILE(DATA_TYPE, 1, N0, b);
+
+        // Load tile from the lhs/rhs tensors
+        T_LOAD(DATA_TYPE, M0, 1, BUFFER, lhs, 0, 0, 1, lhs_stride_y, a);
+        T_LOAD(DATA_TYPE, 1, N0, BUFFER, rhs, 0, 0, 1, rhs_stride_y, b);
+
+        LOOP_UNROLLING(int, m0, 0, 1, M0,
+        {
+            LOOP_UNROLLING(int, n0, 0, 1, N0,
+            {
+                c_f32[m0].s[n0] = arm_matrix_multiply(a[m0].s[0], b[0].s[n0], c_f32[m0].s[n0]);
+            })
+        })
+
+        lhs_offset_first_element_in_bytes += MMUL_K0 * sizeof(DATA_TYPE);
+        rhs_offset_first_element_in_bytes += MMUL_K0 * rhs_stride_y;
+    }
+
+    // For threads "outside" of the dst bound, we do not write but we have to "read" (arm_matrix_multiply). That's why this needs to happen after arm_matrix_multiply
+    if(dst_x_unclamped >= N || dst_y_unclamped >= M)
+    {
+        return;
+    }
+
+#if defined(HALF_PRECISION)
+    TILE(DATA_TYPE, M0, N0, c);
+
+    // Conversion required for the half precision
+    LOOP_UNROLLING(int, m0, 0, 1, M0,
+    {
+        LOOP_UNROLLING(int, n0, 0, 1, N0,
+        {
+            c[m0].s[n0] = c_f32[m0].s[n0];
+        })
+    })
+#else // defined(HALF_PRECISION)
+#define c c_f32
+#endif // defined(HALF_PRECISION)
+
+    if(dst_x + N0 <= N || N0_LEFTOVER == 0)
+    {
+        LOOP_UNROLLING(int, m0, 0, 1, M0,
+        {
+            if(dst_y + m0 < M || M0_LEFTOVER == 0)
+            {
+                VSTORE(N0)
+                (c[m0].v, 0, (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + m0 * dst_stride_y));
+            }
+        })
+    }
+    else
+    {
+        LOOP_UNROLLING(int, m0, 0, 1, M0,
+        {
+            if(dst_y + m0 < M || M0_LEFTOVER == 0)
+            {
+                VSTORE_PARTIAL(N0, N0_LEFTOVER)
+                (c[m0].v, 0, (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + m0 * dst_stride_y));
+            }
+        })
+    }
+
+#undef MMUL_BLOCK_SIZE
+}
+#endif // defined(MAT_MUL_NATIVE_MMUL_NT_NT)
diff --git a/src/gpu/cl/ClKernelLibrary.cpp b/src/gpu/cl/ClKernelLibrary.cpp
index a908004..408f1f7 100644
--- a/src/gpu/cl/ClKernelLibrary.cpp
+++ b/src/gpu/cl/ClKernelLibrary.cpp
@@ -319,6 +319,7 @@
     { "l2_normalize_x", "common/l2_normalize.cl" },
     { "l2_normalize_y", "common/l2_normalize.cl" },
     { "l2_normalize_z", "common/l2_normalize.cl" },
+    { "mat_mul_native_mmul_nt_nt", "common/mat_mul_mmul.cl" },
     { "mat_mul_native_nt_nt", "common/mat_mul.cl" },
     { "mat_mul_native_nt_t", "common/mat_mul.cl" },
     { "mat_mul_native_t_nt", "common/mat_mul.cl" },
@@ -799,6 +800,10 @@
 #include "./cl_kernels/common/mat_mul.clembed"
     },
     {
+        "common/mat_mul_mmul.cl",
+#include "./cl_kernels/common/mat_mul_mmul.clembed"
+    },
+    {
         "common/mat_mul_quantized.cl",
 #include "./cl_kernels/common/mat_mul_quantized.clembed"
     },
diff --git a/src/gpu/cl/kernels/ClMatMulNativeMMULKernel.cpp b/src/gpu/cl/kernels/ClMatMulNativeMMULKernel.cpp
new file mode 100644
index 0000000..32e69ca
--- /dev/null
+++ b/src/gpu/cl/kernels/ClMatMulNativeMMULKernel.cpp
@@ -0,0 +1,261 @@
+/*
+ * Copyright (c) 2023 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "src/gpu/cl/kernels/ClMatMulNativeMMULKernel.h"
+
+#include "arm_compute/core/CL/CLHelpers.h"
+#include "arm_compute/core/CL/ICLTensor.h"
+#include "arm_compute/core/ITensorPack.h"
+#include "arm_compute/core/KernelDescriptors.h"
+#include "arm_compute/core/TensorInfo.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
+
+#include "src/common/utils/Log.h"
+#include "src/core/helpers/AutoConfiguration.h"
+#include "src/core/helpers/WindowHelpers.h"
+
+#include "support/Cast.h"
+#include "support/StringSupport.h"
+
+namespace arm_compute
+{
+namespace opencl
+{
+namespace kernels
+{
+namespace
+{
+// Block size dimensions for the MMUL extension
+constexpr int mmul_m0 = 4;
+constexpr int mmul_n0 = 4;
+constexpr int mmul_k0 = 4;
+
+inline std::pair<int, int> adjust_m0_n0(int m0, int n0, int m, int n)
+{
+    m0 = std::min(m0, m);
+    n0 = adjust_vec_size(n0, n);
+    return { m0, n0 };
+}
+
+Status validate_matmul_kernel_info(const MatMulKernelInfo &matmul_kernel_info)
+{
+    const bool adj_lhs = matmul_kernel_info.adj_lhs;
+    const bool adj_rhs = matmul_kernel_info.adj_rhs;
+    const int  m0      = matmul_kernel_info.m0;
+    const int  n0      = matmul_kernel_info.n0;
+    const int  k0      = matmul_kernel_info.k0;
+
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG((adj_lhs || adj_rhs), "adj_lhs and adj_rhs are not supported yet");
+
+    // Validate M0
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(m0 < 1, "Only positive integers are supported for M0");
+
+    // Validate N0
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(n0 < 1, "Only positive integers are supported for N0");
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(((n0 & (n0 - 1)) && (n0 != 3)) || (n0 > 16), "Only 1,2,3,4,8,16 are supported for N0");
+
+    // Validate K0
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG((k0 != 1), "Only 1 is supported for k0");
+
+    return Status{};
+}
+
+Status validate_input_shapes(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const MatMulKernelInfo &matmul_kernel_info)
+{
+    ARM_COMPUTE_UNUSED(matmul_kernel_info);
+    const size_t lhs_k = lhs_shape.x();
+    const size_t rhs_k = rhs_shape.y();
+
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(lhs_k != rhs_k, "K dimension in Lhs and Rhs matrices must match.");
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG_VAR((lhs_k % mmul_k0) != 0, "K dimension must be a multiple of %d", mmul_k0);
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(lhs_shape.total_size() == 0, "Lhs tensor can't be empty");
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(rhs_shape.total_size() == 0, "Rhs tensor can't be empty");
+
+    constexpr size_t batch_dim_start = 2;
+    for(size_t i = batch_dim_start; i < Coordinates::num_max_dimensions; ++i)
+    {
+        ARM_COMPUTE_RETURN_ERROR_ON_MSG(lhs_shape[i] != rhs_shape[i], "Batch dimension broadcasting is not supported");
+    }
+
+    return Status{};
+}
+
+std::pair<Status, Window> validate_and_configure_window(ITensorInfo *lhs, ITensorInfo *rhs, ITensorInfo *dst, const MatMulKernelInfo &matmul_kernel_info)
+{
+    ARM_COMPUTE_UNUSED(lhs, rhs);
+
+    const Window win = calculate_max_window(*dst, Steps(1, 1));
+
+    // Collapse along the Z direction
+    // This collapse needs to be here in order to tune the Z dimension of LWS
+    Window collapsed = win.collapse(win, Window::DimZ);
+
+    // Reconfigure window size, one arm_matrix_multiply call needs 16 threads to finish.
+    Window::Dimension x_dimension = collapsed.x();
+    Window::Dimension y_dimension = collapsed.y();
+
+    const int m = dst->dimension(1);
+    const int n = dst->dimension(0);
+
+    int m0{};
+    int n0{};
+    std::tie(m0, n0) = adjust_m0_n0(matmul_kernel_info.m0, matmul_kernel_info.n0, m, n);
+
+    // Make M and N multiple of M0 and N0 respectively
+    const unsigned int ceil_to_multiple_n_n0 = ceil_to_multiple(n, n0);
+    const unsigned int ceil_to_multiple_m_m0 = ceil_to_multiple(m, m0);
+
+    // Divide M and N by M0 and N0 respectively
+    const unsigned int n_div_n0 = ceil_to_multiple_n_n0 / n0;
+    const unsigned int m_div_m0 = ceil_to_multiple_m_m0 / m0;
+
+    // Make n_div_n0 and m_div_m0 multiple of mmul_n0 and mmul_m0 respectively
+    const unsigned int ceil_to_multiple_n_div_n0_mmul_n0 = ceil_to_multiple(n_div_n0, mmul_n0);
+    const unsigned int ceil_to_multiple_m_div_m0_mmul_m0 = ceil_to_multiple(m_div_m0, mmul_m0);
+
+    // Ensure x_dimension is multiple of MMUL block size (mmul_m0 * mmul_n0)
+    x_dimension.set_end(ceil_to_multiple_n_div_n0_mmul_n0 * mmul_m0);
+    y_dimension.set_end(ceil_to_multiple_m_div_m0_mmul_m0 / mmul_m0);
+
+    collapsed.set(Window::DimX, x_dimension);
+    collapsed.set(Window::DimY, y_dimension);
+
+    return std::make_pair(Status{}, collapsed);
+}
+}
+ClMatMulNativeMMULKernel::ClMatMulNativeMMULKernel()
+{
+    _type = CLKernelType::GEMM;
+}
+
+Status ClMatMulNativeMMULKernel::validate(const ITensorInfo *lhs, const ITensorInfo *rhs, const ITensorInfo *dst, const MatMulKernelInfo &matmul_kernel_info)
+{
+    ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lhs, rhs, dst);
+    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lhs, 1, DataType::F32, DataType::F16);
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(!arm_matrix_multiply_supported(CLKernelLibrary::get().get_device()), "The extension cl_arm_matrix_multiply is not supported on the target platform");
+    ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, rhs);
+    ARM_COMPUTE_RETURN_ON_ERROR(validate_matmul_kernel_info(matmul_kernel_info));
+    ARM_COMPUTE_RETURN_ON_ERROR(validate_input_shapes(lhs->tensor_shape(), rhs->tensor_shape(), matmul_kernel_info));
+
+    if(dst->total_size() != 0)
+    {
+        const TensorInfo tensor_info_dst = dst->clone()->set_tensor_shape(misc::shape_calculator::compute_matmul_shape(lhs->tensor_shape(), rhs->tensor_shape(), matmul_kernel_info));
+        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(dst, &tensor_info_dst);
+        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, dst);
+    }
+
+    return Status{};
+}
+void ClMatMulNativeMMULKernel::configure(const ClCompileContext &compile_context, ITensorInfo *lhs, ITensorInfo *rhs, ITensorInfo *dst, const MatMulKernelInfo &matmul_kernel_info)
+{
+    ARM_COMPUTE_ERROR_ON_NULLPTR(lhs, rhs, dst);
+    ARM_COMPUTE_LOG_PARAMS(lhs, rhs, dst, matmul_kernel_info);
+    ARM_COMPUTE_ERROR_THROW_ON(validate(lhs, rhs, dst, matmul_kernel_info));
+
+    // dst tensor auto initialization if not yet initialized
+    auto_init_if_empty(*dst, lhs->clone()->set_tensor_shape(misc::shape_calculator::compute_matmul_shape(lhs->tensor_shape(), rhs->tensor_shape(), matmul_kernel_info)));
+
+    const int m = dst->dimension(1);
+    const int n = dst->dimension(0);
+    const int k = lhs->tensor_shape().x();
+    _m          = m;
+    _n          = n;
+
+    int m0{};
+    int n0{};
+    std::tie(m0, n0) = adjust_m0_n0(matmul_kernel_info.m0, matmul_kernel_info.n0, m, n);
+
+    // Configure kernel window
+    const auto win_config = validate_and_configure_window(lhs, rhs, dst, matmul_kernel_info);
+    ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
+    IClKernel::configure_internal(win_config.second);
+
+    // Calculate partial (store instead of load) M0 and partial N0 for the partial blocks at the end of a row/column if any. This is to avoid padding.
+    const unsigned int m0_leftover = m % m0;
+    const unsigned int n0_leftover = n % n0;
+
+    CLBuildOptions build_opts;
+    build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(lhs->data_type()));
+    build_opts.add_option_if(lhs->data_type() == DataType::F16, "-DHALF_PRECISION");
+    build_opts.add_option("-DM0=" + support::cpp11::to_string(m0));
+    build_opts.add_option("-DN0=" + support::cpp11::to_string(n0));
+    build_opts.add_option("-DK0=" + support::cpp11::to_string(matmul_kernel_info.k0));
+    build_opts.add_option("-DM0_LEFTOVER=" + support::cpp11::to_string(m0_leftover));
+    build_opts.add_option("-DN0_LEFTOVER=" + support::cpp11::to_string(n0_leftover));
+    build_opts.add_option("-DMMUL_M0=" + support::cpp11::to_string(mmul_m0));
+    build_opts.add_option("-DMMUL_N0=" + support::cpp11::to_string(mmul_n0));
+    build_opts.add_option("-DMMUL_K0=" + support::cpp11::to_string(mmul_k0));
+    build_opts.add_option("-DK=" + support::cpp11::to_string(k));
+
+    std::string kernel_name("mat_mul_native_mmul_nt_nt");
+
+    // A macro guard to compile ONLY the kernel of interest
+    build_opts.add_option("-D" + upper_string(kernel_name));
+
+    // Create kernel
+    _kernel = create_kernel(compile_context, kernel_name, build_opts.options());
+
+    // Set config_id for enabling LWS tuning
+    _config_id = kernel_name;
+    _config_id += "_";
+    _config_id += lower_string(string_from_data_type(lhs->data_type()));
+    _config_id += "_";
+    _config_id += support::cpp11::to_string(k);
+    _config_id += "_";
+    _config_id += support::cpp11::to_string(dst->dimension(2));
+    _config_id += "_";
+    _config_id += support::cpp11::to_string(m0);
+    _config_id += "_";
+    _config_id += support::cpp11::to_string(n0);
+    _config_id += "_";
+    _config_id += support::cpp11::to_string(matmul_kernel_info.k0);
+}
+
+void ClMatMulNativeMMULKernel::run_op(ITensorPack &tensors, const Window &window, cl::CommandQueue &queue)
+{
+    ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
+    ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICLKernel::window(), window);
+
+    const ICLTensor *lhs = utils::cast::polymorphic_downcast<const ICLTensor *>(tensors.get_const_tensor(TensorType::ACL_SRC_0));
+    const ICLTensor *rhs = utils::cast::polymorphic_downcast<const ICLTensor *>(tensors.get_const_tensor(TensorType::ACL_SRC_1));
+    ICLTensor       *dst = utils::cast::polymorphic_downcast<ICLTensor *>(tensors.get_tensor(TensorType::ACL_DST));
+    ARM_COMPUTE_ERROR_ON_NULLPTR(lhs, rhs, dst);
+    ARM_COMPUTE_LOG_PARAMS(lhs, rhs, dst);
+    unsigned int idx = 0;
+
+    add_3d_tensor_nhw_argument(idx, lhs);
+    add_3d_tensor_nhw_argument(idx, rhs);
+    add_3d_tensor_nhw_argument(idx, dst);
+
+    // Pass m and n at runtime as signed ints, to ensure results of any subtractions they could be operand in, would still be signed.
+    _kernel.setArg<cl_int>(idx++, _m);
+    _kernel.setArg<cl_int>(idx++, _n);
+
+    // LWS_x should be multiple of 16 at least. (32, 2) has been chosen to have more work-items on a single core
+    // LWS also enforces the order of execution of the work items which improves cache utilization
+    enqueue(queue, *this, window, cl::NDRange(32, 2), false);
+}
+
+} // namespace kernels
+} // namespace opencl
+} // namespace arm_compute
diff --git a/src/gpu/cl/kernels/ClMatMulNativeMMULKernel.h b/src/gpu/cl/kernels/ClMatMulNativeMMULKernel.h
new file mode 100644
index 0000000..26fe08c
--- /dev/null
+++ b/src/gpu/cl/kernels/ClMatMulNativeMMULKernel.h
@@ -0,0 +1,93 @@
+/*
+ * Copyright (c) 2023 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef ACL_SRC_GPU_CL_KERNELS_CLMATMULNATIVEMMULKERNEL
+#define ACL_SRC_GPU_CL_KERNELS_CLMATMULNATIVEMMULKERNEL
+
+#include "src/core/common/Macros.h"
+#include "src/gpu/cl/ClCompileContext.h"
+#include "src/gpu/cl/IClKernel.h"
+
+namespace arm_compute
+{
+struct MatMulKernelInfo;
+namespace opencl
+{
+namespace kernels
+{
+class ClMatMulNativeMMULKernel : public IClKernel
+{
+public:
+    ClMatMulNativeMMULKernel();
+    ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(ClMatMulNativeMMULKernel);
+    /** Initialize the kernel's input and output.
+     *
+     * This kernel performs matrix multiplication of lhs and rhs:
+     *
+     *  dst = matmul(lhs, rhs)
+     *
+     * Valid data layouts:
+     * - All
+     *
+     * Valid data type configurations:
+     * |lhs            |rhs            |dst            |
+     * |:--------------|:--------------|:--------------|
+     * |F32            |F32            |F32            |
+     * |F16            |F16            |F16            |
+     *
+     * Shape definitions:
+     *       Dim0, Dim1,       Dim2...
+     * lhs: [   K,    M, Batch dims...]
+     * rhs: [   N,    K, Batch dims...]
+     * dst: [   N,    M, Batch dims...]
+     *
+     * Valid shape configurations:
+     * - K must be a multiple of 4 (MMUL_K0).
+     * - No broadcasting in batch dimensions. I.e. batch dims must be the same across lhs, rhs and dst
+     *
+     * @param[in]  compile_context The compile context to be used.
+     * @param[in]  lhs             Input tensor for the LHS matrix.
+     * @param[in]  rhs             Input tensor for the RHS matrix.
+     * @param[out] dst             Output tensor info.
+     * @param[in]  matmul_info     Attributes for Batch MatMul Kernel
+     */
+    void configure(const ClCompileContext &compile_context, ITensorInfo *lhs, ITensorInfo *rhs, ITensorInfo *dst, const MatMulKernelInfo &matmul_info);
+    /** Static function to check if given info will lead to a valid configuration
+     *
+     * Similar to @ref ClMatMulNativeMMULKernel::configure()
+     *
+     * @return a status
+     */
+    static Status validate(const ITensorInfo *lhs, const ITensorInfo *rhs, const ITensorInfo *dst, const MatMulKernelInfo &matmul_info);
+
+    // Inherited methods overridden:
+    void run_op(ITensorPack &tensors, const Window &window, cl::CommandQueue &queue) override;
+
+private:
+    int _m{ 1 };
+    int _n{ 1 };
+};
+} // namespace kernels
+} // namespace opencl
+} // namespace arm_compute
+#endif /* ACL_SRC_GPU_CL_KERNELS_CLMATMULNATIVEMMULKERNEL */
diff --git a/tests/datasets/LargeMatMulMMULDataset.h b/tests/datasets/LargeMatMulMMULDataset.h
new file mode 100644
index 0000000..23e0b3e
--- /dev/null
+++ b/tests/datasets/LargeMatMulMMULDataset.h
@@ -0,0 +1,64 @@
+/*
+ * Copyright (c) 2023 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifndef ACL_TESTS_DATASETS_LARGEMATMULMMULDATASET
+#define ACL_TESTS_DATASETS_LARGEMATMULMMULDATASET
+
+#include "arm_compute/core/TensorShape.h"
+#include "arm_compute/core/Types.h"
+#include "tests/datasets/MatMulDataset.h"
+
+namespace arm_compute
+{
+namespace test
+{
+namespace datasets
+{
+/** MatMul MMUL shapes are similar to MatMul shapes except that K has to be a multiple of MMUL_K0 which is 4 (e.g. see src/gpu/cl/kernels/ClMatMulNativeMMULKernel.cpp for the definition)
+ */
+class LargeMatMulMMULDataset final : public MatMulDataset
+{
+public:
+    LargeMatMulMMULDataset()
+    {
+        add_config(TensorShape(24U, 13U, 3U, 2U), TensorShape(33U, 24U, 3U, 2U), TensorShape(33U, 13U, 3U, 2U));
+        add_config(TensorShape(36U, 12U, 1U, 5U), TensorShape(21U, 36U, 1U, 5U), TensorShape(21U, 12U, 1U, 5U));
+        add_config(TensorShape(44U, 38U, 3U, 2U), TensorShape(21U, 44U, 3U, 2U), TensorShape(21U, 38U, 3U, 2U));
+    }
+};
+
+class HighDimensionalMatMulMMULDataset final : public MatMulDataset
+{
+public:
+    HighDimensionalMatMulMMULDataset()
+    {
+        add_config(TensorShape(4U, 5U, 2U, 2U, 2U, 2U), TensorShape(5U, 4U, 2U, 2U, 2U, 2U), TensorShape(5U, 5U, 2U, 2U, 2U, 2U)); // 6D tensor
+    }
+};
+
+} // namespace datasets
+} // namespace test
+} // namespace arm_compute
+
+#endif /* ACL_TESTS_DATASETS_LARGEMATMULMMULDATASET */
diff --git a/tests/datasets/SmallMatMulMMULDataset.h b/tests/datasets/SmallMatMulMMULDataset.h
new file mode 100644
index 0000000..9e51748
--- /dev/null
+++ b/tests/datasets/SmallMatMulMMULDataset.h
@@ -0,0 +1,66 @@
+/*
+ * Copyright (c) 2023 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifndef ACL_TESTS_DATASETS_SMALLMATMULMMULDATASET
+#define ACL_TESTS_DATASETS_SMALLMATMULMMULDATASET
+
+#include "arm_compute/core/TensorShape.h"
+#include "arm_compute/core/Types.h"
+#include "tests/datasets/MatMulDataset.h"
+
+namespace arm_compute
+{
+namespace test
+{
+namespace datasets
+{
+/** MatMul MMUL shapes are similar to MatMul shapes except that K has to be a multiple of MMUL_K0 which is 4 (e.g. see src/gpu/cl/kernels/ClMatMulNativeMMULKernel.cpp for the definition)
+ */
+class SmallMatMulMMULDataset final : public MatMulDataset
+{
+public:
+    SmallMatMulMMULDataset()
+    {
+        add_config(TensorShape(8U, 4U, 2U, 2U), TensorShape(2U, 8U, 2U, 2U), TensorShape(2U, 4U, 2U, 2U));
+        add_config(TensorShape(28U, 1U), TensorShape(23U, 28U), TensorShape(23U, 1U));
+        add_config(TensorShape(8U, 4U, 2U), TensorShape(16U, 8U, 2U), TensorShape(16U, 4U, 2U));
+        add_config(TensorShape(32U, 2U), TensorShape(17U, 32U), TensorShape(17U, 2U));
+        add_config(TensorShape(8U, 6U), TensorShape(7U, 8U), TensorShape(7U, 6U));
+    }
+};
+
+class TinyMatMulMMULDataset final : public MatMulDataset
+{
+public:
+    TinyMatMulMMULDataset()
+    {
+        add_config(TensorShape(4U, 4U), TensorShape(4U, 4U), TensorShape(4U, 4U));
+    }
+};
+
+} // namespace datasets
+} // namespace test
+} // namespace arm_compute
+
+#endif /* ACL_TESTS_DATASETS_SMALLMATMULMMULDATASET */
diff --git a/tests/validation/CL/MatMulNativeMMULKernel.cpp b/tests/validation/CL/MatMulNativeMMULKernel.cpp
new file mode 100644
index 0000000..b33a4fa
--- /dev/null
+++ b/tests/validation/CL/MatMulNativeMMULKernel.cpp
@@ -0,0 +1,348 @@
+/*
+ * Copyright (c) 2023 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "arm_compute/runtime/CL/CLTensor.h"
+#include "src/gpu/cl/kernels/ClMatMulNativeMMULKernel.h"
+#include "tests/datasets/LargeMatMulMMULDataset.h"
+#include "tests/datasets/SmallMatMulMMULDataset.h"
+#include "tests/framework/Macros.h"
+#include "tests/framework/datasets/Datasets.h"
+#include "tests/validation/Validation.h"
+#include "tests/validation/fixtures/MatMulKernelFixture.h"
+#include "tests/validation/reference/Permute.h"
+
+#include <tuple>
+
+namespace arm_compute
+{
+namespace test
+{
+namespace validation
+{
+namespace
+{
+RelativeTolerance<float> tolerance_f32(0.001f); /**< Tolerance value for comparing reference's output against implementation's output for floating point data types */
+constexpr float          abs_tolerance_f32(
+    0.0001f); /**< Absolute tolerance value for comparing reference's output against implementation's output for floating point data types in case using relative tolerance fails because of small values */
+constexpr float abs_tolerance_f16(
+    0.001f);                                                   /**< Absolute tolerance value for comparing reference's output against implementation's output for fp16  data types in case using relative tolerance fails because of small values */
+RelativeTolerance<half_float::half> tolerance_f16(half(0.01)); /**< Tolerance value for comparing reference's output against implementation's output for floating point data types */
+} // namespace
+
+/** M0 values to test --precommit*/
+const auto m0_values_precommit = framework::dataset::make("M0", { 1, 3 });
+
+/** N0 values to test --precommit*/
+const auto n0_values_precommit = framework::dataset::make("N0", { 2, 4 });
+
+/** M0 values to test --nightly*/
+const auto m0_values_nightly_lhs_nt = framework::dataset::make("M0", { 1, 2, 3, 4, 5, 6, 7, 8 });
+
+/** N0 values to test --nightly*/
+const auto n0_values_nightly_rhs_nt = framework::dataset::make("N0", { 1, 2, 3, 4, 8, 16 });
+
+/** K0 value -- Fixed to 1 */
+const auto k0_value = framework::dataset::make("K0", { 1 });
+
+template <typename T>
+using CLMatMulNativeMMULKernelFixture = MatMulKernelValidationFixture<T, ClMatMulNativeMMULKernel, true /*use_mmul*/>;
+
+TEST_SUITE(CL)
+TEST_SUITE(MatMulNativeMMULKernel)
+TEST_SUITE(Validate)
+
+TEST_CASE(SupportedBlockSizes, framework::DatasetMode::ALL)
+{
+    if(arm_matrix_multiply_supported(CLKernelLibrary::get().get_device()))
+    {
+        using MatMulConfigurationPair = std::pair<MatMulKernelInfo, bool>;
+
+        const std::vector<MatMulConfigurationPair> supported_block_sizes =
+        {
+            // MatMulKernelInfo(adj_lhs, adj_rhs, M0, N0, K0, export_rhs_to_cl_image = false)
+            // Lhs not-transposed, Rhs-not-transposed
+            { MatMulKernelInfo(false, false, 0, 1, 1), false }, // M0 should be > 0
+            { MatMulKernelInfo(false, false, 3, 5, 1), false }, // N0 not in {1, 2, 3, 4, 8, 16}
+            { MatMulKernelInfo(false, false, 3, 6, 1), false }, // N0 not in {1, 2, 3, 4, 8, 16}
+            { MatMulKernelInfo(false, false, 3, 3, 4), false }, // K0 not 1
+            { MatMulKernelInfo(false, false, 9, 1, 2), true },
+            { MatMulKernelInfo(false, false, 3, 16, 3), true },
+            { MatMulKernelInfo(false, false, 7, 3, 4), true },
+
+            // Lhs not-transposed, Rhs transposed
+            // TODO: COMPMID-6195
+
+            // Lhs transposed, Rhs-not-transposed
+            // TODO: COMPMID-6196
+
+            // Lhs transposed, Rhs-transposed
+            // TODO: COMPMID-6197
+        };
+
+        // Set big enough shapes so that block sizes are not truncated. Also, set all dimensions equal
+        // so that it doesn't fail for different NT/T configurations. We aim to test the block sizes here,
+        // not the shapes themselves.
+        const TensorInfo lhs_info = TensorInfo(TensorShape(100U, 100U), 1, DataType::F32);
+        const TensorInfo rhs_info = TensorInfo(TensorShape(100U, 100U), 1, DataType::F32);
+
+        for(auto &pair : supported_block_sizes)
+        {
+            TensorInfo output_info;
+            Status     status = ClMatMulNativeMMULKernel::validate(&lhs_info, &rhs_info, &output_info, pair.first);
+        }
+    }
+    else
+    {
+        ARM_COMPUTE_TEST_INFO("cl_arm_matrix_multiply not supported. TEST skipped");
+        framework::ARM_COMPUTE_PRINT_INFO();
+    }
+}
+
+TEST_CASE(ValidateInputShapes, framework::DatasetMode::ALL)
+{
+    if(arm_matrix_multiply_supported(CLKernelLibrary::get().get_device()))
+    {
+        // Configurations are assumed to be Nt/Nt, but will be transposed inside the test to test other configurations
+        using ShapeConfigurationTuple = std::tuple<TensorShape, TensorShape, bool>;
+        const std::vector<ShapeConfigurationTuple> shape_configurations =
+        {
+            { TensorShape(4U, 1U), TensorShape(3U, 4U), true },
+            { TensorShape(12U, 12U), TensorShape(3U, 12U), true },
+            { TensorShape(8U, 4U), TensorShape(2U, 8U), true },
+            { TensorShape(8U, 4U), TensorShape(2U, 4U), false }, // Mismatch in the K dimension
+            { TensorShape(5U, 0U), TensorShape(2U, 5U), false }, // Invalid dimension
+            { TensorShape(5U, 7U), TensorShape(2U, 5U), false }, // K not a multiple of 4 (MMUL_K0)
+            { TensorShape(8U, 4U, 3U, 4U, 5U, 6U), TensorShape(2U, 8U, 3U, 4U, 5U, 6U), true },
+            { TensorShape(5U, 4U, 3U, 4U, 5U, 1U), TensorShape(2U, 5U, 3U, 4U, 5U, 6U), false }, // No batch broadcasting
+            { TensorShape(5U, 4U, 3U, 4U, 9U, 6U), TensorShape(2U, 5U, 3U, 4U, 5U, 6U), false }, // Mismatch in batch dimension
+        };
+
+        for(auto &tuple : shape_configurations)
+        {
+            const bool expected = std::get<2>(tuple);
+
+            for(bool adj_lhs :
+                {
+                    false // TODO: COMPMID-6195, COMPMID-6196, COMPMID-6197
+                })
+            {
+                for(bool adj_rhs :
+                    {
+                        false // TODO: COMPMID-6195, COMPMID-6196, COMPMID-6197
+                    })
+                {
+                    TensorShape lhs_shape = std::get<0>(tuple);
+                    TensorShape rhs_shape = std::get<1>(tuple);
+
+                    if(adj_lhs)
+                    {
+                        permute(lhs_shape, PermutationVector(1U, 0U));
+                    }
+
+                    if(adj_rhs)
+                    {
+                        permute(rhs_shape, PermutationVector(1U, 0U));
+                    }
+
+                    const TensorInfo lhs_info = TensorInfo(lhs_shape, 1, DataType::F32);
+                    const TensorInfo rhs_info = TensorInfo(rhs_shape, 1, DataType::F32);
+                    TensorInfo       output_info;
+
+                    MatMulKernelInfo matmul_kernel_info{ adj_lhs, adj_rhs, 1, 1, 1, false /* export_rhs_to_cl_image */ };
+
+                    Status status = ClMatMulNativeMMULKernel::validate(&lhs_info, &rhs_info, &output_info, matmul_kernel_info);
+                    ARM_COMPUTE_EXPECT(bool(status) == expected, framework::LogLevel::ERRORS);
+                }
+            }
+        }
+    }
+    else
+    {
+        ARM_COMPUTE_TEST_INFO("cl_arm_matrix_multiply not supported. TEST skipped");
+        framework::ARM_COMPUTE_PRINT_INFO();
+    }
+}
+
+TEST_CASE(ValidateDataTypes, framework::DatasetMode::ALL)
+{
+    if(arm_matrix_multiply_supported(CLKernelLibrary::get().get_device()))
+    {
+        // Configurations are assumed to be Nt/Nt, but will be transposed inside the test to test other configurations
+        using DataTypeConfigurationTuple = std::tuple<DataType, DataType, DataType, bool>;
+        const std::vector<DataTypeConfigurationTuple> data_type_configurations =
+        {
+            { DataType::F32, DataType::F32, DataType::F32, true },
+            { DataType::F16, DataType::F16, DataType::F16, true },
+            { DataType::F16, DataType::F32, DataType::F32, false },                                              // no mixed precision
+            { DataType::F64, DataType::F64, DataType::F64, false },                                              // no double precision
+            { DataType::QASYMM8, DataType::QASYMM8, DataType::QASYMM8, false },                                  // no quantized types
+            { DataType::QASYMM8_SIGNED, DataType::QASYMM8_SIGNED, DataType::QASYMM8_SIGNED, false },             // no quantized types
+            { DataType::QSYMM8_PER_CHANNEL, DataType::QSYMM8_PER_CHANNEL, DataType::QSYMM8_PER_CHANNEL, false }, // no quantized types
+            { DataType::QASYMM16, DataType::QASYMM16, DataType::QASYMM16, false },                               // no quantized types
+            { DataType::QSYMM16, DataType::QSYMM16, DataType::QSYMM16, false },                                  // no quantized types
+            { DataType::QSYMM8, DataType::QSYMM8, DataType::QSYMM8, false },                                     // no quantized types
+            { DataType::S64, DataType::S64, DataType::S64, false },                                              // no integral types
+            { DataType::S32, DataType::S32, DataType::S32, false },                                              // no integral types
+            { DataType::S16, DataType::S16, DataType::S16, false },                                              // no integral types
+            { DataType::S8, DataType::S8, DataType::S8, false },                                                 // no integral types
+            { DataType::U64, DataType::U64, DataType::U64, false },                                              // no integral types
+            { DataType::U32, DataType::U32, DataType::U32, false },                                              // no integral types
+            { DataType::U16, DataType::U16, DataType::U16, false },                                              // no integral types
+            { DataType::U8, DataType::U8, DataType::U8, false },                                                 // no integral types
+        };
+
+        const TensorShape      shape = TensorShape(8U, 8U);
+        const MatMulKernelInfo matmul_kernel_info{ false, false, 1, 1, 1, false };
+        for(auto &tuple : data_type_configurations)
+        {
+            const bool expected = std::get<3>(tuple);
+
+            const TensorInfo lhs_info(shape, 1, std::get<0>(tuple));
+            const TensorInfo rhs_info(shape, 1, std::get<1>(tuple));
+            TensorInfo       output_info(shape, 1, std::get<2>(tuple));
+
+            Status status = ClMatMulNativeMMULKernel::validate(&lhs_info, &rhs_info, &output_info, matmul_kernel_info);
+            ARM_COMPUTE_EXPECT(bool(status) == expected, framework::LogLevel::ERRORS);
+        }
+    }
+    else
+    {
+        ARM_COMPUTE_TEST_INFO("cl_arm_matrix_multiply not supported. TEST skipped");
+        framework::ARM_COMPUTE_PRINT_INFO();
+    }
+}
+
+TEST_SUITE_END() // Validate
+
+TEST_SUITE(Float)
+TEST_SUITE(FP32)
+TEST_SUITE(Buffer)
+FIXTURE_DATA_TEST_CASE(RunTiny, CLMatMulNativeMMULKernelFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(combine(datasets::TinyMatMulMMULDataset(),
+                                                                                                                     framework::dataset::make("TransposeA", { false })),
+                                                                                                                     framework::dataset::make("TransposeB", { false })),
+                                                                                                                     m0_values_precommit),
+                                                                                                                     n0_values_precommit),
+                                                                                                                     k0_value),
+                                                                                                                     framework::dataset::make("ExportRhsToCLImage", { false })),
+                                                                                                             framework::dataset::make("DataType", DataType::F32)))
+{
+    // Validate output
+    if(_device_supports_mmul)
+    {
+        validate(CLAccessor(_target), _reference, tolerance_f32, 0.f, abs_tolerance_f32);
+    }
+}
+FIXTURE_DATA_TEST_CASE(RunSmall, CLMatMulNativeMMULKernelFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(combine(datasets::SmallMatMulMMULDataset(),
+                                                                                                                      framework::dataset::make("TransposeA", { false })),
+                                                                                                                      framework::dataset::make("TransposeB", { false })),
+                                                                                                                      m0_values_precommit),
+                                                                                                                      n0_values_precommit),
+                                                                                                                      k0_value),
+                                                                                                                      framework::dataset::make("ExportRhsToCLImage", { false })),
+                                                                                                              framework::dataset::make("DataType", DataType::F32)))
+{
+    // Validate output
+    if(_device_supports_mmul)
+    {
+        validate(CLAccessor(_target), _reference, tolerance_f32, 0.f, abs_tolerance_f32);
+    }
+}
+FIXTURE_DATA_TEST_CASE(RunLarge, CLMatMulNativeMMULKernelFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(combine(datasets::LargeMatMulMMULDataset(),
+                                                                                                                  framework::dataset::make("TransposeA", { false })),
+                                                                                                                  framework::dataset::make("TransposeB", { false })),
+                                                                                                                  m0_values_nightly_lhs_nt),
+                                                                                                                  n0_values_nightly_rhs_nt),
+                                                                                                                  k0_value),
+                                                                                                                  framework::dataset::make("ExportRhsToCLImage", { false })),
+                                                                                                                  framework::dataset::make("DataType", DataType::F32)))
+{
+    // Validate output
+    if(_device_supports_mmul)
+    {
+        validate(CLAccessor(_target), _reference, tolerance_f32, 0.f, abs_tolerance_f32);
+    }
+}
+// Running High Dimensional test is enough for FP32, because we're stressing the number of dimensions, not data type or M0/N0/K0
+// It's a good idea to test for each Lhs/Rhs T/NT combinations because they're different CL kernels
+FIXTURE_DATA_TEST_CASE(RunHighDimensional, CLMatMulNativeMMULKernelFixture<float>, framework::DatasetMode::ALL,
+                       combine(combine(combine(combine(combine(combine(combine(datasets::HighDimensionalMatMulMMULDataset(),
+                                                                               framework::dataset::make("TransposeA", { false })),
+                                                                       framework::dataset::make("TransposeB", { false })),
+                                                               framework::dataset::make("M0", { 2 })),
+                                                       framework::dataset::make("N0", { 2 })),
+                                               framework::dataset::make("K0", { 1 })),
+                                       framework::dataset::make("ExportRhsToCLImage", { false })),
+                               framework::dataset::make("DataType", DataType::F32)))
+{
+    // Validate output
+    if(_device_supports_mmul)
+    {
+        validate(CLAccessor(_target), _reference, tolerance_f32, 0.f, abs_tolerance_f32);
+    }
+}
+TEST_SUITE_END() // Buffer
+
+TEST_SUITE_END() // FP32
+
+TEST_SUITE(FP16)
+TEST_SUITE(Buffer)
+FIXTURE_DATA_TEST_CASE(RunSmall, CLMatMulNativeMMULKernelFixture<half>, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(combine(datasets::SmallMatMulMMULDataset(),
+                                                                                                                     framework::dataset::make("TransposeA", { false })),
+                                                                                                                     framework::dataset::make("TransposeB", { false })),
+                                                                                                                     m0_values_precommit),
+                                                                                                                     n0_values_precommit),
+                                                                                                                     k0_value),
+                                                                                                                     framework::dataset::make("ExportRhsToCLImage", { false })),
+                                                                                                             framework::dataset::make("DataType", DataType::F16)))
+{
+    // Validate output
+    if(_device_supports_mmul)
+    {
+        validate(CLAccessor(_target), _reference, tolerance_f16, 0.f, abs_tolerance_f16);
+    }
+}
+FIXTURE_DATA_TEST_CASE(RunLarge, CLMatMulNativeMMULKernelFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(combine(datasets::LargeMatMulMMULDataset(),
+                                                                                                                 framework::dataset::make("TransposeA", { false })),
+                                                                                                                 framework::dataset::make("TransposeB", { false })),
+                                                                                                                 m0_values_nightly_lhs_nt),
+                                                                                                                 n0_values_nightly_rhs_nt),
+                                                                                                                 k0_value),
+                                                                                                                 framework::dataset::make("ExportRhsToCLImage", { false })),
+                                                                                                                 framework::dataset::make("DataType", DataType::F16)))
+{
+    // Validate output
+    if(_device_supports_mmul)
+    {
+        validate(CLAccessor(_target), _reference, tolerance_f16, 0.f, abs_tolerance_f16);
+    }
+}
+TEST_SUITE_END() // Buffer
+
+TEST_SUITE_END() // FP16
+TEST_SUITE_END() // Float
+TEST_SUITE_END() // MatMulNativeMMULKernel
+TEST_SUITE_END() // CL
+} // namespace validation
+} // namespace test
+} // namespace arm_compute
diff --git a/tests/validation/fixtures/MatMulKernelFixture.h b/tests/validation/fixtures/MatMulKernelFixture.h
index 7d0b1a4..59bcfe5 100644
--- a/tests/validation/fixtures/MatMulKernelFixture.h
+++ b/tests/validation/fixtures/MatMulKernelFixture.h
@@ -47,7 +47,7 @@
 {
 using namespace arm_compute::opencl::kernels;
 
-template <typename T, typename KernelType>
+template <typename T, typename KernelType, bool use_mmul = false>
 class MatMulKernelValidationFixture : public framework::Fixture
 {
 public:
@@ -94,13 +94,25 @@
             permute(shape_b, PermutationVector(1U, 0U));
         }
 
+        // Skip configurations unsupported by the device.
         _device_supports_export_to_cl_image = image2d_from_buffer_supported(CLKernelLibrary::get().get_device());
-
-        if(!export_rhs_to_cl_image || _device_supports_export_to_cl_image)
+        if(!_device_supports_export_to_cl_image && export_rhs_to_cl_image)
         {
-            _target    = compute_target(shape_a, shape_b, output_shape, pretranspose_a, pretranspose_b, M0, N0, K0, export_rhs_to_cl_image, data_type, lhs_q_info, rhs_q_info, dst_q_info);
-            _reference = compute_reference(shape_a, shape_b, output_shape, pretranspose_a, pretranspose_b, data_type, lhs_q_info, rhs_q_info, dst_q_info);
+            ARM_COMPUTE_TEST_INFO("cl_khr_image2d_from_buffer not supported. TEST skipped");
+            framework::ARM_COMPUTE_PRINT_INFO();
+            return; // Note: Also need to skip the validate in corresponding FIXTURE_DATA_TEST_CASEs.
         }
+
+        _device_supports_mmul = arm_matrix_multiply_supported(CLKernelLibrary::get().get_device());
+        if(!_device_supports_mmul && use_mmul)
+        {
+            ARM_COMPUTE_TEST_INFO("cl_arm_matrix_multiply not supported. TEST skipped");
+            framework::ARM_COMPUTE_PRINT_INFO();
+            return; // Note: Also need to skip the validate in corresponding FIXTURE_DATA_TEST_CASEs.
+        }
+
+        _target    = compute_target(shape_a, shape_b, output_shape, pretranspose_a, pretranspose_b, M0, N0, K0, export_rhs_to_cl_image, data_type, lhs_q_info, rhs_q_info, dst_q_info);
+        _reference = compute_reference(shape_a, shape_b, output_shape, pretranspose_a, pretranspose_b, data_type, lhs_q_info, rhs_q_info, dst_q_info);
     }
 
 protected:
@@ -274,6 +286,7 @@
     CLTensor        _target{};
     SimpleTensor<T> _reference{};
     bool            _device_supports_export_to_cl_image{ true };
+    bool            _device_supports_mmul{ true };
 };
 
 } // namespace validation