COMPMID-3770: Add batch size in the OpenCL GEMM kernel selection

Change-Id: Ia3030ea701e9ceb2ef567e0258e8f478e18b8b55
Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3871
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: SiCong Li <sicong.li@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
diff --git a/arm_compute/runtime/CL/CLTypes.h b/arm_compute/runtime/CL/CLTypes.h
index cbc5253..19095a5 100644
--- a/arm_compute/runtime/CL/CLTypes.h
+++ b/arm_compute/runtime/CL/CLTypes.h
@@ -53,6 +53,7 @@
     unsigned int m{ 0 };                         /**< Number of rows for the lhs matrix. Lhs matrix NOT transposed */
     unsigned int n{ 0 };                         /**< Number of columns for the rhs matrix. Rhs matrix NOT transposed */
     unsigned int k{ 0 };                         /**< Number of rows for the rhs matrix. Rhs matrix NOT transposed */
+    unsigned int b{ 0 };                         /**< Batch size */
     bool         is_rhs_constant{ false };       /**< True if the content of the rhs matrix is constant */
     DataType     data_type{ DataType::UNKNOWN }; /**< Data type */
 };
diff --git a/arm_compute/runtime/CL/functions/CLGEMM.h b/arm_compute/runtime/CL/functions/CLGEMM.h
index 8e4d390..6e9cf0e 100644
--- a/arm_compute/runtime/CL/functions/CLGEMM.h
+++ b/arm_compute/runtime/CL/functions/CLGEMM.h
@@ -185,7 +185,7 @@
     void prepare() override;
 
 private:
-    static CLGEMMKernelType select_gemm_kernel(unsigned int m, unsigned int n, unsigned int k, DataType data_type, bool reshape_b_only_on_first_run);
+    static CLGEMMKernelType select_gemm_kernel(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type, bool reshape_b_only_on_first_run);
 
     void configure_native_v1(const CLCompileContext &compile_context, const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info);
     void configure_reshaped_v1(const CLCompileContext &compile_context, const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info);
diff --git a/arm_compute/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.h b/arm_compute/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.h
index 815c2c8..579bbe3 100644
--- a/arm_compute/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.h
+++ b/arm_compute/runtime/CL/gemm/CLGEMMKernelSelectionBifrost.h
@@ -44,11 +44,11 @@
     CLGEMMKernelType select_kernel(const CLGEMMKernelSelectionParams &params) override;
 
 private:
-    CLGEMMKernelType g76_f32(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant);
-    CLGEMMKernelType g71_f16(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant);
-    CLGEMMKernelType default_f32(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant);
-    CLGEMMKernelType default_f16(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant);
-    CLGEMMKernelType default_q8(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant);
+    CLGEMMKernelType g76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant);
+    CLGEMMKernelType g71_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant);
+    CLGEMMKernelType default_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant);
+    CLGEMMKernelType default_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant);
+    CLGEMMKernelType default_q8(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant);
 };
 } // namespace cl_gemm
 } // namespace arm_compute
diff --git a/arm_compute/runtime/CL/gemm/CLGEMMKernelSelectionMidgard.h b/arm_compute/runtime/CL/gemm/CLGEMMKernelSelectionMidgard.h
index 4689f0c..5547731 100644
--- a/arm_compute/runtime/CL/gemm/CLGEMMKernelSelectionMidgard.h
+++ b/arm_compute/runtime/CL/gemm/CLGEMMKernelSelectionMidgard.h
@@ -44,9 +44,9 @@
     CLGEMMKernelType select_kernel(const CLGEMMKernelSelectionParams &params) override;
 
 private:
-    CLGEMMKernelType default_f32(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant);
-    CLGEMMKernelType default_f16(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant);
-    CLGEMMKernelType default_q8(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant);
+    CLGEMMKernelType default_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant);
+    CLGEMMKernelType default_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant);
+    CLGEMMKernelType default_q8(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant);
 };
 } // namespace cl_gemm
 } // namespace arm_compute
diff --git a/arm_compute/runtime/CL/gemm/CLGEMMKernelSelectionValhall.h b/arm_compute/runtime/CL/gemm/CLGEMMKernelSelectionValhall.h
index 8712be7..782ef74 100644
--- a/arm_compute/runtime/CL/gemm/CLGEMMKernelSelectionValhall.h
+++ b/arm_compute/runtime/CL/gemm/CLGEMMKernelSelectionValhall.h
@@ -44,9 +44,9 @@
     CLGEMMKernelType select_kernel(const CLGEMMKernelSelectionParams &params) override;
 
 private:
-    CLGEMMKernelType default_f32(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant);
-    CLGEMMKernelType default_f16(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant);
-    CLGEMMKernelType default_q8(unsigned int m, unsigned int n, unsigned int k, bool is_rhs_constant);
+    CLGEMMKernelType default_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant);
+    CLGEMMKernelType default_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant);
+    CLGEMMKernelType default_q8(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant);
 };
 } // namespace cl_gemm
 } // namespace arm_compute