COMPMID-911: Allow GEMM to work with 3D tensors

Change-Id: I8c4823a0d909e19e9ef548f00b9ae98c66de61dd
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/123569
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
diff --git a/src/core/CL/cl_kernels/gemm.cl b/src/core/CL/cl_kernels/gemm.cl
index ad38c7e..2368125 100644
--- a/src/core/CL/cl_kernels/gemm.cl
+++ b/src/core/CL/cl_kernels/gemm.cl
@@ -165,6 +165,12 @@
  * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
  *       This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
  *
+ * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
+ *       -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
+ *       -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
+ *       -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
+ *          (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
+ *
  * @param[in]  src0_ptr                           Pointer to the source matrix. Supported data types: F32
  * @param[in]  src0_stride_x                      Stride of the source matrix in X dimension (in bytes)
  * @param[in]  src0_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
@@ -183,13 +189,22 @@
  * @param[in]  dst_stride_y                       Stride of the destination matrix in Y dimension (in bytes)
  * @param[in]  dst_step_y                         dst_stride_y * number of elements along Y processed per workitem(in bytes)
  * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination matrix
+ * @param[in]  src0_stride_z                      Stride of the source matrix in Z dimension (in bytes)
+ * @param[in]  src1_stride_z                      Stride of the source matrix in Z dimension (in bytes)
+ * @param[in]  dst_stride_z                       Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in]  pad_bottom                         Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
  */
 __kernel void gemm_mm_interleaved_transposed_f32(IMAGE_DECLARATION(src0),
                                                  IMAGE_DECLARATION(src1),
                                                  IMAGE_DECLARATION(dst),
                                                  uint src0_stride_z,
                                                  uint src1_stride_z,
-                                                 uint dst_stride_z)
+                                                 uint dst_stride_z
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+                                                 ,
+                                                 uint pad_bottom
+#endif // REINTERPRET_OUTPUT_AS_3D
+                                                )
 {
     int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
     int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
@@ -273,6 +288,40 @@
     // Compute dst address
     __global uchar *dst_addr = offset(&dst, 0, 0);
 
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+    // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
+    // in order to take into account the presence of possible bottom paddings
+    //
+    //  |             |
+    //  |    plane0   |
+    //  |             |
+    //  |_____________|
+    //  |*************|
+    //  |  pad_bottom |
+    //  |*************|
+    //  |             |
+    //  |    plane1   |
+    //  |             |
+    //  |_____________|
+
+    // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
+    uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
+    zout       = min(DEPTH_GEMM3D - 1, zout);
+
+    // Add offset due to the bottom paddings
+    zout *= (pad_bottom * dst_stride_y);
+
+    // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
+    // multiply dst_stride_z by DEPTH_GEMM3D
+    dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
+
+    // Store 4x4 block
+    vstore4(c00, 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
+    vstore4(c10, 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
+    vstore4(c20, 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
+    vstore4(c30, 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
+
+#else  // defined(REINTERPRET_OUTPUT_AS_3D)
     // Add offset for batched GEMM
     dst_addr += z * dst_stride_z;
 
@@ -281,6 +330,7 @@
     vstore4(c10, 0, (__global float *)(dst_addr + 1 * dst_stride_y));
     vstore4(c20, 0, (__global float *)(dst_addr + 2 * dst_stride_y));
     vstore4(c30, 0, (__global float *)(dst_addr + 3 * dst_stride_y));
+#endif // defined(REINTERPRET_OUTPUT_AS_3D)
 }
 
 /** This OpenCL kernel is optimized for Bifrost. It computes the matrix multiplication between matrix A (src0) and matrix B (src1)
@@ -293,6 +343,12 @@
  * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
  *       This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
  *
+ * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
+ *       -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
+ *       -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
+ *       -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
+ *          (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
+ *
  * @param[in]  src0_ptr                           Pointer to the source matrix. Supported data types: F32
  * @param[in]  src0_stride_x                      Stride of the source matrix in X dimension (in bytes)
  * @param[in]  src0_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
@@ -311,13 +367,22 @@
  * @param[in]  dst_stride_y                       Stride of the destination matrix in Y dimension (in bytes)
  * @param[in]  dst_step_y                         dst_stride_y * number of elements along Y processed per workitem(in bytes)
  * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination matrix
+ * @param[in]  src0_stride_z                      Stride of the source matrix in Z dimension (in bytes)
+ * @param[in]  src1_stride_z                      Stride of the source matrix in Z dimension (in bytes)
+ * @param[in]  dst_stride_z                       Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in]  pad_bottom                         Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
  */
 __kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0),
                                                          IMAGE_DECLARATION(src1),
                                                          IMAGE_DECLARATION(dst),
                                                          uint src0_stride_z,
                                                          uint src1_stride_z,
-                                                         uint dst_stride_z)
+                                                         uint dst_stride_z
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+                                                         ,
+                                                         uint pad_bottom
+#endif // REINTERPRET_OUTPUT_AS_3D
+                                                        )
 {
     int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
     int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
@@ -533,6 +598,40 @@
     // Compute dst address
     __global uchar *dst_addr = offset(&dst, 0, 0);
 
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+    // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
+    // in order to take into account the presence of possible bottom paddings
+    //
+    //  |             |
+    //  |    plane0   |
+    //  |             |
+    //  |_____________|
+    //  |*************|
+    //  |  pad_bottom |
+    //  |*************|
+    //  |             |
+    //  |    plane1   |
+    //  |             |
+    //  |_____________|
+
+    // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
+    uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
+    zout       = min(DEPTH_GEMM3D - 1, zout);
+
+    // Add offset due to the bottom paddings
+    zout *= (pad_bottom * dst_stride_y);
+
+    // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
+    // multiply dst_stride_z by DEPTH_GEMM3D
+    dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
+
+    // Store 4x4 block
+    vstore4((float4)(c00, c01, c02, c03), 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
+    vstore4((float4)(c10, c11, c12, c13), 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
+    vstore4((float4)(c20, c21, c22, c23), 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
+    vstore4((float4)(c30, c31, c32, c33), 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
+
+#else  // defined(REINTERPRET_OUTPUT_AS_3D)
     // Add offset for batched GEMM
     dst_addr += z * dst_stride_z;
 
@@ -541,6 +640,7 @@
     vstore4((float4)(c10, c11, c12, c13), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
     vstore4((float4)(c20, c21, c22, c23), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
     vstore4((float4)(c30, c31, c32, c33), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
+#endif // defined(REINTERPRET_OUTPUT_AS_3D)
 }
 
 // Undefine local defines
@@ -556,6 +656,12 @@
  * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
  *       This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
  *
+ * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
+ *       -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
+ *       -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
+ *       -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
+ *          (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
+ *
  * @param[in]  src0_ptr                           Pointer to the source matrix. Supported data types: F16
  * @param[in]  src0_stride_x                      Stride of the source matrix in X dimension (in bytes)
  * @param[in]  src0_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
@@ -574,13 +680,22 @@
  * @param[in]  dst_stride_y                       Stride of the destination matrix in Y dimension (in bytes)
  * @param[in]  dst_step_y                         dst_stride_y * number of elements along Y processed per workitem(in bytes)
  * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination matrix
+ * @param[in]  src0_stride_z                      Stride of the source matrix in Z dimension (in bytes)
+ * @param[in]  src1_stride_z                      Stride of the source matrix in Z dimension (in bytes)
+ * @param[in]  dst_stride_z                       Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in]  pad_bottom                         Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
  */
 __kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0),
                                                  IMAGE_DECLARATION(src1),
                                                  IMAGE_DECLARATION(dst),
                                                  uint src0_stride_z,
                                                  uint src1_stride_z,
-                                                 uint dst_stride_z)
+                                                 uint dst_stride_z
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+                                                 ,
+                                                 uint pad_bottom
+#endif // REINTERPRET_OUTPUT_AS_3D
+                                                )
 {
     int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
     int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
@@ -664,6 +779,40 @@
     // Compute dst address
     __global uchar *dst_addr = offset(&dst, 0, 0);
 
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+    // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
+    // in order to take into account the presence of possible bottom paddings
+    //
+    //  |             |
+    //  |    plane0   |
+    //  |             |
+    //  |_____________|
+    //  |*************|
+    //  |  pad_bottom |
+    //  |*************|
+    //  |             |
+    //  |    plane1   |
+    //  |             |
+    //  |_____________|
+
+    // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
+    uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
+    zout       = min(DEPTH_GEMM3D - 1, zout);
+
+    // Add offset due to the bottom paddings
+    zout *= (pad_bottom * dst_stride_y);
+
+    // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
+    // multiply dst_stride_z by DEPTH_GEMM3D
+    dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
+
+    // Store 4x8 block
+    vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
+    vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
+    vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
+    vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
+
+#else  // defined(REINTERPRET_OUTPUT_AS_3D)
     // Add offset for batched GEMM
     dst_addr += z * dst_stride_z;
 
@@ -672,6 +821,7 @@
     vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
     vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
     vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
+#endif // defined(REINTERPRET_OUTPUT_AS_3D)
 }
 
 /** This OpenCL kernel optimized for Bifrost architectures computes the matrix multiplication between matrix A (src0) and matrix B (src1)
@@ -683,6 +833,12 @@
  * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
  *       This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
  *
+ * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
+ *       -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
+ *       -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
+ *       -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
+ *          (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
+ *
  * @param[in]  src0_ptr                           Pointer to the source matrix. Supported data types: F16
  * @param[in]  src0_stride_x                      Stride of the source matrix in X dimension (in bytes)
  * @param[in]  src0_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
@@ -701,13 +857,19 @@
  * @param[in]  dst_stride_y                       Stride of the destination matrix in Y dimension (in bytes)
  * @param[in]  dst_step_y                         dst_stride_y * number of elements along Y processed per workitem(in bytes)
  * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination matrix
+ * @param[in]  pad_bottom                         Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
  */
 __kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0),
                                                          IMAGE_DECLARATION(src1),
                                                          IMAGE_DECLARATION(dst),
                                                          uint src0_stride_z,
                                                          uint src1_stride_z,
-                                                         uint dst_stride_z)
+                                                         uint dst_stride_z
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+                                                         ,
+                                                         uint pad_bottom
+#endif // REINTERPRET_OUTPUT_AS_3D
+                                                        )
 {
     int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
     int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
@@ -876,11 +1038,49 @@
     // Add offset for batched GEMM
     dst_addr += z * dst_stride_z;
 
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+    // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
+    // in order to take into account the presence of possible bottom paddings
+    //
+    //  |             |
+    //  |    plane0   |
+    //  |             |
+    //  |_____________|
+    //  |*************|
+    //  |  pad_bottom |
+    //  |*************|
+    //  |             |
+    //  |    plane1   |
+    //  |             |
+    //  |_____________|
+
+    // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
+    uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
+    zout       = min(DEPTH_GEMM3D - 1, zout);
+
+    // Add offset due to the bottom paddings
+    zout *= (pad_bottom * dst_stride_y);
+
+    // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
+    // multiply dst_stride_z by DEPTH_GEMM3D
+    dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
+
+    // Store 4x8 block
+    vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
+    vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
+    vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
+    vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
+
+#else  // defined(REINTERPRET_OUTPUT_AS_3D)
+    // Add offset for batched GEMM
+    dst_addr += z * dst_stride_z;
+
     // Store 4x8 block
     vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
     vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
     vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
     vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
+#endif // defined(REINTERPRET_OUTPUT_AS_3D)
 }
 
 // Undefine local defines
@@ -917,6 +1117,9 @@
  * @param[in]  dst_stride_y                       Stride of the destination matrix in Y dimension (in bytes)
  * @param[in]  dst_step_y                         dst_stride_y * number of elements along Y processed per workitem(in bytes)
  * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination matrix
+ * @param[in]  src0_stride_z                      Stride of the source matrix in Z dimension (in bytes)
+ * @param[in]  src1_stride_z                      Stride of the source matrix in Z dimension (in bytes)
+ * @param[in]  dst_stride_z                       Stride of the destination tensor in Z dimension (in bytes)
  */
 __kernel void gemm_mm_interleaved_transposed_qs8(IMAGE_DECLARATION(src0),
                                                  IMAGE_DECLARATION(src1),
@@ -1039,6 +1242,9 @@
  * @param[in]  dst_stride_y                       Stride of the destination matrix in Y dimension (in bytes)
  * @param[in]  dst_step_y                         dst_stride_y * number of elements along Y processed per workitem(in bytes)
  * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination matrix
+ * @param[in]  src0_stride_z                      Stride of the source matrix in Z dimension (in bytes)
+ * @param[in]  src1_stride_z                      Stride of the source matrix in Z dimension (in bytes)
+ * @param[in]  dst_stride_z                       Stride of the destination tensor in Z dimension (in bytes)
  */
 __kernel void gemm_mm_interleaved_transposed_qs16(IMAGE_DECLARATION(src0),
                                                   IMAGE_DECLARATION(src1),
@@ -1138,6 +1344,12 @@
  * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
  *       This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
  *
+ * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
+ *       -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
+ *       -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
+ *       -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
+ *          (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
+ *
  * @param[in]  src0_ptr                           Pointer to the source matrix. Supported data types: F16/F32
  * @param[in]  src0_stride_x                      Stride of the source matrix in X dimension (in bytes)
  * @param[in]  src0_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
@@ -1156,13 +1368,22 @@
  * @param[in]  dst_stride_y                       Stride of the destination matrix in Y dimension (in bytes)
  * @param[in]  dst_step_y                         dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
  * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination matrix
+ * @param[in]  src0_stride_z                      Stride of the source matrix in Z dimension (in bytes)
+ * @param[in]  src1_stride_z                      Stride of the source matrix in Z dimension (in bytes)
+ * @param[in]  dst_stride_z                       Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in]  pad_bottom                         Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
  */
 __kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
                                      IMAGE_DECLARATION(src1),
                                      IMAGE_DECLARATION(dst),
                                      uint src0_stride_z,
                                      uint src1_stride_z,
-                                     uint dst_stride_z)
+                                     uint dst_stride_z
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+                                     ,
+                                     uint pad_bottom
+#endif // REINTERPRET_OUTPUT_AS_3D
+                                    )
 {
     int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
 
@@ -1271,36 +1492,85 @@
     // Compute dst address
     __global uchar *dst_addr = offset(&dst, 0, 0);
 
-    // Add offset for batched GEMM
-    dst_addr += get_global_id(2) * dst_stride_z;
-
     // Multiply by the weight of matrix-matrix product and store the result
 #if defined(ALPHA)
     acc0 = acc0 * (VECTOR_TYPE)ALPHA;
 #endif // defined(ALPHA)
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
+    acc1 = acc1 * (VECTOR_TYPE)ALPHA;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
+    acc2 = acc2 * (VECTOR_TYPE)ALPHA;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
+    acc3 = acc3 * (VECTOR_TYPE)ALPHA;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
+
+    int z = get_global_id(2);
+
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+    // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
+    // in order to take into account the presence of possible bottom paddings
+    //
+    //  |             |
+    //  |    plane0   |
+    //  |             |
+    //  |_____________|
+    //  |*************|
+    //  |  pad_bottom |
+    //  |*************|
+    //  |             |
+    //  |    plane1   |
+    //  |             |
+    //  |_____________|
+
+    // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
+    uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
+    zout       = min(DEPTH_GEMM3D - 1, zout);
+
+    // Add offset due to the bottom paddings
+    zout *= (pad_bottom * dst_stride_y);
+
+    // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
+    // multiply dst_stride_z by DEPTH_GEMM3D
+    dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
+
+    // Store output block
+    VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
+    (acc0, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y + zout.s0));
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+    VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
+    (acc1, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y + zout.s1));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+    VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
+    (acc2, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y + zout.s2));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+    VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
+    (acc3, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y + zout.s3));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+
+#else // defined(REINTERPRET_OUTPUT_AS_3D)
+    // Add offset for batched GEMM
+    dst_addr += z * dst_stride_z;
+
+    // Store output block
     VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
     (acc0, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y));
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if defined(ALPHA)
-    acc1 = acc1 * (VECTOR_TYPE)ALPHA;
-#endif // defined(ALPHA)
     VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
     (acc1, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y));
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if defined(ALPHA)
-    acc2 = acc2 * (VECTOR_TYPE)ALPHA;
-#endif // defined(ALPHA)
     VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
     (acc2, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y));
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-#if defined(ALPHA)
-    acc3 = acc3 * (VECTOR_TYPE)ALPHA;
-#endif // defined(ALPHA)
     VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
     (acc3, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y));
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#endif // defined(REINTERPRET_OUTPUT_AS_3D)
 }
 #endif // defined(DATA_TYPE)
 
@@ -1314,6 +1584,12 @@
  * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
  *       This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
  *
+ * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
+ *       -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
+ *       -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
+ *       -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
+ *          (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
+ *
  * @param[in]  src0_ptr                           Pointer to the source matrix. Supported data types: F16/F32
  * @param[in]  src0_stride_x                      Stride of the source matrix in X dimension (in bytes)
  * @param[in]  src0_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
@@ -1332,13 +1608,22 @@
  * @param[in]  dst_stride_y                       Stride of the destination matrix in Y dimension (in bytes)
  * @param[in]  dst_step_y                         dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
  * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination matrix
+ * @param[in]  src0_stride_z                      Stride of the source matrix in Z dimension (in bytes)
+ * @param[in]  src1_stride_z                      Stride of the source matrix in Z dimension (in bytes)
+ * @param[in]  dst_stride_z                       Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in]  pad_bottom                         Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
  */
 __kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
                                                  IMAGE_DECLARATION(src1),
                                                  IMAGE_DECLARATION(dst),
                                                  uint src0_stride_z,
                                                  uint src1_stride_z,
-                                                 uint dst_stride_z)
+                                                 uint dst_stride_z
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+                                                 ,
+                                                 uint pad_bottom
+#endif // REINTERPRET_OUTPUT_AS_3D
+                                                )
 {
     int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
 
@@ -1585,6 +1870,8 @@
         src_addr.s0 += sizeof(float);
     }
 
+    int z = get_global_id(2);
+
     // Compute destination address
     Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
 
@@ -1595,46 +1882,83 @@
     acc02 = acc02 * ALPHA;
     acc03 = acc03 * ALPHA;
 #endif // defined(ALPHA)
-
-    // Compute dst address
-    __global uchar *dst_addr = offset(&dst, 0, 0);
-
-    // Add offset for batched GEMM
-    dst_addr += get_global_id(2) * dst_stride_z;
-
-    float4 acc0 = ((float4)(acc00, acc01, acc02, acc03));
-    vstore4(acc0, 0, (__global float *)(dst_addr + 0 * dst_stride_y));
-
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if defined(ALPHA)
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
     acc10 = acc10 * ALPHA;
     acc11 = acc11 * ALPHA;
     acc12 = acc12 * ALPHA;
     acc13 = acc13 * ALPHA;
-#endif // defined(ALPHA)
-    float4 acc1 = ((float4)(acc10, acc11, acc12, acc13));
-    vstore4(acc1, 0, (__global float *)(dst_addr + 1 * dst_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if defined(ALPHA)
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
     acc20 = acc20 * ALPHA;
     acc21 = acc21 * ALPHA;
     acc22 = acc22 * ALPHA;
     acc23 = acc23 * ALPHA;
-#endif // defined(ALPHA)
-    float4 acc2 = ((float4)(acc20, acc21, acc22, acc23));
-    vstore4(acc2, 0, (__global float *)(dst_addr + 2 * dst_stride_y));
-#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-#if defined(ALPHA)
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
     acc30 = acc30 * ALPHA;
     acc31 = acc31 * ALPHA;
     acc32 = acc32 * ALPHA;
     acc33 = acc33 * ALPHA;
-#endif // defined(ALPHA)
-    float4 acc3 = ((float4)(acc30, acc31, acc32, acc33));
-    vstore4(acc3, 0, (__global float *)(dst_addr + 3 * dst_stride_y));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
+
+    // Compute dst address
+    __global uchar *dst_addr = offset(&dst, 0, 0);
+
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+    // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
+    // in order to take into account the presence of possible bottom paddings
+    //
+    //  |             |
+    //  |    plane0   |
+    //  |             |
+    //  |_____________|
+    //  |*************|
+    //  |  pad_bottom |
+    //  |*************|
+    //  |             |
+    //  |    plane1   |
+    //  |             |
+    //  |_____________|
+
+    // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
+    uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
+    zout       = min(DEPTH_GEMM3D - 1, zout);
+
+    // Add offset due to the bottom paddings
+    zout *= (pad_bottom * dst_stride_y);
+
+    // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
+    // multiply dst_stride_z by DEPTH_GEMM3D
+    dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
+
+    // Store the output block
+    vstore4((float4)(acc00, acc01, acc02, acc03), 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+    vstore4((float4)(acc10, acc11, acc12, acc13), 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+    vstore4((float4)(acc20, acc21, acc22, acc23), 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+    vstore4((float4)(acc30, acc31, acc32, acc33), 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+
+#else // defined(REINTERPRET_OUTPUT_AS_3D)
+    // Add offset for batched GEMM
+    dst_addr += z * dst_stride_z;
+
+    // Store the output block
+    vstore4((float4)(acc00, acc01, acc02, acc03), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+    vstore4((float4)(acc10, acc11, acc12, acc13), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+    vstore4((float4)(acc20, acc21, acc22, acc23), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+    vstore4((float4)(acc30, acc31, acc32, acc33), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#endif // defined(REINTERPRET_OUTPUT_AS_3D)
 }
 
 /** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
@@ -1648,6 +1972,12 @@
  * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
  *       This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
  *
+ * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
+ *       -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
+ *       -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
+ *       -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
+ *          (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
+ *
  * @param[in]  src0_ptr                           Pointer to the source matrix. Supported data types: F16/F32
  * @param[in]  src0_stride_x                      Stride of the source matrix in X dimension (in bytes)
  * @param[in]  src0_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
@@ -1666,13 +1996,22 @@
  * @param[in]  dst_stride_y                       Stride of the destination matrix in Y dimension (in bytes)
  * @param[in]  dst_step_y                         dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
  * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination matrix
+ * @param[in]  src0_stride_z                      Stride of the source matrix in Z dimension (in bytes)
+ * @param[in]  src1_stride_z                      Stride of the source matrix in Z dimension (in bytes)
+ * @param[in]  dst_stride_z                       Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in]  pad_bottom                         Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
  */
 __kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
                                                       IMAGE_DECLARATION(src1),
                                                       IMAGE_DECLARATION(dst),
                                                       uint src0_stride_z,
                                                       uint src1_stride_z,
-                                                      uint dst_stride_z)
+                                                      uint dst_stride_z
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+                                                      ,
+                                                      uint pad_bottom
+#endif // REINTERPRET_OUTPUT_AS_3D
+                                                     )
 {
     // Requires 2 NUM_ELEMS_PROCESSED_PER_THREAD_X, C vect2, A vect4, B (2 vload2) // to fix for NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
     int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
@@ -1857,46 +2196,87 @@
         src_addr.s0 += sizeof(float);
     }
 
+    // Multiply by the weight of matrix-matrix product and store the result
+#if defined(ALPHA)
+    acc00 = acc00 * ALPHA;
+    acc01 = acc01 * ALPHA;
+#endif // defined(ALPHA)
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
+    acc10 = acc10 * ALPHA;
+    acc11 = acc11 * ALPHA;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
+    acc20 = acc20 * ALPHA;
+    acc21 = acc21 * ALPHA;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
+    acc30 = acc30 * ALPHA;
+    acc31 = acc31 * ALPHA;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
+
+    int z = get_global_id(2);
+
     // Compute destination address
     Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
 
     // Compute dst address
     __global uchar *dst_addr = offset(&dst, 0, 0);
 
-    // Add offset for batched GEMM
-    dst_addr += get_global_id(2) * dst_stride_z;
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+    // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
+    // in order to take into account the presence of possible bottom paddings
+    //
+    //  |             |
+    //  |    plane0   |
+    //  |             |
+    //  |_____________|
+    //  |*************|
+    //  |  pad_bottom |
+    //  |*************|
+    //  |             |
+    //  |    plane1   |
+    //  |             |
+    //  |_____________|
 
-    // Multiply by the weight of matrix-matrix product and store the result
-#if defined(ALPHA)
-    acc00 = acc00 * ALPHA;
-    acc01 = acc01 * ALPHA;
-#endif // defined(ALPHA)
-    float2 acc0 = ((float2)(acc00, acc01));
-    vstore2(acc0, 0, (__global float *)(dst_addr + 0 * dst_stride_y));
+    // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
+    uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
+    zout       = min(DEPTH_GEMM3D - 1, zout);
+
+    // Add offset due to the bottom paddings
+    zout *= (pad_bottom * dst_stride_y);
+
+    // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
+    // multiply dst_stride_z by DEPTH_GEMM3D
+    dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
+
+    // Store the output block
+    vstore2((float2)(acc00, acc01), 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if defined(ALPHA)
-    acc10 = acc10 * ALPHA;
-    acc11 = acc11 * ALPHA;
-#endif // defined(ALPHA)
-    float2 acc1 = ((float2)(acc10, acc11));
-    vstore2(acc1, 0, (__global float *)(dst_addr + 1 * dst_stride_y));
+    vstore2((float2)(acc10, acc11), 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if defined(ALPHA)
-    acc20 = acc20 * ALPHA;
-    acc21 = acc21 * ALPHA;
-#endif // defined(ALPHA)
-    float2 acc2 = ((float2)(acc20, acc21));
-    vstore2(acc2, 0, (__global float *)(dst_addr + 2 * dst_stride_y));
+    vstore2((float2)(acc20, acc21), 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-#if defined(ALPHA)
-    acc30 = acc30 * ALPHA;
-    acc31 = acc31 * ALPHA;
-#endif // defined(ALPHA)
-    float2 acc3 = (float2)(acc30, acc31);
-    vstore2(acc3, 0, (__global float *)(dst_addr + 3 * dst_stride_y));
+    vstore2((float2)(acc30, acc31), 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+
+#else // defined(REINTERPRET_OUTPUT_AS_3D)
+    // Add offset for batched GEMM
+    dst_addr += z * dst_stride_z;
+
+    // Store the output block
+    vstore2((float2)(acc00, acc01), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+    vstore2((float2)(acc10, acc11), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+    vstore2((float2)(acc20, acc21), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+    vstore2((float2)(acc30, acc31), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#endif // defined(REINTERPRET_OUTPUT_AS_3D)
 }
 
 #if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
@@ -1910,6 +2290,12 @@
  * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
  *       This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
  *
+ * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
+ *       -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
+ *       -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
+ *       -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
+ *          (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
+ *
  * @param[in]  src0_ptr                           Pointer to the source matrix. Supported data types: F16
  * @param[in]  src0_stride_x                      Stride of the source matrix in X dimension (in bytes)
  * @param[in]  src0_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
@@ -1928,13 +2314,22 @@
  * @param[in]  dst_stride_y                       Stride of the destination matrix in Y dimension (in bytes)
  * @param[in]  dst_step_y                         dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
  * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination matrix
+ * @param[in]  src0_stride_z                      Stride of the source matrix in Z dimension (in bytes)
+ * @param[in]  src1_stride_z                      Stride of the source matrix in Z dimension (in bytes)
+ * @param[in]  dst_stride_z                       Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in]  pad_bottom                         Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
  */
 __kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0),
                                                  IMAGE_DECLARATION(src1),
                                                  IMAGE_DECLARATION(dst),
                                                  uint src0_stride_z,
                                                  uint src1_stride_z,
-                                                 uint dst_stride_z)
+                                                 uint dst_stride_z
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+                                                 ,
+                                                 uint pad_bottom
+#endif // REINTERPRET_OUTPUT_AS_3D
+                                                )
 {
     int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
 
@@ -2071,38 +2466,83 @@
 #endif                                   // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
     }
 
+    // Multiply by the weight of matrix-matrix product and store the result
+#if defined(ALPHA)
+    acc0 = acc0 * (half8)ALPHA;
+#endif // defined(ALPHA)
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
+    acc1 = acc1 * (half8)ALPHA;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
+    acc2 = acc2 * (half8)ALPHA;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
+    acc3 = acc3 * (half8)ALPHA;
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
+
+    int z = get_global_id(2);
+
     // Compute destination address
     Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
 
     // Compute dst address
     __global uchar *dst_addr = offset(&dst, 0, 0);
 
-    // Add offset for batched GEMM
-    dst_addr += get_global_id(2) * dst_stride_z;
+#if defined(REINTERPRET_OUTPUT_AS_3D)
+    // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
+    // in order to take into account the presence of possible bottom paddings
+    //
+    //  |             |
+    //  |    plane0   |
+    //  |             |
+    //  |_____________|
+    //  |*************|
+    //  |  pad_bottom |
+    //  |*************|
+    //  |             |
+    //  |    plane1   |
+    //  |             |
+    //  |_____________|
 
-    // Multiply by the weight of matrix-matrix product and store the result
-#if defined(ALPHA)
-    acc0 = acc0 * (half8)ALPHA;
-#endif // defined(ALPHA)
+    // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
+    uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
+    zout       = min(DEPTH_GEMM3D - 1, zout);
+
+    // Add offset due to the bottom paddings
+    zout *= (pad_bottom * dst_stride_y);
+
+    // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
+    // multiply dst_stride_z by DEPTH_GEMM3D
+    dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
+
+    // Store the output block
+    vstore8(acc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+    vstore8(acc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+    vstore8(acc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+    vstore8(acc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+
+#else // defined(REINTERPRET_OUTPUT_AS_3D)
+    // Add offset for batched GEMM
+    dst_addr += z * dst_stride_z;
+
+    // Store the output block
     vstore8(acc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-#if defined(ALPHA)
-    acc1 = acc1 * (half8)ALPHA;
-#endif // defined(ALPHA)
     vstore8(acc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-#if defined(ALPHA)
-    acc2 = acc2 * (half8)ALPHA;
-#endif // defined(ALPHA)
     vstore8(acc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-#if defined(ALPHA)
-    acc3 = acc3 * (half8)ALPHA;
-#endif // defined(ALPHA)
     vstore8(acc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+#endif // REINTERPRET_OUTPUT_AS_3D
 }
 #endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
 
@@ -2135,6 +2575,9 @@
  * @param[in]  dst_stride_y                       Stride of the destination matrix in Y dimension (in bytes)
  * @param[in]  dst_step_y                         dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
  * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination matrix
+ * @param[in]  src0_stride_z                      Stride of the source matrix in Z dimension (in bytes)
+ * @param[in]  src1_stride_z                      Stride of the source matrix in Z dimension (in bytes)
+ * @param[in]  dst_stride_z                       Stride of the destination tensor in Z dimension (in bytes)
  */
 __kernel void gemm_mm_qs8(IMAGE_DECLARATION(src0),
                           IMAGE_DECLARATION(src1),
@@ -2319,6 +2762,9 @@
  * @param[in]  dst_stride_y                       Stride of the destination matrix in Y dimension (in bytes)
  * @param[in]  dst_step_y                         dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
  * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination matrix
+ * @param[in]  src0_stride_z                      Stride of the source matrix in Z dimension (in bytes)
+ * @param[in]  src1_stride_z                      Stride of the source matrix in Z dimension (in bytes)
+ * @param[in]  dst_stride_z                       Stride of the destination tensor in Z dimension (in bytes)
  */
 __kernel void gemm_mm_qs16(IMAGE_DECLARATION(src0),
                            IMAGE_DECLARATION(src1),
@@ -2471,20 +2917,24 @@
  * @param[in]  src_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
  * @param[in]  src_stride_y                      Stride of the source matrix in Y dimension (in bytes)
  * @param[in]  src_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in]  src_stride_z                      Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in]  src_step_z                        dst_stride_z * number of elements along Z processed per workitem(in bytes)
  * @param[in]  src_offset_first_element_in_bytes The offset of the first element in the source matrix
  * @param[out] dst_ptr                           Pointer to the destination matrix Supported data types: same as @p src_ptr
  * @param[in]  dst_stride_x                      Stride of the destination matrix in X dimension (in bytes)
  * @param[in]  dst_step_x                        dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
  * @param[in]  dst_stride_y                      Stride of the destination matrix in Y dimension (in bytes)
  * @param[in]  dst_step_y                        dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in]  dst_stride_z                      Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in]  dst_step_z                        dst_stride_z * number of elements along Z processed per workitem(in bytes)
  * @param[in]  dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
  */
-__kernel void gemm_ma_f32(IMAGE_DECLARATION(src),
-                          IMAGE_DECLARATION(dst))
+__kernel void gemm_ma_f32(TENSOR3D_DECLARATION(src),
+                          TENSOR3D_DECLARATION(dst))
 {
     // Compute source and destination addresses
-    Image src = CONVERT_TO_IMAGE_STRUCT(src);
-    Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
+    Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
+    Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
 
     // Load values from A x B
     float4 alpha_ab = vload4(0, (__global float *)dst.ptr);
@@ -2509,20 +2959,24 @@
  * @param[in]  src_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
  * @param[in]  src_stride_y                      Stride of the source matrix in Y dimension (in bytes)
  * @param[in]  src_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in]  src_stride_z                      Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in]  src_step_z                        dst_stride_z * number of elements along Z processed per workitem(in bytes)
  * @param[in]  src_offset_first_element_in_bytes The offset of the first element in the source matrix
  * @param[out] dst_ptr                           Pointer to the destination matrix Supported data types: same as @p src_ptr
  * @param[in]  dst_stride_x                      Stride of the destination matrix in X dimension (in bytes)
  * @param[in]  dst_step_x                        dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
  * @param[in]  dst_stride_y                      Stride of the destination matrix in Y dimension (in bytes)
  * @param[in]  dst_step_y                        dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in]  dst_stride_z                      Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in]  dst_step_z                        dst_stride_z * number of elements along Z processed per workitem(in bytes)
  * @param[in]  dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
  */
-__kernel void gemm_ma_f16(IMAGE_DECLARATION(src),
-                          IMAGE_DECLARATION(dst))
+__kernel void gemm_ma_f16(TENSOR3D_DECLARATION(src),
+                          TENSOR3D_DECLARATION(dst))
 {
     // Compute source and destination addresses
-    Image src = CONVERT_TO_IMAGE_STRUCT(src);
-    Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
+    Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
+    Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
 
     // Load values from A x B
     half8 alpha_ab = vload8(0, (__global half *)dst.ptr);
@@ -2550,20 +3004,24 @@
  * @param[in]  src_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
  * @param[in]  src_stride_y                      Stride of the source matrix in Y dimension (in bytes)
  * @param[in]  src_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in]  src_stride_z                      Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in]  src_step_z                        dst_stride_z * number of elements along Z processed per workitem(in bytes)
  * @param[in]  src_offset_first_element_in_bytes The offset of the first element in the source matrix
  * @param[out] dst_ptr                           Pointer to the destination matrix Supported data types: same as @p src_ptr
  * @param[in]  dst_stride_x                      Stride of the destination matrix in X dimension (in bytes)
  * @param[in]  dst_step_x                        dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
  * @param[in]  dst_stride_y                      Stride of the destination matrix in Y dimension (in bytes)
  * @param[in]  dst_step_y                        dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in]  dst_stride_z                      Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in]  dst_step_z                        dst_stride_z * number of elements along Z processed per workitem(in bytes)
  * @param[in]  dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
  */
-__kernel void gemm_ma_qs8(IMAGE_DECLARATION(src),
-                          IMAGE_DECLARATION(dst))
+__kernel void gemm_ma_qs8(TENSOR3D_DECLARATION(src),
+                          TENSOR3D_DECLARATION(dst))
 {
     // Compute source and destination addresses
-    Image src = CONVERT_TO_IMAGE_STRUCT(src);
-    Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
+    Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
+    Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
 
     // Load values from A x B
     char16 alpha_ab = vload16(0, (__global char *)dst.ptr);
@@ -2589,20 +3047,24 @@
  * @param[in]  src_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
  * @param[in]  src_stride_y                      Stride of the source matrix in Y dimension (in bytes)
  * @param[in]  src_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in]  src_stride_z                      Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in]  src_step_z                        dst_stride_z * number of elements along Z processed per workitem(in bytes)
  * @param[in]  src_offset_first_element_in_bytes The offset of the first element in the source matrix
  * @param[out] dst_ptr                           Pointer to the destination matrix Supported data types: same as @p src_ptr
  * @param[in]  dst_stride_x                      Stride of the destination matrix in X dimension (in bytes)
  * @param[in]  dst_step_x                        dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
  * @param[in]  dst_stride_y                      Stride of the destination matrix in Y dimension (in bytes)
  * @param[in]  dst_step_y                        dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in]  dst_stride_z                      Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in]  dst_step_z                        dst_stride_z * number of elements along Z processed per workitem(in bytes)
  * @param[in]  dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
  */
-__kernel void gemm_ma_qs16(IMAGE_DECLARATION(src),
-                           IMAGE_DECLARATION(dst))
+__kernel void gemm_ma_qs16(TENSOR3D_DECLARATION(src),
+                           TENSOR3D_DECLARATION(dst))
 {
     // Compute source and destination addresses
-    Image src = CONVERT_TO_IMAGE_STRUCT(src);
-    Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
+    Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
+    Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
 
     // Load values from A x B
     short8 alpha_ab = vload8(0, (__global short *)dst.ptr);
diff --git a/src/core/CL/kernels/CLGEMMMatrixAdditionKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixAdditionKernel.cpp
index e6a1baf..c50ee24 100644
--- a/src/core/CL/kernels/CLGEMMMatrixAdditionKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMMatrixAdditionKernel.cpp
@@ -126,14 +126,14 @@
     ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
     ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICLKernel::window(), window);
 
-    Window slice = window.first_slice_window_2D();
+    Window slice = window.first_slice_window_3D();
 
     do
     {
         unsigned int idx = 0;
-        add_2D_tensor_argument(idx, _input, slice);
-        add_2D_tensor_argument(idx, _output, slice);
+        add_3D_tensor_argument(idx, _input, slice);
+        add_3D_tensor_argument(idx, _output, slice);
         enqueue(queue, *this, slice);
     }
-    while(window.slide_window_slice_2D(slice));
+    while(window.slide_window_slice_3D(slice));
 }
diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
index fc52f4e..2c2a92d 100644
--- a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
@@ -56,19 +56,13 @@
     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1);
     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input0, input1);
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(is_data_type_fixed_point(input0->data_type()) && (reshape_info.depth_output_gemm3d() != 1), "GEMM3D only supports floating point data types");
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(input0->num_dimensions() > 4, "The number of dimensions for the matrix A must be <= 4");
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->num_dimensions() > 3, "The number of dimensions for the matrix B must be <= 3");
 
     if(!is_interleaved_transposed)
     {
         ARM_COMPUTE_RETURN_ERROR_ON(input0->dimension(0) != input1->dimension(1));
-
-        if(output->total_size() != 0)
-        {
-            ARM_COMPUTE_RETURN_ERROR_ON(input1->dimension(0) != output->dimension(0));
-            ARM_COMPUTE_RETURN_ERROR_ON(input0->dimension(1) != output->dimension(1));
-            ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, output);
-            ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input0, output);
-        }
     }
     else
     {
@@ -94,14 +88,14 @@
 
         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input0, &tensor_info_reshaped0);
         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, &tensor_info_reshaped1);
+    }
 
-        if(output->total_size() != 0)
-        {
-            ARM_COMPUTE_RETURN_ERROR_ON(output->dimension(0) != static_cast<size_t>(n));
-            ARM_COMPUTE_RETURN_ERROR_ON(output->dimension(1) != static_cast<size_t>(m));
-            ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, output);
-            ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input0, output);
-        }
+    if(output->total_size() != 0)
+    {
+        const TensorInfo tensor_info_output = output->clone()->set_tensor_shape(compute_mm_shape(*input0, *input1, is_interleaved_transposed, reshape_info));
+        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(output, &tensor_info_output);
+        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, output);
+        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input0, output);
     }
 
     return Status{};
@@ -113,6 +107,7 @@
 {
     bool   window_changed = false;
     Window win{};
+    Window win_out{};
 
     const DataType data_type                           = input0->data_type();
     unsigned int &num_elems_processed_per_iteration_x = num_elements_processed[0];
@@ -121,23 +116,43 @@
     // Output tensor auto inizialitation if not yet initialized
     auto_init_if_empty(*output, input0->clone()->set_tensor_shape(compute_mm_shape(*input0, *input1, is_interleaved_transposed, reshape_info)));
 
+    TensorInfo tmp_info(*output);
+
+    if(reshape_info.depth_output_gemm3d() != 1)
+    {
+        // Since the output tensor has to be reinterpreted as 3D and the execute window is based on a 2D GEMM,
+        // the window needs to be constructed on the 2D collapsed version of the tensor
+        TensorShape tmp_shape(output->tensor_shape());
+        tmp_shape.collapse(2U, 1U);
+        tmp_info.set_tensor_shape(tmp_shape);
+    }
+
     if(is_interleaved_transposed)
     {
         // Configure kernel window
         num_elems_processed_per_iteration_x = max_cl_vector_width / data_size_from_type(data_type);
         num_elems_processed_per_iteration_y = 4;
 
-        win = calculate_max_window(*output, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
+        // Note: bottom paddings are calculated manually as the output can be reinterpreted as 3D tensor
+        // The only way to set properly the paddings, it is to set those explicitly through the AccessWindowStatic
+        const int m          = reshape_info.m();
+        const int bottom_pad = (num_elems_processed_per_iteration_y - (m % num_elems_processed_per_iteration_y)) % num_elems_processed_per_iteration_y;
+
+        win     = calculate_max_window(tmp_info, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
+        win_out = calculate_max_window(*output, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
 
         AccessWindowRectangle input0_access(input0, 0, 0, num_elems_processed_per_iteration_y, 1, 1.f, 0.25f);
         AccessWindowStatic    input1_access(input1, 0, 0,
                                             ceil_to_multiple(input1->dimension(0), num_elems_processed_per_iteration_x),
                                             ceil_to_multiple(input1->dimension(1), num_elems_processed_per_iteration_y));
-        AccessWindowRectangle output_access(output, 0, 0, num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y);
+        AccessWindowStatic output_access(output, 0, 0,
+                                         ceil_to_multiple(output->dimension(0), num_elems_processed_per_iteration_x),
+                                         output->dimension(1) + bottom_pad);
 
-        window_changed = update_window_and_padding(win, input0_access, input1_access, output_access);
+        window_changed = update_window_and_padding(win, input0_access, input1_access) || // window used by the execute_window_loop
+                         update_window_and_padding(win_out, output_access);              // window used to update the padding requirements of output tensor
 
-        output_access.set_valid_region(win, ValidRegion(Coordinates(0, 0), output->tensor_shape()));
+        output_access.set_valid_region(win_out, ValidRegion(Coordinates(0, 0), output->tensor_shape()));
     }
     else // The input tensors have not been reshaped
     {
@@ -145,6 +160,11 @@
         num_elems_processed_per_iteration_x = max_cl_vector_width / data_size_from_type(data_type);
         num_elems_processed_per_iteration_y = std::min(static_cast<int>(output->dimension(1)), 4);
 
+        // Note: bottom paddings are calculated manually as the output can be reinterpreted as 3D tensor
+        // The only way to set properly the paddings, it is to set those explicitly through the AccessWindowStatic
+        const int m          = input0->tensor_shape()[1];
+        const int bottom_pad = (num_elems_processed_per_iteration_y - (m % num_elems_processed_per_iteration_y)) % num_elems_processed_per_iteration_y;
+
         // Create kernels according to the architecture, data type and input size.
         GPUTarget arch_target = get_arch_from_target(gpu_target);
         if(arch_target == GPUTarget::BIFROST && data_type == DataType::F32)
@@ -153,17 +173,21 @@
         }
 
         // Configure window
-        win = calculate_max_window(*output, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
+        win     = calculate_max_window(tmp_info, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
+        win_out = calculate_max_window(*output, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
 
-        AccessWindowStatic    input0_access(input0, 0, 0, input0->dimension(0), ceil_to_multiple(input0->dimension(1), num_elems_processed_per_iteration_y));
-        AccessWindowStatic    input1_access(input1, 0, 0, ceil_to_multiple(input1->dimension(0), num_elems_processed_per_iteration_x), input1->dimension(1));
-        AccessWindowRectangle output_access(output, 0, 0, num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y);
+        AccessWindowStatic input0_access(input0, 0, 0, input0->dimension(0), ceil_to_multiple(input0->dimension(1), num_elems_processed_per_iteration_y));
+        AccessWindowStatic input1_access(input1, 0, 0, ceil_to_multiple(input1->dimension(0), num_elems_processed_per_iteration_x), input1->dimension(1));
+        AccessWindowStatic output_access(output, 0, 0,
+                                         ceil_to_multiple(output->dimension(0), num_elems_processed_per_iteration_x),
+                                         output->dimension(1) + bottom_pad);
 
-        window_changed = update_window_and_padding(win, input0_access, input1_access, output_access);
+        window_changed = update_window_and_padding(win, input0_access, input1_access) || // window used by the execute_window_loop
+                         update_window_and_padding(win_out, output_access);              // window used to update the padding requirements of output tensor
 
         Coordinates coord;
         coord.set_num_dimensions(output->num_dimensions());
-        output_access.set_valid_region(win, ValidRegion(coord, output->tensor_shape()));
+        output_access.set_valid_region(win_out, ValidRegion(coord, output->tensor_shape()));
     }
 
     // Collapse along the Z direction
@@ -178,7 +202,7 @@
 } // namespace
 
 CLGEMMMatrixMultiplyKernel::CLGEMMMatrixMultiplyKernel()
-    : _input0(nullptr), _input1(nullptr), _output(nullptr), _slide_matrix_b(true)
+    : _input0(nullptr), _input1(nullptr), _output(nullptr), _slide_matrix_b(true), _is_gemm3d(false)
 {
 }
 
@@ -194,9 +218,14 @@
     _output         = output;
     _slide_matrix_b = _input1->info()->num_dimensions() >= _input0->info()->num_dimensions();
 
-    const DataType  data_type  = input0->info()->data_type();
-    const int       fp_pos     = input0->info()->fixed_point_position();
-    const GPUTarget gpu_target = get_target();
+    const DataType data_type = input0->info()->data_type();
+    const int      fp_pos    = input0->info()->fixed_point_position();
+
+    // Get target architecture
+    GPUTarget gpu_target = get_target();
+
+    // Check if the output has to be reinterpreted as 3D
+    _is_gemm3d = (reshape_info.depth_output_gemm3d() != 1) && is_data_type_float(data_type);
 
     ElementsProcessed num_elements_processed{};
 
@@ -216,6 +245,9 @@
                                       "-DALPHA=" + support::cpp11::to_string((data_type == DataType::QS8 ? sqcvt_qs8_f32(alpha, fp_pos) : sqcvt_qs16_f32(alpha, fp_pos))),
                                       "-DALPHA=" + float_to_string_with_full_precision(alpha));
     }
+    build_opts.add_option_if(_is_gemm3d, "-DREINTERPRET_OUTPUT_AS_3D");
+    build_opts.add_option_if(_is_gemm3d, "-DHEIGHT_GEMM3D=" + support::cpp11::to_string(output->info()->dimension(1)));
+    build_opts.add_option_if(_is_gemm3d, "-DDEPTH_GEMM3D=" + support::cpp11::to_string(output->info()->dimension(2)));
 
     // Do not slide matrix B if _slide_matrix_b = false
     build_opts.add_option_if(!_slide_matrix_b, "-DMATRIX_B_DEPTH=" + support::cpp11::to_string(input1->info()->dimension(2)));
@@ -285,6 +317,7 @@
     // Set config_id for enabling LWS tuning
     _config_id = "gemm_";
     _config_id += (is_interleaved_transposed ? "reshaped_" : "");
+    _config_id += (_is_gemm3d ? "3d_" : "");
     _config_id += lower_string(string_from_data_type(input0->info()->data_type()));
     _config_id += "_";
     _config_id += support::cpp11::to_string(output->info()->dimension(1));
@@ -334,6 +367,13 @@
     slice_matrix_b.set(Window::DimX, Window::Dimension(0, 1, 1));
     slice_matrix_b.set(Window::DimY, Window::Dimension(0, 1, 1));
 
+    if(_is_gemm3d)
+    {
+        // Pass bottom paddings to the kernel if the output has to be reinterpreted as 3D tensor
+        const unsigned int idx0 = 3 * num_arguments_per_2D_tensor() + 3;
+        _kernel.setArg<cl_uint>(idx0, static_cast<unsigned int>(_output->info()->padding().bottom));
+    }
+
     do
     {
         Window slice_b = slice;
diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp
index a0ec66f..f9713bb 100644
--- a/src/runtime/CL/functions/CLGEMM.cpp
+++ b/src/runtime/CL/functions/CLGEMM.cpp
@@ -24,10 +24,6 @@
 #include "arm_compute/runtime/CL/functions/CLGEMM.h"
 
 #include "arm_compute/core/CL/ICLTensor.h"
-#include "arm_compute/core/CL/kernels/CLGEMMInterleave4x4Kernel.h"
-#include "arm_compute/core/CL/kernels/CLGEMMMatrixAdditionKernel.h"
-#include "arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyKernel.h"
-#include "arm_compute/core/CL/kernels/CLGEMMTranspose1xWKernel.h"
 #include "arm_compute/core/Error.h"
 #include "arm_compute/core/GPUTarget.h"
 #include "arm_compute/core/Helpers.h"
@@ -111,6 +107,7 @@
     const int m                         = a->info()->dimension(1);
     const int n                         = b->info()->dimension(0);
     const int k                         = a->info()->dimension(0);
+    const int depth_output_gemm3d       = gemm_info.depth_output_gemm3d();
     int       mult_transpose1xW_width   = 1;
     int       mult_interleave4x4_height = 1;
 
@@ -144,7 +141,7 @@
     }
 
     // Configure and tune matrix multiply kernel
-    _mm_kernel.configure(matrix_a, matrix_b, output, alpha, _is_interleaved_transposed, GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height));
+    _mm_kernel.configure(matrix_a, matrix_b, output, alpha, _is_interleaved_transposed, GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d));
     CLScheduler::get().tune_kernel_static(_mm_kernel);
 
     if(_is_interleaved_transposed)
@@ -197,7 +194,7 @@
         mult_interleave4x4_height = 2;
     }
 
-    const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height);
+    const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, gemm_info.depth_output_gemm3d());
 
     // Check if we need to reshape the matrix A and matrix B
     const bool run_interleave_transpose = is_interleaved_transposed(m, n, k, a->data_type(), reshape_b_only_on_first_run, gpu_target);