COMPMID-1518: Add support for GEMM3D in CLGEMMLowpMatrixMultiplyCore

Change-Id: Ib14ac821ee5d4aff80bd602cd3e76e7018abb5e6
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/150268
Tested-by: bsgcomp <bsgcomp@arm.com>
Reviewed-by: Isabella Gottardi <isabella.gottardi@arm.com>
Reviewed-by: Michele DiGiorgio <michele.digiorgio@arm.com>
diff --git a/src/core/CL/kernels/CLGEMMLowpOffsetContributionKernel.cpp b/src/core/CL/kernels/CLGEMMLowpOffsetContributionKernel.cpp
index aa954ab..3888353 100644
--- a/src/core/CL/kernels/CLGEMMLowpOffsetContributionKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMLowpOffsetContributionKernel.cpp
@@ -62,16 +62,24 @@
     if(b_offset != 0)
     {
         ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(vector_sum_row, 1, DataType::S32);
-        ARM_COMPUTE_RETURN_ERROR_ON(vector_sum_row->dimension(0) != mm_result->dimension(1));
+
+        // Check if input is a 3D reinterpretation
+        const bool reinterpret_as_3d = vector_sum_row != nullptr && mm_result->num_dimensions() > 1 && mm_result->tensor_shape().y() != vector_sum_row->tensor_shape().x();
+
+        // Validate input
+        ARM_COMPUTE_RETURN_ERROR_ON(reinterpret_as_3d && vector_sum_row->dimension(0) != (mm_result->dimension(1) * mm_result->dimension(2)));
+        ARM_COMPUTE_RETURN_ERROR_ON(!reinterpret_as_3d && vector_sum_row != nullptr && vector_sum_row->dimension(0) != mm_result->dimension(1));
 
         TensorShape output_shape = mm_result->tensor_shape();
         if(output_shape.num_dimensions() > 1)
         {
+            const unsigned int output_batch_idx = reinterpret_as_3d ? 3 : 2;
+
             TensorShape vector_sum_row_shape = vector_sum_row->tensor_shape();
             vector_sum_row_shape.collapse_from(1);
-            output_shape.collapse_from(2);
+            output_shape.collapse_from(output_batch_idx);
 
-            ARM_COMPUTE_RETURN_ERROR_ON_MSG(vector_sum_row_shape[1] != output_shape[2],
+            ARM_COMPUTE_RETURN_ERROR_ON_MSG(vector_sum_row_shape[1] != output_shape[output_batch_idx],
                                             "mm_result tensor must have the same number of batches of output tensor");
 
             if(a_offset != 0)
@@ -98,20 +106,17 @@
     Window win = calculate_max_window(*mm_result, Steps(num_elems_processed_per_iteration));
 
     AccessWindowHorizontal mm_result_access(mm_result, 0, num_elems_processed_per_iteration);
-    window_changed = window_changed || update_window_and_padding(win,
-                                                                 mm_result_access);
+    window_changed = window_changed || update_window_and_padding(win, mm_result_access);
 
     if(a_offset != 0)
     {
         AccessWindowHorizontal vector_sum_col_access(vector_sum_col, 0, num_elems_processed_per_iteration);
-        window_changed = window_changed || update_window_and_padding(win,
-                                                                     vector_sum_col_access);
+        window_changed = window_changed || update_window_and_padding(win, vector_sum_col_access);
     }
     if(b_offset != 0)
     {
         AccessWindowStatic vector_sum_row_access(vector_sum_row, 0, 0, vector_sum_row->dimension(0), 0); // NOLINT
-        window_changed = window_changed || update_window_and_padding(win,
-                                                                     vector_sum_row_access);
+        window_changed = window_changed || update_window_and_padding(win, vector_sum_row_access);
     }
 
     Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
@@ -137,6 +142,11 @@
     _vector_sum_row = vector_sum_row;
     _mm_result      = mm_result;
 
+    // Check if input is a 3D reinterpretation
+    const bool reinterpret_as_3d = vector_sum_row != nullptr
+                                   && mm_result->info()->num_dimensions() > 1
+                                   && mm_result->info()->tensor_shape().y() != vector_sum_row->info()->tensor_shape().x();
+
     // Set the arguments to pass at compile time
     CLBuildOptions build_opts;
 
@@ -149,6 +159,8 @@
     // If b_offset == 0, vector_sum_row can be a nullptr
     build_opts.add_option_if(b_offset != 0, "-DB_OFFSET=" + support::cpp11::to_string(b_offset));
     build_opts.add_option("-DK_OFFSET=" + support::cpp11::to_string(a_offset * b_offset * k));
+    build_opts.add_option_if(reinterpret_as_3d, "-DHEIGHT_INPUT3D=" + support::cpp11::to_string(mm_result->info()->dimension(1)));
+    build_opts.add_option_if(reinterpret_as_3d, "-DDEPTH_INPUT3D=" + support::cpp11::to_string(mm_result->info()->dimension(2)));
 
     // Create kernel
     _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("gemmlowp_offset_contribution", build_opts.options()));
@@ -194,11 +206,13 @@
     // Set window for vector_sum_col
     Window win_vector_sum_col = slice;
     win_vector_sum_col.set(Window::DimY, Window::Dimension(0, 0, 0));
+    win_vector_sum_col.set(Window::DimZ, Window::Dimension(0, 0, 0));
 
     // Set window for vector_sum_row
     Window win_vector_sum_row = slice;
     win_vector_sum_row.set(Window::DimX, Window::Dimension(0, 0, 0));
     win_vector_sum_row.set(Window::DimY, Window::Dimension(0, 0, 0));
+    win_vector_sum_col.set(Window::DimZ, Window::Dimension(0, 0, 0));
 
     do
     {