COMPMID-1518: Add support for GEMM3D in CLGEMMLowpMatrixMultiplyCore

Change-Id: Ib14ac821ee5d4aff80bd602cd3e76e7018abb5e6
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/150268
Tested-by: bsgcomp <bsgcomp@arm.com>
Reviewed-by: Isabella Gottardi <isabella.gottardi@arm.com>
Reviewed-by: Michele DiGiorgio <michele.digiorgio@arm.com>
diff --git a/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp b/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
index 763ebce..1d6f343 100644
--- a/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
+++ b/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
@@ -107,15 +107,23 @@
     // Arguments used by GEMMReshapeInfo
     // If we pass the matrix A and matrix B reshaped to CLGEMMMatrixMultiplyKernel, we need to pass m, n, k, mult_transpose1xW_width and mult_interleave4x4_height to CLGEMMReshapeInfo
     // in order to know how the matrices have been reshaped
+    bool          reinterpret_input_as_3d   = gemm_info.reinterpret_input_as_3d();
     const int     m                         = a->info()->dimension(1);
     const int     n                         = b->info()->dimension(0);
     const int     k                         = a->info()->dimension(0);
+    const int     depth_output_gemm3d       = gemm_info.depth_output_gemm3d();
     constexpr int mult_transpose1xW_width   = 1;
     constexpr int mult_interleave4x4_height = 1;
 
     // Check if we need to reshape the matrix A and matrix B
     _is_interleaved_transposed = is_interleaved_transposed(m, n, k, _reshape_b_only_on_first_run, gpu_target);
 
+    // if _is_interleaved_transposed is set, force reinterpret_input_as_3d to be false as the output of CLGEMMInterleaveKernel will be 2D
+    if(_is_interleaved_transposed)
+    {
+        reinterpret_input_as_3d = false;
+    }
+
     if(_is_interleaved_transposed)
     {
         matrix_a = &_tmp_a;
@@ -128,14 +136,15 @@
         }
 
         // Configure interleave kernel
-        _mtx_a_reshape_kernel.configure(a, &_tmp_a, mult_interleave4x4_height);
+        _mtx_a_reshape_kernel.configure(a, &_tmp_a, mult_interleave4x4_height, gemm_info.reinterpret_input_as_3d());
 
         // Configure transpose kernel
         _mtx_b_reshape_kernel.configure(b, &_tmp_b, mult_transpose1xW_width);
     }
-
     // Configure matrix multiply kernel
-    _mm_kernel.configure(matrix_a, matrix_b, output, _is_interleaved_transposed, GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height));
+    _mm_kernel.configure(matrix_a, matrix_b, output, _is_interleaved_transposed, GEMMReshapeInfo(m, n, k,
+                                                                                                 mult_transpose1xW_width, mult_interleave4x4_height,
+                                                                                                 depth_output_gemm3d, reinterpret_input_as_3d));
 
     // Initialize matrix B reduction kernel only if _a_offset is not equal to 0
     if(_a_offset != 0)
@@ -191,28 +200,30 @@
     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::QASYMM8);
     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::S32);
     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, b);
-    ARM_COMPUTE_RETURN_ERROR_ON_MSG((a)->dimension(0) != (b)->dimension(1),
-                                    "The product AB is defined only if the number of columns in A is equal to the number of rows in B");
-    ARM_COMPUTE_RETURN_ERROR_ON_MSG((a)->dimension(1) != (output)->dimension(1),
-                                    "The output matrix must have the same number of rows as the matrix A");
-    ARM_COMPUTE_RETURN_ERROR_ON_MSG((b)->dimension(0) != (output)->dimension(0),
-                                    "The output matrix must have the same number of columns as the matrix B");
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_a_reshaped(), "Matrix A already reshaped is not supported");
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_b_reshaped(), "Matrix B already reshaped is not supported");
 
     int32_t a_offset = a->quantization_info().offset;
     int32_t b_offset = b->quantization_info().offset;
 
-    const int             m                         = a->dimension(1);
-    const int             n                         = b->dimension(0);
-    const int             k                         = a->dimension(0);
-    constexpr int         mult_transpose1xW_width   = 1;
-    constexpr int         mult_interleave4x4_height = 1;
-    const int             depth_output_gemm3d       = gemm_info.depth_output_gemm3d();
-    const GEMMReshapeInfo reshape_info(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d);
+    const int     m                         = a->dimension(1);
+    const int     n                         = b->dimension(0);
+    const int     k                         = a->dimension(0);
+    constexpr int mult_transpose1xW_width   = 1;
+    constexpr int mult_interleave4x4_height = 1;
+    bool          reinterpret_input_as_3d   = gemm_info.reinterpret_input_as_3d();
+    const int     depth_output_gemm3d       = gemm_info.depth_output_gemm3d();
 
     bool reshape_matrices = is_interleaved_transposed(m, n, k, gemm_info.reshape_b_only_on_first_run(), CLScheduler::get().target());
 
+    // if reshape_matrices is set, force reinterpret_input_as_3d to be false as the output of CLGEMMInterleaveKernel will be 2D
+    if(reshape_matrices)
+    {
+        reinterpret_input_as_3d = false;
+    }
+
+    const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d, reinterpret_input_as_3d);
+
     if(reshape_matrices)
     {
         TensorInfo info_a(compute_interleaved_shape(*a, mult_interleave4x4_height, gemm_info.reinterpret_input_as_3d()), 1, a->data_type());