COMPMID-1710: Fixing gemm_mm_reshaped_lhs_nt_rhs_t with REINTERPRET_OUTPUT_AS_3D

Change-Id: I9af1f7263c6e71e38af97f3112d35044cf60ddf0
Reviewed-on: https://review.mlplatform.org/403
Reviewed-by: Anthony Barbier <Anthony.barbier@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/core/CL/cl_kernels/gemm.cl b/src/core/CL/cl_kernels/gemm.cl
index d37dd2d..44b50b3 100644
--- a/src/core/CL/cl_kernels/gemm.cl
+++ b/src/core/CL/cl_kernels/gemm.cl
@@ -942,13 +942,13 @@
  * @param[in]  lhs_stride_y                      Stride of the LHS reshaped matrix in Y dimension (in bytes)
  * @param[in]  lhs_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
  * @param[in]  lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
- * @param[in]  rhs_ptr                           Pointer to the RHS reshaped matrix. Supported data type: same as @p src0_ptr
+ * @param[in]  rhs_ptr                           Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
  * @param[in]  rhs_stride_x                      Stride of the RHS reshaped matrix in X dimension (in bytes)
  * @param[in]  rhs_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
  * @param[in]  rhs_stride_y                      Stride of the RHS reshaped matrix in Y dimension (in bytes)
  * @param[in]  rhs_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
  * @param[in]  rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
- * @param[out] dst_ptr                           Pointer to the destination matrix Supported data type: S32
+ * @param[out] dst_ptr                           Pointer to the destination matrix Supported data type: same as @p lhs_ptr
  * @param[in]  dst_stride_x                      Stride of the destination matrix in X dimension (in bytes)
  * @param[in]  dst_step_x                        dst_stride_x * number of elements along X processed per workitem(in bytes)
  * @param[in]  dst_stride_y                      Stride of the destination matrix in Y dimension (in bytes)
@@ -1182,41 +1182,41 @@
     // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
     zout0 = (0 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
     zout0 = min((uint)(DEPTH_GEMM3D - 1), zout0);
-    zout0 *= (dst_cross_plane_pad * dst_stride_z);
+    zout0 *= (dst_cross_plane_pad * dst_stride_y);
 #if M0 > 1
     zout1 = (1 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
     zout1 = min((uint)(DEPTH_GEMM3D - 1), zout1);
-    zout1 *= (dst_cross_plane_pad * dst_stride_z);
+    zout1 *= (dst_cross_plane_pad * dst_stride_y);
 #endif // M0 > 1
 #if M0 > 2
     zout2 = (2 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
     zout2 = min((uint)(DEPTH_GEMM3D - 1), zout2);
-    zout2 *= (dst_cross_plane_pad * dst_stride_z);
+    zout2 *= (dst_cross_plane_pad * dst_stride_y);
 #endif // M0 > 2
 #if M0 > 3
     zout3 = (3 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
     zout3 = min((uint)(DEPTH_GEMM3D - 1), zout3);
-    zout3 *= (dst_cross_plane_pad * dst_stride_z);
+    zout3 *= (dst_cross_plane_pad * dst_stride_y);
 #endif // M0 > 3
 #if M0 > 4
     zout4 = (4 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
     zout4 = min((uint)(DEPTH_GEMM3D - 1), zout4);
-    zout4 *= (dst_cross_plane_pad * dst_stride_z);
+    zout4 *= (dst_cross_plane_pad * dst_stride_y);
 #endif // M0 > 4
 #if M0 > 5
     zout5 = (5 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
     zout5 = min((uint)(DEPTH_GEMM3D - 1), zout5);
-    zout5 *= (dst_cross_plane_pad * dst_stride_z);
+    zout5 *= (dst_cross_plane_pad * dst_stride_y);
 #endif // M0 > 5
 #if M0 > 6
     zout6 = (6 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
     zout6 = min((uint)(DEPTH_GEMM3D - 1), zout6);
-    zout6 *= (dst_cross_plane_pad * dst_stride_z);
+    zout6 *= (dst_cross_plane_pad * dst_stride_y);
 #endif // M0 > 6
 #if M0 > 6
     zout7 = (7 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
     zout7 = min((uint)(DEPTH_GEMM3D - 1), zout7);
-    zout7 *= (dst_cross_plane_pad * dst_stride_z);
+    zout7 *= (dst_cross_plane_pad * dst_stride_y);
 #endif // M0 > 7
 
     // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we