COMPMID-2097: Implement a heuristic to dispatch CLGEMMReshapedOnlyRHS kernel from CLGEMM

Change-Id: I4170a80647b02501aa669e2c0347ddc39888ee76
Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Reviewed-on: https://review.mlplatform.org/c/928
Reviewed-by: Giuseppe Rossini <giuseppe.rossini@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp
index 2ac6f81..60bfbf2 100644
--- a/src/runtime/CL/functions/CLGEMM.cpp
+++ b/src/runtime/CL/functions/CLGEMM.cpp
@@ -23,7 +23,10 @@
  */
 #include "arm_compute/runtime/CL/functions/CLGEMM.h"
 
+#include "arm_compute/core/CL/ICLGEMMKernelConfiguration.h"
 #include "arm_compute/core/CL/ICLTensor.h"
+#include "arm_compute/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfiguration.h"
+#include "arm_compute/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfiguration.h"
 #include "arm_compute/core/Error.h"
 #include "arm_compute/core/GPUTarget.h"
 #include "arm_compute/core/Helpers.h"
@@ -33,7 +36,6 @@
 #include "arm_compute/core/Validate.h"
 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
 #include "arm_compute/runtime/CL/CLScheduler.h"
-#include "arm_compute/runtime/CL/gemm_reshaped/CLGEMMReshapedConfiguration.h"
 #include "arm_compute/runtime/ITensorAllocator.h"
 
 namespace arm_compute
@@ -41,46 +43,6 @@
 using namespace arm_compute::misc::shape_calculator;
 using namespace arm_compute::cl_gemm;
 
-namespace
-{
-inline bool is_interleaved_transposed(unsigned int m, unsigned int n, unsigned int k, DataType data_type, bool reshape_b_only_on_first_run, GPUTarget gpu_target)
-{
-    bool flag = true;
-
-    if(gpu_target_is_in(gpu_target, GPUTarget::G52, GPUTarget::G52LIT, GPUTarget::G71, GPUTarget::G72, GPUTarget::G76))
-    {
-        if((m > 1) && n < 16)
-        {
-            flag = true;
-        }
-        else
-        {
-            // COMPMID-852
-            if(k > 256 && m > 4 && is_data_type_float(data_type) && reshape_b_only_on_first_run)
-            {
-                constexpr float alpha = 3.2f;
-                constexpr float fact0 = 1.51f;
-                constexpr float fact1 = 1.66f;
-                constexpr float ops   = 12.0f;
-                const float     scale = k > 1024 ? 1.07f : 1.0f;
-                flag                  = alpha + ((n * fact0) / ops) < ((fact1 * n * scale) / ops);
-            }
-            else
-            {
-                flag = false;
-            }
-        }
-    }
-    else
-    {
-        // We reshape the matrices only if we do not have the vector-by-matrix case and we reshape the matrix B only once
-        flag = m != 1 && reshape_b_only_on_first_run;
-    }
-
-    return flag;
-}
-} // namespace
-
 CLGEMM::CLGEMM(std::shared_ptr<IMemoryManager> memory_manager)
     : _memory_group(std::move(memory_manager)),
       _mm_kernel(),
@@ -88,57 +50,102 @@
       _reshape_lhs_kernel(),
       _reshape_rhs_kernel(),
       _mm_reshaped_kernel(),
+      _mm_reshaped_only_rhs_kernel(),
       _tmp_a(),
       _tmp_b(),
       _original_b(nullptr),
-      _is_interleaved_transposed(false),
       _run_addition(false),
       _reshape_b_only_on_first_run(false),
       _is_prepared(false),
-      _is_new_gemm_reshaped(false)
+      _gemm_type(GEMMType::NATIVE)
 {
 }
 
-void CLGEMM::configure(const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info)
+CLGEMM::GEMMType CLGEMM::select_gemm_type(unsigned int m, unsigned int n, unsigned int k, DataType data_type, bool reshape_b_only_on_first_run, GPUTarget gpu_target)
 {
-    ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, output);
+    GEMMType gemm_type = GEMMType::RESHAPED_V1;
 
-    // Perform validation step
-    ARM_COMPUTE_ERROR_THROW_ON(validate(a->info(), b->info(), c != nullptr ? c->info() : nullptr, output->info(), alpha, beta, gemm_info));
+    if(gpu_target_is_in(gpu_target, GPUTarget::G52, GPUTarget::G52LIT, GPUTarget::G71, GPUTarget::G72, GPUTarget::G76))
+    {
+        if((m > 1) && (n < 16))
+        {
+            gemm_type = GEMMType::RESHAPED_V1;
+        }
+        else if((m == 1) && (data_type == DataType::F32))
+        {
+            gemm_type = GEMMType::RESHAPED_ONLY_RHS;
+        }
+        else
+        {
+            // COMPMID-852
+            if((k > 256) && (m > 4) && is_data_type_float(data_type) && reshape_b_only_on_first_run)
+            {
+                constexpr float alpha = 3.2f;
+                constexpr float fact0 = 1.51f;
+                constexpr float fact1 = 1.66f;
+                constexpr float ops   = 12.0f;
+                const float     scale = k > 1024 ? 1.07f : 1.0f;
+                gemm_type             = (alpha + ((n * fact0) / ops) < ((fact1 * n * scale) / ops)) ? GEMMType::RESHAPED_V1 : GEMMType::NATIVE;
+            }
+            else
+            {
+                gemm_type = GEMMType::NATIVE;
+            }
+        }
 
-    // Check if we need to reshape the matrix B only on the first run
-    _reshape_b_only_on_first_run = gemm_info.reshape_b_only_on_first_run();
-    _is_prepared                 = gemm_info.retain_internal_weights();
-    _original_b                  = b;
+        const auto workload = static_cast<float>((m * n) / 20.0f);
 
-    const ICLTensor *matrix_a = a;
-    const ICLTensor *matrix_b = b;
+        gemm_type = ((workload > 1600.0f) && (gemm_type == GEMMType::RESHAPED_V1) && (data_type == DataType::F32)) ? GEMMType::RESHAPED_V2 : gemm_type;
+    }
+    else
+    {
+        // We reshape the matrices only if we do not have the vector-by-matrix case and we reshape the matrix B only once
+        gemm_type = ((m != 1) && reshape_b_only_on_first_run) ? GEMMType::RESHAPED_V1 : GEMMType::NATIVE;
+    }
 
-    // Get the GPU target
-    const GPUTarget gpu_target = CLScheduler::get().target();
+    return gemm_type;
+}
+
+void CLGEMM::configure_native(const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info)
+{
+    const unsigned int m          = gemm_info.reinterpret_input_as_3d() ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1);
+    const unsigned int n          = b->info()->dimension(0);
+    const unsigned int k          = a->info()->dimension(0);
+    const GPUTarget    gpu_target = CLScheduler::get().target();
 
     // Set the target for the kernels
-    _reshape_lhs_kernel.set_target(gpu_target);
     _mm_kernel.set_target(gpu_target);
 
-    // Arguments used by GEMMReshapeInfo
-    // If we pass the matrix A and matrix B reshaped to CLGEMMMatrixMultiplyKernel, we need to pass m, n, k, mult_transpose1xW_width and mult_interleave4x4_height to CLGEMMReshapeInfo
-    // in order to know how the matrices have been reshaped
-    DataType           data_type                 = a->info()->data_type();
+    GEMMReshapeInfo reshape_info(m, n, k, 1, 1, gemm_info.depth_output_gemm3d(), gemm_info.reinterpret_input_as_3d());
+
+    // Configure and tune matrix multiply kernel
+    _mm_kernel.configure(a, b, c, output, alpha, beta, false, reshape_info, gemm_info.fp_mixed_precision());
+
+    // Tune kernel statically
+    CLScheduler::get().tune_kernel_static(_mm_kernel);
+}
+
+void CLGEMM::configure_reshaped_v1(const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info)
+{
     bool               reinterpret_input_as_3d   = gemm_info.reinterpret_input_as_3d();
     const unsigned int m                         = reinterpret_input_as_3d ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1);
     const unsigned int n                         = b->info()->dimension(0);
     const unsigned int k                         = a->info()->dimension(0);
-    const unsigned int batch_size                = reinterpret_input_as_3d ? a->info()->dimension(3) : a->info()->dimension(2);
     const int          depth_output_gemm3d       = gemm_info.depth_output_gemm3d();
+    const GPUTarget    gpu_target                = CLScheduler::get().target();
     int                mult_transpose1xW_width   = 1;
     int                mult_interleave4x4_height = 1;
 
+    // Set the target for the kernels
+    _reshape_lhs_kernel.set_target(gpu_target);
+    _mm_kernel.set_target(gpu_target);
+
     if(get_arch_from_target(gpu_target) == GPUTarget::BIFROST)
     {
         mult_transpose1xW_width   = 4;
         mult_interleave4x4_height = 2;
     }
+
     GEMMRHSMatrixInfo rhs_info;
     rhs_info.n0         = 16 / b->info()->element_size();
     rhs_info.k0         = 1;
@@ -153,112 +160,183 @@
     lhs_info.interleave = true;
     lhs_info.transpose  = true;
 
-    // Check if we need to reshape the matrix A and matrix B
-    _is_interleaved_transposed = is_interleaved_transposed(m, n, k, a->info()->data_type(), _reshape_b_only_on_first_run, gpu_target);
+    GEMMReshapeInfo reshape_info(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d, false);
 
-    // Check if we can run the new reshaped GEMM
-    const auto workload   = static_cast<float>((m * n) / 20.0f);
-    _is_new_gemm_reshaped = (workload > 1600.0f) && (get_arch_from_target(gpu_target) == GPUTarget::BIFROST) && _is_interleaved_transposed && (data_type == DataType::F32);
-
-    const bool add_matrix_c  = (beta != 0.f && c != nullptr);
-    const bool is_beta_one   = std::abs(1.0f - beta) < 0.00001f;
-    const bool use_fused_add = is_beta_one && (c != nullptr && c->info()->num_dimensions() == 1) && !_is_new_gemm_reshaped;
-
-    // if _is_interleaved_transposed is set, force reinterpret_input_as_3d to be false as the output of CLGEMMInterleaveKernel will be 2D
-    if(_is_interleaved_transposed)
+    _memory_group.manage(&_tmp_a);
+    if(!_reshape_b_only_on_first_run)
     {
-        reinterpret_input_as_3d = false;
-
-        matrix_a = &_tmp_a;
-        matrix_b = &_tmp_b;
-
-        // Manage intermediate buffers
-        _memory_group.manage(&_tmp_a);
-        if(!_reshape_b_only_on_first_run)
-        {
-            _memory_group.manage(&_tmp_b);
-        }
-        // _tmp_a and _tmp_b will be auto configured in _interleave_kernel and in _transpose_kernel
-
-        if(_is_new_gemm_reshaped)
-        {
-            GEMMLHSMatrixInfo lhs_info;
-
-            // Pick up the GEMM configuration
-            std::tie(lhs_info, rhs_info) = CLGEMMReshapedConfigurationFactory::create()->configure(m, n, k, batch_size, data_type);
-
-            _reshape_lhs_kernel.configure(a, &_tmp_a, lhs_info, gemm_info.reinterpret_input_as_3d());
-            _reshape_rhs_kernel.configure(b, &_tmp_b, rhs_info);
-
-            // Configure and tune matrix multiply kernel
-            _mm_reshaped_kernel.configure(matrix_a, matrix_b, output, alpha, lhs_info, rhs_info, GEMMReshapeInfo(m, n, k, 1, 1,
-                                                                                                                 depth_output_gemm3d, reinterpret_input_as_3d));
-        }
-        else
-        {
-            // Configure interleave kernel
-            _reshape_lhs_kernel.configure(a, &_tmp_a, lhs_info, gemm_info.reinterpret_input_as_3d());
-            // Configure transpose kernel
-            _reshape_rhs_kernel.configure(b, &_tmp_b, rhs_info);
-        }
+        _memory_group.manage(&_tmp_b);
     }
 
-    if(!_is_new_gemm_reshaped)
-    {
-        // Configure and tune matrix multiply kernel
-        _mm_kernel.configure(matrix_a, matrix_b, (add_matrix_c && !use_fused_add) ? nullptr : c, output, alpha, beta, _is_interleaved_transposed,
-                             GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d, reinterpret_input_as_3d),
-                             gemm_info.fp_mixed_precision());
-        CLScheduler::get().tune_kernel_static(_mm_kernel);
-    }
+    // Configure interleave kernel
+    _reshape_lhs_kernel.configure(a, &_tmp_a, lhs_info, reinterpret_input_as_3d);
 
-    if(_is_interleaved_transposed)
-    {
-        // Allocate intermediate tensors
-        _tmp_a.allocator()->allocate();
-        if(!_reshape_b_only_on_first_run)
-        {
-            _tmp_b.allocator()->allocate();
-        }
-    }
+    // Configure transpose kernel
+    _reshape_rhs_kernel.configure(b, &_tmp_b, rhs_info);
 
-    // Configure matrix addition kernel
-    if(add_matrix_c && !use_fused_add)
+    // Configure and tune matrix multiply kernel
+    _mm_kernel.configure(&_tmp_a, &_tmp_b, c, output, alpha, beta, true, reshape_info, gemm_info.fp_mixed_precision());
+
+    CLScheduler::get().tune_kernel_static(_mm_kernel);
+
+    // Allocate intermediate tensors
+    _tmp_a.allocator()->allocate();
+    if(!_reshape_b_only_on_first_run)
     {
-        _ma_kernel.configure(c, output, beta);
-        _run_addition = true;
+        _tmp_b.allocator()->allocate();
     }
 }
 
-Status CLGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
+void CLGEMM::configure_reshaped_v2(const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info)
+{
+    ARM_COMPUTE_ERROR_ON(c != nullptr);
+    ARM_COMPUTE_UNUSED(beta);
+    ARM_COMPUTE_UNUSED(c);
+
+    DataType           data_type               = a->info()->data_type();
+    bool               reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
+    const unsigned int m                       = reinterpret_input_as_3d ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1);
+    const unsigned int n                       = b->info()->dimension(0);
+    const unsigned int k                       = a->info()->dimension(0);
+    const unsigned int batch_size              = reinterpret_input_as_3d ? a->info()->dimension(3) : a->info()->dimension(2);
+    const int          depth_output_gemm3d     = gemm_info.depth_output_gemm3d();
+    const GPUTarget    gpu_target              = CLScheduler::get().target();
+
+    // Set the target for the kernels
+    _reshape_lhs_kernel.set_target(gpu_target);
+    _mm_kernel.set_target(gpu_target);
+
+    GEMMReshapeInfo reshape_info(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d);
+
+    // Manage intermediate buffers
+    _memory_group.manage(&_tmp_a);
+    if(!_reshape_b_only_on_first_run)
+    {
+        _memory_group.manage(&_tmp_b);
+    }
+    // _tmp_a and _tmp_b will be auto configured in _interleave_kernel and in _transpose_kernel
+
+    GEMMLHSMatrixInfo lhs_info{};
+    GEMMRHSMatrixInfo rhs_info{};
+
+    // Pick up the GEMM configuration
+    std::unique_ptr<ICLGEMMKernelConfiguration> gemm_config = CLGEMMReshapedKernelConfigurationFactory::create(gpu_target);
+    ARM_COMPUTE_ERROR_ON_NULLPTR(gemm_config.get());
+
+    // Configure lhs_info and rhs_info
+    std::tie(lhs_info, rhs_info) = gemm_config->configure(m, n, k, batch_size, data_type);
+
+    _reshape_lhs_kernel.configure(a, &_tmp_a, lhs_info, gemm_info.reinterpret_input_as_3d());
+    _reshape_rhs_kernel.configure(b, &_tmp_b, rhs_info);
+
+    // Configure and tune matrix multiply kernel
+    _mm_reshaped_kernel.configure(&_tmp_a, &_tmp_b, output, alpha, lhs_info, rhs_info, reshape_info);
+
+    // Allocate intermediate tensors
+    _tmp_a.allocator()->allocate();
+    if(!_reshape_b_only_on_first_run)
+    {
+        _tmp_b.allocator()->allocate();
+    }
+}
+
+void CLGEMM::configure_reshaped_only_rhs(const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info)
+{
+    ARM_COMPUTE_ERROR_ON(c != nullptr);
+    ARM_COMPUTE_UNUSED(beta);
+    ARM_COMPUTE_UNUSED(c);
+
+    DataType           data_type               = a->info()->data_type();
+    bool               reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
+    const unsigned int m                       = reinterpret_input_as_3d ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1);
+    const unsigned int n                       = b->info()->dimension(0);
+    const unsigned int k                       = a->info()->dimension(0);
+    const unsigned int batch_size              = reinterpret_input_as_3d ? a->info()->dimension(3) : a->info()->dimension(2);
+    const int          depth_output_gemm3d     = gemm_info.depth_output_gemm3d();
+    const GPUTarget    gpu_target              = CLScheduler::get().target();
+
+    // Set the target for the kernels
+    _mm_kernel.set_target(gpu_target);
+
+    GEMMReshapeInfo reshape_info(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d);
+
+    // Manage intermediate buffers
+    if(!_reshape_b_only_on_first_run)
+    {
+        _memory_group.manage(&_tmp_b);
+    }
+
+    GEMMLHSMatrixInfo lhs_info{};
+    GEMMRHSMatrixInfo rhs_info{};
+
+    // Pick up the GEMM configuration
+    std::unique_ptr<ICLGEMMKernelConfiguration> gemm_config = CLGEMMReshapedOnlyRHSKernelConfigurationFactory::create(gpu_target);
+    ARM_COMPUTE_ERROR_ON_NULLPTR(gemm_config.get());
+
+    // Configure lhs_info and rhs_info
+    std::tie(lhs_info, rhs_info) = gemm_config->configure(m, n, k, batch_size, data_type);
+
+    _reshape_rhs_kernel.configure(b, &_tmp_b, rhs_info);
+
+    // Configure and tune matrix multiply kernel
+    _mm_reshaped_only_rhs_kernel.configure(a, &_tmp_b, output, alpha, lhs_info, rhs_info, reshape_info);
+
+    if(!_reshape_b_only_on_first_run)
+    {
+        _tmp_b.allocator()->allocate();
+    }
+}
+
+Status CLGEMM::validate_native(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
 {
     ARM_COMPUTE_UNUSED(alpha);
     ARM_COMPUTE_UNUSED(output);
 
-    // Check if we need to reshape the matrix B only on the first run
-    const bool reshape_b_only_on_first_run = gemm_info.reshape_b_only_on_first_run();
+    // Get the GPU target
+    const GPUTarget    gpu_target              = CLScheduler::get().target();
+    bool               reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
+    const unsigned int m                       = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
+    const unsigned int n                       = b->dimension(0);
+    const unsigned int k                       = a->dimension(0);
+    const int          depth_output_gemm3d     = gemm_info.depth_output_gemm3d();
+    const bool         add_c                   = (beta != 0.f && c != nullptr);
+    const bool         is_beta_one             = std::abs(1.0f - beta) < 0.00001f;
+    const bool         fuse_add                = is_beta_one && (c != nullptr && c->num_dimensions() == 1);
 
-    const ITensorInfo *matrix_a_info = a;
-    const ITensorInfo *matrix_b_info = b;
+    const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d);
+
+    // Validate matrix multiply
+    ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyKernel::validate(a, b, (add_c && fuse_add) ? c : nullptr, output, alpha, beta,
+                                                                     false, reshape_info, gpu_target, gemm_info.fp_mixed_precision()));
+
+    if(add_c && !fuse_add)
+    {
+        // Validate matrix addition kernel
+        ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixAdditionKernel::validate(c, output, beta));
+    }
+
+    return Status{};
+}
+
+Status CLGEMM::validate_reshaped_v1(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
+{
+    ARM_COMPUTE_UNUSED(alpha);
+    ARM_COMPUTE_UNUSED(output);
 
     TensorInfo tmp_a_info{};
     TensorInfo tmp_b_info{};
 
     // Get the GPU target
-    const GPUTarget gpu_target = CLScheduler::get().target();
-
-    // Arguments used by GEMMReshapeInfo
-    // If we pass the matrix A and matrix B reshaped to CLGEMMMatrixMultiplyKernel, we need to pass m, n, k, mult_transpose1xW_width and mult_interleave4x4_height to CLGEMMReshapeInfo
-    // in order to know how the matrices have been reshaped
-    DataType           data_type                 = a->data_type();
-    bool               reinterpret_input_as_3d   = gemm_info.reinterpret_input_as_3d();
-    const unsigned int m                         = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
+    const GPUTarget    gpu_target                = CLScheduler::get().target();
+    const unsigned int m                         = gemm_info.reinterpret_input_as_3d() ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
     const unsigned int n                         = b->dimension(0);
     const unsigned int k                         = a->dimension(0);
-    const unsigned int batch_size                = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
     int                mult_transpose1xW_width   = 1;
     int                mult_interleave4x4_height = 1;
     const int          depth_output_gemm3d       = gemm_info.depth_output_gemm3d();
+    const bool         add_c                     = (beta != 0.f && c != nullptr);
+    const bool         is_beta_one               = std::abs(1.0f - beta) < 0.00001f;
+    const bool         fuse_add                  = is_beta_one && (c != nullptr && c->num_dimensions() == 1);
 
     if(get_arch_from_target(gpu_target) == GPUTarget::BIFROST)
     {
@@ -280,66 +358,21 @@
     lhs_info.interleave = true;
     lhs_info.transpose  = true;
 
-    // Check if we need to reshape the matrix A and matrix B
-    const bool run_interleave_transpose = is_interleaved_transposed(m, n, k, a->data_type(), reshape_b_only_on_first_run, gpu_target);
+    const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d, false);
 
-    // Check if we can run the new reshaped GEMM
-    const auto workload             = static_cast<float>((m * n) / 20.0f);
-    const bool is_new_gemm_reshaped = (workload > 1600.f) && (get_arch_from_target(gpu_target) == GPUTarget::BIFROST) && run_interleave_transpose && (data_type == DataType::F32);
+    // Validate interleave kernel
+    auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_lhs_reshaped_shape(*a, lhs_info, gemm_info.reinterpret_input_as_3d())));
+    ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeLHSMatrixKernel::validate(a, &tmp_a_info, lhs_info, gemm_info.reinterpret_input_as_3d()));
 
-    const bool add_matrix_c  = (beta != 0.f && c != nullptr);
-    const bool is_beta_one   = std::abs(1.0f - beta) < 0.00001f;
-    const bool use_fused_add = is_beta_one && (c != nullptr && c->num_dimensions() == 1) && !is_new_gemm_reshaped;
+    // Validate transpose kernel
+    auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
+    ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeRHSMatrixKernel::validate(b, &tmp_b_info, rhs_info));
 
-    // if _is_interleaved_transposed is set, force reinterpret_input_as_3d to be false as the output of CLGEMMInterleaveKernel will be 2D
-    if(run_interleave_transpose)
-    {
-        reinterpret_input_as_3d = false;
-    }
+    // Validate matrix multiply
+    ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyKernel::validate(&tmp_a_info, &tmp_b_info, (add_c && fuse_add) ? c : nullptr, output, alpha, beta,
+                                                                     true, reshape_info, gpu_target, gemm_info.fp_mixed_precision()));
 
-    const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d, reinterpret_input_as_3d);
-
-    if(run_interleave_transpose)
-    {
-        matrix_a_info = &tmp_a_info;
-        matrix_b_info = &tmp_b_info;
-
-        if(is_new_gemm_reshaped)
-        {
-            GEMMLHSMatrixInfo lhs_info;
-
-            // Pick up the GEMM configuration
-            std::tie(lhs_info, rhs_info) = CLGEMMReshapedConfigurationFactory::create()->configure(m, n, k, batch_size, data_type);
-
-            auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_lhs_reshaped_shape(*a, lhs_info, gemm_info.reinterpret_input_as_3d())));
-            ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeLHSMatrixKernel::validate(a, &tmp_a_info, lhs_info, gemm_info.reinterpret_input_as_3d()));
-
-            auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
-            ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeRHSMatrixKernel::validate(b, &tmp_b_info, rhs_info));
-
-            // Validate matrix multiply
-            ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyReshapedKernel::validate(matrix_a_info, matrix_b_info, output, alpha, lhs_info, rhs_info, GEMMReshapeInfo(m, n, k, 1, 1,
-                                                                                     depth_output_gemm3d, reinterpret_input_as_3d)));
-        }
-        else
-        {
-            // Validate interleave kernel
-            auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_lhs_reshaped_shape(*a, lhs_info, gemm_info.reinterpret_input_as_3d())));
-            ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeLHSMatrixKernel::validate(a, &tmp_a_info, lhs_info, gemm_info.reinterpret_input_as_3d()));
-            // Validate transpose kernel
-            auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
-            ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeRHSMatrixKernel::validate(b, &tmp_b_info, rhs_info));
-        }
-    }
-
-    if(!is_new_gemm_reshaped)
-    {
-        // Validate matrix multiply
-        ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyKernel::validate(matrix_a_info, matrix_b_info, (add_matrix_c && !use_fused_add) ? nullptr : c, output, alpha, beta,
-                                                                         run_interleave_transpose, reshape_info, gpu_target, gemm_info.fp_mixed_precision()));
-    }
-
-    if(add_matrix_c && !use_fused_add)
+    if(add_c && !fuse_add)
     {
         // Validate matrix addition kernel
         ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixAdditionKernel::validate(c, output, beta));
@@ -348,32 +381,263 @@
     return Status{};
 }
 
+Status CLGEMM::validate_reshaped_v2(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
+{
+    ARM_COMPUTE_UNUSED(alpha);
+    ARM_COMPUTE_UNUSED(output);
+
+    TensorInfo tmp_a_info{};
+    TensorInfo tmp_b_info{};
+
+    // Get the GPU target
+    const GPUTarget    gpu_target              = CLScheduler::get().target();
+    DataType           data_type               = a->data_type();
+    bool               reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
+    const unsigned int m                       = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
+    const unsigned int n                       = b->dimension(0);
+    const unsigned int k                       = a->dimension(0);
+    const unsigned int batch_size              = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
+    const int          depth_output_gemm3d     = gemm_info.depth_output_gemm3d();
+    const bool         add_c                   = (beta != 0.f && c != nullptr);
+
+    const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, 1, 1, depth_output_gemm3d, false);
+
+    GEMMLHSMatrixInfo lhs_info;
+    GEMMRHSMatrixInfo rhs_info;
+
+    // Pick up the GEMM configuration
+    std::unique_ptr<ICLGEMMKernelConfiguration> gemm_config = CLGEMMReshapedKernelConfigurationFactory::create(gpu_target);
+    ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(gemm_config.get());
+
+    // Configure lhs_info and rhs_info
+    std::tie(lhs_info, rhs_info) = gemm_config->configure(m, n, k, batch_size, data_type);
+
+    auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_lhs_reshaped_shape(*a, lhs_info, gemm_info.reinterpret_input_as_3d())));
+    ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeLHSMatrixKernel::validate(a, &tmp_a_info, lhs_info, gemm_info.reinterpret_input_as_3d()));
+
+    auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
+    ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeRHSMatrixKernel::validate(b, &tmp_b_info, rhs_info));
+
+    // Validate matrix multiply
+    ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyReshapedKernel::validate(&tmp_a_info, &tmp_b_info, output, alpha, lhs_info, rhs_info, reshape_info));
+
+    if(add_c)
+    {
+        // Validate matrix addition kernel
+        ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixAdditionKernel::validate(c, output, beta));
+    }
+
+    return Status{};
+}
+
+Status CLGEMM::validate_reshaped_only_rhs(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
+{
+    ARM_COMPUTE_UNUSED(alpha);
+    ARM_COMPUTE_UNUSED(output);
+
+    TensorInfo tmp_b_info{};
+
+    // Get the GPU target
+    const GPUTarget    gpu_target              = CLScheduler::get().target();
+    const DataType     data_type               = a->data_type();
+    bool               reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
+    const unsigned int m                       = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
+    const unsigned int n                       = b->dimension(0);
+    const unsigned int k                       = a->dimension(0);
+    const unsigned int batch_size              = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
+    const int          depth_output_gemm3d     = gemm_info.depth_output_gemm3d();
+    const bool         add_c                   = (beta != 0.f && c != nullptr);
+
+    const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d);
+
+    GEMMLHSMatrixInfo lhs_info;
+    GEMMRHSMatrixInfo rhs_info;
+
+    // Pick up the GEMM configuration
+    std::unique_ptr<ICLGEMMKernelConfiguration> gemm_config = CLGEMMReshapedOnlyRHSKernelConfigurationFactory::create(gpu_target);
+    ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(gemm_config.get());
+
+    // Configure lhs_info and rhs_info
+    std::tie(lhs_info, rhs_info) = gemm_config->configure(m, n, k, batch_size, data_type);
+
+    auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
+    ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeRHSMatrixKernel::validate(b, &tmp_b_info, rhs_info));
+
+    // Validate matrix multiply
+    ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyReshapedOnlyRHSKernel::validate(a, &tmp_b_info, output, alpha, lhs_info, rhs_info, reshape_info));
+
+    if(add_c)
+    {
+        // Validate matrix addition kernel
+        ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixAdditionKernel::validate(c, output, beta));
+    }
+
+    return Status{};
+}
+
+void CLGEMM::configure(const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info)
+{
+    ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, output);
+
+    // Perform validation step
+    ARM_COMPUTE_ERROR_THROW_ON(validate(a->info(), b->info(), c != nullptr ? c->info() : nullptr, output->info(), alpha, beta, gemm_info));
+
+    // Check if we need to reshape the matrix B only on the first run
+    _reshape_b_only_on_first_run = gemm_info.reshape_b_only_on_first_run();
+    _is_prepared                 = gemm_info.retain_internal_weights();
+    _original_b                  = b;
+
+    // Get the GPU target
+    const GPUTarget    gpu_target              = CLScheduler::get().target();
+    bool               reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
+    const unsigned int m                       = reinterpret_input_as_3d ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1);
+    const unsigned int n                       = b->info()->dimension(0);
+    const unsigned int k                       = a->info()->dimension(0);
+
+    // Select GEMMType
+    _gemm_type = select_gemm_type(m, n, k, a->info()->data_type(), _reshape_b_only_on_first_run, gpu_target);
+
+    const bool is_gemm_v2  = (_gemm_type == GEMMType::RESHAPED_V2) || (_gemm_type == GEMMType::RESHAPED_ONLY_RHS);
+    const bool add_c       = (beta != 0.f && c != nullptr);
+    const bool is_beta_one = std::abs(1.0f - beta) < 0.00001f;
+    const bool fuse_add    = is_beta_one && (c != nullptr && c->info()->num_dimensions() == 1) && !is_gemm_v2;
+
+    switch(_gemm_type)
+    {
+        case GEMMType::NATIVE:
+        {
+            configure_native(a, b, (add_c && fuse_add) ? c : nullptr, output, alpha, beta, gemm_info);
+            break;
+        }
+        case GEMMType::RESHAPED_V1:
+        {
+            configure_reshaped_v1(a, b, (add_c && fuse_add) ? c : nullptr, output, alpha, beta, gemm_info);
+            break;
+        }
+        case GEMMType::RESHAPED_V2:
+        {
+            configure_reshaped_v2(a, b, (add_c && fuse_add) ? c : nullptr, output, alpha, beta, gemm_info);
+            break;
+        }
+        case GEMMType::RESHAPED_ONLY_RHS:
+        {
+            configure_reshaped_only_rhs(a, b, (add_c && fuse_add) ? c : nullptr, output, alpha, beta, gemm_info);
+            break;
+        }
+        default:
+        {
+            ARM_COMPUTE_ERROR("GEMMType not supported");
+        }
+    }
+
+    // Configure matrix addition kernel
+    if(add_c && !fuse_add)
+    {
+        _ma_kernel.configure(c, output, beta);
+        _run_addition = true;
+    }
+}
+
+Status CLGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
+{
+    // Get the GPU target
+    const GPUTarget    gpu_target              = CLScheduler::get().target();
+    bool               reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
+    const unsigned int m                       = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
+    const unsigned int n                       = b->dimension(0);
+    const unsigned int k                       = a->dimension(0);
+
+    // Select GEMMType
+    GEMMType gemm_type = select_gemm_type(m, n, k, a->data_type(), gemm_info.reshape_b_only_on_first_run(), gpu_target);
+
+    switch(gemm_type)
+    {
+        case GEMMType::NATIVE:
+        {
+            ARM_COMPUTE_RETURN_ON_ERROR(validate_native(a, b, c, output, alpha, beta, gemm_info));
+            break;
+        }
+        case GEMMType::RESHAPED_V1:
+        {
+            ARM_COMPUTE_RETURN_ON_ERROR(validate_reshaped_v1(a, b, c, output, alpha, beta, gemm_info));
+            break;
+        }
+        case GEMMType::RESHAPED_V2:
+        {
+            ARM_COMPUTE_RETURN_ON_ERROR(validate_reshaped_v2(a, b, c, output, alpha, beta, gemm_info));
+            break;
+        }
+        case GEMMType::RESHAPED_ONLY_RHS:
+        {
+            ARM_COMPUTE_RETURN_ON_ERROR(validate_reshaped_only_rhs(a, b, c, output, alpha, beta, gemm_info));
+            break;
+        }
+        default:
+        {
+            ARM_COMPUTE_RETURN_ERROR_MSG("GEMMType not supported");
+        }
+    }
+
+    return Status{};
+}
+
 void CLGEMM::run()
 {
     prepare();
 
     MemoryGroupResourceScope scope_mg(_memory_group);
 
-    if(_is_interleaved_transposed)
-    {
-        // Run interleave kernel
-        CLScheduler::get().enqueue(_reshape_lhs_kernel, false);
-
-        if(!_reshape_b_only_on_first_run)
-        {
-            // Run transpose kernel
-            CLScheduler::get().enqueue(_reshape_rhs_kernel, false);
-        }
-    }
-
     // Run matrix multiply kernel
-    if(_is_new_gemm_reshaped)
+    switch(_gemm_type)
     {
-        CLScheduler::get().enqueue(_mm_reshaped_kernel, !_run_addition);
-    }
-    else
-    {
-        CLScheduler::get().enqueue(_mm_kernel, !_run_addition);
+        case GEMMType::NATIVE:
+        {
+            CLScheduler::get().enqueue(_mm_kernel, !_run_addition);
+            break;
+        }
+        case GEMMType::RESHAPED_V1:
+        {
+            // Run interleave kernel
+            CLScheduler::get().enqueue(_reshape_lhs_kernel, false);
+
+            if(!_reshape_b_only_on_first_run)
+            {
+                // Run transpose kernel
+                CLScheduler::get().enqueue(_reshape_rhs_kernel, false);
+            }
+
+            CLScheduler::get().enqueue(_mm_kernel, !_run_addition);
+            break;
+        }
+        case GEMMType::RESHAPED_V2:
+        {
+            // Run interleave kernel
+            CLScheduler::get().enqueue(_reshape_lhs_kernel, false);
+
+            if(!_reshape_b_only_on_first_run)
+            {
+                // Run transpose kernel
+                CLScheduler::get().enqueue(_reshape_rhs_kernel, false);
+            }
+
+            CLScheduler::get().enqueue(_mm_reshaped_kernel, !_run_addition);
+            break;
+        }
+        case GEMMType::RESHAPED_ONLY_RHS:
+        {
+            if(!_reshape_b_only_on_first_run)
+            {
+                // Run transpose kernel
+                CLScheduler::get().enqueue(_reshape_rhs_kernel, false);
+            }
+
+            CLScheduler::get().enqueue(_mm_reshaped_only_rhs_kernel, !_run_addition);
+            break;
+        }
+        default:
+        {
+            ARM_COMPUTE_ERROR("GEMMType not supported");
+        }
     }
 
     // Run matrix addition kernel
@@ -387,7 +651,7 @@
 {
     if(!_is_prepared)
     {
-        if(_is_interleaved_transposed && _reshape_b_only_on_first_run)
+        if(_gemm_type != GEMMType::NATIVE && _reshape_b_only_on_first_run)
         {
             // Run transpose kernel and mark original weights tensor as unused
             _tmp_b.allocator()->allocate();
diff --git a/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp b/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
index c0bd85d..c447cb8 100644
--- a/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
+++ b/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
@@ -24,6 +24,7 @@
 #include "arm_compute/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.h"
 
 #include "arm_compute/core/CL/ICLTensor.h"
+#include "arm_compute/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfiguration.h"
 #include "arm_compute/core/Error.h"
 #include "arm_compute/core/Helpers.h"
 #include "arm_compute/core/TensorInfo.h"
@@ -31,7 +32,6 @@
 #include "arm_compute/core/Validate.h"
 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
 #include "arm_compute/runtime/CL/CLScheduler.h"
-#include "arm_compute/runtime/CL/gemm_reshaped/CLGEMMReshapedConfiguration.h"
 
 namespace arm_compute
 {
@@ -122,12 +122,12 @@
         }
 
         // Pick up the GEMM configuration
-        std::tie(lhs_info, rhs_info) = CLGEMMReshapedConfigurationFactory::create()->configure(m, n, k, batch_size, DataType::QASYMM8);
+        std::tie(lhs_info, rhs_info) = CLGEMMReshapedKernelConfigurationFactory::create(gpu_target)->configure(m, n, k, batch_size, DataType::QASYMM8);
 
-        // Configure interleave kernel
+        // Configure reshape LHS kernel
         _mtx_a_reshape_kernel.configure(a, &_tmp_a, lhs_info, gemm_info.reinterpret_input_as_3d());
 
-        // Configure transpose kernel
+        // Configure reshape RHS kernel
         _mtx_b_reshape_kernel.configure(b, &_tmp_b, rhs_info);
     }
 
@@ -236,6 +236,9 @@
     GEMMRHSMatrixInfo rhs_info;
     GEMMLHSMatrixInfo lhs_info;
 
+    // Get the GPU target
+    const GPUTarget gpu_target = CLScheduler::get().target();
+
     bool               reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
     const unsigned int m                       = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
     const unsigned int n                       = b->dimension(0);
@@ -259,14 +262,13 @@
         matrix_b_info = &tmp_b_info;
 
         // Pick up the GEMM configuration
-        std::tie(lhs_info, rhs_info) = CLGEMMReshapedConfigurationFactory::create()->configure(m, n, k, batch_size, DataType::QASYMM8);
+        std::tie(lhs_info, rhs_info) = CLGEMMReshapedKernelConfigurationFactory::create(gpu_target)->configure(m, n, k, batch_size, DataType::QASYMM8);
 
-        // Validate interleave kernel
+        // Validate reshape LHS kernel
         auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_lhs_reshaped_shape(*a, lhs_info, gemm_info.reinterpret_input_as_3d())));
         ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeLHSMatrixKernel::validate(a, &tmp_a_info, lhs_info, gemm_info.reinterpret_input_as_3d()));
 
-        // Validate transpose kernel
-
+        // Validate reshape RHS kernel
         auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
         ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeRHSMatrixKernel::validate(b, &tmp_b_info, rhs_info));
     }