COMPMID-2675: Fix arguments passed at compile time for GEMM - OpenCL

Change-Id: I47b84a6f815492e24771d488aa8b29d14e572f40
Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Reviewed-on: https://review.mlplatform.org/c/1956
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/core/CL/cl_kernels/gemm.cl b/src/core/CL/cl_kernels/gemm.cl
index d7484d7..8e628e8 100644
--- a/src/core/CL/cl_kernels/gemm.cl
+++ b/src/core/CL/cl_kernels/gemm.cl
@@ -2644,12 +2644,12 @@
 #if K0 > 8
         RHS_VFMA_M0xN0(8, a, b8, c);
         RHS_VFMA_M0xN0(9, a, b9, c);
-        RHS_VFMA_M0xN0(A, a, b10, c);
-        RHS_VFMA_M0xN0(B, a, b11, c);
-        RHS_VFMA_M0xN0(C, a, b12, c);
-        RHS_VFMA_M0xN0(D, a, b13, c);
-        RHS_VFMA_M0xN0(E, a, b14, c);
-        RHS_VFMA_M0xN0(F, a, b15, c);
+        RHS_VFMA_M0xN0(A, a, bA, c);
+        RHS_VFMA_M0xN0(B, a, bB, c);
+        RHS_VFMA_M0xN0(C, a, bC, c);
+        RHS_VFMA_M0xN0(D, a, bD, c);
+        RHS_VFMA_M0xN0(E, a, bE, c);
+        RHS_VFMA_M0xN0(F, a, bF, c);
 #endif // K0 > 8
 
         lhs_offset += K0 * sizeof(DATA_TYPE);
diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.cpp
index b1d0059..b00faed 100644
--- a/src/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyNativeKernel.cpp
@@ -155,7 +155,7 @@
 
     // Note: bottom paddings are calculated manually as the output can be reinterpreted as 3D tensor
     // The only way to set properly the paddings, it is to set those explicitly through the AccessWindowStatic
-    const unsigned int m          = reinterpret_output_as_3d ? gemm_info.m : input0->dimension(1);
+    const unsigned int m          = reinterpret_output_as_3d ? gemm_info.m : output->dimension(1);
     const unsigned int bottom_pad = (num_elems_processed_per_iteration_y - (m % num_elems_processed_per_iteration_y)) % num_elems_processed_per_iteration_y;
 
     win     = calculate_max_window(tmp_info, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
@@ -246,6 +246,14 @@
     ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
     ICLKernel::configure_internal(win_config.second);
 
+    // If _reinterpret_input_as_3d = _reinterpret_output_as_3d = true,
+    // we will dispatch a batched-GEMM to reduce the complexity of the address calculation within the OpenCL kernel.
+    // This means that the actual m used by the kernel is given by output->info()->dimension(1) and not by gemm_info.m
+    const unsigned int internal_m = _reinterpret_output_as_3d ? gemm_info.m : output->info()->dimension(1);
+
+    const unsigned int h_gemm_3d = _reinterpret_output_as_3d ? output->info()->dimension(1) : input0->info()->dimension(1);
+    const unsigned int d_gemm_3d = _reinterpret_output_as_3d ? output->info()->dimension(2) : input0->info()->dimension(2);
+
     // Create build options
     CLBuildOptions build_opts;
     build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(input0->info()->data_type()));
@@ -255,11 +263,11 @@
     build_opts.add_option_if(gemm_info.broadcast_bias, "-DBROADCAST_BIAS");
     build_opts.add_option_if(_reinterpret_input_as_3d, "-DREINTERPRET_INPUT_AS_3D");
     build_opts.add_option_if(_reinterpret_output_as_3d, "-DREINTERPRET_OUTPUT_AS_3D");
-    build_opts.add_option_if(_reinterpret_input_as_3d || _reinterpret_output_as_3d, "-DHEIGHT_GEMM3D=" + support::cpp11::to_string(output->info()->dimension(1)));
-    build_opts.add_option_if(_reinterpret_input_as_3d || _reinterpret_output_as_3d, "-DDEPTH_GEMM3D=" + support::cpp11::to_string(output->info()->dimension(2)));
+    build_opts.add_option_if(_reinterpret_input_as_3d || _reinterpret_output_as_3d, "-DHEIGHT_GEMM3D=" + support::cpp11::to_string(h_gemm_3d));
+    build_opts.add_option_if(_reinterpret_input_as_3d || _reinterpret_output_as_3d, "-DDEPTH_GEMM3D=" + support::cpp11::to_string(d_gemm_3d));
     build_opts.add_option_if(!_slide_matrix_b, "-DMATRIX_B_DEPTH=" + support::cpp11::to_string(input1->info()->dimension(2)));
     build_opts.add_option_if(_use_dummy_work_items, "-DDUMMY_WORK_ITEMS");
-    build_opts.add_option("-DM=" + support::cpp11::to_string(input0->info()->dimension(1)));
+    build_opts.add_option("-DM=" + support::cpp11::to_string(internal_m));
     build_opts.add_option("-DN=" + support::cpp11::to_string(gemm_info.n));
     build_opts.add_option("-DK=" + support::cpp11::to_string(gemm_info.k));
     build_opts.add_option("-DM0=" + support::cpp11::to_string(lhs_info.m0));
diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp
index 0e9ca78..fff4da6 100644
--- a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp
@@ -158,7 +158,7 @@
 
     // Note: bottom paddings are calculated manually as the output can be reinterpreted as 3D tensor
     // The only way to set properly the paddings, it is to set those explicitly through the AccessWindowStatic
-    const unsigned int m          = reinterpret_output_as_3d ? gemm_info.m : input0->dimension(1);
+    const unsigned int m          = reinterpret_output_as_3d ? gemm_info.m : output->dimension(1);
     const unsigned int bottom_pad = (num_elems_processed_per_iteration_y - (m % num_elems_processed_per_iteration_y)) % num_elems_processed_per_iteration_y;
 
     win     = calculate_max_window(tmp_info, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
@@ -249,6 +249,14 @@
     ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
     ICLKernel::configure_internal(win_config.second);
 
+    // If _reinterpret_input_as_3d = _reinterpret_output_as_3d = true,
+    // we will dispatch a batched-GEMM to reduce the complexity of the address calculation within the OpenCL kernel.
+    // This means that the actual m used by the kernel is given by output->info()->dimension(1) and not by gemm_info.m
+    const unsigned int internal_m = _reinterpret_output_as_3d ? gemm_info.m : output->info()->dimension(1);
+
+    const unsigned int h_gemm_3d = _reinterpret_output_as_3d ? output->info()->dimension(1) : input0->info()->dimension(1);
+    const unsigned int d_gemm_3d = _reinterpret_output_as_3d ? output->info()->dimension(2) : input0->info()->dimension(2);
+
     // Create build options
     CLBuildOptions build_opts;
     build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(input0->info()->data_type()));
@@ -258,12 +266,12 @@
     build_opts.add_option_if(_reinterpret_input_as_3d, "-DREINTERPRET_INPUT_AS_3D");
     build_opts.add_option_if(_reinterpret_output_as_3d, "-DREINTERPRET_OUTPUT_AS_3D");
     build_opts.add_option_if(gemm_info.broadcast_bias, "-DBROADCAST_BIAS");
-    build_opts.add_option_if(_reinterpret_input_as_3d || _reinterpret_output_as_3d, "-DHEIGHT_GEMM3D=" + support::cpp11::to_string(output->info()->dimension(1)));
-    build_opts.add_option_if(_reinterpret_input_as_3d || _reinterpret_output_as_3d, "-DDEPTH_GEMM3D=" + support::cpp11::to_string(output->info()->dimension(2)));
+    build_opts.add_option_if(_reinterpret_input_as_3d || _reinterpret_output_as_3d, "-DHEIGHT_GEMM3D=" + support::cpp11::to_string(h_gemm_3d));
+    build_opts.add_option_if(_reinterpret_input_as_3d || _reinterpret_output_as_3d, "-DDEPTH_GEMM3D=" + support::cpp11::to_string(d_gemm_3d));
     build_opts.add_option_if(!_slide_matrix_b, "-DMATRIX_B_DEPTH=" + support::cpp11::to_string(input1->info()->dimension(2)));
     build_opts.add_option_if(rhs_info.interleave, "-DRHS_INTERLEAVE");
     build_opts.add_option_if(_use_dummy_work_items, "-DDUMMY_WORK_ITEMS");
-    build_opts.add_option("-DM=" + support::cpp11::to_string(input0->info()->dimension(1)));
+    build_opts.add_option("-DM=" + support::cpp11::to_string(internal_m));
     build_opts.add_option("-DN=" + support::cpp11::to_string(gemm_info.n));
     build_opts.add_option("-DK=" + support::cpp11::to_string(gemm_info.k));
     build_opts.add_option("-DM0=" + support::cpp11::to_string(lhs_info.m0));