COMPMID-1451: Reverting changes for CLGEMM and CLGEMMLowp previuosly done (384496)
              Mirroring CLGEMM behaviour to CLGEMMLowp

Change-Id: I308b54e2c0de131a5322b77e83e7454db498d692
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/153175
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Tested-by: bsgcomp <bsgcomp@arm.com>
diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp
index 0c82e6d..6adbdc0 100644
--- a/src/runtime/CL/functions/CLGEMM.cpp
+++ b/src/runtime/CL/functions/CLGEMM.cpp
@@ -133,10 +133,7 @@
     if(_is_interleaved_transposed)
     {
         reinterpret_input_as_3d = false;
-    }
 
-    if(_is_interleaved_transposed)
-    {
         matrix_a = &_tmp_a;
         matrix_b = &_tmp_b;
 
@@ -200,8 +197,7 @@
     // If we pass the matrix A and matrix B reshaped to CLGEMMMatrixMultiplyKernel, we need to pass m, n, k, mult_transpose1xW_width and mult_interleave4x4_height to CLGEMMReshapeInfo
     // in order to know how the matrices have been reshaped
     bool      reinterpret_input_as_3d   = gemm_info.reinterpret_input_as_3d();
-    bool      reinterpret_output_as_3d  = (gemm_info.depth_output_gemm3d() != 1);
-    const int m                         = a->dimension(1);
+    const int m                         = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
     const int n                         = b->dimension(0);
     const int k                         = a->dimension(0);
     int       mult_transpose1xW_width   = 1;
@@ -217,21 +213,13 @@
     // Check if we need to reshape the matrix A and matrix B
     const bool run_interleave_transpose = is_interleaved_transposed(m, n, k, a->data_type(), reshape_b_only_on_first_run, gpu_target);
 
-    // In case both input and output have to be reinterpreted as 3D tensors,
-    // force reinterpret_input_as_3d and reinterpret_output_as_3d to be false.
-    if(reinterpret_input_as_3d == reinterpret_output_as_3d)
-    {
-        reinterpret_input_as_3d  = false;
-        reinterpret_output_as_3d = false;
-    }
-
     // if _is_interleaved_transposed is set, force reinterpret_input_as_3d to be false as the output of CLGEMMInterleaveKernel will be 2D
     if(run_interleave_transpose)
     {
         reinterpret_input_as_3d = false;
     }
 
-    const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, reinterpret_output_as_3d ? depth_output_gemm3d : 1, reinterpret_input_as_3d);
+    const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d, reinterpret_input_as_3d);
 
     if(run_interleave_transpose)
     {
@@ -239,8 +227,8 @@
         matrix_b_info = &tmp_b_info;
 
         // Validate interleave kernel
-        auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_interleaved_shape(*a, mult_interleave4x4_height, reinterpret_input_as_3d)));
-        ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMInterleave4x4Kernel::validate(a, &tmp_a_info, mult_interleave4x4_height, reinterpret_input_as_3d));
+        auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_interleaved_shape(*a, mult_interleave4x4_height, gemm_info.reinterpret_input_as_3d())));
+        ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMInterleave4x4Kernel::validate(a, &tmp_a_info, mult_interleave4x4_height, gemm_info.reinterpret_input_as_3d()));
 
         // Validate transpose kernel
         auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_transpose1xW_with_element_size_shape(*b, mult_transpose1xW_width)));