Add Gemm MMUL Reshaped Only Rhs Support for FP32/FP16

This patch introduces a GEMM routine that is optimized for Arm(R) Mali(TM)-G715 and Arm(R) Mali(TM)-G615

Resolves: COMPMID-5216
Signed-off-by: Gunes Bayir <gunes.bayir@arm.com>
Change-Id: I2e5d7806f5904347185bb3e250f73d73d6669dba
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7914
Reviewed-by: SiCong Li <sicong.li@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
diff --git a/Android.bp b/Android.bp
index 2590469..16e67ad 100644
--- a/Android.bp
+++ b/Android.bp
@@ -40,6 +40,7 @@
         "src/core/CL/cl_kernels/common/floor.cl",
         "src/core/CL/cl_kernels/common/gather.cl",
         "src/core/CL/cl_kernels/common/gemm.cl",
+        "src/core/CL/cl_kernels/common/gemm_reshaped_only_rhs_mmul.cl",
         "src/core/CL/cl_kernels/common/gemm_utils.cl",
         "src/core/CL/cl_kernels/common/gemmlowp.cl",
         "src/core/CL/cl_kernels/common/gemv.cl",
@@ -617,6 +618,7 @@
         "src/gpu/cl/kernels/ClGemmMatrixMultiplyNativeKernel.cpp",
         "src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedKernel.cpp",
         "src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsKernel.cpp",
+        "src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.cpp",
         "src/gpu/cl/kernels/ClGemmReshapeLhsMatrixKernel.cpp",
         "src/gpu/cl/kernels/ClGemmReshapeRhsMatrixKernel.cpp",
         "src/gpu/cl/kernels/ClHeightConcatenateKernel.cpp",
diff --git a/SConscript b/SConscript
index 358f9dd..6f6b078 100644
--- a/SConscript
+++ b/SConscript
@@ -369,6 +369,7 @@
                        'src/core/CL/cl_kernels/common/floor.cl',
                        'src/core/CL/cl_kernels/common/gather.cl',
                        'src/core/CL/cl_kernels/common/gemm.cl',
+                       'src/core/CL/cl_kernels/common/gemm_reshaped_only_rhs_mmul.cl',
                        'src/core/CL/cl_kernels/common/gemm_utils.cl',
                        'src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/act_eltwise_op_act/gemm_mm_native.cl',
                        'src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/act_eltwise_op_act/gemm_mm_reshaped.cl',
diff --git a/arm_compute/core/CL/CLHelpers.h b/arm_compute/core/CL/CLHelpers.h
index 94ec5d7..edbc705 100644
--- a/arm_compute/core/CL/CLHelpers.h
+++ b/arm_compute/core/CL/CLHelpers.h
@@ -260,5 +260,12 @@
  */
 void set_unroll_with_pragma(CLBuildOptions &built_opts, std::initializer_list<int> values);
 
+/** Helper function to check whether the cl_arm_matrix_multiply extension is supported
+ *
+ * @param[in] device A CL device
+ *
+ * @return True if the extension is supported
+ */
+bool arm_matrix_multiply_supported(const cl::Device &device);
 } // namespace arm_compute
 #endif /* ARM_COMPUTE_CLHELPERS_H */
diff --git a/arm_compute/core/GPUTarget.h b/arm_compute/core/GPUTarget.h
index 6a8577a..7e2cfe1 100644
--- a/arm_compute/core/GPUTarget.h
+++ b/arm_compute/core/GPUTarget.h
@@ -51,9 +51,11 @@
     G31           = 0x242,
     G76           = 0x250,
     G77           = 0x310,
+    G57           = 0x311,
     G78           = 0x320,
     G710          = 0x330,
-    G57           = 0x340,
+    G715          = 0x340,
+    G615          = 0x341
 };
 
 /** Enable bitwise operations on GPUTarget enumerations */
diff --git a/arm_compute/runtime/CL/CLTypes.h b/arm_compute/runtime/CL/CLTypes.h
index bba25c6..d298ecd 100644
--- a/arm_compute/runtime/CL/CLTypes.h
+++ b/arm_compute/runtime/CL/CLTypes.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2020-2021 Arm Limited.
+ * Copyright (c) 2020-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -35,7 +35,9 @@
     /** Reshaped GEMM kernel where both lhs and rhs matrices are reshaped. Configurable reshape and block size */
     RESHAPED,
     /** Reshaped GEMM kernel where only the rhs matrix is reshaped. Configurable reshape and block size */
-    RESHAPED_ONLY_RHS
+    RESHAPED_ONLY_RHS,
+    /** Reshaped GEMM kernel where only the rhs matrix is reshaped. Using MMUL with configurable block size. */
+    RESHAPED_ONLY_RHS_MMUL
 };
 
 /** OpenCL GEMM kernel selection parameters. These information are retrieved to select the GEMM kernel on OpenCL */
diff --git a/filelist.json b/filelist.json
index ab2cc83..513a220 100644
--- a/filelist.json
+++ b/filelist.json
@@ -479,6 +479,7 @@
           "src/gpu/cl/kernels/ClGemmLowpQuantizeDownInt32ScaleByFloatKernel.cpp",
           "src/gpu/cl/kernels/ClGemmLowpQuantizeDownInt32ScaleKernel.cpp",
           "src/gpu/cl/kernels/ClGemmMatrixMultiplyNativeKernel.cpp",
+          "src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.cpp",
           "src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedKernel.cpp",
           "src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsKernel.cpp",
           "src/gpu/cl/kernels/ClGemmReshapeLhsMatrixKernel.cpp",
diff --git a/src/core/CL/CLHelpers.cpp b/src/core/CL/CLHelpers.cpp
index 5172a77..94675d6 100644
--- a/src/core/CL/CLHelpers.cpp
+++ b/src/core/CL/CLHelpers.cpp
@@ -491,4 +491,8 @@
     }
 }
 
+bool arm_matrix_multiply_supported(const cl::Device &device)
+{
+    return device_supports_extension(device, "cl_arm_matrix_multiply");
+}
 } // namespace arm_compute
diff --git a/src/core/CL/cl_kernels/common/gemm_reshaped_only_rhs_mmul.cl b/src/core/CL/cl_kernels/common/gemm_reshaped_only_rhs_mmul.cl
new file mode 100644
index 0000000..8919023
--- /dev/null
+++ b/src/core/CL/cl_kernels/common/gemm_reshaped_only_rhs_mmul.cl
@@ -0,0 +1,528 @@
+/*
+ * Copyright (c) 2022 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "activation_float_helpers.h"
+#include "helpers.h"
+#include "tile_helpers.h"
+
+#if defined(GEMM_MM_RESHAPED_ONLY_RHS_NT_MMUL)
+/** This OpenCL kernel computes the matrix multiplication between 2 matrices using the MMUL extension:
+ *
+ *  The LHS matrix is NOT reshaped
+ *  The RHS is reshaped with @ref ClGemmMatrixMultiplyReshapedOnlyRhsKernel and the block K0xN0 is NOT transposed
+ *
+ * @note The block's dimensions used for reshaping the RHS matrix (N0 and K0) must be passed at compile time using -DN0 and -DK0 (e.g. -DN0=8, -DK0=4).
+ * @note The number of M0 rows to process must be passed at compile time using -DM0 (e.g. -DM0=2)
+ * @note The number of output columns processed by the the cooperative mmul extension must be passed at compile time using -DMMUL_N0 (e.g., -DMMUL_N0=2)
+ * @note The number of output rows processed by the the cooperative mmul extension must be passed at compile time using -DMMUL_M0 (e.g., -DMMUL_M0=2)
+ * @note The number of lhs columns (or rhs rows) processed by the the cooperative mmul extension must be passed at compile time using -DMMUL_K0 (e.g., -DMMUL_K0=2)
+ * @note Only the following configurations of M0, N0 and K0 are currently supported:
+ *  - M0 > 0
+ *  - N0 = 1, 2, 3, 4, 8, 16
+ *  - K0 = 1
+ *
+ * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
+ *       The activation function is performed after the bias addition
+ *
+ * @param[in]  lhs_ptr                           Pointer to the LHS tensor. Supported data types: F16/F32
+ * @param[in]  lhs_stride_y                      Stride of the LHS tensor in Y dimension (in bytes)
+ * @param[in]  lhs_stride_z                      Stride of the LHS tensor in Z dimension (in bytes)
+ * @param[in]  lhs_w                             The size of the width dimension of the LHS tensor
+ * @param[in]  lhs_h                             The size of the height dimension of the LHS tensor
+ * @param[in]  lhs_n                             The size of the depth dimension of the LHS tensor
+ * @param[in]  lhs_offset_first_element_in_bytes The offset of the first element in the LHS tensor
+ * @param[in]  rhs_ptr                           Pointer to the RHS reshaped tensor. Supported data type: same as @p lhs_ptr
+ * @param[in]  rhs_stride_y                      Stride of the RHS tensor in Y dimension (in bytes)
+ * @param[in]  rhs_stride_z                      Stride of the RHS tensor in Z dimension (in bytes)
+ * @param[in]  rhs_w                             The size of the width dimension of the RHS tensor
+ * @param[in]  rhs_h                             The size of the height dimension of the RHS tensor
+ * @param[in]  rhs_n                             The size of the depth dimension of the RHS tensor
+ * @param[in]  rhs_offset_first_element_in_bytes The offset of the first element in the RHS tensor
+ * @param[in]  bia_ptr                           (Optional) Pointer to the bias tensor. Supported data type: same as @p lhs_ptr
+ * @param[in]  bia_stride_y                      (Optional) Stride of the bias tensor in Y dimension (in bytes)
+ * @param[in]  bia_stride_z                      (Optional) Stride of the bias tensor in Z dimension (in bytes)
+ * @param[in]  bia_w                             (Optional) The size of the width dimension of the bias tensor
+ * @param[in]  bia_h                             (Optional) The size of the height dimension of the bias tensor
+ * @param[in]  bia_n                             (Optional) The size of the depth dimension of the bias tensor
+ * @param[in]  bia_offset_first_element_in_bytes (Optional) The offset of the first element in the bias tensor
+ * @param[out] dst_ptr                           Pointer to the destination tensor. Supported data type: same as @p lhs_ptr
+ * @param[in]  dst_stride_y                      Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in]  dst_stride_z                      Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in]  dst_w                             The size of the width dimension of the destination tensor
+ * @param[in]  dst_h                             The size of the height dimension of the destination tensor
+ * @param[in]  dst_n                             The size of the depth dimension of the destination tensor
+ * @param[in]  dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ * @param[in]  M                                 Number of rows in LHS matrix not reshaped
+ * @param[in]  N                                 Number of columns in RHS matrix not reshaped
+ * @param[in]  K                                 Number of columns in LHS matrix and rows in RHS matrix not reshaped
+ */
+__kernel void gemm_mm_reshaped_only_rhs_nt_mmul(
+    TENSOR3D_T(lhs, BUFFER),
+    TENSOR3D_T(rhs, BUFFER),
+#if defined(BETA)
+    TENSOR3D_T(bia, BUFFER),
+#endif // defined(BETA)
+    TENSOR3D_T(dst, BUFFER),
+    const int M,
+    const int N,
+    const int K)
+{
+#define MMUL_BLOCK_SIZE (MMUL_N0 * MMUL_K0)
+
+    uint x0 = get_global_id(0); // (N / N0) * MMUL_K0
+    uint y0 = get_global_id(1); // (M / M0) / MMUL_M0
+    uint z  = get_global_id(2); // Batch
+
+    // Get block ID and thread ID within the block
+    uint block_id  = (x0 / MMUL_BLOCK_SIZE);
+    uint thread_id = (x0 % MMUL_BLOCK_SIZE);
+
+    // Coordinate within a block
+    uint block_x = thread_id % MMUL_N0;
+    uint block_y = (thread_id / MMUL_M0);
+
+    // Starting destination coordinates
+    uint dst_x = min(block_x * N0 + block_id * MMUL_N0 * N0, (uint)(N - 1));
+    uint dst_y = min(block_y * M0 + y0 * M0 * MMUL_M0, (uint)(M - M0));
+
+    // Note: We need to clamp dst_x and dst_y because we always need to execute a complete MMUL block! Only after the matrix multiplication
+    // part can we exit the kernel if it is out-of-bound. Remember, we have a cooperative matrix multiplication. Therefore, we need a full block to get the correct results
+
+    // Starting LHS coordinates
+    uint lhs_x = block_x;
+    uint lhs_y = dst_y;
+
+    // Starting RHS coordinates
+    uint rhs_x = block_y * N0 * MMUL_N0 + block_x * N0;
+    uint rhs_y = block_id;
+
+    // Compute LHS/RHS/DST matrix address
+    lhs_offset_first_element_in_bytes += lhs_x * sizeof(DATA_TYPE) + lhs_y * lhs_stride_y + z * lhs_stride_z;
+    rhs_offset_first_element_in_bytes += rhs_x * sizeof(DATA_TYPE) + rhs_y * rhs_stride_y + z * rhs_stride_z;
+    dst_offset_first_element_in_bytes += dst_x * sizeof(DATA_TYPE) + dst_y * dst_stride_y + z * dst_stride_z;
+
+    // Note: If RHS derives from the weights of convolution 2d layer, RHS will always be 2D and rhs_stride_z will always be equal to 0 for
+    // not sliding the tensor
+
+    // Initialize the accumulators
+    // MMUL extension accumulate the result in F32 for both F32 and F16
+    TILE(float, M0, N0, c_f32);
+
+#if !defined(HALF_PRECISION)
+#define c c_f32
+#endif // !defined(HALF_PRECISION)
+
+    LOOP_UNROLLING(int, i, 0, 1, M0,
+    {
+        c_f32[i].v = 0;
+    })
+
+    for(int k = 0; k <= K - MMUL_K0; k += MMUL_K0)
+    {
+        TILE(DATA_TYPE, M0, 1, a);
+        TILE(DATA_TYPE, 1, N0, b);
+
+        // Load tile from the lhs/rhs tensors
+        T_LOAD(DATA_TYPE, M0, 1, BUFFER, lhs, 0, 0, 1, lhs_stride_y, a);
+        T_LOAD(DATA_TYPE, 1, N0, BUFFER, rhs, 0, 0, 1, 0, b);
+
+        LOOP_UNROLLING(int, m0, 0, 1, M0,
+        {
+            LOOP_UNROLLING(int, n0, 0, 1, N0,
+            {
+                c_f32[m0].s[n0] = arm_matrix_multiply(a[m0].s[0], b[0].s[n0], c_f32[m0].s[n0]);
+            })
+        })
+
+        lhs_offset_first_element_in_bytes += MMUL_K0 * sizeof(DATA_TYPE);
+        rhs_offset_first_element_in_bytes += MMUL_K0 * MMUL_N0 * N0 * sizeof(DATA_TYPE);
+    }
+
+    if(block_x * N0 + block_id * MMUL_N0 * N0 >= N)
+    {
+        return;
+    }
+
+    if(block_y * M0 + y0 * M0 * MMUL_M0 >= M)
+    {
+        return;
+    }
+
+#if defined(HALF_PRECISION)
+    TILE(DATA_TYPE, M0, N0, c);
+
+    // Conversion required for the half precision
+    LOOP_UNROLLING(int, m0, 0, 1, M0,
+    {
+        LOOP_UNROLLING(int, n0, 0, 1, N0,
+        {
+            c[m0].s[n0] = c_f32[m0].s[n0];
+        })
+    })
+#endif // defined(HALF_PRECISION)
+
+    // Multiply by the weight of matrix-matrix product and store the result
+#if defined(ALPHA)
+    T_SCALE_CONSTANT(DATA_TYPE, M0, N0, c, (DATA_TYPE)ALPHA, c);
+#endif // defined(ALPHA)
+
+    // Add beta*bias
+#if defined(BETA)
+#if defined(BROADCAST_BIAS)
+    bia_offset_first_element_in_bytes += dst_x * sizeof(DATA_TYPE);
+
+    TILE(DATA_TYPE, 1, N0, bias0);
+
+    if(dst_x + N0 <= N || N0_LEFTOVER == 0)
+    {
+        bias0[0].v = VLOAD(N0)(0, (DATA_TYPE *)(bia_ptr + bia_offset_first_element_in_bytes));
+    }
+    else
+    {
+        VLOAD_PARTIAL(N0, N0_LEFTOVER)
+        (bias0[0].v, 0, (DATA_TYPE *)(bia_ptr + bia_offset_first_element_in_bytes));
+    }
+
+#ifndef UNIT_BETA
+    T_SCALE_CONSTANT(DATA_TYPE, 1, N0, bias0, (DATA_TYPE)BETA, bias0);
+#endif // UNIT_BIAS
+
+    // c = c + bias[broadcasted]
+    T_ELTWISE_BROADCAST_X(V_ADD, DATA_TYPE, M0, N0, c, bias0, c);
+#else // defined(BROADCAST_BIAS)
+    TILE(DATA_TYPE, M0, N0, bias0);
+
+    bia_offset_first_element_in_bytes += dst_x * sizeof(DATA_TYPE) + dst_y * bia_stride_y + z * bia_stride_z;
+
+    if(dst_x + N0 <= N || N0_LEFTOVER == 0)
+    {
+        LOOP_UNROLLING(int, m0, 0, 1, M0,
+        {
+            if(dst_y + m0 < M || M0_LEFTOVER == 0)
+            {
+                bias0[m0].v = VLOAD(N0)(0, (DATA_TYPE *)(bia_ptr + bia_offset_first_element_in_bytes + m0 * bia_stride_y));
+            }
+        })
+    }
+    else
+    {
+        LOOP_UNROLLING(int, m0, 0, 1, M0,
+        {
+            if(dst_y + m0 < M || M0_LEFTOVER == 0)
+            {
+                VLOAD_PARTIAL(N0, N0_LEFTOVER)
+                (bias0[m0].v, 0, (DATA_TYPE *)(bia_ptr + bia_offset_first_element_in_bytes + m0 * bia_stride_y));
+            }
+        })
+    }
+
+#ifndef UNIT_BETA
+    T_SCALE_CONSTANT(DATA_TYPE, M0, N0, bias0, (DATA_TYPE)BETA, bias0);
+#endif // UNIT_BIAS
+
+    // c = c + bias
+    T_ADD(DATA_TYPE, M0, N0, c, bias0, c);
+    // c = c + bias
+#endif // defined(BROADCAST_BIAS)
+#endif // defined(BETA)
+
+    T_ACTIVATION(DATA_TYPE, M0, N0, ACTIVATION_TYPE, A_VAL, B_VAL, c, c);
+
+    // Store
+    if(dst_x + N0 <= N || N0_LEFTOVER == 0)
+    {
+        LOOP_UNROLLING(int, m0, 0, 1, M0,
+        {
+            if(dst_y + m0 < M || M0_LEFTOVER == 0)
+            {
+                VSTORE(N0)
+                (c[m0].v, 0, (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + m0 * dst_stride_y));
+            }
+        })
+    }
+    else
+    {
+        LOOP_UNROLLING(int, m0, 0, 1, M0,
+        {
+            if(dst_y + m0 < M || M0_LEFTOVER == 0)
+            {
+                VSTORE_PARTIAL(N0, N0_LEFTOVER)
+                (c[m0].v, 0, (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + m0 * dst_stride_y));
+            }
+        })
+    }
+
+#undef RHS_BLOCK_SIZE
+#undef RHS_OFFSET_X
+#undef RHS_STEP_X
+}
+#endif // defined(GEMM_MM_RESHAPED_ONLY_RHS_MMUL)
+
+#if defined(GEMM_MM_RESHAPED_ONLY_RHS_NT_MMUL_TEXTURE)
+/** This OpenCL kernel computes the matrix multiplication between 2 matrices using the MMUL extension and the OpenCL image for RHS:
+ *
+ *  The LHS matrix is NOT reshaped
+ *  The RHS is reshaped with @ref ClGemmMatrixMultiplyReshapedOnlyRhsKernel and the block K0xN0 is NOT transposed
+ *
+ * @note The block's dimensions used for reshaping the RHS matrix (N0 and K0) must be passed at compile time using -DN0 and -DK0 (e.g. -DN0=8, -DK0=4).
+ * @note The number of M0 rows to process must be passed at compile time using -DM0 (e.g. -DM0=2)
+ * @note The number of output columns processed by the the cooperative mmul extension must be passed at compile time using -DMMUL_N0 (e.g., -DMMUL_N0=2)
+ * @note The number of output rows processed by the the cooperative mmul extension must be passed at compile time using -DMMUL_M0 (e.g., -DMMUL_M0=2)
+ * @note The number of lhs columns (or rhs rows) processed by the the cooperative mmul extension must be passed at compile time using -DMMUL_K0 (e.g., -DMMUL_K0=2)
+ * @note Only the following configurations of M0, N0 and K0 are currently supported:
+ *  - M0 > 0
+ *  - N0 = 1, 2, 3, 4, 8, 16
+ *  - K0 = 1
+ *
+ * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
+ *       The activation function is performed after the bias addition
+ *
+ * @param[in]  lhs_ptr                           Pointer to the LHS tensor. Supported data types: F16/F32
+ * @param[in]  lhs_stride_y                      Stride of the LHS tensor in Y dimension (in bytes)
+ * @param[in]  lhs_stride_z                      Stride of the LHS tensor in Z dimension (in bytes)
+ * @param[in]  lhs_w                             The size of the width dimension of the LHS tensor
+ * @param[in]  lhs_h                             The size of the height dimension of the LHS tensor
+ * @param[in]  lhs_n                             The size of the depth dimension of the LHS tensor
+ * @param[in]  lhs_offset_first_element_in_bytes The offset of the first element in the LHS tensor
+ * @param[in]  rhs_ptr                           Pointer to the RHS reshaped tensor. Supported data type: same as @p lhs_ptr
+ * @param[in]  rhs_stride_y                      Stride of the RHS tensor in Y dimension (in bytes)
+ * @param[in]  rhs_stride_z                      Stride of the RHS tensor in Z dimension (in bytes)
+ * @param[in]  rhs_w                             The size of the width dimension of the RHS tensor
+ * @param[in]  rhs_h                             The size of the height dimension of the RHS tensor
+ * @param[in]  rhs_n                             The size of the depth dimension of the RHS tensor
+ * @param[in]  rhs_offset_first_element_in_bytes The offset of the first element in the RHS tensor
+ * @param[in]  bia_ptr                           (Optional) Pointer to the bias tensor. Supported data type: same as @p lhs_ptr
+ * @param[in]  bia_stride_y                      (Optional) Stride of the bias tensor in Y dimension (in bytes)
+ * @param[in]  bia_stride_z                      (Optional) Stride of the bias tensor in Z dimension (in bytes)
+ * @param[in]  bia_w                             (Optional) The size of the width dimension of the bias tensor
+ * @param[in]  bia_h                             (Optional) The size of the height dimension of the bias tensor
+ * @param[in]  bia_n                             (Optional) The size of the depth dimension of the bias tensor
+ * @param[in]  bia_offset_first_element_in_bytes (Optional) The offset of the first element in the bias tensor
+ * @param[out] dst_ptr                           Pointer to the destination tensor. Supported data type: same as @p lhs_ptr
+ * @param[in]  dst_stride_y                      Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in]  dst_stride_z                      Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in]  dst_w                             The size of the width dimension of the destination tensor
+ * @param[in]  dst_h                             The size of the height dimension of the destination tensor
+ * @param[in]  dst_n                             The size of the depth dimension of the destination tensor
+ * @param[in]  dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ * @param[in]  M                                 Number of rows in LHS matrix not reshaped
+ * @param[in]  N                                 Number of columns in RHS matrix not reshaped
+ * @param[in]  K                                 Number of columns in LHS matrix and rows in RHS matrix not reshaped
+ */
+__kernel void gemm_mm_reshaped_only_rhs_nt_mmul_texture(
+    TENSOR3D_T(lhs, BUFFER),
+    TENSOR3D_T(rhs, IMAGE),
+#if defined(BETA)
+    TENSOR3D_T(bia, BUFFER),
+#endif // defined(BETA)
+    TENSOR3D_T(dst, BUFFER),
+    const int M,
+    const int N,
+    const int K)
+{
+#define MMUL_BLOCK_SIZE (MMUL_N0 * MMUL_K0)
+
+    uint x0 = get_global_id(0); // (N / N0) * MMUL_K0
+    uint y0 = get_global_id(1); // (M / M0) / MMUL_M0
+    uint z  = get_global_id(2); // Batch
+
+    // Get block ID and thread ID within the block
+    uint block_id  = (x0 / MMUL_BLOCK_SIZE);
+    uint thread_id = (x0 % MMUL_BLOCK_SIZE);
+
+    // Coordinate within a block
+    uint block_x = thread_id % MMUL_N0;
+    uint block_y = (thread_id / MMUL_M0);
+
+    // Starting destination coordinates
+    uint dst_x = min(block_x * N0 + block_id * MMUL_N0 * N0, (uint)(N - 1));
+    uint dst_y = min(block_y * M0 + y0 * M0 * MMUL_M0, (uint)(M - M0));
+
+    // Note: We need to clamp dst_x and dst_y because we always need to execute a complete MMUL block! Only after the matrix multiplication
+    // part can we exit the kernel if it is out-of-bound. Remember, we have a cooperative matrix multiplication. Therefore, we need a full block to get the correct results
+
+    // Starting LHS coordinates
+    uint lhs_x = block_x;
+    uint lhs_y = dst_y;
+
+    // Starting RHS coordinates
+    uint rhs_x = block_y * N0 * MMUL_N0 + block_x * N0;
+    uint rhs_y = block_id + z * rhs_h;
+
+    // Compute LHS/RHS/DST matrix address
+    lhs_offset_first_element_in_bytes += lhs_x * sizeof(DATA_TYPE) + lhs_y * lhs_stride_y + z * lhs_stride_z;
+    dst_offset_first_element_in_bytes += dst_x * sizeof(DATA_TYPE) + dst_y * dst_stride_y + z * dst_stride_z;
+
+    // Initialize the accumulators
+    // MMUL extension accumulate the result in F32 for both F32 and F16
+    TILE(float, M0, N0, c_f32);
+
+#if !defined(HALF_PRECISION)
+#define c c_f32
+#endif // !defined(HALF_PRECISION)
+
+    LOOP_UNROLLING(int, i, 0, 1, M0,
+    {
+        c_f32[i].v = 0;
+    })
+
+    for(int k = 0; k <= K - MMUL_K0; k += MMUL_K0)
+    {
+        TILE(DATA_TYPE, M0, 1, a);
+        TILE(DATA_TYPE, 1, N0, b);
+
+        // Load tile from the lhs/rhs tensors
+        T_LOAD(DATA_TYPE, M0, 1, BUFFER, lhs, 0, 0, 1, lhs_stride_y, a);
+        T_LOAD(DATA_TYPE, 1, N0, IMAGE, rhs, rhs_x, rhs_y, 1, rhs_stride_y, b);
+
+        LOOP_UNROLLING(int, m0, 0, 1, M0,
+        {
+            LOOP_UNROLLING(int, n0, 0, 1, N0,
+            {
+                c_f32[m0].s[n0] = arm_matrix_multiply(a[m0].s[0], b[0].s[n0], c_f32[m0].s[n0]);
+            })
+        })
+
+        lhs_offset_first_element_in_bytes += MMUL_K0 * sizeof(DATA_TYPE);
+        rhs_x += MMUL_K0 * MMUL_N0 * N0;
+    }
+
+    if(block_x * N0 + block_id * MMUL_N0 * N0 >= N)
+    {
+        return;
+    }
+
+    if(block_y * M0 + y0 * M0 * MMUL_M0 >= M)
+    {
+        return;
+    }
+
+#if defined(HALF_PRECISION)
+    TILE(DATA_TYPE, M0, N0, c);
+
+    // Conversion required for the half precision
+    LOOP_UNROLLING(int, m0, 0, 1, M0,
+    {
+        LOOP_UNROLLING(int, n0, 0, 1, N0,
+        {
+            c[m0].s[n0] = c_f32[m0].s[n0];
+        })
+    })
+#endif // defined(HALF_PRECISION)
+
+    // Multiply by the weight of matrix-matrix product and store the result
+#if defined(ALPHA)
+    T_SCALE_CONSTANT(DATA_TYPE, M0, N0, c, (DATA_TYPE)ALPHA, c);
+#endif // defined(ALPHA)
+
+    // Add beta*bias
+#if defined(BETA)
+#if defined(BROADCAST_BIAS)
+    bia_offset_first_element_in_bytes += dst_x * sizeof(DATA_TYPE);
+
+    TILE(DATA_TYPE, 1, N0, bias0);
+
+    if(dst_x + N0 <= N || N0_LEFTOVER == 0)
+    {
+        bias0[0].v = VLOAD(N0)(0, (DATA_TYPE *)(bia_ptr + bia_offset_first_element_in_bytes));
+    }
+    else
+    {
+        VLOAD_PARTIAL(N0, N0_LEFTOVER)
+        (bias0[0].v, 0, (DATA_TYPE *)(bia_ptr + bia_offset_first_element_in_bytes));
+    }
+
+#ifndef UNIT_BETA
+    T_SCALE_CONSTANT(DATA_TYPE, 1, N0, bias0, (DATA_TYPE)BETA, bias0);
+#endif // UNIT_BIAS
+
+    // c = c + bias[broadcasted]
+    T_ELTWISE_BROADCAST_X(V_ADD, DATA_TYPE, M0, N0, c, bias0, c);
+#else // defined(BROADCAST_BIAS)
+    TILE(DATA_TYPE, M0, N0, bias0);
+
+    bia_offset_first_element_in_bytes += dst_x * sizeof(DATA_TYPE) + dst_y * bia_stride_y + z * bia_stride_z;
+
+    if(dst_x + N0 <= N || N0_LEFTOVER == 0)
+    {
+        LOOP_UNROLLING(int, m0, 0, 1, M0,
+        {
+            if(dst_y + m0 < M || M0_LEFTOVER == 0)
+            {
+                bias0[m0].v = VLOAD(N0)(0, (DATA_TYPE *)(bia_ptr + bia_offset_first_element_in_bytes + m0 * bia_stride_y));
+            }
+        })
+    }
+    else
+    {
+        LOOP_UNROLLING(int, m0, 0, 1, M0,
+        {
+            if(dst_y + m0 < M || M0_LEFTOVER == 0)
+            {
+                VLOAD_PARTIAL(N0, N0_LEFTOVER)
+                (bias0[m0].v, 0, (DATA_TYPE *)(bia_ptr + bia_offset_first_element_in_bytes + m0 * bia_stride_y));
+            }
+        })
+    }
+
+#ifndef UNIT_BETA
+    T_SCALE_CONSTANT(DATA_TYPE, M0, N0, bias0, (DATA_TYPE)BETA, bias0);
+#endif // UNIT_BIAS
+
+    // c = c + bias
+    T_ADD(DATA_TYPE, M0, N0, c, bias0, c);
+    // c = c + bias
+#endif // defined(BROADCAST_BIAS)
+#endif // defined(BETA)
+
+    T_ACTIVATION(DATA_TYPE, M0, N0, ACTIVATION_TYPE, A_VAL, B_VAL, c, c);
+
+    // Store
+    if(dst_x + N0 <= N || N0_LEFTOVER == 0)
+    {
+        LOOP_UNROLLING(int, m0, 0, 1, M0,
+        {
+            if(dst_y + m0 < M || M0_LEFTOVER == 0)
+            {
+                VSTORE(N0)
+                (c[m0].v, 0, (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + m0 * dst_stride_y));
+            }
+        })
+    }
+    else
+    {
+        LOOP_UNROLLING(int, m0, 0, 1, M0,
+        {
+            if(dst_y + m0 < M || M0_LEFTOVER == 0)
+            {
+                VSTORE_PARTIAL(N0, N0_LEFTOVER)
+                (c[m0].v, 0, (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + m0 * dst_stride_y));
+            }
+        })
+    }
+
+#undef RHS_BLOCK_SIZE
+#undef RHS_OFFSET_X
+#undef RHS_STEP_X
+}
+#endif // defined(GEMM_MM_RESHAPED_ONLY_RHS_MMUL_TEXTURE)
\ No newline at end of file
diff --git a/src/core/CL/cl_kernels/tile_helpers.h b/src/core/CL/cl_kernels/tile_helpers.h
index 0ce343e..4b6144a 100644
--- a/src/core/CL/cl_kernels/tile_helpers.h
+++ b/src/core/CL/cl_kernels/tile_helpers.h
@@ -970,8 +970,8 @@
 #define ACT_OP_QUANTIZED(op, DATA_TYPE, VEC_SIZE, ZERO_VALUE, A_VAL, B_VAL, x) op##_op_quantized(DATA_TYPE, VEC_SIZE, ZERO_VALUE, A_VAL, B_VAL, x)
 #define ACTIVATION_QUANTIZED(op, DATA_TYPE, VEC_SIZE, ZERO_VALUE, A_VAL, B_VAL, x) ACT_OP_QUANTIZED(op, DATA_TYPE, VEC_SIZE, ZERO_VALUE, A_VAL, B_VAL, x)
 
-#define T_ADD(A_VAL, B_VAL) ((A_VAL) + (B_VAL))
-#define T_DIV(A_VAL, B_VAL) ((A_VAL) / (B_VAL))
+#define V_ADD(A_VAL, B_VAL) ((A_VAL) + (B_VAL))
+#define V_DIV(A_VAL, B_VAL) ((A_VAL) / (B_VAL))
 
 /** Element-wise activation for quantized types
  *
@@ -995,6 +995,25 @@
         })                                                                                          \
     })
 
+/** Element-wise addition between two tiles
+ *
+ * @note Performs: LHS + RHS = DST
+ *
+ * @param[in]  DATA_TYPE LHS/RHS/DST data type
+ * @param[in]  M0        Number of LHS rows
+ * @param[in]  N0        Number of LHS columns
+ * @param[in]  lhs       LHS tile
+ * @param[in]  rhs       Constant RHS tile
+ * @param[out] dst       DST tile
+ */
+#define T_ADD(DATA_TYPE, M0, N0, lhs, rhs, dst) \
+    ({                                                            \
+        LOOP_UNROLLING(int, _m0, 0, 1, M0,                        \
+        {                                                         \
+            dst[_m0].v = lhs[_m0].v + rhs[_m0].v; \
+        })                                                        \
+    })
+
 /** Element-wise addition with a constant value
  *
  * @note Performs: LHS + constant = DST
@@ -1010,15 +1029,31 @@
     ({                                                            \
         LOOP_UNROLLING(int, _m0, 0, 1, M0,                        \
         {                                                         \
-            LOOP_UNROLLING(int, _n0, 0, 1, N0,                    \
-            {                                                     \
-                dst[_m0].s[_n0] = lhs[_m0].s[_n0] + rhs_constant; \
-            })                                                    \
+            dst[_m0].v = lhs[_m0].v + (DATA_TYPE)rhs_constant;               \
         })                                                        \
     })
 
-#define T_ELTWISE_BROADCAST_ADD_X(DST_DATA_TYPE, M0, N0, lhs, rhs, dst) T_ELTWISE_BROADCAST_X(T_ADD, DST_DATA_TYPE, M0, N0, lhs, rhs, dst)
-#define T_ELTWISE_BROADCAST_DIV_X(DST_DATA_TYPE, M0, N0, lhs, rhs, dst) T_ELTWISE_BROADCAST_X(T_DIV, DST_DATA_TYPE, M0, N0, lhs, rhs, dst)
+#define T_ELTWISE_BROADCAST_ADD_X(DST_DATA_TYPE, M0, N0, lhs, rhs, dst) T_ELTWISE_BROADCAST_X(V_ADD, DST_DATA_TYPE, M0, N0, lhs, rhs, dst)
+#define T_ELTWISE_BROADCAST_DIV_X(DST_DATA_TYPE, M0, N0, lhs, rhs, dst) T_ELTWISE_BROADCAST_X(V_DIV, DST_DATA_TYPE, M0, N0, lhs, rhs, dst)
+
+/** Element-wise scale with a constant value
+ *
+ * @note Performs: LHS * constant = DST
+ *
+ * @param[in]  DATA_TYPE    LHS/RHS/DST data type
+ * @param[in]  M0           Number of LHS rows
+ * @param[in]  N0           Number of LHS columns
+ * @param[in]  lhs          LHS tile
+ * @param[in]  rhs_constant Constant value
+ * @param[out] dst          DST tile
+ */
+#define T_SCALE_CONSTANT(DATA_TYPE, M0, N0, lhs, rhs_constant, dst) \
+    ({                                                            \
+        LOOP_UNROLLING(int, _m0, 0, 1, M0,                        \
+        {                                                         \
+            dst[_m0].v = lhs[_m0].v * (DATA_TYPE)rhs_constant; \
+        })                                                        \
+    })
 
 /** Element-wise operation with RHS broadcasted (RHS has the X dimension only)
  *
@@ -1041,8 +1076,8 @@
         })                                                  \
     })
 
-#define T_ELTWISE_ADD(DST_DATA_TYPE, M0, N0, lhs, rhs, dst) T_ELTWISE(T_ADD, DST_DATA_TYPE, M0, N0, lhs, rhs, dst)
-#define T_ELTWISE_DIV(DST_DATA_TYPE, M0, N0, lhs, rhs, dst) T_ELTWISE(T_DIV, DST_DATA_TYPE, M0, N0, lhs, rhs, dst)
+#define T_ELTWISE_ADD(DST_DATA_TYPE, M0, N0, lhs, rhs, dst) T_ELTWISE(V_ADD, DST_DATA_TYPE, M0, N0, lhs, rhs, dst)
+#define T_ELTWISE_DIV(DST_DATA_TYPE, M0, N0, lhs, rhs, dst) T_ELTWISE(V_DIV, DST_DATA_TYPE, M0, N0, lhs, rhs, dst)
 
 /** Element-wise operation between two tiles (LHS and RHS)
  *
diff --git a/src/core/GPUTarget.cpp b/src/core/GPUTarget.cpp
index 5984c88..e74abf6 100644
--- a/src/core/GPUTarget.cpp
+++ b/src/core/GPUTarget.cpp
@@ -47,6 +47,14 @@
     {
         return arm_compute::GPUTarget::G57;
     }
+    else if(version.find("G715") != std::string::npos)
+    {
+        return arm_compute::GPUTarget::G715;
+    }
+    else if(version.find("G615") != std::string::npos)
+    {
+        return arm_compute::GPUTarget::G615;
+    }
     else
     {
         return arm_compute::GPUTarget::UNKNOWN;
@@ -141,7 +149,9 @@
         { GPUTarget::G77, "g77" },
         { GPUTarget::G78, "g78" },
         { GPUTarget::G710, "g710" },
-        { GPUTarget::G57, "g57" }
+        { GPUTarget::G57, "g57" },
+        { GPUTarget::G715, "g715" },
+        { GPUTarget::G615, "g615" }
     };
 
     return gpu_target_map[target];
diff --git a/src/gpu/cl/ClKernelLibrary.cpp b/src/gpu/cl/ClKernelLibrary.cpp
index 1bf7f2b..52661d6 100644
--- a/src/gpu/cl/ClKernelLibrary.cpp
+++ b/src/gpu/cl/ClKernelLibrary.cpp
@@ -272,6 +272,8 @@
     { "gemm_mv", "common/gemv.cl" },
     { "gemm_mv_quantized", "common/gemv.cl" },
     { "gemm_mm_native", "common/gemm.cl" },
+    { "gemm_mm_reshaped_only_rhs_nt_mmul", "common/gemm_reshaped_only_rhs_mmul.cl" },
+    { "gemm_mm_reshaped_only_rhs_nt_mmul_texture", "common/gemm_reshaped_only_rhs_mmul.cl" },
     { "gemm_mm_native_post_act_eltwise_op_act", "common/experimental/gemm_fused_post_ops/act_eltwise_op_act/gemm_mm_native.cl" },
     { "gemm_mm_reshaped_lhs_nt_rhs_t", "common/gemm.cl" },
     { "gemm_mm_reshaped_lhs_nt_rhs_t_texture", "common/gemm.cl" },
@@ -584,6 +586,10 @@
 #include "./cl_kernels/common/gemm.clembed"
     },
     {
+        "common/gemm_reshaped_only_rhs_mmul.cl",
+#include "./cl_kernels/common/gemm_reshaped_only_rhs_mmul.clembed"
+    },
+    {
         "common/gemm_utils.cl",
 #include "./cl_kernels/common/gemm_utils.clembed"
     },
diff --git a/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.cpp b/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.cpp
new file mode 100644
index 0000000..fe46913
--- /dev/null
+++ b/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.cpp
@@ -0,0 +1,365 @@
+/*
+ * Copyright (c) 2022 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.h"
+
+#include "arm_compute/core/CL/CLHelpers.h"
+#include "arm_compute/core/CL/CLKernelLibrary.h"
+#include "arm_compute/core/CL/ICLTensor.h"
+#include "arm_compute/core/CL/OpenCL.h"
+#include "arm_compute/core/Helpers.h"
+#include "arm_compute/core/TensorInfo.h"
+#include "arm_compute/core/Utils.h"
+#include "arm_compute/core/Validate.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
+#include "src/core/CL/CLUtils.h"
+#include "src/core/helpers/AutoConfiguration.h"
+#include "src/core/helpers/WindowHelpers.h"
+#include "src/core/utils/helpers/float_ops.h"
+#include "src/gpu/cl/kernels/gemm/ClGemmHelpers.h"
+#include "support/Cast.h"
+#include "support/StringSupport.h"
+
+namespace arm_compute
+{
+namespace opencl
+{
+namespace kernels
+{
+namespace
+{
+using ElementsProcessed = Steps;
+
+// Block size dimensions for the MMUL extension
+constexpr int mmul_m0 = 4;
+constexpr int mmul_n0 = 4;
+constexpr int mmul_k0 = 4;
+
+Status validate_arguments(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, float alpha, float beta, const GEMMLHSMatrixInfo &lhs_info,
+                          const GEMMRHSMatrixInfo &rhs_info,
+                          const GEMMKernelInfo    &gemm_info)
+{
+    ARM_COMPUTE_UNUSED(alpha);
+    ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src0, src1, dst);
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(!arm_matrix_multiply_supported(CLKernelLibrary::get().get_device()), "The extension cl_arm_matrix_multiply is not supported on the target platform");
+    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src0, 1, DataType::F16, DataType::F32);
+    ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src0, src1);
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(src0->num_dimensions() > 4, "The number of dimensions for the LHS matrix must be <= 4");
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(src1->num_dimensions() > 3, "The number of dimensions for the RHS matrix must be <= 3");
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(lhs_info.m0 < 1, "Only values greater than 0 are supported for m0");
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(rhs_info.n0 != 1 && rhs_info.n0 != 2 && rhs_info.n0 != 3 && rhs_info.n0 != 4 && rhs_info.n0 != 8 && rhs_info.n0 != 16, "Only 1,2,3,4,8, and 16 are supported for n0");
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG((rhs_info.k0 != 1 || lhs_info.k0 != 1), "Only 1 is supported for k0");
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG((rhs_info.h0 != 4), "Only 4 is supported for h0");
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(rhs_info.interleave != true, "Only true is supported for interleave with mmul extension enabled");
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(rhs_info.transpose != false, "Only false is supported for transpose with mmul extension enabled");
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.fp_mixed_precision, "Mixed precision not supported");
+    ARM_COMPUTE_RETURN_ON_ERROR(gemm::validate_image2d_support_on_rhs(*src1, rhs_info));
+
+    const unsigned int m = gemm_info.m;
+    const unsigned int n = gemm_info.n;
+    const unsigned int k = gemm_info.k;
+
+    ARM_COMPUTE_UNUSED(m);
+    ARM_COMPUTE_UNUSED(n);
+    ARM_COMPUTE_UNUSED(k);
+
+    ARM_COMPUTE_RETURN_ERROR_ON(src0->dimension(0) != k);
+
+    // Validate the reinterpreted-as-3D-case
+    if(gemm_info.depth_output_gemm3d != 0)
+    {
+        ARM_COMPUTE_RETURN_ERROR_ON(src0->dimension(1) * src0->dimension(2) != m);
+    }
+    else
+    {
+        ARM_COMPUTE_RETURN_ERROR_ON(src0->dimension(1) != m);
+    }
+
+    // Validate the gemm-batched case
+    if(src1->num_dimensions() > 2)
+    {
+        if(gemm_info.depth_output_gemm3d != 0)
+        {
+            ARM_COMPUTE_RETURN_ERROR_ON(src0->dimension(3) != src1->dimension(2));
+        }
+        else
+        {
+            ARM_COMPUTE_RETURN_ERROR_ON(src0->dimension(2) != src1->dimension(2));
+        }
+    }
+
+    if(src2 != nullptr && !(helpers::float_ops::is_zero(beta)))
+    {
+        const unsigned int src2_dim0 = src2->dimension(0);
+        const unsigned int src2_dim1 = src2->dimension(1);
+
+        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src2, src1);
+        if(gemm_info.broadcast_bias)
+        {
+            ARM_COMPUTE_RETURN_ERROR_ON_MSG((src2_dim1 != 1 || src2_dim0 != n), "Incorrect dimension of bias matrix which is to be broadcasted");
+        }
+        else
+        {
+            ARM_COMPUTE_RETURN_ERROR_ON_MSG((src2_dim0 != n || src2_dim1 != m), "Incorrect dimension of bias matrix");
+        }
+    }
+
+    TensorShape tensor_shape1{ src1->tensor_shape() };
+    tensor_shape1.set(0, n);
+    tensor_shape1.set(1, k);
+
+    const TensorInfo tensor_info1          = src1->clone()->set_tensor_shape(tensor_shape1);
+    const TensorInfo tensor_info_reshaped1 = src1->clone()->set_tensor_shape(misc::shape_calculator::compute_rhs_reshaped_shape(tensor_info1, rhs_info));
+
+    ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(src1, &tensor_info_reshaped1);
+
+    if(dst->total_size() != 0)
+    {
+        const TensorInfo tensor_info_dst = dst->clone()->set_tensor_shape(misc::shape_calculator::compute_mm_shape(*src0, *src1, gemm_info));
+        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(dst, &tensor_info_dst);
+        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src0, dst);
+    }
+
+    return Status{};
+}
+
+std::pair<Status, Window> validate_and_configure_window(ITensorInfo *src0, ITensorInfo *src1, ITensorInfo *src2, ITensorInfo *dst, const GEMMLHSMatrixInfo &lhs_info,
+                                                        const GEMMRHSMatrixInfo &rhs_info,
+                                                        const GEMMKernelInfo    &gemm_info)
+{
+    ARM_COMPUTE_UNUSED(src0, src1, src2);
+    bool reinterpret_output_as_3d = gemm_info.depth_output_gemm3d != 0;
+
+    // dst tensor auto initialization if not yet initialized
+    auto_init_if_empty(*dst, src0->clone()->set_tensor_shape(misc::shape_calculator::compute_mm_shape(*src0, *src1, gemm_info)));
+
+    TensorInfo tmp_info(*dst);
+
+    if(reinterpret_output_as_3d)
+    {
+        // Since the dst tensor has to be reinterpreted as 3D and the execute window is based on a 2D GEMM,
+        // the window needs to be constructed on the 2D collapsed version of the tensor
+        TensorShape tmp_shape(dst->tensor_shape());
+        tmp_shape.collapse(2U, 1U);
+        tmp_info.set_tensor_shape(tmp_shape);
+    }
+
+    Window win = calculate_max_window(tmp_info, Steps(1, 1));
+
+    // Collapse along the Z direction
+    // This collapse needs to be here in order to tune the Z dimension of LWS
+    const unsigned int dimension_to_collapse = std::min(static_cast<unsigned int>(dst->num_dimensions()), 2u);
+    Window             collapsed             = win.collapse(win, dimension_to_collapse);
+
+    // Reconfigure window size, one arm_matrix_multiply kernel needs 16 threads to finish.
+    Window::Dimension x_dimension = collapsed.x();
+    Window::Dimension y_dimension = collapsed.y();
+
+    // Make M and N multiple of M0 and N0 respectively
+    const unsigned int ceil_to_multiple_n_n0 = ceil_to_multiple(x_dimension.end(), rhs_info.n0);
+    const unsigned int ceil_to_multiple_m_m0 = ceil_to_multiple(y_dimension.end(), lhs_info.m0);
+
+    // Divide M and N by M0 and N0 respectively
+    const unsigned int n_div_n0 = ceil_to_multiple_n_n0 / rhs_info.n0;
+    const unsigned int m_div_m0 = ceil_to_multiple_m_m0 / lhs_info.m0;
+
+    // Make n_div_n0 and m_div_m0 multiple of mmul_n0 and mmul_k0 respectively
+    const unsigned int ceil_to_multiple_n_div_n0_mmul_n0 = ceil_to_multiple(n_div_n0, mmul_n0);
+    const unsigned int ceil_to_multiple_m_div_m0_mmul_k0 = ceil_to_multiple(m_div_m0, mmul_k0);
+
+    // Ensure x_dimension is multiple of MMUL block size (mmul_n0 * mmul_k0)
+    x_dimension.set_end(ceil_to_multiple_n_div_n0_mmul_n0 * mmul_k0);
+    y_dimension.set_end(ceil_to_multiple_m_div_m0_mmul_k0 / mmul_k0);
+
+    collapsed.set(Window::DimX, x_dimension);
+    collapsed.set(Window::DimY, y_dimension);
+
+    return std::make_pair(Status{}, collapsed);
+}
+} // namespace
+
+ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel::ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel()
+{
+    _type = CLKernelType::GEMM;
+}
+
+void ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel::configure(const CLCompileContext &compile_context, ITensorInfo *src0, ITensorInfo *src1, ITensorInfo *src2, ITensorInfo *dst, float alpha,
+                                                              float                    beta,
+                                                              const GEMMLHSMatrixInfo &lhs_info,
+                                                              const GEMMRHSMatrixInfo &rhs_info, const GEMMKernelInfo &gemm_info)
+{
+    ARM_COMPUTE_ERROR_ON_NULLPTR(src0, src1, dst);
+
+    // dst tensor auto initialization if not yet initialized
+    auto_init_if_empty(*dst, src0->clone()->set_tensor_shape(misc::shape_calculator::compute_mm_shape(*src0, *src1, gemm_info)));
+
+    ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(src0, src1, src2, dst, alpha, beta, lhs_info, rhs_info, gemm_info));
+
+    auto padding_info   = get_padding_info({ src0, src1, src2, dst });
+    _add_bias           = src2 != nullptr;
+    _export_to_cl_image = rhs_info.export_to_cl_image;
+
+    // Configure kernel window
+    auto win_config = validate_and_configure_window(src0, src1, src2, dst, lhs_info, rhs_info, gemm_info);
+    ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
+
+    IClKernel::configure_internal(win_config.second);
+
+    _m = gemm_info.m;
+    _n = gemm_info.n;
+    _k = gemm_info.k;
+
+    const unsigned int m0_leftover = _m % lhs_info.m0;
+    const unsigned int n0_leftover = _n % rhs_info.n0;
+
+    // Create build options
+    CLBuildOptions build_opts;
+    build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(src0->data_type()));
+    build_opts.add_option_if(!(helpers::float_ops::is_one(alpha)), "-DALPHA=" + float_to_string_with_full_precision(alpha));
+    build_opts.add_option_if(src2 != nullptr, "-DBETA=" + float_to_string_with_full_precision(beta));
+    build_opts.add_option_if(helpers::float_ops::is_one(beta), "-DUNIT_BETA");
+    build_opts.add_option_if(gemm_info.broadcast_bias, "-DBROADCAST_BIAS");
+    build_opts.add_option_if(src0->data_type() == DataType::F16, "-DHALF_PRECISION");
+    build_opts.add_option("-DM0=" + support::cpp11::to_string(lhs_info.m0));
+    build_opts.add_option("-DN0=" + support::cpp11::to_string(rhs_info.n0));
+    build_opts.add_option("-DK0=" + support::cpp11::to_string(rhs_info.k0));
+    build_opts.add_option("-DM0_LEFTOVER=" + support::cpp11::to_string(m0_leftover));
+    build_opts.add_option("-DN0_LEFTOVER=" + support::cpp11::to_string(n0_leftover));
+    build_opts.add_option("-DMMUL_M0=" + support::cpp11::to_string(mmul_m0));
+    build_opts.add_option("-DMMUL_N0=" + support::cpp11::to_string(mmul_n0));
+    build_opts.add_option("-DMMUL_K0=" + support::cpp11::to_string(mmul_k0));
+    build_opts.add_option("-DACTIVATION_TYPE=" + lower_string(string_from_activation_func(gemm_info.activation_info.activation())));
+    build_opts.add_option("-DA_VAL=" + float_to_string_with_full_precision(gemm_info.activation_info.a()));
+    build_opts.add_option("-DB_VAL=" + float_to_string_with_full_precision(gemm_info.activation_info.b()));
+
+    std::string kernel_name("gemm_mm_reshaped_only_rhs_nt_mmul");
+    kernel_name += rhs_info.export_to_cl_image ? "_texture" : "";
+
+    // A macro guard to compile ONLY the kernel of interest
+    build_opts.add_option("-D" + upper_string(kernel_name));
+
+    // Create kernel
+    _kernel = create_kernel(compile_context, kernel_name, build_opts.options());
+
+    // Set config_id for enabling LWS tuning
+    _config_id = kernel_name;
+    _config_id += "_";
+    _config_id += (_add_bias ? "add_bias_" : "");
+    _config_id += (gemm_info.broadcast_bias ? "broadcast_bias_" : "");
+    _config_id += (gemm_info.activation_info.enabled() ? "fused_activation_" : "");
+    _config_id += lower_string(string_from_data_type(src0->data_type()));
+    _config_id += "_";
+    _config_id += support::cpp11::to_string(_m);
+    _config_id += "_";
+    _config_id += support::cpp11::to_string(_n);
+    _config_id += "_";
+    _config_id += support::cpp11::to_string(_k);
+    _config_id += "_";
+    _config_id += support::cpp11::to_string(lhs_info.m0);
+    _config_id += "_";
+    _config_id += support::cpp11::to_string(rhs_info.n0);
+
+    ARM_COMPUTE_ERROR_ON(has_padding_changed(padding_info));
+}
+
+Status ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel::validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, float alpha, float beta,
+                                                               const GEMMLHSMatrixInfo &lhs_info,
+                                                               const GEMMRHSMatrixInfo &rhs_info, const GEMMKernelInfo &gemm_info)
+{
+    ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(src0, src1, src2, dst, alpha, beta, lhs_info, rhs_info, gemm_info));
+    ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(src0->clone().get(),
+                                                              src1->clone().get(),
+                                                              src2 != nullptr ? src2->clone().get() : nullptr,
+                                                              dst->clone().get(),
+                                                              lhs_info,
+                                                              rhs_info,
+                                                              gemm_info)
+                                .first);
+
+    return Status{};
+}
+
+void ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel::run_op(ITensorPack &tensors, const Window &window, cl::CommandQueue &queue)
+{
+    ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
+    ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICLKernel::window(), window);
+
+    const auto src0 = utils::cast::polymorphic_downcast<const ICLTensor *>(tensors.get_const_tensor(TensorType::ACL_SRC_0));
+    const auto src1 = utils::cast::polymorphic_downcast<const ICLTensor *>(tensors.get_const_tensor(TensorType::ACL_SRC_1));
+    const auto src2 = utils::cast::polymorphic_downcast<const ICLTensor *>(tensors.get_const_tensor(TensorType::ACL_SRC_2));
+    auto       dst  = utils::cast::polymorphic_downcast<ICLTensor *>(tensors.get_tensor(TensorType::ACL_DST));
+
+    ARM_COMPUTE_ERROR_ON_NULLPTR(src0, src1, dst);
+    ARM_COMPUTE_ERROR_ON(_add_bias && src2 == nullptr);
+
+    if(src1->info()->num_dimensions() < 3)
+    {
+        // The stride_z for matrix B must be zero if we do not slice
+        ARM_COMPUTE_ERROR_ON(src1->info()->strides_in_bytes()[3] != 0);
+    }
+
+    cl::Image2D src1_image2d;
+
+    if(_export_to_cl_image)
+    {
+        const TensorShape shape2d(src1->info()->dimension(0) / 4, src1->info()->dimension(1) * src1->info()->dimension(2));
+        const size_t      image_row_pitch = src1->info()->strides_in_bytes()[1];
+
+        src1_image2d = create_image2d_from_buffer(CLKernelLibrary::get().context(), src1->cl_buffer(), shape2d, src1->info()->data_type(), image_row_pitch);
+    }
+
+    Window slice = window.first_slice_window_3D();
+
+    do
+    {
+        unsigned int idx = 0;
+
+        add_3d_tensor_nhw_argument(idx, src0);
+        if(_export_to_cl_image)
+        {
+            _kernel.setArg(idx++, src1_image2d);
+        }
+        add_3d_tensor_nhw_argument(idx, src1);
+
+        // Bias buffer (_add_bias == true)
+        if(_add_bias)
+        {
+            add_3d_tensor_nhw_argument(idx, src2);
+        }
+        // dst buffer
+        add_3d_tensor_nhw_argument(idx, dst);
+
+        // Pass m, n and k at runtime as signed ints, to ensure results of any subtractions they could be operand in, would still be signed.
+        _kernel.setArg<cl_int>(idx++, _m);
+        _kernel.setArg<cl_int>(idx++, _n);
+        _kernel.setArg<cl_int>(idx++, _k);
+
+        // LWS_x should be multiple of 16 at least. (32, 2) has been chosen to have more work-items on a single core
+        // LWS also enforces the order of execution of the workitems which improves cache utilization
+        enqueue(queue, *this, slice, cl::NDRange(32, 2), false);
+    }
+    while(window.slide_window_slice_3D(slice));
+}
+} // namespace kernels
+} // namespace opencl
+} // namespace arm_compute
diff --git a/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.h b/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.h
new file mode 100644
index 0000000..59612fc
--- /dev/null
+++ b/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.h
@@ -0,0 +1,89 @@
+/*
+ * Copyright (c) 2022 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef ARM_COMPUTE_CL_GEMM_MATRIXMULTIPLY_RESHAPED_ONLY_RHS_MMUL_KERNEL_H
+#define ARM_COMPUTE_CL_GEMM_MATRIXMULTIPLY_RESHAPED_ONLY_RHS_MMUL_KERNEL_H
+
+#include "arm_compute/core/KernelDescriptors.h"
+#include "src/core/common/Macros.h"
+#include "src/gpu/cl/ClCompileContext.h"
+#include "src/gpu/cl/IClKernel.h"
+
+namespace arm_compute
+{
+namespace opencl
+{
+namespace kernels
+{
+/** OpenCL kernel to multiply matrices using MMUL when only the input matrix RHS (src1) has been reshaped */
+class ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel : public IClKernel
+{
+public:
+    ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel();
+    ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel);
+    /** Initialize the kernel's input and dst.
+     *
+     * @param[in]  compile_context The compile context to be used.
+     * @param[in]  src0            Input tensor for the LHS matrix. Data type supported: F16/F32.
+     * @param[in]  src1            Input tensor containing the RHS reshaped matrix. Data type supported: same as @p src0.
+     * @param[in]  src2            Input tensor containing the bias matrix. Data type supported: same as @p src0.
+     * @param[out] dst             dst tensor info. Data type supported: same as @p src0
+     * @param[in]  alpha           Weight of the matrix product
+     * @param[in]  beta            Weight of the matrix bias
+     * @param[in]  lhs_info        LHS matrix information used to retrieve the number of rows and accumulations to be processed by each thread. Only the following values are supported:
+     *                             lhs_info.m0 > 0
+     *                             lhs_info.k0: 1
+     * @param[in]  rhs_info        RHS matrix information used to retrieve the number of columns and accumulations to be processed by each thread. Only the following values are supported:
+     *                             rhs_info.n0: 1,2,3,4,8,16
+     *                             rhs_info.k0: same of lhs_info.k0
+     *                             rhs_info.transpose: false
+     * @param[in]  gemm_info       GEMM information used to retrieve the original dimensions of the input matrices
+     */
+    void configure(const ClCompileContext &compile_context, ITensorInfo *src0, ITensorInfo *src1, ITensorInfo *src2, ITensorInfo *dst, float alpha, float beta,
+                   const GEMMLHSMatrixInfo &lhs_info,
+                   const GEMMRHSMatrixInfo &rhs_info,
+                   const GEMMKernelInfo    &gemm_info);
+    /** Static function to check if given info will lead to a valid configuration
+     *
+     * Similar to @ref ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel::configure()
+     *
+     * @return a status
+     */
+    static Status validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, float alpha, float beta, const GEMMLHSMatrixInfo &lhs_info,
+                           const GEMMRHSMatrixInfo &rhs_info,
+                           const GEMMKernelInfo    &gemm_info);
+
+    // Inherited methods overridden:
+    void run_op(ITensorPack &tensors, const Window &window, cl::CommandQueue &queue) override;
+
+private:
+    bool       _add_bias{ false };
+    bool       _export_to_cl_image{ false };
+    signed int _m{ 1 };
+    signed int _n{ 1 };
+    signed int _k{ 1 };
+};
+} // namespace kernels
+} // namespace opencl
+} // namespace arm_compute
+#endif /* ARM_COMPUTE_CL_GEMM_MATRIXMULTIPLY_RESHAPED_ONLY_RHS_MMUL_KERNEL_H */
diff --git a/src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp b/src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp
index 1bf27ba..67da061 100644
--- a/src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp
+++ b/src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2019-2021 Arm Limited.
+ * Copyright (c) 2019-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -110,6 +110,23 @@
 
     return Status{};
 }
+
+bool is_mmul_kernel_preferred(const unsigned int m, const unsigned int n, const unsigned int k, const unsigned int b,
+                              const DataType data_type, unsigned int &best_m0, unsigned int &best_n0)
+{
+    ARM_COMPUTE_UNUSED(n, k, b, data_type);
+
+    const unsigned int mmul_k0 = 4;
+    best_m0                    = 4;
+    best_n0                    = 4;
+
+    const unsigned int ceil_to_multiple_m_m0             = ceil_to_multiple(m, best_m0);
+    const unsigned int m_div_m0                          = ceil_to_multiple_m_m0 / best_m0;
+    const unsigned int ceil_to_multiple_m_div_m0_mmul_k0 = ceil_to_multiple(m_div_m0, mmul_k0);
+    const unsigned int gws_y                             = ceil_to_multiple_m_div_m0_mmul_k0 / mmul_k0;
+
+    return ((k % mmul_k0) == 0) && (gws_y > 4);
+}
 } // namespace gemm
 } // namespace kernels
 } // namespace opencl
diff --git a/src/gpu/cl/kernels/gemm/ClGemmHelpers.h b/src/gpu/cl/kernels/gemm/ClGemmHelpers.h
index 3fce8c9..bf1e8fc 100644
--- a/src/gpu/cl/kernels/gemm/ClGemmHelpers.h
+++ b/src/gpu/cl/kernels/gemm/ClGemmHelpers.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2019-2021 Arm Limited.
+ * Copyright (c) 2019-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -88,6 +88,21 @@
  * @return Status reporting if we can use the image2d OpenCL object on the RHS reshaped matrix
  */
 Status validate_image2d_support_on_rhs(const ITensorInfo &tensor_reshaped_info, const GEMMRHSMatrixInfo &rhs_info);
+
+/** Determine if the MMUL kernels should be preferred
+ *
+ * @param[in]      m         Number of rows of the LHS matrix
+ * @param[in]      n         Number of columns of the RHS matrix
+ * @param[in]      k         Number of columns of the LHS matrix, rows of the RHS matrix
+ * @param[in]      b         Batch size
+ * @param[in]      data_type Data type FP32/FP16
+ * @param[in, out] best_m0   Suggested M0 (number of rows of the output block) for the kernel
+ * @param[in, out] best_n0   Suggested N0 (number of columns of the output block) for the kernel
+ *
+ * @return true if MMUL kernel is preferred over kernels w/o MMUL, false otherwise
+ */
+bool is_mmul_kernel_preferred(const unsigned int m, const unsigned int n, const unsigned int k, const unsigned int b,
+                              const DataType data_type, unsigned int &best_m0, unsigned int &best_n0);
 } // namespace gemm
 } // namespace kernels
 } // namespace opencl
diff --git a/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp
index a82084a..9776298 100644
--- a/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp
+++ b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2020-2021 Arm Limited.
+ * Copyright (c) 2020-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -29,7 +29,9 @@
 #include "arm_compute/core/TensorInfo.h"
 #include "arm_compute/core/TensorShape.h"
 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
+
 #include "src/gpu/cl/kernels/gemm/ClGemmHelpers.h"
+#include "src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.h"
 
 #include <utility>
 
@@ -61,6 +63,10 @@
                                                                     &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G78_f16,
                                                                     &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8);
 
+    CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G715(&ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G715_f32,
+                                                                     &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G715_f16,
+                                                                     &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8);
+
     ConfigurationFunctionExecutorPtr func = nullptr;
 
     switch(_target)
@@ -68,6 +74,10 @@
         case GPUTarget::G78:
             func = configs_G78.get_function(data_type);
             break;
+        case GPUTarget::G715:
+        case GPUTarget::G615:
+            func = configs_G715.get_function(data_type);
+            break;
         case GPUTarget::G77:
         default:
             func = configs_G77.get_function(data_type);
@@ -564,6 +574,36 @@
         }
     }
 }
+
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G715_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+{
+    unsigned int best_m0;
+    unsigned int best_n0;
+
+    if(is_mmul_kernel_preferred(m, n, k, b, DataType::F32, best_m0, best_n0))
+    {
+        return configure_lhs_rhs_info(m, n, best_m0, best_n0, 1, 1, 4, false, true, false, false, true);
+    }
+    else
+    {
+        return configure_G77_f32(m, n, k, b);
+    }
+}
+
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G715_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+{
+    unsigned int best_m0;
+    unsigned int best_n0;
+
+    if(is_mmul_kernel_preferred(m, n, k, b, DataType::F16, best_m0, best_n0))
+    {
+        return configure_lhs_rhs_info(m, n, best_m0, best_n0, 1, 1, 4, false, true, false, false, true);
+    }
+    else
+    {
+        return configure_G78_f16(m, n, k, b);
+    }
+}
 } // namespace gemm
 } // namespace kernels
 } // namespace opencl
diff --git a/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.h b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.h
index c5e80a7..0ec068f 100644
--- a/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.h
+++ b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2020-2021 Arm Limited.
+ * Copyright (c) 2020-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -53,6 +53,8 @@
     std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G78_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
     std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G78_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
     std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G77_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
+    std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G715_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
+    std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G715_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
 };
 } // namespace gemm
 } // namespace kernels
diff --git a/src/gpu/cl/operators/ClGemm.cpp b/src/gpu/cl/operators/ClGemm.cpp
index 88f6b79..4db39a6 100644
--- a/src/gpu/cl/operators/ClGemm.cpp
+++ b/src/gpu/cl/operators/ClGemm.cpp
@@ -191,6 +191,7 @@
       _mm_native_kernel(std::make_unique<ClGemmMatrixMultiplyNativeKernel>()),
       _mm_reshaped_kernel(std::make_unique<ClGemmMatrixMultiplyReshapedKernel>()),
       _mm_reshaped_only_rhs_kernel(std::make_unique<ClGemmMatrixMultiplyReshapedOnlyRhsKernel>()),
+      _mm_reshaped_only_rhs_mmul_kernel(std::make_unique<ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel>()),
       _tmp_a(),
       _tmp_b(),
       _reshape_b_only_on_first_run(false),
@@ -324,6 +325,53 @@
     _aux_mem[RhsReshape] = MemoryInfo(offset_int_vec(RhsReshape), _reshape_b_only_on_first_run ? MemoryLifetime::Persistent : MemoryLifetime::Temporary, _tmp_b.total_size());
 }
 
+void ClGemm::configure_reshaped_only_rhs_mmul(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta,
+                                              const GEMMInfo &gemm_info)
+{
+    DataType           data_type               = a->data_type();
+    bool               reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
+    const unsigned int m                       = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
+    const unsigned int n                       = b->dimension(0);
+    const unsigned int k                       = a->dimension(0);
+    const unsigned int batch_size              = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
+    const int          depth_output_gemm3d     = gemm_info.depth_output_gemm3d();
+    const GPUTarget    gpu_target              = CLScheduler::get().target();
+    bool               broadcast_bias          = gemm_info.broadcast_bias();
+
+    GEMMKernelInfo kernel_info;
+    kernel_info.m                       = m;
+    kernel_info.n                       = n;
+    kernel_info.k                       = k;
+    kernel_info.depth_output_gemm3d     = depth_output_gemm3d;
+    kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
+    kernel_info.broadcast_bias          = broadcast_bias;
+    kernel_info.activation_info         = gemm_info.activation_info();
+    kernel_info.post_ops                = gemm_info.post_ops();
+
+    // Set the target for the kernels
+    _mm_reshaped_only_rhs_mmul_kernel->set_target(gpu_target);
+
+    GEMMLHSMatrixInfo lhs_info{};
+    GEMMRHSMatrixInfo rhs_info{};
+
+    // Pick up the GEMM configuration
+    auto gemm_config = select_default_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size });
+    lhs_info         = gemm_config.lhs_info;
+    rhs_info         = gemm_config.rhs_info;
+    // Force H0 to 4 in order to use the MMUL extension
+    rhs_info.h0 = 4;
+
+    // Reshape Rhs matrix
+    _reshape_rhs_kernel->configure(compile_context, b, &_tmp_b, rhs_info);
+
+    // Configure matrix multiply kernel with no y padding support
+    kernel_info.has_pad_y = false;
+    _mm_reshaped_only_rhs_mmul_kernel->configure(compile_context, a, &_tmp_b, c, output, alpha, beta, lhs_info, rhs_info, kernel_info);
+
+    // Request memory for RHS reshape matrix
+    _aux_mem[RhsReshape] = MemoryInfo(offset_int_vec(RhsReshape), _reshape_b_only_on_first_run ? MemoryLifetime::Persistent : MemoryLifetime::Temporary, _tmp_b.total_size());
+}
+
 Status ClGemm::validate_native(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
 {
     ARM_COMPUTE_UNUSED(alpha);
@@ -458,6 +506,54 @@
     return Status{};
 }
 
+Status ClGemm::validate_reshaped_only_rhs_mmul(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
+{
+    ARM_COMPUTE_UNUSED(alpha);
+    ARM_COMPUTE_UNUSED(output);
+    TensorInfo tmp_b_info{};
+
+    // Get the GPU target
+    const GPUTarget    gpu_target              = CLScheduler::get().target();
+    const DataType     data_type               = a->data_type();
+    bool               reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
+    const unsigned int m                       = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
+    const unsigned int n                       = b->dimension(0);
+    const unsigned int k                       = a->dimension(0);
+    const unsigned int batch_size              = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
+    const int          depth_output_gemm3d     = gemm_info.depth_output_gemm3d();
+    const bool         broadcast_bias          = gemm_info.broadcast_bias();
+
+    GEMMKernelInfo kernel_info;
+    kernel_info.m                       = m;
+    kernel_info.n                       = n;
+    kernel_info.k                       = k;
+    kernel_info.depth_output_gemm3d     = depth_output_gemm3d;
+    kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
+    kernel_info.broadcast_bias          = broadcast_bias;
+    kernel_info.activation_info         = gemm_info.activation_info();
+    kernel_info.post_ops                = gemm_info.post_ops();
+
+    GEMMLHSMatrixInfo lhs_info;
+    GEMMRHSMatrixInfo rhs_info;
+
+    // Pick up the GEMM configuration
+    // NOTE: No need to validate mlgo configurations as they automatically fall back to default heuristics if validation fails
+    const auto gemm_config = select_default_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size });
+    lhs_info               = gemm_config.lhs_info;
+    rhs_info               = gemm_config.rhs_info;
+    // Force H0 to 4 in order to use the MMUL extension
+    rhs_info.h0 = 4;
+
+    auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
+    ARM_COMPUTE_RETURN_ON_ERROR(ClGemmReshapeRhsMatrixKernel::validate(b, &tmp_b_info, rhs_info));
+
+    // Validate matrix multiply
+    kernel_info.has_pad_y = false;
+    ARM_COMPUTE_RETURN_ON_ERROR(ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel::validate(a, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info));
+
+    return Status{};
+}
+
 void ClGemm::configure(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
 {
     ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, output);
@@ -501,6 +597,11 @@
             configure_reshaped_only_rhs(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
             break;
         }
+        case CLGEMMKernelType::RESHAPED_ONLY_RHS_MMUL:
+        {
+            configure_reshaped_only_rhs_mmul(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
+            break;
+        }
         default:
         {
             ARM_COMPUTE_ERROR("GEMMType not supported");
@@ -545,6 +646,11 @@
             ARM_COMPUTE_RETURN_ON_ERROR(validate_reshaped_only_rhs(a, b, c_to_use, output, alpha, beta, gemm_info));
             break;
         }
+        case CLGEMMKernelType::RESHAPED_ONLY_RHS_MMUL:
+        {
+            ARM_COMPUTE_RETURN_ON_ERROR(validate_reshaped_only_rhs_mmul(a, b, c_to_use, output, alpha, beta, gemm_info));
+            break;
+        }
         default:
         {
             ARM_COMPUTE_RETURN_ERROR_MSG("GEMMType not supported");
@@ -627,6 +733,34 @@
             }
             break;
         }
+        case CLGEMMKernelType::RESHAPED_ONLY_RHS_MMUL:
+        {
+            if(!_reshape_b_only_on_first_run)
+            {
+                // Run transpose kernel
+                ITensorPack reshape_rhs_pack{ { ACL_SRC, rhs }, { ACL_DST, rhs_reshaped.get() } };
+                CLScheduler::get().enqueue_op(*_reshape_rhs_kernel, reshape_rhs_pack, false);
+            }
+            // In case of RESHAPED_ONLY_RHS, we need to check the padding requirement
+            // Check if the lhs or dst tensors have padding
+            const unsigned int cross_plane_pad_lhs = lhs->info()->padding().top + lhs->info()->padding().bottom;
+            const unsigned int cross_plane_pad_dst = dst->info()->padding().top + dst->info()->padding().bottom;
+            bool               has_pad_y           = (cross_plane_pad_lhs != 0) || (cross_plane_pad_dst != 0);
+
+            // Copy original tensor pack and overwrite rhs with reshaped counterpart
+            ITensorPack gemm_reshaped_onlyrhs_pack(tensors);
+            gemm_reshaped_onlyrhs_pack.add_const_tensor(ACL_SRC_1, rhs_reshaped.get());
+
+            if(has_pad_y)
+            {
+                ARM_COMPUTE_ERROR_ON(has_pad_y);
+            }
+            else
+            {
+                CLScheduler::get().enqueue_op(*_mm_reshaped_only_rhs_mmul_kernel, gemm_reshaped_onlyrhs_pack, true);
+            }
+            break;
+        }
         default:
         {
             ARM_COMPUTE_ERROR("GEMMType not supported");
diff --git a/src/gpu/cl/operators/ClGemm.h b/src/gpu/cl/operators/ClGemm.h
index 3c0cad3..aac463f 100644
--- a/src/gpu/cl/operators/ClGemm.h
+++ b/src/gpu/cl/operators/ClGemm.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2016-2021 Arm Limited.
+ * Copyright (c) 2016-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -34,6 +34,7 @@
 #include "src/gpu/cl/kernels/ClGemmMatrixMultiplyNativeKernel.h"
 #include "src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedKernel.h"
 #include "src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsKernel.h"
+#include "src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.h"
 #include "src/gpu/cl/kernels/ClGemmReshapeLhsMatrixKernel.h"
 #include "src/gpu/cl/kernels/ClGemmReshapeRhsMatrixKernel.h"
 
@@ -50,6 +51,7 @@
  *  -# @ref kernels::ClGemmMatrixMultiplyNativeKernel (only if NATIVE is selected by the select_gemm_kernel method())
  *  -# @ref kernels::ClGemmMatrixMultiplyReshapedKernel (only if RESHAPED is selected by the select_gemm_kernel method())
  *  -# @ref kernels::ClGemmMatrixMultiplyReshapedOnlyRhsKernel (only if RESHAPED_ONLY_RHS is selected by the select_gemm_kernel method())
+ *  -# @ref kernels::ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel (only if RESHAPED_ONLY_RHS_MMUL is selected by the select_gemm_kernel method())
  */
 class ClGemm : public IClOperator
 {
@@ -102,10 +104,12 @@
     void configure_native(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info);
     void configure_reshaped(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info);
     void configure_reshaped_only_rhs(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info);
+    void configure_reshaped_only_rhs_mmul(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info);
 
     static Status validate_native(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info);
     static Status validate_reshaped(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info);
     static Status validate_reshaped_only_rhs(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info);
+    static Status validate_reshaped_only_rhs_mmul(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info);
 
 private:
     enum AuxTensorIdx
@@ -116,17 +120,18 @@
     };
 
 private:
-    std::unique_ptr<kernels::ClGemmReshapeLhsMatrixKernel>              _reshape_lhs_kernel;
-    std::unique_ptr<kernels::ClGemmReshapeRhsMatrixKernel>              _reshape_rhs_kernel;
-    std::unique_ptr<kernels::ClGemmMatrixMultiplyNativeKernel>          _mm_native_kernel;
-    std::unique_ptr<kernels::ClGemmMatrixMultiplyReshapedKernel>        _mm_reshaped_kernel;
-    std::unique_ptr<kernels::ClGemmMatrixMultiplyReshapedOnlyRhsKernel> _mm_reshaped_only_rhs_kernel;
-    TensorInfo                                                          _tmp_a;
-    TensorInfo                                                          _tmp_b;
-    bool                                                                _reshape_b_only_on_first_run;
-    CLGEMMKernelType                                                    _gemm_kernel_type;
-    bool                                                                _is_prepared;
-    experimental::MemoryRequirements                                    _aux_mem{};
+    std::unique_ptr<kernels::ClGemmReshapeLhsMatrixKernel>                  _reshape_lhs_kernel;
+    std::unique_ptr<kernels::ClGemmReshapeRhsMatrixKernel>                  _reshape_rhs_kernel;
+    std::unique_ptr<kernels::ClGemmMatrixMultiplyNativeKernel>              _mm_native_kernel;
+    std::unique_ptr<kernels::ClGemmMatrixMultiplyReshapedKernel>            _mm_reshaped_kernel;
+    std::unique_ptr<kernels::ClGemmMatrixMultiplyReshapedOnlyRhsKernel>     _mm_reshaped_only_rhs_kernel;
+    std::unique_ptr<kernels::ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel> _mm_reshaped_only_rhs_mmul_kernel;
+    TensorInfo                                                              _tmp_a;
+    TensorInfo                                                              _tmp_b;
+    bool                                                                    _reshape_b_only_on_first_run;
+    CLGEMMKernelType                                                        _gemm_kernel_type;
+    bool                                                                    _is_prepared;
+    experimental::MemoryRequirements                                        _aux_mem{};
 };
 } // namespace opencl
 } // namespace arm_compute
diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp
index cc6689c..427ea51 100644
--- a/src/runtime/CL/functions/CLGEMM.cpp
+++ b/src/runtime/CL/functions/CLGEMM.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2021 Arm Limited.
+ * Copyright (c) 2017-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -30,7 +30,6 @@
 #include "arm_compute/core/TensorInfo.h"
 #include "arm_compute/core/Types.h"
 #include "arm_compute/core/Utils.h"
-#include "arm_compute/runtime/CL/functions/CLGEMM.h"
 #include "src/core/helpers/MemoryHelpers.h"
 #include "src/gpu/cl/operators/ClGemm.h"
 
diff --git a/src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.cpp b/src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.cpp
index 64271a8..4c7daf9 100644
--- a/src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.cpp
+++ b/src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2020-2021 Arm Limited.
+ * Copyright (c) 2020-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -79,10 +79,28 @@
         { DataType::QSYMM8_PER_CHANNEL, &CLGEMMDefaultTypeValhall::default_q8 }
     };
 
+    // Mali-G715 and Mali-G615 configurations
+    static std::map<DataType, FunctionExecutorPtr> gemm_g715_configs =
+    {
+        { DataType::F32, &CLGEMMDefaultTypeValhall::g715_f32 },
+        { DataType::F16, &CLGEMMDefaultTypeValhall::g715_f16 },
+        { DataType::QASYMM8, &CLGEMMDefaultTypeValhall::default_q8 },
+        { DataType::QASYMM8_SIGNED, &CLGEMMDefaultTypeValhall::default_q8 },
+        { DataType::QSYMM8, &CLGEMMDefaultTypeValhall::default_q8 },
+        { DataType::QSYMM8_PER_CHANNEL, &CLGEMMDefaultTypeValhall::default_q8 }
+    };
+
     const DataType data_type = params.data_type;
 
     switch(_target)
     {
+        case GPUTarget::G715:
+        case GPUTarget::G615:
+            if(gemm_g715_configs.find(data_type) != gemm_g715_configs.end())
+            {
+                return (this->*gemm_g715_configs[data_type])(params.m, params.n, params.k, params.b, params.is_rhs_constant);
+            }
+            ARM_COMPUTE_ERROR("Not supported data type");
         case GPUTarget::G78:
             if(gemm_g78_configs.find(data_type) != gemm_g78_configs.end())
             {
@@ -306,5 +324,46 @@
 
     return CLGEMMKernelType::RESHAPED_ONLY_RHS;
 }
+
+CLGEMMKernelType CLGEMMDefaultTypeValhall::g715_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant)
+{
+    if(!is_rhs_constant)
+    {
+        return default_f32(m, n, k, b, is_rhs_constant);
+    }
+
+    unsigned int best_m0;
+    unsigned int best_n0;
+
+    if(opencl::kernels::gemm::is_mmul_kernel_preferred(m, n, k, b, DataType::F32, best_m0, best_n0))
+    {
+        return CLGEMMKernelType::RESHAPED_ONLY_RHS_MMUL;
+    }
+    else
+    {
+        return default_f32(m, n, k, b, is_rhs_constant);
+    }
+}
+
+CLGEMMKernelType CLGEMMDefaultTypeValhall::g715_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant)
+{
+    if(!is_rhs_constant)
+    {
+        return g78_f16(m, n, k, b, is_rhs_constant);
+    }
+
+    unsigned int best_m0;
+    unsigned int best_n0;
+
+    if(opencl::kernels::gemm::is_mmul_kernel_preferred(m, n, k, b, DataType::F16, best_m0, best_n0))
+    {
+        return CLGEMMKernelType::RESHAPED_ONLY_RHS_MMUL;
+    }
+    else
+    {
+        return g78_f16(m, n, k, b, is_rhs_constant);
+    }
+}
+
 } // namespace cl_gemm
 } // namespace arm_compute
diff --git a/src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.h b/src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.h
index c88fbcf..0893f11 100644
--- a/src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.h
+++ b/src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2020-2021 Arm Limited.
+ * Copyright (c) 2020-2022 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -50,6 +50,8 @@
     CLGEMMKernelType g77_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant);
     CLGEMMKernelType g78_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant);
     CLGEMMKernelType g78_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant);
+    CLGEMMKernelType g715_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant);
+    CLGEMMKernelType g715_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant);
 };
 } // namespace cl_gemm
 } // namespace arm_compute
diff --git a/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRhsMMUL.cpp b/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRhsMMUL.cpp
new file mode 100644
index 0000000..7808be8
--- /dev/null
+++ b/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRhsMMUL.cpp
@@ -0,0 +1,231 @@
+/*
+ * Copyright (c) 2022 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.h"
+#include "src/gpu/cl/kernels/ClGemmReshapeRhsMatrixKernel.h"
+#include "tests/CL/CLAccessor.h"
+#include "tests/CL/Helper.h"
+#include "tests/framework/Macros.h"
+#include "tests/framework/datasets/Datasets.h"
+#include "tests/validation/Validation.h"
+#include "tests/validation/fixtures/GEMMFixture.h"
+
+namespace arm_compute
+{
+namespace test
+{
+namespace validation
+{
+using namespace arm_compute::opencl::kernels;
+
+// Create function for ClGemmReshapeRhsMatrixKernel
+using CLGEMMReshapeRHSMatrix = CLSynthetizeOperator<ClGemmReshapeRhsMatrixKernel>;
+
+// Create function for ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel
+using CLGEMMMatrixMultiplyReshapedOnlyRhsMMUL = CLSynthetizeOperator<ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel>;
+
+// Fixture for CLGEMMMatrixMultiplyReshapedOnlyRhsMMUL
+template <typename T>
+using CLGEMMMatrixMultiplyReshapedOnlyRhsMMULFixture = GEMMMatrixMultiplyReshapedOnlyRhsMMULValidationFixture<CLTensor, CLAccessor, T, CLGEMMReshapeRHSMatrix, CLGEMMMatrixMultiplyReshapedOnlyRhsMMUL>;
+
+namespace
+{
+// *INDENT-OFF*
+// clang-format off
+RelativeTolerance<float> rel_tolerance_f32(0.001f);
+constexpr float          abs_tolerance_f32(0.0001f);
+RelativeTolerance<half_float::half> rel_tolerance_f16(half_float::half(0.001f));
+constexpr float          abs_tolerance_f16(0.3f);
+
+/** Alpha values to test - Precommit */
+const auto a_values = framework::dataset::make("alpha", {1.0f, 0.75f} );
+
+/** Beta values to test - Precommit */
+const auto beta_values = framework::dataset::make("beta", {0.0f, -0.75f} );
+
+/** M values to test */
+const auto m_values = framework::dataset::make("M", {49});
+
+/** N values to test */
+const auto n_values = framework::dataset::make("N", {257});
+
+/** K values to test */
+/** The test case requires this to be multiple of 4*/
+const auto k_values = framework::dataset::make("K", {192});
+
+/** Batch size values to test */
+const auto b_values = framework::dataset::make("batch_size", {1, 2});
+
+/** Activation values to test */
+const auto act_values = framework::dataset::make("Activation",
+{
+    ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU),
+});
+
+/** M0 values to test - Precommit */
+const auto m0_values_precommit = framework::dataset::make("M0", { 1, 2, 4 });
+
+/** N0 values to test - Precommit */
+const auto n0_values_precommit = framework::dataset::make("N0", { 4, 8 });
+
+/** K0 values to test - Precommit */
+const auto k0_values_precommit = framework::dataset::make("K0", { 1 });
+
+/** Broadcast bias from vector to matrix */
+const auto broadcast_bias_values = framework::dataset::make("broadcast_bias", { false, true } );
+
+} // namespace
+
+TEST_SUITE(CL)
+TEST_SUITE(GEMMMatrixMultiplyReshapedOnlyRhsMMUL)
+TEST_SUITE(Float)
+TEST_SUITE(FP32)
+FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedOnlyRhsMMULFixture<float>, framework::DatasetMode::ALL,
+                combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
+                                                                   m_values,
+                                                                   n_values),
+                                                                   k_values),
+                                                                   b_values),
+                                                                   m0_values_precommit),
+                                                                   n0_values_precommit),
+                                                                   k0_values_precommit),
+                                                                   framework::dataset::make("ExportToCLImage", false)),
+                                                                   framework::dataset::make("DataType", DataType::F32)),
+                                                                   a_values),
+                                                                   beta_values),
+                                                                   broadcast_bias_values),
+                                                                   act_values))
+{
+    // Validate output
+    if(validate_result)
+    {
+        validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
+    }
+    else
+    {
+        ARM_COMPUTE_TEST_INFO("cl_arm_matrix_multiply not supported. TEST skipped");
+        framework::ARM_COMPUTE_PRINT_INFO();
+    }
+}
+
+TEST_SUITE_END() // FP32
+
+TEST_SUITE(FP16)
+FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedOnlyRhsMMULFixture<half>, framework::DatasetMode::ALL,
+                combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
+                                                                   m_values,
+                                                                   n_values),
+                                                                   k_values),
+                                                                   b_values),
+                                                                   m0_values_precommit),
+                                                                   n0_values_precommit),
+                                                                   k0_values_precommit),
+                                                                   framework::dataset::make("ExportToCLImage", false)),
+                                                                   framework::dataset::make("DataType", DataType::F16)),
+                                                                   a_values),
+                                                                   beta_values),
+                                                                   broadcast_bias_values),
+                                                                   act_values))
+{
+    // Validate output
+    if(validate_result)
+    {
+        validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
+    }
+    else
+    {
+        ARM_COMPUTE_TEST_INFO("cl_arm_matrix_multiply not supported. TEST skipped");
+        framework::ARM_COMPUTE_PRINT_INFO();
+    }
+}
+TEST_SUITE_END() // FP16
+
+TEST_SUITE(ExportToCLImage)
+TEST_SUITE(FP32)
+FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedOnlyRhsMMULFixture<float>, framework::DatasetMode::ALL,
+                combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
+                                                                   m_values,
+                                                                   n_values),
+                                                                   k_values),
+                                                                   b_values),
+                                                                   m0_values_precommit),
+                                                                   n0_values_precommit),
+                                                                   k0_values_precommit),
+                                                                   framework::dataset::make("ExportToCLImage", true)),
+                                                                   framework::dataset::make("DataType", DataType::F32)),
+                                                                   a_values),
+                                                                   beta_values),
+                                                                   broadcast_bias_values),
+                                                                   act_values))
+{
+    // Validate output
+    if(validate_result)
+    {
+        validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
+    }
+    else
+    {
+        ARM_COMPUTE_TEST_INFO("cl_arm_matrix_multiply not supported. TEST skipped");
+        framework::ARM_COMPUTE_PRINT_INFO();
+    }
+}
+
+TEST_SUITE_END() // FP32
+
+TEST_SUITE(FP16)
+FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedOnlyRhsMMULFixture<half>, framework::DatasetMode::ALL,
+                combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
+                                                                   m_values,
+                                                                   n_values),
+                                                                   k_values),
+                                                                   b_values),
+                                                                   m0_values_precommit),
+                                                                   n0_values_precommit),
+                                                                   k0_values_precommit),
+                                                                   framework::dataset::make("ExportToCLImage", true)),
+                                                                   framework::dataset::make("DataType", DataType::F16)),
+                                                                   a_values),
+                                                                   beta_values),
+                                                                   broadcast_bias_values),
+                                                                   act_values))
+{
+    // Validate output
+    if(validate_result)
+    {
+        validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
+    }
+    else
+    {
+        ARM_COMPUTE_TEST_INFO("cl_arm_matrix_multiply not supported. TEST skipped");
+        framework::ARM_COMPUTE_PRINT_INFO();
+    }
+}
+TEST_SUITE_END() // FP16
+TEST_SUITE_END() // ExportToCLImage
+TEST_SUITE_END() // Float
+TEST_SUITE_END() // GEMMMatrixMultiplyReshapedOnlyRhsMMUL
+TEST_SUITE_END() // CL
+} // namespace validation
+} // namespace test
+} // namespace arm_compute
diff --git a/tests/validation/fixtures/GEMMFixture.h b/tests/validation/fixtures/GEMMFixture.h
index 884b13d..55bbbda 100644
--- a/tests/validation/fixtures/GEMMFixture.h
+++ b/tests/validation/fixtures/GEMMFixture.h
@@ -163,18 +163,18 @@
             const int m          = reinterpret_output_as_3d ? output_shape[1] * output_shape[2] : output_shape[1];
             const int batch_size = reinterpret_output_as_3d ? output_shape[3] : output_shape[2];
 
-            // In case of broadcast, we need simply copy the first into the following "M" ones
+            // In case of broadcast, we need to simply copy the first into the following "M" ones
             for(int i = 1; i < m * batch_size; i++)
             {
                 memcpy(c.data() + i * n, c.data(), n * sizeof(T));
             }
         }
-        
+
         /* Note: Assuming the usual batch matmul dimensions A = (B x M x K), B = (B x K x N), if pretranspose_A is set to true, then A is assumed to be (B x K x M),
            therefore, A must be pre-transposed before passing it to the fixture. And, we transpose A again in the fixture to make it (B x M x K)
            in order to be able to call reference implementation that works with (B x M x K) input.
            Similarly, if pretranspose_B is set to true, then B is assumed to be (B x N x K), B must be pre-transposed before passing it to the fixture. */
-           
+
         // Define transposed shapes
         TensorShape a_transposed_shape(a.shape().y(), a.shape().x());
         TensorShape b_transposed_shape(b.shape().y(), b.shape().x());
@@ -315,7 +315,7 @@
 
         if(broadcast_bias)
         {
-            // In case of broadcast, we need simply copy the first into the following "M" ones
+            // In case of broadcast, we need to simply copy the first into the following "M" ones
             for(int i = 1; i < m * batch_size; i++)
             {
                 memcpy(bias.data() + i * n, bias.data(), n * sizeof(T));
@@ -438,7 +438,7 @@
         fill(rhs, 1);
         fill(bias, 2);
 
-        // In case of broadcast, we need simply copy the first into the following "M" ones
+        // In case of broadcast, we need to simply copy the first into the following "M" ones
         for(int i = 1; i < m * batch_size; i++)
         {
             memcpy(bias.data() + i * n, bias.data(), n * sizeof(T));
@@ -593,7 +593,7 @@
 
         if(broadcast_bias)
         {
-            // In case of broadcast, we need simply copy the first into the following "M" ones
+            // In case of broadcast, we need to simply copy the first into the following "M" ones
             for(int i = 1; i < m * batch_size; i++)
             {
                 memcpy(bias.data() + i * n, bias.data(), n * sizeof(T));
@@ -748,7 +748,7 @@
         fill(rhs, 1);
         fill(bias, 2);
 
-        // In case of broadcast, we need simply copy the first into the following "M" ones
+        // In case of broadcast, we need to simply copy the first into the following "M" ones
         for(int i = 1; i < m * batch_size; i++)
         {
             memcpy(bias.data() + i * n, bias.data(), n * sizeof(T));
@@ -923,7 +923,7 @@
 
         if(broadcast_bias)
         {
-            // In case of broadcast, we need simply copy the first into the following "M" ones
+            // In case of broadcast, we need to simply copy the first into the following "M" ones
             for(int i = 1; i < m * batch_size; i++)
             {
                 memcpy(bias.data() + i * n, bias.data(), n * sizeof(T));
@@ -1169,7 +1169,7 @@
 
         if(broadcast_bias)
         {
-            // In case of broadcast, we need simply copy the first into the following "M" ones
+            // In case of broadcast, we need to simply copy the first into the following "M" ones
             for(int i = 1; i < m * batch_size; i++)
             {
                 memcpy(bias.data() + i * n, bias.data(), n * sizeof(T));
@@ -1361,7 +1361,7 @@
         fill(rhs, 1);
         fill(bias, 2);
 
-        // In case of broadcast, we need simply copy the first into the following "M" ones
+        // In case of broadcast, we need to simply copy the first into the following "M" ones
         for(int i = 1; i < m * batch_size; i++)
         {
             memcpy(bias.data() + i * n, bias.data(), n * sizeof(T));
@@ -1533,7 +1533,7 @@
 
         if(broadcast_bias)
         {
-            // In case of broadcast, we need simply copy the first into the following "M" ones
+            // In case of broadcast, we need to simply copy the first into the following "M" ones
             for(int i = 1; i < m * batch_size; i++)
             {
                 memcpy(bias.data() + i * n, bias.data(), n * sizeof(T));
@@ -1759,7 +1759,7 @@
 
         if(broadcast_bias)
         {
-            // In case of broadcast, we need simply copy the first into the following "M" ones
+            // In case of broadcast, we need to simply copy the first into the following "M" ones
             for(int i = 1; i < m * batch_size; i++)
             {
                 memcpy(bias.data() + i * n, bias.data(), n * sizeof(T));
@@ -1941,7 +1941,7 @@
         fill(rhs, 1);
         fill(bias, 2);
 
-        // In case of broadcast, we need simply copy the first into the following "M" ones
+        // In case of broadcast, we need to simply copy the first into the following "M" ones
         for(int i = 1; i < m * batch_size; i++)
         {
             memcpy(bias.data() + i * n, bias.data(), n * sizeof(T));
@@ -2078,7 +2078,7 @@
 
         if(broadcast_bias)
         {
-            // In case of broadcast, we need simply copy the first into the following "M" ones
+            // In case of broadcast, we need to simply copy the first into the following "M" ones
             for(int i = 1; i < m * batch_size; i++)
             {
                 memcpy(bias.data() + i * n, bias.data(), n * sizeof(T));
@@ -2274,7 +2274,7 @@
 
         if(broadcast_bias)
         {
-            // In case of broadcast, we need simply copy the first into the following "M" ones
+            // In case of broadcast, we need to simply copy the first into the following "M" ones
             for(int i = 1; i < m * batch_size; i++)
             {
                 memcpy(bias.data() + i * n, bias.data(), n * sizeof(T));
@@ -2421,7 +2421,7 @@
         fill(rhs, 1);
         fill(bias, 2);
 
-        // In case of broadcast, we need simply copy the first into the following "M" ones
+        // In case of broadcast, we need to simply copy the first into the following "M" ones
         for(int i = 1; i < m * batch_size; i++)
         {
             memcpy(bias.data() + i * n, bias.data(), n * sizeof(T));
@@ -2434,6 +2434,171 @@
     SimpleTensor<T> _reference{};
 };
 
+template <typename TensorType, typename AccessorType, typename T, typename ReshapeRHSOperatorType, typename GEMMOperatorType>
+class GEMMMatrixMultiplyReshapedOnlyRhsMMULValidationFixture : public framework::Fixture
+{
+public:
+    template <typename...>
+    void setup(unsigned int m, unsigned int n, unsigned int k, unsigned int batch_size, unsigned int m0, unsigned int n0, unsigned int k0, bool export_to_cl_image, DataType data_type, float alpha,
+               float beta, bool broadcast_bias,
+               const ActivationLayerInfo &act_info)
+    {
+        GEMMLHSMatrixInfo lhs_info;
+        lhs_info.m0 = m0;
+        lhs_info.k0 = k0;
+
+        GEMMRHSMatrixInfo rhs_info;
+        rhs_info.n0                 = n0;
+        rhs_info.k0                 = k0;
+        rhs_info.interleave         = true;
+        rhs_info.transpose          = false;
+        rhs_info.h0                 = 4;
+        rhs_info.export_to_cl_image = export_to_cl_image;
+
+        // Set the tensor shapes for LHS and RHS matrices
+        const TensorShape lhs_shape(k, m, batch_size);
+        const TensorShape rhs_shape(n, k, batch_size);
+        const TensorShape bias_shape(n,
+                                     broadcast_bias ? 1 : m,
+                                     broadcast_bias ? 1 : batch_size);
+
+        _target    = compute_target(lhs_shape, rhs_shape, bias_shape, lhs_info, rhs_info, data_type, alpha, beta, broadcast_bias, act_info);
+        _reference = compute_reference(lhs_shape, rhs_shape, data_type, alpha, beta, broadcast_bias, act_info);
+    }
+
+protected:
+    template <typename U>
+    void fill(U &&tensor, int i)
+    {
+        static_assert(std::is_floating_point<T>::value || std::is_same<T, half>::value, "Only floating point data types supported.");
+        using DistributionType = typename std::conditional<std::is_same<T, half>::value, arm_compute::utils::uniform_real_distribution_16bit<T>, std::uniform_real_distribution<T>>::type;
+
+        DistributionType distribution{ T(-1.0f), T(1.0f) };
+        library->fill(tensor, distribution, i);
+
+        // Fill border with infinity in order to check the presence of NaN values (i.e. inf * 0)
+        DistributionType distribution_inf{ T(std::numeric_limits<float>::infinity()), T(std::numeric_limits<float>::infinity()) };
+        library->fill_borders_with_garbage(tensor, distribution_inf, i);
+    }
+
+    TensorType compute_target(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const TensorShape &bias_shape, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info,
+                              DataType data_type, float alpha, float beta, bool broadcast_bias, const ActivationLayerInfo &act_info)
+    {
+        // Create tensors
+        TensorType lhs  = create_tensor<TensorType>(lhs_shape, data_type, 1);
+        TensorType rhs  = create_tensor<TensorType>(rhs_shape, data_type, 1);
+        TensorType bias = create_tensor<TensorType>(bias_shape, data_type, 1);
+        TensorType rhs_reshaped;
+        TensorType dst;
+
+        const unsigned int M = lhs_shape[1];
+        const unsigned int N = rhs_shape[0];
+        const unsigned int K = lhs_shape[0];
+        GEMMKernelInfo     kernel_info;
+        kernel_info.m                       = M;
+        kernel_info.n                       = N;
+        kernel_info.k                       = K;
+        kernel_info.depth_output_gemm3d     = 0;
+        kernel_info.reinterpret_input_as_3d = false;
+        kernel_info.broadcast_bias          = broadcast_bias;
+        kernel_info.activation_info         = act_info;
+
+        // Create and configure function
+        ReshapeRHSOperatorType reshape_rhs;
+        GEMMOperatorType       gemm;
+
+        validate_result = bool(reshape_rhs.validate(rhs.info(), rhs_reshaped.info(), rhs_info));
+        if(!validate_result)
+        {
+            return nullptr;
+        }
+
+        reshape_rhs.configure(rhs.info(), rhs_reshaped.info(), rhs_info);
+
+        validate_result = bool(gemm.validate(lhs.info(), rhs_reshaped.info(), bias.info(), dst.info(), alpha, beta, lhs_info, rhs_info, kernel_info));
+        if(!validate_result)
+        {
+            return nullptr;
+        }
+
+        gemm.configure(lhs.info(), rhs_reshaped.info(), bias.info(), dst.info(), alpha, beta, lhs_info, rhs_info, kernel_info);
+
+        ARM_COMPUTE_ASSERT(lhs.info()->is_resizable());
+        ARM_COMPUTE_ASSERT(rhs.info()->is_resizable());
+        ARM_COMPUTE_ASSERT(bias.info()->is_resizable());
+
+        // Allocate tensors
+        lhs.allocator()->allocate();
+        rhs.allocator()->allocate();
+        rhs_reshaped.allocator()->allocate();
+        bias.allocator()->allocate();
+        dst.allocator()->allocate();
+
+        ARM_COMPUTE_ASSERT(!lhs.info()->is_resizable());
+        ARM_COMPUTE_ASSERT(!rhs.info()->is_resizable());
+        ARM_COMPUTE_ASSERT(!rhs_reshaped.info()->is_resizable());
+        ARM_COMPUTE_ASSERT(!bias.info()->is_resizable());
+        ARM_COMPUTE_ASSERT(!dst.info()->is_resizable());
+
+        // Fill tensors
+        fill(AccessorType(lhs), 0);
+        fill(AccessorType(rhs), 1);
+        fill(AccessorType(bias), 2);
+
+        // Compute GEMM
+        ITensorPack reshape_rhs_pack = { { ACL_SRC, &rhs }, { ACL_DST, &rhs_reshaped } };
+        reshape_rhs.run(reshape_rhs_pack);
+        ITensorPack gemm_pack({ { ACL_SRC_0, &lhs },
+            { ACL_SRC_1, &rhs_reshaped },
+            { ACL_SRC_2, &bias },
+            { ACL_DST, &dst }
+        });
+        gemm.run(gemm_pack);
+
+        return dst;
+    }
+
+    SimpleTensor<T> compute_reference(const TensorShape &lhs_shape, const TensorShape &rhs_shape, DataType data_type, float alpha, float beta, bool broadcast_bias,
+                                      const ActivationLayerInfo &act_info)
+    {
+        if(!validate_result)
+            return SimpleTensor<T>();
+
+        TensorShape dst_shape = lhs_shape;
+        dst_shape[0]          = rhs_shape[0];
+        dst_shape[1]          = lhs_shape[1];
+
+        // Create reference
+        SimpleTensor<T> lhs{ lhs_shape, data_type, 1 };
+        SimpleTensor<T> rhs{ rhs_shape, data_type, 1 };
+        SimpleTensor<T> bias{ dst_shape, data_type, 1 };
+
+        const int n          = rhs_shape[0];
+        const int m          = lhs_shape[1];
+        const int batch_size = lhs_shape[2];
+
+        // Fill reference
+        fill(lhs, 0);
+        fill(rhs, 1);
+        fill(bias, 2);
+
+        if(broadcast_bias)
+        {
+            // In case of broadcast, we need to simply copy the first into the following "M" ones
+            for(int i = 1; i < m * batch_size; i++)
+            {
+                memcpy(bias.data() + i * n, bias.data(), n * sizeof(T));
+            }
+        }
+
+        return reference::activation_layer(reference::gemm<T>(lhs, rhs, bias, alpha, beta), act_info);
+    }
+
+    bool            validate_result = true;
+    TensorType      _target{};
+    SimpleTensor<T> _reference{};
+};
+
 } // namespace validation
 } // namespace test
 } // namespace arm_compute
diff --git a/utils/TypePrinter.h b/utils/TypePrinter.h
index dae81e4..31eff57 100644
--- a/utils/TypePrinter.h
+++ b/utils/TypePrinter.h
@@ -2358,6 +2358,12 @@
         case GPUTarget::G710:
             os << "G710";
             break;
+        case GPUTarget::G715:
+            os << "G715";
+            break;
+        case GPUTarget::G615:
+            os << "G615";
+            break;
         default:
             ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
     }