COMPMID-2097: Implement a heuristic to dispatch CLGEMMReshapedOnlyRHS kernel from CLGEMM

Change-Id: I4170a80647b02501aa669e2c0347ddc39888ee76
Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Reviewed-on: https://review.mlplatform.org/c/928
Reviewed-by: Giuseppe Rossini <giuseppe.rossini@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
diff --git a/arm_compute/runtime/CL/functions/CLGEMM.h b/arm_compute/runtime/CL/functions/CLGEMM.h
index 0bad446..8c462fa 100644
--- a/arm_compute/runtime/CL/functions/CLGEMM.h
+++ b/arm_compute/runtime/CL/functions/CLGEMM.h
@@ -27,6 +27,7 @@
 #include "arm_compute/core/CL/kernels/CLGEMMMatrixAdditionKernel.h"
 #include "arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyKernel.h"
 #include "arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.h"
+#include "arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.h"
 #include "arm_compute/core/CL/kernels/CLGEMMReshapeLHSMatrixKernel.h"
 #include "arm_compute/core/CL/kernels/CLGEMMReshapeRHSMatrixKernel.h"
 #include "arm_compute/runtime/CL/CLMemoryGroup.h"
@@ -40,10 +41,11 @@
 
 /** Basic function to execute GEMM on OpenCL. This function calls the following OpenCL kernels:
  *
- *  -# @ref CLGEMMReshapeLHSMatrixKernel (only if the reshaped GEMM is selected by the heuristic model)
- *  -# @ref CLGEMMReshapeRHSMatrixKernel (only if the reshaped GEMM is selected by the heuristic model)
- *  -# @ref CLGEMMMatrixMultiplyKernel (if GPU target is NOT G76 or if the reshaped GEMM is NOT selected)
- *  -# @ref CLGEMMMatrixMultiplyReshapedKernel (only if the reshaped GEMM is selected by the heuristic model and the GPU target IS Mali-G76)
+ *  -# @ref CLGEMMReshapeLHSMatrixKernel (only if the RESHAPED_V1 is selected by the heuristic model)
+ *  -# @ref CLGEMMReshapeRHSMatrixKernel (only if either the RESHAPED_V1 or RESHAPED_ONLY_RHS is selected by the select_gemm_type method())
+ *  -# @ref CLGEMMMatrixMultiplyKernel (only if either the NATIVE or RESHAPED_V1 is selected by the select_gemm_type method())
+ *  -# @ref CLGEMMMatrixMultiplyReshapedKernel (only if RESHAPED_V1 is selected by the select_gemm_type method())
+ *  -# @ref CLGEMMMatrixMultiplyReshapedOnlyRHSKernel (only if RESHAPED_ONLY_RHS is selected by the select_gemm_type method())
  *  -# @ref CLGEMMMatrixAdditionKernel (if c != nullptr and beta != 0.0)
  *
  */
@@ -102,20 +104,41 @@
     void prepare() override;
 
 private:
-    CLMemoryGroup                      _memory_group;
-    CLGEMMMatrixMultiplyKernel         _mm_kernel;
-    CLGEMMMatrixAdditionKernel         _ma_kernel;
-    CLGEMMReshapeLHSMatrixKernel       _reshape_lhs_kernel;
-    CLGEMMReshapeRHSMatrixKernel       _reshape_rhs_kernel;
-    CLGEMMMatrixMultiplyReshapedKernel _mm_reshaped_kernel;
-    CLTensor                           _tmp_a;
-    CLTensor                           _tmp_b;
-    const ICLTensor                   *_original_b;
-    bool                               _is_interleaved_transposed;
-    bool                               _run_addition;
-    bool                               _reshape_b_only_on_first_run;
-    bool                               _is_prepared;
-    bool                               _is_new_gemm_reshaped; // Remove when COMPMID-1892 is completed
+    enum class GEMMType
+    {
+        NATIVE,
+        RESHAPED_V1,
+        RESHAPED_V2,
+        RESHAPED_ONLY_RHS
+    };
+
+    // TODO (COMPMID-2095)
+    static GEMMType select_gemm_type(unsigned int m, unsigned int n, unsigned int k, DataType data_type, bool reshape_b_only_on_first_run, GPUTarget gpu_target);
+
+    void configure_native(const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info);
+    void configure_reshaped_v1(const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info);
+    void configure_reshaped_v2(const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info);
+    void configure_reshaped_only_rhs(const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info);
+
+    static Status validate_native(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info);
+    static Status validate_reshaped_v1(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info);
+    static Status validate_reshaped_v2(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info);
+    static Status validate_reshaped_only_rhs(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info);
+
+    CLMemoryGroup                             _memory_group;
+    CLGEMMMatrixMultiplyKernel                _mm_kernel;
+    CLGEMMMatrixAdditionKernel                _ma_kernel;
+    CLGEMMReshapeLHSMatrixKernel              _reshape_lhs_kernel;
+    CLGEMMReshapeRHSMatrixKernel              _reshape_rhs_kernel;
+    CLGEMMMatrixMultiplyReshapedKernel        _mm_reshaped_kernel;
+    CLGEMMMatrixMultiplyReshapedOnlyRHSKernel _mm_reshaped_only_rhs_kernel;
+    CLTensor                                  _tmp_a;
+    CLTensor                                  _tmp_b;
+    const ICLTensor                          *_original_b;
+    bool                                      _run_addition;
+    bool                                      _reshape_b_only_on_first_run;
+    bool                                      _is_prepared;
+    GEMMType                                  _gemm_type;
 };
 } // namespace arm_compute