Add Gemm MMUL Reshaped Only Rhs Support for FP32/FP16

This patch introduces a GEMM routine that is optimized for Arm(R) Mali(TM)-G715 and Arm(R) Mali(TM)-G615

Resolves: COMPMID-5216
Signed-off-by: Gunes Bayir <gunes.bayir@arm.com>
Change-Id: I2e5d7806f5904347185bb3e250f73d73d6669dba
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7914
Reviewed-by: SiCong Li <sicong.li@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/gpu/cl/operators/ClGemm.cpp b/src/gpu/cl/operators/ClGemm.cpp
index 88f6b79..4db39a6 100644
--- a/src/gpu/cl/operators/ClGemm.cpp
+++ b/src/gpu/cl/operators/ClGemm.cpp
@@ -191,6 +191,7 @@
       _mm_native_kernel(std::make_unique<ClGemmMatrixMultiplyNativeKernel>()),
       _mm_reshaped_kernel(std::make_unique<ClGemmMatrixMultiplyReshapedKernel>()),
       _mm_reshaped_only_rhs_kernel(std::make_unique<ClGemmMatrixMultiplyReshapedOnlyRhsKernel>()),
+      _mm_reshaped_only_rhs_mmul_kernel(std::make_unique<ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel>()),
       _tmp_a(),
       _tmp_b(),
       _reshape_b_only_on_first_run(false),
@@ -324,6 +325,53 @@
     _aux_mem[RhsReshape] = MemoryInfo(offset_int_vec(RhsReshape), _reshape_b_only_on_first_run ? MemoryLifetime::Persistent : MemoryLifetime::Temporary, _tmp_b.total_size());
 }
 
+void ClGemm::configure_reshaped_only_rhs_mmul(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta,
+                                              const GEMMInfo &gemm_info)
+{
+    DataType           data_type               = a->data_type();
+    bool               reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
+    const unsigned int m                       = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
+    const unsigned int n                       = b->dimension(0);
+    const unsigned int k                       = a->dimension(0);
+    const unsigned int batch_size              = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
+    const int          depth_output_gemm3d     = gemm_info.depth_output_gemm3d();
+    const GPUTarget    gpu_target              = CLScheduler::get().target();
+    bool               broadcast_bias          = gemm_info.broadcast_bias();
+
+    GEMMKernelInfo kernel_info;
+    kernel_info.m                       = m;
+    kernel_info.n                       = n;
+    kernel_info.k                       = k;
+    kernel_info.depth_output_gemm3d     = depth_output_gemm3d;
+    kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
+    kernel_info.broadcast_bias          = broadcast_bias;
+    kernel_info.activation_info         = gemm_info.activation_info();
+    kernel_info.post_ops                = gemm_info.post_ops();
+
+    // Set the target for the kernels
+    _mm_reshaped_only_rhs_mmul_kernel->set_target(gpu_target);
+
+    GEMMLHSMatrixInfo lhs_info{};
+    GEMMRHSMatrixInfo rhs_info{};
+
+    // Pick up the GEMM configuration
+    auto gemm_config = select_default_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size });
+    lhs_info         = gemm_config.lhs_info;
+    rhs_info         = gemm_config.rhs_info;
+    // Force H0 to 4 in order to use the MMUL extension
+    rhs_info.h0 = 4;
+
+    // Reshape Rhs matrix
+    _reshape_rhs_kernel->configure(compile_context, b, &_tmp_b, rhs_info);
+
+    // Configure matrix multiply kernel with no y padding support
+    kernel_info.has_pad_y = false;
+    _mm_reshaped_only_rhs_mmul_kernel->configure(compile_context, a, &_tmp_b, c, output, alpha, beta, lhs_info, rhs_info, kernel_info);
+
+    // Request memory for RHS reshape matrix
+    _aux_mem[RhsReshape] = MemoryInfo(offset_int_vec(RhsReshape), _reshape_b_only_on_first_run ? MemoryLifetime::Persistent : MemoryLifetime::Temporary, _tmp_b.total_size());
+}
+
 Status ClGemm::validate_native(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
 {
     ARM_COMPUTE_UNUSED(alpha);
@@ -458,6 +506,54 @@
     return Status{};
 }
 
+Status ClGemm::validate_reshaped_only_rhs_mmul(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
+{
+    ARM_COMPUTE_UNUSED(alpha);
+    ARM_COMPUTE_UNUSED(output);
+    TensorInfo tmp_b_info{};
+
+    // Get the GPU target
+    const GPUTarget    gpu_target              = CLScheduler::get().target();
+    const DataType     data_type               = a->data_type();
+    bool               reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
+    const unsigned int m                       = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
+    const unsigned int n                       = b->dimension(0);
+    const unsigned int k                       = a->dimension(0);
+    const unsigned int batch_size              = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
+    const int          depth_output_gemm3d     = gemm_info.depth_output_gemm3d();
+    const bool         broadcast_bias          = gemm_info.broadcast_bias();
+
+    GEMMKernelInfo kernel_info;
+    kernel_info.m                       = m;
+    kernel_info.n                       = n;
+    kernel_info.k                       = k;
+    kernel_info.depth_output_gemm3d     = depth_output_gemm3d;
+    kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
+    kernel_info.broadcast_bias          = broadcast_bias;
+    kernel_info.activation_info         = gemm_info.activation_info();
+    kernel_info.post_ops                = gemm_info.post_ops();
+
+    GEMMLHSMatrixInfo lhs_info;
+    GEMMRHSMatrixInfo rhs_info;
+
+    // Pick up the GEMM configuration
+    // NOTE: No need to validate mlgo configurations as they automatically fall back to default heuristics if validation fails
+    const auto gemm_config = select_default_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size });
+    lhs_info               = gemm_config.lhs_info;
+    rhs_info               = gemm_config.rhs_info;
+    // Force H0 to 4 in order to use the MMUL extension
+    rhs_info.h0 = 4;
+
+    auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
+    ARM_COMPUTE_RETURN_ON_ERROR(ClGemmReshapeRhsMatrixKernel::validate(b, &tmp_b_info, rhs_info));
+
+    // Validate matrix multiply
+    kernel_info.has_pad_y = false;
+    ARM_COMPUTE_RETURN_ON_ERROR(ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel::validate(a, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info));
+
+    return Status{};
+}
+
 void ClGemm::configure(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
 {
     ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, output);
@@ -501,6 +597,11 @@
             configure_reshaped_only_rhs(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
             break;
         }
+        case CLGEMMKernelType::RESHAPED_ONLY_RHS_MMUL:
+        {
+            configure_reshaped_only_rhs_mmul(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
+            break;
+        }
         default:
         {
             ARM_COMPUTE_ERROR("GEMMType not supported");
@@ -545,6 +646,11 @@
             ARM_COMPUTE_RETURN_ON_ERROR(validate_reshaped_only_rhs(a, b, c_to_use, output, alpha, beta, gemm_info));
             break;
         }
+        case CLGEMMKernelType::RESHAPED_ONLY_RHS_MMUL:
+        {
+            ARM_COMPUTE_RETURN_ON_ERROR(validate_reshaped_only_rhs_mmul(a, b, c_to_use, output, alpha, beta, gemm_info));
+            break;
+        }
         default:
         {
             ARM_COMPUTE_RETURN_ERROR_MSG("GEMMType not supported");
@@ -627,6 +733,34 @@
             }
             break;
         }
+        case CLGEMMKernelType::RESHAPED_ONLY_RHS_MMUL:
+        {
+            if(!_reshape_b_only_on_first_run)
+            {
+                // Run transpose kernel
+                ITensorPack reshape_rhs_pack{ { ACL_SRC, rhs }, { ACL_DST, rhs_reshaped.get() } };
+                CLScheduler::get().enqueue_op(*_reshape_rhs_kernel, reshape_rhs_pack, false);
+            }
+            // In case of RESHAPED_ONLY_RHS, we need to check the padding requirement
+            // Check if the lhs or dst tensors have padding
+            const unsigned int cross_plane_pad_lhs = lhs->info()->padding().top + lhs->info()->padding().bottom;
+            const unsigned int cross_plane_pad_dst = dst->info()->padding().top + dst->info()->padding().bottom;
+            bool               has_pad_y           = (cross_plane_pad_lhs != 0) || (cross_plane_pad_dst != 0);
+
+            // Copy original tensor pack and overwrite rhs with reshaped counterpart
+            ITensorPack gemm_reshaped_onlyrhs_pack(tensors);
+            gemm_reshaped_onlyrhs_pack.add_const_tensor(ACL_SRC_1, rhs_reshaped.get());
+
+            if(has_pad_y)
+            {
+                ARM_COMPUTE_ERROR_ON(has_pad_y);
+            }
+            else
+            {
+                CLScheduler::get().enqueue_op(*_mm_reshaped_only_rhs_mmul_kernel, gemm_reshaped_onlyrhs_pack, true);
+            }
+            break;
+        }
         default:
         {
             ARM_COMPUTE_ERROR("GEMMType not supported");