COMPMID-1518: Add support for GEMM3D in CLGEMMLowpMatrixMultiplyCore

Change-Id: Ib14ac821ee5d4aff80bd602cd3e76e7018abb5e6
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/150268
Tested-by: bsgcomp <bsgcomp@arm.com>
Reviewed-by: Isabella Gottardi <isabella.gottardi@arm.com>
Reviewed-by: Michele DiGiorgio <michele.digiorgio@arm.com>
diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp
index 9dbfd3e..821464e 100644
--- a/src/runtime/CL/functions/CLGEMM.cpp
+++ b/src/runtime/CL/functions/CLGEMM.cpp
@@ -72,8 +72,18 @@
 } // namespace
 
 CLGEMM::CLGEMM(std::shared_ptr<IMemoryManager> memory_manager)
-    : _memory_group(std::move(memory_manager)), _interleave_kernel(), _transpose_kernel(), _mm_kernel(), _ma_kernel(), _tmp_a(), _tmp_b(), _original_b(nullptr), _is_interleaved_transposed(false),
-      _run_addition(false), _reshape_b_only_on_first_run(false), _is_prepared(false)
+    : _memory_group(std::move(memory_manager)),
+      _interleave_kernel(),
+      _transpose_kernel(),
+      _mm_kernel(),
+      _ma_kernel(),
+      _tmp_a(),
+      _tmp_b(),
+      _original_b(nullptr),
+      _is_interleaved_transposed(false),
+      _run_addition(false),
+      _reshape_b_only_on_first_run(false),
+      _is_prepared(false)
 {
 }
 
@@ -146,8 +156,9 @@
     }
 
     // Configure and tune matrix multiply kernel
-    _mm_kernel.configure(matrix_a, matrix_b, output, alpha, _is_interleaved_transposed, GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d,
-                                                                                                        reinterpret_input_as_3d));
+    _mm_kernel.configure(matrix_a, matrix_b, output, alpha, _is_interleaved_transposed, GEMMReshapeInfo(m, n, k,
+                                                                                                        mult_transpose1xW_width, mult_interleave4x4_height,
+                                                                                                        depth_output_gemm3d, reinterpret_input_as_3d));
     CLScheduler::get().tune_kernel_static(_mm_kernel);
 
     if(_is_interleaved_transposed)
diff --git a/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp b/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp
index c9daea4..bd5e969 100644
--- a/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp
+++ b/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp
@@ -130,9 +130,10 @@
 {
     const bool is_quantized = is_data_type_quantized_asymmetric(input->data_type());
 
-    const GEMMInfo &gemm_info = GEMMInfo(false, false, true /* Reshape weights only for the first run */, gemm_3d_depth, skip_im2col /* Reinterpret the input as 3D if im2col is skipped */);
     if(is_quantized)
     {
+        const GEMMInfo &gemm_info = GEMMInfo(false, false, true /* Reshape weights only for the first run */);
+
         // Since we need negative offsets for computing convolution, we need to change QuantizationInfo()
         // Extract and negate input and weights offset
         const QuantizationInfo input_quantization_info   = input->quantization_info();
@@ -148,6 +149,8 @@
     }
     else
     {
+        const GEMMInfo &gemm_info = GEMMInfo(false, false, true /* Reshape weights only for the first run */, gemm_3d_depth, skip_im2col /* Reinterpret the input as 3D if im2col is skipped */);
+
         // Perform validation step on Matrix multiply function
         return CLGEMM::validate(input, weights, nullptr, output, 1.0f, 0.0f, gemm_info);
     }
diff --git a/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp b/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
index 763ebce..1d6f343 100644
--- a/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
+++ b/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
@@ -107,15 +107,23 @@
     // Arguments used by GEMMReshapeInfo
     // If we pass the matrix A and matrix B reshaped to CLGEMMMatrixMultiplyKernel, we need to pass m, n, k, mult_transpose1xW_width and mult_interleave4x4_height to CLGEMMReshapeInfo
     // in order to know how the matrices have been reshaped
+    bool          reinterpret_input_as_3d   = gemm_info.reinterpret_input_as_3d();
     const int     m                         = a->info()->dimension(1);
     const int     n                         = b->info()->dimension(0);
     const int     k                         = a->info()->dimension(0);
+    const int     depth_output_gemm3d       = gemm_info.depth_output_gemm3d();
     constexpr int mult_transpose1xW_width   = 1;
     constexpr int mult_interleave4x4_height = 1;
 
     // Check if we need to reshape the matrix A and matrix B
     _is_interleaved_transposed = is_interleaved_transposed(m, n, k, _reshape_b_only_on_first_run, gpu_target);
 
+    // if _is_interleaved_transposed is set, force reinterpret_input_as_3d to be false as the output of CLGEMMInterleaveKernel will be 2D
+    if(_is_interleaved_transposed)
+    {
+        reinterpret_input_as_3d = false;
+    }
+
     if(_is_interleaved_transposed)
     {
         matrix_a = &_tmp_a;
@@ -128,14 +136,15 @@
         }
 
         // Configure interleave kernel
-        _mtx_a_reshape_kernel.configure(a, &_tmp_a, mult_interleave4x4_height);
+        _mtx_a_reshape_kernel.configure(a, &_tmp_a, mult_interleave4x4_height, gemm_info.reinterpret_input_as_3d());
 
         // Configure transpose kernel
         _mtx_b_reshape_kernel.configure(b, &_tmp_b, mult_transpose1xW_width);
     }
-
     // Configure matrix multiply kernel
-    _mm_kernel.configure(matrix_a, matrix_b, output, _is_interleaved_transposed, GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height));
+    _mm_kernel.configure(matrix_a, matrix_b, output, _is_interleaved_transposed, GEMMReshapeInfo(m, n, k,
+                                                                                                 mult_transpose1xW_width, mult_interleave4x4_height,
+                                                                                                 depth_output_gemm3d, reinterpret_input_as_3d));
 
     // Initialize matrix B reduction kernel only if _a_offset is not equal to 0
     if(_a_offset != 0)
@@ -191,28 +200,30 @@
     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::QASYMM8);
     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::S32);
     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, b);
-    ARM_COMPUTE_RETURN_ERROR_ON_MSG((a)->dimension(0) != (b)->dimension(1),
-                                    "The product AB is defined only if the number of columns in A is equal to the number of rows in B");
-    ARM_COMPUTE_RETURN_ERROR_ON_MSG((a)->dimension(1) != (output)->dimension(1),
-                                    "The output matrix must have the same number of rows as the matrix A");
-    ARM_COMPUTE_RETURN_ERROR_ON_MSG((b)->dimension(0) != (output)->dimension(0),
-                                    "The output matrix must have the same number of columns as the matrix B");
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_a_reshaped(), "Matrix A already reshaped is not supported");
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_b_reshaped(), "Matrix B already reshaped is not supported");
 
     int32_t a_offset = a->quantization_info().offset;
     int32_t b_offset = b->quantization_info().offset;
 
-    const int             m                         = a->dimension(1);
-    const int             n                         = b->dimension(0);
-    const int             k                         = a->dimension(0);
-    constexpr int         mult_transpose1xW_width   = 1;
-    constexpr int         mult_interleave4x4_height = 1;
-    const int             depth_output_gemm3d       = gemm_info.depth_output_gemm3d();
-    const GEMMReshapeInfo reshape_info(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d);
+    const int     m                         = a->dimension(1);
+    const int     n                         = b->dimension(0);
+    const int     k                         = a->dimension(0);
+    constexpr int mult_transpose1xW_width   = 1;
+    constexpr int mult_interleave4x4_height = 1;
+    bool          reinterpret_input_as_3d   = gemm_info.reinterpret_input_as_3d();
+    const int     depth_output_gemm3d       = gemm_info.depth_output_gemm3d();
 
     bool reshape_matrices = is_interleaved_transposed(m, n, k, gemm_info.reshape_b_only_on_first_run(), CLScheduler::get().target());
 
+    // if reshape_matrices is set, force reinterpret_input_as_3d to be false as the output of CLGEMMInterleaveKernel will be 2D
+    if(reshape_matrices)
+    {
+        reinterpret_input_as_3d = false;
+    }
+
+    const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d, reinterpret_input_as_3d);
+
     if(reshape_matrices)
     {
         TensorInfo info_a(compute_interleaved_shape(*a, mult_interleave4x4_height, gemm_info.reinterpret_input_as_3d()), 1, a->data_type());