COMPMID-1276 - Allow GEMM to work with 3D input tensor

Skipped im2col in CLGEMMConvolutionLayer for 1x1 convolutions with NHWC data layout

Change-Id: I894e6b952ed8605e8f3ffc0ffc25c24730d4664c
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/141909
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
diff --git a/src/core/CL/cl_kernels/gemm.cl b/src/core/CL/cl_kernels/gemm.cl
index 5a6efe6..932e0d6 100644
--- a/src/core/CL/cl_kernels/gemm.cl
+++ b/src/core/CL/cl_kernels/gemm.cl
@@ -88,6 +88,11 @@
  *
  * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float)
  * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
+ * @note In case the input has to be reinterpreted as a 3D tensor (i.e. input of convolution layer 1x1), the following information must be passed at compile time:
+ *       -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
+ *       -# HEIGHT_GEMM3D: The height of the input in case it has to be reinterpreted as a 3D tensor.
+ *       -# DEPTH_GEMM3D: The depth of the input in case it has to be reinterpreted as a 3D tensor
+ *          (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
  *
  * @param[in]  src_ptr                           Pointer to the source matrix. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
  * @param[in]  src_stride_x                      Stride of the source matrix in X dimension (in bytes)
@@ -105,9 +110,15 @@
  * @param[in]  dst_stride_z                      Stride of the destination tensor in Z dimension (in bytes)
  * @param[in]  dst_step_z                        dst_stride_z * number of elements along Z processed per workitem(in bytes)
  * @param[in]  dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
+ * @param[in]  cross_plane_pad                   (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
  */
 __kernel void gemm_interleave4x4(TENSOR3D_DECLARATION(src),
-                                 TENSOR3D_DECLARATION(dst))
+                                 TENSOR3D_DECLARATION(dst)
+#if defined(REINTERPRET_INPUT_AS_3D)
+                                 ,
+                                 uint cross_plane_pad
+#endif // REINTERPRET_INPUT_AS_3D
+                                )
 {
     // Compute source and destination addresses
     uint x = get_global_id(0);
@@ -124,6 +135,45 @@
     // Add offset for batched GEMM
     dst_addr_in_bytes += z * dst_stride_z;
 
+#if defined(REINTERPRET_INPUT_AS_3D)
+    __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * 4 * sizeof(DATA_TYPE) + y * 4 * src_stride_y;
+
+    // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
+    // in order to take into account the presence of possible cross plane paddings
+    //
+    //  |                  |
+    //  |      plane0      |
+    //  |                  |
+    //  |__________________|
+    //  |******************|
+    //  |  cross_plane_pad |
+    //  |******************|
+    //  |                  |
+    //  |      plane1      |
+    //  |                  |
+    //  |__________________|
+
+    // The plane (zin) is calculated dividing M (y * 4) by HEIGHT_GEMM3D
+    uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(y * 4)) / (uint4)HEIGHT_GEMM3D;
+    zin       = min(DEPTH_GEMM3D - 1, zin);
+
+    // Add offset due to the cross plane paddings
+    zin *= (cross_plane_pad * src_stride_y);
+
+    // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
+    // multiply src_stride_z by DEPTH_GEMM3D
+    input_ptr += z * src_stride_z * DEPTH_GEMM3D;
+
+    // Load values from Matrix A
+    VEC_DATA_TYPE(DATA_TYPE, 4)
+    a0 = vload4(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y + zin.s0));
+    VEC_DATA_TYPE(DATA_TYPE, 4)
+    a1 = vload4(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y + zin.s1));
+    VEC_DATA_TYPE(DATA_TYPE, 4)
+    a2 = vload4(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y + zin.s2));
+    VEC_DATA_TYPE(DATA_TYPE, 4)
+    a3 = vload4(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y + zin.s3));
+#else  // defined(REINTERPRET_INPUT_AS_3D)
     __global uchar *input_ptr = src.ptr;
 
     // Load values from Matrix A
@@ -135,6 +185,7 @@
     a2 = vload4(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y));
     VEC_DATA_TYPE(DATA_TYPE, 4)
     a3 = vload4(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y));
+#endif // defined(REINTERPRET_INPUT_AS_3D)
 
     VEC_DATA_TYPE(DATA_TYPE, 4)
     val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s0, a1.s0, a2.s0, a3.s0);
@@ -188,7 +239,7 @@
  * @param[in]  src0_stride_z                      Stride of the source matrix in Z dimension (in bytes)
  * @param[in]  src1_stride_z                      Stride of the source matrix in Z dimension (in bytes)
  * @param[in]  dst_stride_z                       Stride of the destination tensor in Z dimension (in bytes)
- * @param[in]  cross_plane_pad                    Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
+ * @param[in]  cross_plane_pad                    (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
  */
 __kernel void gemm_mm_interleaved_transposed_f32(IMAGE_DECLARATION(src0),
                                                  IMAGE_DECLARATION(src1),
@@ -366,7 +417,7 @@
  * @param[in]  src0_stride_z                      Stride of the source matrix in Z dimension (in bytes)
  * @param[in]  src1_stride_z                      Stride of the source matrix in Z dimension (in bytes)
  * @param[in]  dst_stride_z                       Stride of the destination tensor in Z dimension (in bytes)
- * @param[in]  cross_plane_pad                    Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
+ * @param[in]  cross_plane_pad                    (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
  */
 __kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0),
                                                          IMAGE_DECLARATION(src1),
@@ -679,7 +730,7 @@
  * @param[in]  src0_stride_z                      Stride of the source matrix in Z dimension (in bytes)
  * @param[in]  src1_stride_z                      Stride of the source matrix in Z dimension (in bytes)
  * @param[in]  dst_stride_z                       Stride of the destination tensor in Z dimension (in bytes)
- * @param[in]  cross_plane_pad                    Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
+ * @param[in]  cross_plane_pad                    (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
  */
 __kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0),
                                                  IMAGE_DECLARATION(src1),
@@ -853,7 +904,7 @@
  * @param[in]  dst_stride_y                       Stride of the destination matrix in Y dimension (in bytes)
  * @param[in]  dst_step_y                         dst_stride_y * number of elements along Y processed per workitem(in bytes)
  * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination matrix
- * @param[in]  cross_plane_pad                    Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
+ * @param[in]  cross_plane_pad                    (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
  */
 __kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0),
                                                          IMAGE_DECLARATION(src1),
@@ -1095,7 +1146,8 @@
  * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
  *       This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
  *
- * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
+ * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
+ *       -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
  *       -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
  *       -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
  *       -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
@@ -1122,7 +1174,8 @@
  * @param[in]  src0_stride_z                      Stride of the source matrix in Z dimension (in bytes)
  * @param[in]  src1_stride_z                      Stride of the source matrix in Z dimension (in bytes)
  * @param[in]  dst_stride_z                       Stride of the destination tensor in Z dimension (in bytes)
- * @param[in]  cross_plane_pad                    Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
+ * @param[in]  src_cross_plane_pad                (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
+ * @param[in]  dst_cross_plane_pad                (Optional) Bottom paddings in unit of elements for the output tensor (only if defined REINTERPRET_OUTPUT_AS_3D)
  */
 __kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
                                      IMAGE_DECLARATION(src1),
@@ -1130,9 +1183,13 @@
                                      uint src0_stride_z,
                                      uint src1_stride_z,
                                      uint dst_stride_z
+#if defined(REINTERPRET_INPUT_AS_3D)
+                                     ,
+                                     uint src_cross_plane_pad
+#endif // REINTERPRET_INPUT_AS_3D
 #if defined(REINTERPRET_OUTPUT_AS_3D)
                                      ,
-                                     uint cross_plane_pad
+                                     uint dst_cross_plane_pad
 #endif // REINTERPRET_OUTPUT_AS_3D
                                     )
 {
@@ -1147,9 +1204,40 @@
     // Update address for the matrix B
     src_addr.s1 += idx * sizeof(DATA_TYPE);
 
+#if defined(REINTERPRET_INPUT_AS_3D)
+    // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
+    // in order to take into account the presence of possible cross plane paddings
+    //
+    //  |                  |
+    //  |      plane0      |
+    //  |                  |
+    //  |__________________|
+    //  |******************|
+    //  |  cross_plane_pad |
+    //  |******************|
+    //  |                  |
+    //  |      plane1      |
+    //  |                  |
+    //  |__________________|
+
+    // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
+    uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
+    zin       = min(DEPTH_GEMM3D - 1, zin);
+
+    // Add offset due to the cross plane paddings
+    zin *= (src_cross_plane_pad * src0_stride_y);
+
+    // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
+    // multiply src0_stride_z by DEPTH_GEMM3D
+    src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
+
+#else // defined(REINTERPRET_INPUT_AS_3D)
+
     // Add offset for batched GEMM
     src_addr.s0 += get_global_id(2) * src0_stride_z;
 
+#endif // defined(REINTERPRET_INPUT_AS_3D)
+
 #if defined(MATRIX_B_DEPTH)
     // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
     src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
@@ -1172,6 +1260,23 @@
 
     for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(DATA_TYPE)); src_addr += (int2)(2 * sizeof(DATA_TYPE), 2 * src1_stride_y))
     {
+#if defined(REINTERPRET_INPUT_AS_3D)
+        // Load values from matrix A
+        VEC_DATA_TYPE(DATA_TYPE, 2)
+        a0 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+        VEC_DATA_TYPE(DATA_TYPE, 2)
+        a1 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+        VEC_DATA_TYPE(DATA_TYPE, 2)
+        a2 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+        VEC_DATA_TYPE(DATA_TYPE, 2)
+        a3 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#else  // defined(REINTERPRET_INPUT_AS_3D)
         // Load values from matrix A
         VEC_DATA_TYPE(DATA_TYPE, 2)
         a0 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
@@ -1187,6 +1292,8 @@
         VEC_DATA_TYPE(DATA_TYPE, 2)
         a3 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#endif // defined(REINTERPRET_INPUT_AS_3D)
+
         // Load values from matrix B
         VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
         VECTOR_TYPE b1 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1 + src1_stride_y));
@@ -1210,6 +1317,19 @@
 
     for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(DATA_TYPE), src1_stride_y))
     {
+#if defined(REINTERPRET_INPUT_AS_3D)
+        // Load values from matrix A
+        DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+        DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+        DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+        DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#else  // defined(REINTERPRET_INPUT_AS_3D)
         // Load values from matrix A
         DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
@@ -1221,6 +1341,8 @@
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
         DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#endif // defined(REINTERPRET_INPUT_AS_3D)
+
         // Load values from matrix B
         VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
 
@@ -1280,7 +1402,7 @@
     zout       = min(DEPTH_GEMM3D - 1, zout);
 
     // Add offset due to the cross plane paddings
-    zout *= (cross_plane_pad * dst_stride_y);
+    zout *= (dst_cross_plane_pad * dst_stride_y);
 
     // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
     // multiply dst_stride_z by DEPTH_GEMM3D
@@ -1335,7 +1457,8 @@
  * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
  *       This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
  *
- * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
+ * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
+ *       -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
  *       -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
  *       -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
  *       -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
@@ -1362,7 +1485,8 @@
  * @param[in]  src0_stride_z                      Stride of the source matrix in Z dimension (in bytes)
  * @param[in]  src1_stride_z                      Stride of the source matrix in Z dimension (in bytes)
  * @param[in]  dst_stride_z                       Stride of the destination tensor in Z dimension (in bytes)
- * @param[in]  cross_plane_pad                    Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
+ * @param[in]  src_cross_plane_pad                (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
+ * @param[in]  dst_cross_plane_pad                (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
  */
 __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
                                                  IMAGE_DECLARATION(src1),
@@ -1370,9 +1494,13 @@
                                                  uint src0_stride_z,
                                                  uint src1_stride_z,
                                                  uint dst_stride_z
+#if defined(REINTERPRET_INPUT_AS_3D)
+                                                 ,
+                                                 uint src_cross_plane_pad
+#endif // REINTERPRET_INPUT_AS_3D
 #if defined(REINTERPRET_OUTPUT_AS_3D)
                                                  ,
-                                                 uint cross_plane_pad
+                                                 uint dst_cross_plane_pad
 #endif // REINTERPRET_OUTPUT_AS_3D
                                                 )
 {
@@ -1387,9 +1515,40 @@
     // Update address for matrix B
     src_addr.s1 += idx * sizeof(float);
 
+#if defined(REINTERPRET_INPUT_AS_3D)
+    // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
+    // in order to take into account the presence of possible cross plane paddings
+    //
+    //  |                  |
+    //  |      plane0      |
+    //  |                  |
+    //  |__________________|
+    //  |******************|
+    //  |  cross_plane_pad |
+    //  |******************|
+    //  |                  |
+    //  |      plane1      |
+    //  |                  |
+    //  |__________________|
+
+    // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
+    uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
+    zin       = min(DEPTH_GEMM3D - 1, zin);
+
+    // Add offset due to the cross plane paddings
+    zin *= (src_cross_plane_pad * src0_stride_y);
+
+    // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
+    // multiply src0_stride_z by DEPTH_GEMM3D
+    src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
+
+#else // defined(REINTERPRET_INPUT_AS_3D)
+
     // Add offset for batched GEMM
     src_addr.s0 += get_global_id(2) * src0_stride_z;
 
+#endif // defined(REINTERPRET_INPUT_AS_3D)
+
 #if defined(MATRIX_B_DEPTH)
     // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
     src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
@@ -1428,6 +1587,19 @@
     int i = 0;
     for(; i <= ((int)COLS_A - 4); i += 4)
     {
+#if defined(REINTERPRET_INPUT_AS_3D)
+        // Load values from matrix A and matrix B
+        float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+        float4 a1 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+        float4 a2 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+        float4 a3 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#else  // defined(REINTERPRET_INPUT_AS_3D)
         // Load values from matrix A and matrix B
         float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
@@ -1439,6 +1611,8 @@
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
         float4 a3 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#endif // defined(REINTERPRET_INPUT_AS_3D)
+
         float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
         src_addr.s1 += src1_stride_y;
 
@@ -1579,8 +1753,21 @@
 
     for(; i < (int)COLS_A; ++i)
     {
+#if defined(REINTERPRET_INPUT_AS_3D)
         // Load values from matrix A
-        float a0 = *((__global float *)(src0_ptr + src_addr.s0));
+        float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+        float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+        float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+        float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#else  // defined(REINTERPRET_INPUT_AS_3D)
+        // Load values from matrix A
+        float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
         float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
@@ -1590,6 +1777,8 @@
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
         float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#endif // defined(REINTERPRET_INPUT_AS_3D)
+
         // Load values from matrix B
         float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
         src_addr.s1 += src1_stride_y;
@@ -1676,7 +1865,7 @@
     zout       = min(DEPTH_GEMM3D - 1, zout);
 
     // Add offset due to the cross plane paddings
-    zout *= (cross_plane_pad * dst_stride_y);
+    zout *= (dst_cross_plane_pad * dst_stride_y);
 
     // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
     // multiply dst_stride_z by DEPTH_GEMM3D
@@ -1723,7 +1912,8 @@
  * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
  *       This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
  *
- * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
+ * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
+ *       -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
  *       -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
  *       -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
  *       -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
@@ -1750,7 +1940,8 @@
  * @param[in]  src0_stride_z                      Stride of the source matrix in Z dimension (in bytes)
  * @param[in]  src1_stride_z                      Stride of the source matrix in Z dimension (in bytes)
  * @param[in]  dst_stride_z                       Stride of the destination tensor in Z dimension (in bytes)
- * @param[in]  cross_plane_pad                    Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
+ * @param[in]  src_cross_plane_pad                (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
+ * @param[in]  dst_cross_plane_pad                (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
  */
 __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
                                                       IMAGE_DECLARATION(src1),
@@ -1758,9 +1949,13 @@
                                                       uint src0_stride_z,
                                                       uint src1_stride_z,
                                                       uint dst_stride_z
+#if defined(REINTERPRET_INPUT_AS_3D)
+                                                      ,
+                                                      uint src_cross_plane_pad
+#endif // REINTERPRET_INPUT_AS_3D
 #if defined(REINTERPRET_OUTPUT_AS_3D)
                                                       ,
-                                                      uint cross_plane_pad
+                                                      uint dst_cross_plane_pad
 #endif // REINTERPRET_OUTPUT_AS_3D
                                                      )
 {
@@ -1776,9 +1971,40 @@
     // Update address for the matrix B
     src_addr.s1 += idx * sizeof(float);
 
+#if defined(REINTERPRET_INPUT_AS_3D)
+    // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
+    // in order to take into account the presence of possible cross plane paddings
+    //
+    //  |                  |
+    //  |      plane0      |
+    //  |                  |
+    //  |__________________|
+    //  |******************|
+    //  |  cross_plane_pad |
+    //  |******************|
+    //  |                  |
+    //  |      plane1      |
+    //  |                  |
+    //  |__________________|
+
+    // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
+    uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
+    zin       = min(DEPTH_GEMM3D - 1, zin);
+
+    // Add offset due to the cross plane paddings
+    zin *= (src_cross_plane_pad * src0_stride_y);
+
+    // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
+    // multiply src0_stride_z by DEPTH_GEMM3D
+    src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
+
+#else // defined(REINTERPRET_INPUT_AS_3D)
+
     // Add offset for batched GEMM
     src_addr.s0 += get_global_id(2) * src0_stride_z;
 
+#endif // defined(REINTERPRET_INPUT_AS_3D)
+
 #if defined(MATRIX_B_DEPTH)
     // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
     src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
@@ -1807,8 +2033,13 @@
     int i = 0;
     for(; i <= ((int)COLS_A - 8); i += 8)
     {
+#if defined(REINTERPRET_INPUT_AS_3D)
+        // Load values from matrix A
+        float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + zin.s0));
+#else  // defined(REINTERPRET_INPUT_AS_3D)
         // Load values from matrix A
         float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0));
+#endif // defined(REINTERPRET_INPUT_AS_3D)
 
         // Load values from matrix B
         float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
@@ -1848,7 +2079,11 @@
         acc01 = fma(a0.s7, b7.s1, acc01);
 
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-        a0    = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
+#if defined(REINTERPRET_INPUT_AS_3D)
+        a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
+#else  // defined(REINTERPRET_INPUT_AS_3D)
+        a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
+#endif // defined(REINTERPRET_INPUT_AS_3D)
         acc10 = fma(a0.s0, b0.s0, acc10);
         acc10 = fma(a0.s1, b1.s0, acc10);
         acc10 = fma(a0.s2, b2.s0, acc10);
@@ -1868,7 +2103,11 @@
         acc11 = fma(a0.s7, b7.s1, acc11);
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-        a0    = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
+#if defined(REINTERPRET_INPUT_AS_3D)
+        a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
+#else  // defined(REINTERPRET_INPUT_AS_3D)
+        a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
+#endif // defined(REINTERPRET_INPUT_AS_3D)
         acc20 = fma(a0.s0, b0.s0, acc20);
         acc20 = fma(a0.s1, b1.s0, acc20);
         acc20 = fma(a0.s2, b2.s0, acc20);
@@ -1888,7 +2127,11 @@
         acc21 = fma(a0.s7, b7.s1, acc21);
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-        a0    = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
+#if defined(REINTERPRET_INPUT_AS_3D)
+        a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
+#else  // defined(REINTERPRET_INPUT_AS_3D)
+        a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
+#endif // defined(REINTERPRET_INPUT_AS_3D)
         acc30 = fma(a0.s0, b0.s0, acc30);
         acc30 = fma(a0.s1, b1.s0, acc30);
         acc30 = fma(a0.s2, b2.s0, acc30);
@@ -1913,6 +2156,19 @@
     // float size increment
     for(; i < (int)COLS_A; ++i)
     {
+#if defined(REINTERPRET_INPUT_AS_3D)
+        // Load values from matrix A
+        float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+        float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+        float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+        float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#else  // defined(REINTERPRET_INPUT_AS_3D)
         // Load values from matrix A
         float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
@@ -1924,6 +2180,8 @@
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
         float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#endif // defined(REINTERPRET_INPUT_AS_3D)
+
         // Load values from matrix B
         float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
         src_addr.s1 += src1_stride_y;
@@ -1994,7 +2252,7 @@
     zout       = min(DEPTH_GEMM3D - 1, zout);
 
     // Add offset due to the cross plane paddings
-    zout *= (cross_plane_pad * dst_stride_y);
+    zout *= (dst_cross_plane_pad * dst_stride_y);
 
     // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
     // multiply dst_stride_z by DEPTH_GEMM3D
@@ -2041,7 +2299,8 @@
  * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
  *       This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
  *
- * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
+ * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
+ *       -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
  *       -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
  *       -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
  *       -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
@@ -2068,7 +2327,8 @@
  * @param[in]  src0_stride_z                      Stride of the source matrix in Z dimension (in bytes)
  * @param[in]  src1_stride_z                      Stride of the source matrix in Z dimension (in bytes)
  * @param[in]  dst_stride_z                       Stride of the destination tensor in Z dimension (in bytes)
- * @param[in]  cross_plane_pad                    Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
+ * @param[in]  src_cross_plane_pad                (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
+ * @param[in]  dst_cross_plane_pad                (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
  */
 __kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0),
                                                  IMAGE_DECLARATION(src1),
@@ -2076,9 +2336,13 @@
                                                  uint src0_stride_z,
                                                  uint src1_stride_z,
                                                  uint dst_stride_z
+#if defined(REINTERPRET_INPUT_AS_3D)
+                                                 ,
+                                                 uint src_cross_plane_pad
+#endif // REINTERPRET_INPUT_AS_3D
 #if defined(REINTERPRET_OUTPUT_AS_3D)
                                                  ,
-                                                 uint cross_plane_pad
+                                                 uint dst_cross_plane_pad
 #endif // REINTERPRET_OUTPUT_AS_3D
                                                 )
 {
@@ -2093,9 +2357,40 @@
     // Update address for the matrix B
     src_addr.s1 += idx * sizeof(half);
 
+#if defined(REINTERPRET_INPUT_AS_3D)
+    // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
+    // in order to take into account the presence of possible cross plane paddings
+    //
+    //  |                  |
+    //  |      plane0      |
+    //  |                  |
+    //  |__________________|
+    //  |******************|
+    //  |  cross_plane_pad |
+    //  |******************|
+    //  |                  |
+    //  |      plane1      |
+    //  |                  |
+    //  |__________________|
+
+    // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
+    uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
+    zin       = min(DEPTH_GEMM3D - 1, zin);
+
+    // Add offset due to the cross plane paddings
+    zin *= (src_cross_plane_pad * src0_stride_y);
+
+    // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
+    // multiply src0_stride_z by DEPTH_GEMM3D
+    src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
+
+#else // defined(REINTERPRET_INPUT_AS_3D)
+
     // Add offset for batched GEMM
     src_addr.s0 += get_global_id(2) * src0_stride_z;
 
+#endif // defined(REINTERPRET_INPUT_AS_3D)
+
 #if defined(MATRIX_B_DEPTH)
     // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
     src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
@@ -2117,6 +2412,19 @@
     int i = 0;
     for(; i <= ((int)COLS_A - 4); i += 4)
     {
+#if defined(REINTERPRET_INPUT_AS_3D)
+        // Load values from matrix A
+        half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+        half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+        half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+        half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#else  // defined(REINTERPRET_INPUT_AS_3D)
         // Load values from matrix A
         half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
@@ -2128,6 +2436,8 @@
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
         half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#endif // defined(REINTERPRET_INPUT_AS_3D)
+
         // Load values from matrix B
         half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
         src_addr.s1 += src1_stride_y;
@@ -2188,6 +2498,19 @@
 
     for(; i < (int)COLS_A; ++i)
     {
+#if defined(REINTERPRET_INPUT_AS_3D)
+        // Load values from matrix A
+        half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+        half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+        half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+        half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#else  // defined(REINTERPRET_INPUT_AS_3D)
         // Load values from matrix A
         half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
@@ -2199,6 +2522,8 @@
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
         half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#endif // defined(REINTERPRET_INPUT_AS_3D)
+
         // Load values from matrix B
         half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
 
@@ -2260,7 +2585,7 @@
     zout       = min(DEPTH_GEMM3D - 1, zout);
 
     // Add offset due to the cross plane paddings
-    zout *= (cross_plane_pad * dst_stride_y);
+    zout *= (dst_cross_plane_pad * dst_stride_y);
 
     // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
     // multiply dst_stride_z by DEPTH_GEMM3D
diff --git a/src/core/CL/kernels/CLGEMMInterleave4x4Kernel.cpp b/src/core/CL/kernels/CLGEMMInterleave4x4Kernel.cpp
index 12a40cd..6ea1160 100644
--- a/src/core/CL/kernels/CLGEMMInterleave4x4Kernel.cpp
+++ b/src/core/CL/kernels/CLGEMMInterleave4x4Kernel.cpp
@@ -23,6 +23,7 @@
  */
 #include "arm_compute/core/CL/kernels/CLGEMMInterleave4x4Kernel.h"
 
+#include "arm_compute/core/AccessWindowStatic.h"
 #include "arm_compute/core/CL/CLHelpers.h"
 #include "arm_compute/core/CL/CLKernelLibrary.h"
 #include "arm_compute/core/CL/CLValidate.h"
@@ -30,6 +31,7 @@
 #include "arm_compute/core/CL/OpenCL.h"
 #include "arm_compute/core/Error.h"
 #include "arm_compute/core/Helpers.h"
+#include "arm_compute/core/TensorInfo.h"
 #include "arm_compute/core/Types.h"
 #include "arm_compute/core/Utils.h"
 #include "arm_compute/core/Window.h"
@@ -40,7 +42,7 @@
 
 namespace
 {
-Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, int mult_interleave4x4_height)
+Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, int mult_interleave4x4_height, bool reinterpret_input_as_3d)
 {
     ARM_COMPUTE_RETURN_ERROR_ON(mult_interleave4x4_height < 1);
     ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input);
@@ -50,24 +52,30 @@
 
     if(output->total_size() != 0)
     {
-        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), compute_interleaved_shape(*input, mult_interleave4x4_height));
+        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), compute_interleaved_shape(*input, mult_interleave4x4_height, reinterpret_input_as_3d));
         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
     }
 
     return Status{};
 }
 
-std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, int mult_interleave4x4_height)
+std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, int mult_interleave4x4_height, bool reinterpret_input_as_3d)
 {
     constexpr unsigned int num_elems_processed_per_iteration_x = 4;
     constexpr unsigned int num_elems_processed_per_iteration_y = 4;
     const unsigned int     num_elems_written_per_iteration     = num_elems_processed_per_iteration_x * num_elems_processed_per_iteration_y * mult_interleave4x4_height;
     bool                   window_changed                      = false;
 
-    // Configure kernel window
-    Window                win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
-    AccessWindowRectangle input_access(input, 0, 0, num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y);
-    window_changed = window_changed || update_window_and_padding(win, input_access);
+    TensorInfo tmp_info(*input);
+
+    if(reinterpret_input_as_3d)
+    {
+        // Since the input tensor has to be reinterpreted as 3D and the execute window is based on a 2D interleave,
+        // the window needs to be constructed on the 2D collapsed version of the tensor
+        TensorShape tmp_shape(input->tensor_shape());
+        tmp_shape.collapse(2U, 1U);
+        tmp_info.set_tensor_shape(tmp_shape);
+    }
 
     // Output auto inizialitation if not yet initialized
     auto_init_if_empty(*output, input->clone()->set_tensor_shape(compute_interleaved_shape(*input, mult_interleave4x4_height)));
@@ -76,9 +84,22 @@
     const float scale_x = 4.0f * static_cast<float>(mult_interleave4x4_height);
     const float scale_y = 1.0f / (scale_x);
 
+    // Note: bottom paddings are calculated manually as the input can be reinterpreted as 3D tensor
+    // The only way to set properly the paddings, it is to set those explicitly through the AccessWindowStatic
+    const int m          = reinterpret_input_as_3d ? input->tensor_shape()[1] * input->tensor_shape()[2] : input->tensor_shape()[1];
+    const int bottom_pad = (num_elems_processed_per_iteration_y - (m % num_elems_processed_per_iteration_y)) % num_elems_processed_per_iteration_y;
+
+    Window win    = calculate_max_window(tmp_info, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
+    Window win_in = calculate_max_window(*input, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
+
+    AccessWindowStatic input_access(input, 0, 0,
+                                    ceil_to_multiple(input->dimension(0), num_elems_processed_per_iteration_x),
+                                    input->dimension(1) + bottom_pad);
     AccessWindowRectangle output_access(output, 0, 0, num_elems_written_per_iteration, 1, scale_x, scale_y);
-    window_changed = window_changed || update_window_and_padding(win, output_access);
-    output_access.set_valid_region(win, input->valid_region());
+
+    window_changed = update_window_and_padding(win_in, input_access) || // window used by the execute_window_loop
+                     update_window_and_padding(win, output_access);     // window used to update the padding requirements of output tensor
+    output_access.set_valid_region(win, ValidRegion(Coordinates(0, 0), output->tensor_shape()));
 
     // Collapse along the Z direction
     // This collapse needs to be here in order to tune the Z dimension of LWS
@@ -90,26 +111,31 @@
 } // namespace
 
 CLGEMMInterleave4x4Kernel::CLGEMMInterleave4x4Kernel()
-    : _input(nullptr), _output(nullptr)
+    : _input(nullptr), _output(nullptr), _reinterpret_input_as_3d(false)
 {
 }
 
-void CLGEMMInterleave4x4Kernel::configure(const ICLTensor *input, ICLTensor *output, int mult_interleave4x4_height)
+void CLGEMMInterleave4x4Kernel::configure(const ICLTensor *input, ICLTensor *output, int mult_interleave4x4_height, bool reinterpret_input_as_3d)
 {
     ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
 
     // Output auto inizialitation if not yet initialized
-    auto_init_if_empty(*output->info(), input->info()->clone()->set_tensor_shape(compute_interleaved_shape(*input->info(), mult_interleave4x4_height)));
+    auto_init_if_empty(*output->info(), input->info()->clone()->set_tensor_shape(compute_interleaved_shape(*input->info(), mult_interleave4x4_height, reinterpret_input_as_3d)));
 
     // Perform validate step
-    ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), mult_interleave4x4_height));
+    ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), mult_interleave4x4_height, reinterpret_input_as_3d));
 
-    _input  = input;
-    _output = output;
+    _input                   = input;
+    _output                  = output;
+    _reinterpret_input_as_3d = reinterpret_input_as_3d;
 
     // Create build options
     CLBuildOptions build_opts;
     build_opts.add_option("-DMULT_INTERLEAVE4X4_HEIGHT=" + support::cpp11::to_string(mult_interleave4x4_height));
+    build_opts.add_option_if(_reinterpret_input_as_3d, "-DREINTERPRET_INPUT_AS_3D");
+    build_opts.add_option_if(_reinterpret_input_as_3d, "-DHEIGHT_GEMM3D=" + support::cpp11::to_string(input->info()->dimension(1)));
+    build_opts.add_option_if(_reinterpret_input_as_3d, "-DDEPTH_GEMM3D=" + support::cpp11::to_string(input->info()->dimension(2)));
+
     switch(input->info()->element_size())
     {
         case 1:
@@ -129,12 +155,13 @@
     _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("gemm_interleave4x4", build_opts.options()));
 
     // Configure kernel window
-    auto win_config = validate_and_configure_window(input->info(), output->info(), mult_interleave4x4_height);
+    auto win_config = validate_and_configure_window(input->info(), output->info(), mult_interleave4x4_height, reinterpret_input_as_3d);
     ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
     ICLKernel::configure(win_config.second);
 
     // Set config_id for enabling LWS tuning
     _config_id = "interleave4x4_";
+    _config_id += (_reinterpret_input_as_3d ? "3d_" : "");
     _config_id += lower_string(string_from_data_type(input->info()->data_type()));
     _config_id += "_";
     _config_id += support::cpp11::to_string(output->info()->dimension(0));
@@ -146,10 +173,10 @@
     _config_id += support::cpp11::to_string(output->info()->dimension(3));
 }
 
-Status CLGEMMInterleave4x4Kernel::validate(const ITensorInfo *input, const ITensorInfo *output, int mult_interleave4x4_height)
+Status CLGEMMInterleave4x4Kernel::validate(const ITensorInfo *input, const ITensorInfo *output, int mult_interleave4x4_height, bool reinterpret_input_as_3d)
 {
-    ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, mult_interleave4x4_height));
-    ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), output->clone().get(), mult_interleave4x4_height).first);
+    ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, mult_interleave4x4_height, reinterpret_input_as_3d));
+    ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), output->clone().get(), mult_interleave4x4_height, reinterpret_input_as_3d).first);
 
     return Status{};
 }
@@ -170,6 +197,14 @@
      */
     Window slice = window.first_slice_window_3D();
 
+    if(_reinterpret_input_as_3d)
+    {
+        // Pass bottom paddings to the kernel if the input has to be reinterpreted as 3D tensor
+        const unsigned int idx0                  = 2 * num_arguments_per_3D_tensor();
+        const unsigned int total_cross_plane_pad = _input->info()->padding().top + _input->info()->padding().bottom;
+        _kernel.setArg<cl_uint>(idx0, static_cast<unsigned int>(total_cross_plane_pad));
+    }
+
     do
     {
         unsigned int idx = 0;
diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
index 0c629af..c9e6bb3 100644
--- a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
@@ -56,6 +56,7 @@
     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1);
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(input0->num_dimensions() > 4, "The number of dimensions for the matrix A must be <= 4");
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->num_dimensions() > 3, "The number of dimensions for the matrix B must be <= 3");
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(is_interleaved_transposed && reshape_info.reinterpret_input_as_3d(), "The input tensor cannot be reinterpreted as 3D if is_interleaved_transposed is true");
 
     if(!is_interleaved_transposed)
     {
@@ -125,6 +126,9 @@
 
     if(is_interleaved_transposed)
     {
+        // reinterpret_input_as_3d is not supported if is_interleaved_transposed is set
+        ARM_COMPUTE_ERROR_ON(reshape_info.reinterpret_input_as_3d());
+
         // Configure kernel window
         num_elems_processed_per_iteration_x = max_cl_vector_width / data_size_from_type(data_type);
         num_elems_processed_per_iteration_y = 4;
@@ -158,7 +162,7 @@
 
         // Note: bottom paddings are calculated manually as the output can be reinterpreted as 3D tensor
         // The only way to set properly the paddings, it is to set those explicitly through the AccessWindowStatic
-        const int m          = input0->tensor_shape()[1];
+        const int m          = reshape_info.reinterpret_input_as_3d() ? input0->tensor_shape()[1] * input0->tensor_shape()[2] : input0->tensor_shape()[1];
         const int bottom_pad = (num_elems_processed_per_iteration_y - (m % num_elems_processed_per_iteration_y)) % num_elems_processed_per_iteration_y;
 
         // Create kernels according to the architecture, data type and input size.
@@ -172,7 +176,7 @@
         win     = calculate_max_window(tmp_info, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
         win_out = calculate_max_window(*output, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
 
-        AccessWindowStatic input0_access(input0, 0, 0, input0->dimension(0), ceil_to_multiple(input0->dimension(1), num_elems_processed_per_iteration_y));
+        AccessWindowStatic input0_access(input0, 0, 0, input0->dimension(0), input0->dimension(1) + bottom_pad);
         AccessWindowStatic input1_access(input1, 0, 0, ceil_to_multiple(input1->dimension(0), num_elems_processed_per_iteration_x), input1->dimension(1));
         AccessWindowStatic output_access(output, 0, 0,
                                          ceil_to_multiple(output->dimension(0), num_elems_processed_per_iteration_x),
@@ -198,7 +202,7 @@
 } // namespace
 
 CLGEMMMatrixMultiplyKernel::CLGEMMMatrixMultiplyKernel()
-    : _input0(nullptr), _input1(nullptr), _output(nullptr), _slide_matrix_b(true), _is_gemm3d(false)
+    : _input0(nullptr), _input1(nullptr), _output(nullptr), _slide_matrix_b(true), _reinterpret_input_as_3d(false), _reinterpret_output_as_3d(false)
 {
 }
 
@@ -209,19 +213,22 @@
     // Perform validate step
     ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input0->info(), input1->info(), output->info(), is_interleaved_transposed, reshape_info));
 
-    _input0         = input0;
-    _input1         = input1;
-    _output         = output;
-    _slide_matrix_b = _input1->info()->num_dimensions() >= _input0->info()->num_dimensions();
+    _input0                   = input0;
+    _input1                   = input1;
+    _output                   = output;
+    _reinterpret_input_as_3d  = reshape_info.reinterpret_input_as_3d();
+    _reinterpret_output_as_3d = (reshape_info.depth_output_gemm3d() != 1);
+
+    // Check if we need to slide the matrix B
+    const unsigned int num_dimensions_input0 = _reinterpret_input_as_3d ? _input0->info()->num_dimensions() - 1 : _input0->info()->num_dimensions();
+
+    _slide_matrix_b = (_input1->info()->num_dimensions() >= num_dimensions_input0);
 
     const DataType data_type = input0->info()->data_type();
 
     // Get target architecture
     GPUTarget gpu_target = get_target();
 
-    // Check if the output has to be reinterpreted as 3D
-    _is_gemm3d = (reshape_info.depth_output_gemm3d() != 1) && is_data_type_float(data_type);
-
     ElementsProcessed num_elements_processed{};
 
     // Configure kernel window
@@ -237,9 +244,10 @@
     {
         build_opts.add_option("-DALPHA=" + float_to_string_with_full_precision(alpha));
     }
-    build_opts.add_option_if(_is_gemm3d, "-DREINTERPRET_OUTPUT_AS_3D");
-    build_opts.add_option_if(_is_gemm3d, "-DHEIGHT_GEMM3D=" + support::cpp11::to_string(output->info()->dimension(1)));
-    build_opts.add_option_if(_is_gemm3d, "-DDEPTH_GEMM3D=" + support::cpp11::to_string(output->info()->dimension(2)));
+    build_opts.add_option_if(_reinterpret_input_as_3d, "-DREINTERPRET_INPUT_AS_3D");
+    build_opts.add_option_if(_reinterpret_output_as_3d, "-DREINTERPRET_OUTPUT_AS_3D");
+    build_opts.add_option_if(_reinterpret_input_as_3d || _reinterpret_output_as_3d, "-DHEIGHT_GEMM3D=" + support::cpp11::to_string(output->info()->dimension(1)));
+    build_opts.add_option_if(_reinterpret_input_as_3d || _reinterpret_output_as_3d, "-DDEPTH_GEMM3D=" + support::cpp11::to_string(output->info()->dimension(2)));
 
     // Do not slide matrix B if _slide_matrix_b = false
     build_opts.add_option_if(!_slide_matrix_b, "-DMATRIX_B_DEPTH=" + support::cpp11::to_string(input1->info()->dimension(2)));
@@ -305,7 +313,8 @@
     // Set config_id for enabling LWS tuning
     _config_id = "gemm_";
     _config_id += (is_interleaved_transposed ? "reshaped_" : "");
-    _config_id += (_is_gemm3d ? "3d_" : "");
+    _config_id += (_reinterpret_input_as_3d ? "3di_" : "");
+    _config_id += (_reinterpret_output_as_3d ? "3do_" : "");
     _config_id += lower_string(string_from_data_type(input0->info()->data_type()));
     _config_id += "_";
     _config_id += support::cpp11::to_string(output->info()->dimension(1));
@@ -355,10 +364,18 @@
     slice_matrix_b.set(Window::DimX, Window::Dimension(0, 1, 1));
     slice_matrix_b.set(Window::DimY, Window::Dimension(0, 1, 1));
 
-    if(_is_gemm3d)
+    if(_reinterpret_input_as_3d)
     {
         // Pass bottom paddings to the kernel if the output has to be reinterpreted as 3D tensor
         const unsigned int idx0                  = 3 * num_arguments_per_2D_tensor() + 3;
+        const unsigned int total_cross_plane_pad = _input0->info()->padding().top + _input0->info()->padding().bottom;
+        _kernel.setArg<cl_uint>(idx0, static_cast<unsigned int>(total_cross_plane_pad));
+    }
+
+    if(_reinterpret_output_as_3d)
+    {
+        // Pass bottom paddings to the kernel if the output has to be reinterpreted as 3D tensor
+        const unsigned int idx0                  = 3 * num_arguments_per_2D_tensor() + 3 + (_reinterpret_input_as_3d ? 1 : 0);
         const unsigned int total_cross_plane_pad = _output->info()->padding().top + _output->info()->padding().bottom;
         _kernel.setArg<cl_uint>(idx0, static_cast<unsigned int>(total_cross_plane_pad));
     }
diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp
index 1f4df4f..1d1b17b 100644
--- a/src/runtime/CL/functions/CLGEMM.cpp
+++ b/src/runtime/CL/functions/CLGEMM.cpp
@@ -102,7 +102,8 @@
     // Arguments used by GEMMReshapeInfo
     // If we pass the matrix A and matrix B reshaped to CLGEMMMatrixMultiplyKernel, we need to pass m, n, k, mult_transpose1xW_width and mult_interleave4x4_height to CLGEMMReshapeInfo
     // in order to know how the matrices have been reshaped
-    const int m                         = a->info()->dimension(1);
+    bool      reinterpret_input_as_3d   = gemm_info.reinterpret_input_as_3d();
+    const int m                         = reinterpret_input_as_3d ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1);
     const int n                         = b->info()->dimension(0);
     const int k                         = a->info()->dimension(0);
     const int depth_output_gemm3d       = gemm_info.depth_output_gemm3d();
@@ -118,6 +119,12 @@
     // Check if we need to reshape the matrix A and matrix B
     _is_interleaved_transposed = is_interleaved_transposed(m, n, k, a->info()->data_type(), _reshape_b_only_on_first_run, gpu_target);
 
+    // if _is_interleaved_transposed is set, force reinterpret_input_as_3d to be false as the output of CLGEMMInterleaveKernel will be 2D
+    if(_is_interleaved_transposed)
+    {
+        reinterpret_input_as_3d = false;
+    }
+
     if(_is_interleaved_transposed)
     {
         matrix_a = &_tmp_a;
@@ -132,14 +139,15 @@
         // _tmp_a and _tmp_b will be auto configured in _interleave_kernel and in _transpose_kernel
 
         // Configure interleave kernel
-        _interleave_kernel.configure(a, &_tmp_a, mult_interleave4x4_height);
+        _interleave_kernel.configure(a, &_tmp_a, mult_interleave4x4_height, gemm_info.reinterpret_input_as_3d());
 
         // Configure transpose kernel
         _transpose_kernel.configure(b, &_tmp_b, mult_transpose1xW_width);
     }
 
     // Configure and tune matrix multiply kernel
-    _mm_kernel.configure(matrix_a, matrix_b, output, alpha, _is_interleaved_transposed, GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d));
+    _mm_kernel.configure(matrix_a, matrix_b, output, alpha, _is_interleaved_transposed, GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d,
+                                                                                                        reinterpret_input_as_3d));
     CLScheduler::get().tune_kernel_static(_mm_kernel);
 
     if(_is_interleaved_transposed)
@@ -180,11 +188,13 @@
     // Arguments used by GEMMReshapeInfo
     // If we pass the matrix A and matrix B reshaped to CLGEMMMatrixMultiplyKernel, we need to pass m, n, k, mult_transpose1xW_width and mult_interleave4x4_height to CLGEMMReshapeInfo
     // in order to know how the matrices have been reshaped
-    const int m                         = a->dimension(1);
+    bool      reinterpret_input_as_3d   = gemm_info.reinterpret_input_as_3d();
+    const int m                         = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
     const int n                         = b->dimension(0);
     const int k                         = a->dimension(0);
     int       mult_transpose1xW_width   = 1;
     int       mult_interleave4x4_height = 1;
+    const int depth_output_gemm3d       = gemm_info.depth_output_gemm3d();
 
     if(get_arch_from_target(gpu_target) == GPUTarget::BIFROST)
     {
@@ -192,19 +202,25 @@
         mult_interleave4x4_height = 2;
     }
 
-    const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, gemm_info.depth_output_gemm3d());
-
     // Check if we need to reshape the matrix A and matrix B
     const bool run_interleave_transpose = is_interleaved_transposed(m, n, k, a->data_type(), reshape_b_only_on_first_run, gpu_target);
 
+    // if _is_interleaved_transposed is set, force reinterpret_input_as_3d to be false as the output of CLGEMMInterleaveKernel will be 2D
+    if(run_interleave_transpose)
+    {
+        reinterpret_input_as_3d = false;
+    }
+
+    const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d, reinterpret_input_as_3d);
+
     if(run_interleave_transpose)
     {
         matrix_a_info = &tmp_a_info;
         matrix_b_info = &tmp_b_info;
 
         // Validate interleave kernel
-        auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_interleaved_shape(*a, mult_interleave4x4_height)));
-        ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMInterleave4x4Kernel::validate(a, &tmp_a_info, mult_interleave4x4_height));
+        auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_interleaved_shape(*a, mult_interleave4x4_height, gemm_info.reinterpret_input_as_3d())));
+        ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMInterleave4x4Kernel::validate(a, &tmp_a_info, mult_interleave4x4_height, gemm_info.reinterpret_input_as_3d()));
 
         // Validate transpose kernel
         auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_transpose1xW_with_element_size_shape(*b, mult_transpose1xW_width)));
diff --git a/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp b/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp
index f1d2924..de62829 100644
--- a/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp
+++ b/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp
@@ -91,15 +91,15 @@
 
 CLGEMMConvolutionLayer::CLGEMMConvolutionLayer(std::shared_ptr<IMemoryManager> memory_manager)
     : _memory_group(memory_manager), _reshape_weights(), _im2col_kernel(), _mm_gemm(memory_manager), _mm_gemmlowp(memory_manager), _gemmlowp_output_stage(), _col2im_kernel(), _activationlayer_function(),
-      _original_weights(nullptr), _im2col_output(), _weights_reshaped(), _gemm_output(), _tmp_output(), _data_layout(DataLayout::NCHW), _skip_im2col(false), _is_quantized(false),
-      _is_activationlayer_enabled(false), _is_prepared(false)
+      _add_bias_kernel(), _original_weights(nullptr), _im2col_output(), _weights_reshaped(), _gemm_output(), _tmp_output(), _data_layout(DataLayout::NCHW), _append_bias(false), _skip_im2col(false),
+      _is_quantized(false), _is_activationlayer_enabled(false), _is_prepared(false)
 {
 }
 
 void CLGEMMConvolutionLayer::configure_mm(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output, int gemm_3d_depth)
 {
     ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights);
-    ARM_COMPUTE_ERROR_THROW_ON(validate_mm(input->info(), weights->info(), output->info()));
+    ARM_COMPUTE_ERROR_THROW_ON(validate_mm(input->info(), weights->info(), output->info(), _skip_im2col));
 
     if(_is_quantized)
     {
@@ -120,15 +120,16 @@
     else
     {
         // Configure matrix multiply function
-        _mm_gemm.configure(input, weights, nullptr, output, 1.0f, 0.0f, GEMMInfo(false, false, true /* Reshape weights only for the first run*/, gemm_3d_depth));
+        _mm_gemm.configure(input, weights, nullptr, output, 1.0f, 0.0f, GEMMInfo(false, false, true /* Reshape weights only for the first run*/, gemm_3d_depth,
+                                                                                 _skip_im2col /* Reinterpret the input as 3D if im2col is skipped */));
     }
 }
 
-Status CLGEMMConvolutionLayer::validate_mm(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *output, int gemm_3d_depth)
+Status CLGEMMConvolutionLayer::validate_mm(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *output, int gemm_3d_depth, bool skip_im2col)
 {
     const bool is_quantized = is_data_type_quantized_asymmetric(input->data_type());
 
-    const GEMMInfo &gemm_info = GEMMInfo(false, false, true /* Reshape weights only for the first run */, gemm_3d_depth);
+    const GEMMInfo &gemm_info = GEMMInfo(false, false, true /* Reshape weights only for the first run */, gemm_3d_depth, skip_im2col /* Reinterpret the input as 3D if im2col is skipped */);
     if(is_quantized)
     {
         // Since we need negative offsets for computing convolution, we need to change QuantizationInfo()
@@ -180,7 +181,8 @@
     _original_weights = weights;
     _is_quantized     = is_data_type_quantized_asymmetric(input->info()->data_type());
     _data_layout      = data_layout;
-    _skip_im2col      = false;
+    _skip_im2col      = (data_layout == DataLayout::NHWC && kernel_width == 1 && kernel_height == 1) && !_is_quantized;
+    _append_bias      = (biases != nullptr) && (!_is_quantized);
 
     // Set the GPU target for im2col and col2im
     _im2col_kernel.set_target(CLScheduler::get().target());
@@ -191,9 +193,8 @@
     ICLTensor       *gemm_output_to_use        = output;
     ICLTensor       *gemm_output_staged_to_use = output;
 
-    const bool       append_bias   = (biases != nullptr) && (!_is_quantized);
-    const unsigned   bias_element  = (append_bias) ? 1 : 0;
-    const ICLTensor *biases_to_use = (append_bias) ? biases : nullptr;
+    const unsigned   bias_element  = (_append_bias && !_skip_im2col) ? 1 : 0;
+    const ICLTensor *biases_to_use = (_append_bias && !_skip_im2col) ? biases : nullptr;
 
     // Get parameters from conv_info
     unsigned int stride_x = 0;
@@ -238,12 +239,17 @@
         _memory_group.manage(&_im2col_output);
 
         // Configure and tune im2col
-        _im2col_kernel.configure(input, &_im2col_output, Size2D(kernel_width, kernel_height), conv_info, append_bias, dilation);
+        _im2col_kernel.configure(input, &_im2col_output, Size2D(kernel_width, kernel_height), conv_info, _append_bias, dilation);
         CLScheduler::get().tune_kernel_static(_im2col_kernel);
 
         // Update GEMM input
         gemm_input_to_use = &_im2col_output;
     }
+    else if(_append_bias)
+    {
+        // Configure add bias kernel
+        _add_bias_kernel.configure(output, biases, output, ConvertPolicy::SATURATE);
+    }
 
     // Create GEMM output tensor
     if(!is_nhwc || _is_quantized)
@@ -281,28 +287,23 @@
         float multiplier = input->info()->quantization_info().scale * weights->info()->quantization_info().scale / output_quant_info.scale;
         int   output_multiplier, output_shift;
         quantization::calculate_quantized_multiplier_less_than_one(multiplier, &output_multiplier, &output_shift);
-        if(!is_nhwc)
-        {
-            _memory_group.manage(&_tmp_output);
-            gemm_output_staged_to_use = &_tmp_output;
-        }
+
+        _memory_group.manage(&_tmp_output);
+        gemm_output_staged_to_use = &_tmp_output;
+
         _gemmlowp_output_stage.configure(gemm_output_to_use, biases, gemm_output_staged_to_use, output_multiplier, output_shift, output_quant_info.offset);
     }
 
-    if(!is_nhwc)
+    if(!is_nhwc || _is_quantized)
     {
         // Configure and tune Col2Im
         _col2im_kernel.configure(_is_quantized ? gemm_output_staged_to_use : gemm_output_to_use, output, std::make_pair(conv_w, conv_h));
         CLScheduler::get().tune_kernel_static(_col2im_kernel);
     }
 
-    if(_is_quantized && !is_nhwc)
-    {
-        _tmp_output.allocator()->allocate();
-    }
-
     if(!is_nhwc || _is_quantized)
     {
+        _tmp_output.allocator()->allocate();
         _gemm_output.allocator()->allocate();
     }
 
@@ -348,10 +349,10 @@
     const ITensorInfo *weights_to_use            = weights;
 
     const bool     is_nhwc      = data_layout == DataLayout::NHWC;
-    const bool     skip_im2col  = false;
     const bool     is_quantized = is_data_type_quantized_asymmetric(data_type);
+    const bool     skip_im2col  = (data_layout == DataLayout::NHWC && kernel_width == 1 && kernel_height == 1) && !is_quantized;
     const bool     append_bias  = (biases != nullptr) && (!is_quantized);
-    const unsigned bias_element = (append_bias) ? 1 : 0;
+    const unsigned bias_element = (append_bias && !skip_im2col) ? 1 : 0;
 
     ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(idx_channel) != input->dimension(idx_channel));
     ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 4);
@@ -410,6 +411,11 @@
         ARM_COMPUTE_RETURN_ON_ERROR(CLIm2ColKernel::validate(input, &im2col_reshaped_info, Size2D(kernel_width, kernel_height), conv_info, append_bias, dilation));
         gemm_input_to_use = &im2col_reshaped_info;
     }
+    else if(append_bias)
+    {
+        // Validate add bias kernel
+        ARM_COMPUTE_RETURN_ON_ERROR(CLArithmeticAdditionKernel::validate(output, biases, output, ConvertPolicy::SATURATE));
+    }
 
     // Create GEMM output tensor
     if(!is_nhwc || is_quantized)
@@ -424,25 +430,24 @@
         gemm_output_to_use = &info_gemm;
     }
 
-    ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemm_input_to_use, weights_to_use, gemm_output_to_use, (data_layout == DataLayout::NHWC) ? conv_h : 1));
+    ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemm_input_to_use, weights_to_use, gemm_output_to_use, (data_layout == DataLayout::NHWC) ? conv_h : 1, skip_im2col));
 
     if(is_quantized)
     {
         float multiplier = input->quantization_info().scale * weights_to_use->quantization_info().scale / output->quantization_info().scale;
         int   output_multiplier, output_shift;
         quantization::calculate_quantized_multiplier_less_than_one(multiplier, &output_multiplier, &output_shift);
-        if(!is_nhwc)
-        {
-            tmp_info = TensorInfo(gemm_output_to_use->tensor_shape(), 1, DataType::QASYMM8);
-            tmp_info.set_quantization_info(output->quantization_info());
-            gemm_output_staged_to_use = &tmp_info;
-        }
+
+        tmp_info = TensorInfo(gemm_output_to_use->tensor_shape(), 1, DataType::QASYMM8);
+        tmp_info.set_quantization_info(output->quantization_info());
+        gemm_output_staged_to_use = &tmp_info;
+
         // Validate output stage for quantized case
         CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint::validate(gemm_output_to_use, biases, gemm_output_staged_to_use, output->quantization_info().offset);
     }
 
     // Validate Col2Im
-    if(!is_nhwc)
+    if(!is_nhwc || is_quantized)
     {
         ARM_COMPUTE_RETURN_ON_ERROR(CLCol2ImKernel::validate(is_quantized ? gemm_output_staged_to_use : gemm_output_to_use,
                                                              output,
@@ -485,8 +490,13 @@
         _mm_gemm.run();
     }
 
+    if(_skip_im2col && _append_bias)
+    {
+        CLScheduler::get().enqueue(_add_bias_kernel);
+    }
+
     // Reshape output matrix
-    if(_data_layout == DataLayout::NCHW)
+    if(_data_layout == DataLayout::NCHW || _is_quantized)
     {
         CLScheduler::get().enqueue(_col2im_kernel, false);
     }
diff --git a/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp b/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
index 842ee73..c2e18a7 100644
--- a/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
+++ b/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
@@ -205,16 +205,17 @@
     const int             k                         = a->dimension(0);
     constexpr int         mult_transpose1xW_width   = 1;
     constexpr int         mult_interleave4x4_height = 1;
-    const GEMMReshapeInfo reshape_info(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height);
+    const int             depth_output_gemm3d       = gemm_info.depth_output_gemm3d();
+    const GEMMReshapeInfo reshape_info(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d);
 
     bool reshape_matrices = is_interleaved_transposed(m, n, k, gemm_info.reshape_b_only_on_first_run(), CLScheduler::get().target());
 
     if(reshape_matrices)
     {
-        TensorInfo info_a(compute_interleaved_shape(*a, mult_interleave4x4_height), 1, a->data_type());
+        TensorInfo info_a(compute_interleaved_shape(*a, mult_interleave4x4_height, gemm_info.reinterpret_input_as_3d()), 1, a->data_type());
         TensorInfo info_b(compute_transpose1xW_with_element_size_shape(*b, mult_transpose1xW_width), 1, b->data_type());
 
-        ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMInterleave4x4Kernel::validate(a, &info_a, mult_interleave4x4_height));
+        ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMInterleave4x4Kernel::validate(a, &info_a, mult_interleave4x4_height, gemm_info.reinterpret_input_as_3d()));
         ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMTranspose1xWKernel::validate(b, &info_b, mult_transpose1xW_width));
         ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixMultiplyKernel::validate(&info_a, &info_b, output, reshape_matrices, reshape_info));
     }