COMPMID-1026 - Add support for 4x4 output tile in CLWinogradConvolutionLayer

The performance achieved can be found at the following confluence page:
https://confluence.arm.com/display/MLENG/GEMM-based+convolution+vs+Winograd-based+convolution+on+OpenCL

Change-Id: I4b690cfdd4eb4ff0cd17b14fdd49ccaa1d1dc85c
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/127729
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
diff --git a/src/core/CL/CLKernelLibrary.cpp b/src/core/CL/CLKernelLibrary.cpp
index 50f623f..59be956 100644
--- a/src/core/CL/CLKernelLibrary.cpp
+++ b/src/core/CL/CLKernelLibrary.cpp
@@ -363,7 +363,9 @@
     { "winograd_input_transform_4x4_5x5_stepz1_nchw", "winograd.cl" },
     { "winograd_input_transform_2x2_3x3_stepz1_nchw", "winograd.cl" },
     { "winograd_input_transform_2x2_3x3_stepz2_nchw", "winograd.cl" },
+    { "winograd_input_transform_4x4_3x3_stepz1_nchw", "winograd.cl" },
     { "winograd_output_transform_2x2_3x3_nchw", "winograd.cl" },
+    { "winograd_output_transform_4x4_3x3_nchw", "winograd.cl" },
     { "winograd_output_transform_4x4_5x5_nchw", "winograd.cl" },
     { "YUYV422_to_IYUV_bt709", "color_convert.cl" },
     { "YUYV422_to_NV12_bt709", "color_convert.cl" },
diff --git a/src/core/CL/cl_kernels/winograd.cl b/src/core/CL/cl_kernels/winograd.cl
index cda23b0..f40a969 100644
--- a/src/core/CL/cl_kernels/winograd.cl
+++ b/src/core/CL/cl_kernels/winograd.cl
@@ -708,6 +708,265 @@
     vstore2(out33, 0, (__global float *)(dst_addr + 15 * dst_stride_z));
 }
 
+/** This OpenCL kernel computes the input transform when the output tile is 4x4, the filter size 3x3 and the data format is NCHW
+ *
+ * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
+ * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
+ *
+ * @param[in] src_ptr                           Pointer to the source image. Supported data types: F32
+ * @param[in] src_stride_x                      Stride of the source image in X dimension (in bytes)
+ * @param[in] src_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] src_stride_y                      Stride of the source image in Y dimension (in bytes)
+ * @param[in] src_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
+ * @param[in] src_stride_z                      Stride of the source tensor in Z dimension (in bytes)
+ * @param[in] src_step_z                        src_stride_z * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_ptr                           Pointer to the destination tensor. Supported data types: as @p src_ptr
+ * @param[in] dst_stride_x                      Stride of the destination tensor in X dimension (in bytes)
+ * @param[in] dst_step_x                        dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in] dst_stride_y                      Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in] dst_step_y                        dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_stride_z                      Stride of the destination tensor in Z dimension (in bytes)
+ * @param[in] dst_step_z                        dst_stride_z * number of elements along Y processed per workitem(in bytes)
+ * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ */
+__kernel void winograd_input_transform_4x4_3x3_stepz1_nchw(
+    TENSOR3D_DECLARATION(src),
+    TENSOR3D_DECLARATION(dst))
+{
+    int x = get_global_id(0);
+    int y = get_global_id(1);
+    int z = get_global_id(2);
+
+    // Compute input address
+    __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * 4 * src_stride_x + y * 4 * src_stride_y + z * src_stride_z;
+
+    src_addr = src_addr - ((int)PAD_LEFT * src_stride_x) - ((int)PAD_TOP * src_stride_y);
+
+    // Row4
+    float4 d40 = vload4(0, (__global float *)(src_addr + 4 * src_stride_y));
+    float2 d41 = vload2(2, (__global float *)(src_addr + 4 * src_stride_y));
+
+    float k0 = d41.s0;
+    float k1 = d41.s0;
+    float k2 = d41.s0;
+    float k3 = d41.s0;
+    float k4 = d41.s0;
+    float k5 = 0.0f;
+
+    k0 += 4.0f * d40.s0 - 5.0f * d40.s2;
+    k1 += -4.0f * d40.s1 - 4.0f * d40.s2 + d40.s3;
+    k2 += 4.0f * d40.s1 - 4.0f * d40.s2 - d40.s3;
+    k3 += -2.0f * d40.s1 + 2.0f * d40.s3 - d40.s2;
+    k4 += 2.0f * d40.s1 - 2.0f * d40.s3 - d40.s2;
+    k5 += 4.0f * d40.s1 - 5.0f * d40.s3 + d41.s1;
+
+    // Row0
+    float4 d00 = vload4(0, (__global float *)(src_addr + 0 * src_stride_y));
+    float2 d01 = vload2(2, (__global float *)(src_addr + 0 * src_stride_y));
+
+    // Row2
+    float4 d20 = vload4(0, (__global float *)(src_addr + 2 * src_stride_y));
+    float2 d21 = vload2(2, (__global float *)(src_addr + 2 * src_stride_y));
+
+    // Compute destination address
+    __global float *dst_addr = (__global float *)(dst_ptr + dst_offset_first_element_in_bytes + z * dst_stride_x + (x + y * (int)NUM_TILES_X) * dst_stride_y);
+
+    uint dst_plane_stride = dst_stride_z / sizeof(float);
+
+    float out0  = k0;
+    float out1  = k1;
+    float out2  = k2;
+    float out3  = k3;
+    float out4  = k4;
+    float out5  = k5;
+    float out6  = k0;
+    float out7  = k1;
+    float out8  = k2;
+    float out9  = k3;
+    float out10 = k4;
+    float out11 = k5;
+    float out12 = k0;
+    float out13 = k1;
+    float out14 = k2;
+    float out15 = k3;
+    float out16 = k4;
+    float out17 = k5;
+    float out18 = k0;
+    float out19 = k1;
+    float out20 = k2;
+    float out21 = k3;
+    float out22 = k4;
+    float out23 = k5;
+    float out24 = k0;
+    float out25 = k1;
+    float out26 = k2;
+    float out27 = k3;
+    float out28 = k4;
+    float out29 = k5;
+
+    // Channels [0, 5]: [out00, out01, out02, out03, out04, out05]
+    out0 += 16.0f * d00.s0 - 20.0f * d00.s2 - 20.0f * d20.s0 + 25.0f * d20.s2 + 4.0f * d01.s0 - 5.0f * d21.s0;
+    out1 += -16.0f * d00.s1 - 16.0f * d00.s2 + 4.0f * d00.s3 + 20.0f * d20.s1 + 20.0f * d20.s2 - 5.0f * d20.s3 + 4.0f * d01.s0 - 5.0f * d21.s0;
+    out2 += 16.0f * d00.s1 - 16.0f * d00.s2 - 4.0f * d00.s3 - 20.0f * d20.s1 + 20.0f * d20.s2 + 5.0f * d20.s3 + 4.0f * d01.s0 - 5.0f * d21.s0;
+    out3 += -8.0f * d00.s1 - 4.0f * d00.s2 + 8.0f * d00.s3 + 10.0f * d20.s1 + 5.0f * d20.s2 - 10.0f * d20.s3 + 4.0f * d01.s0 - 5.0f * d21.s0;
+    out4 += 8.0f * d00.s1 - 4.0f * d00.s2 - 8.0f * d00.s3 - 10.0f * d20.s1 + 5.0f * d20.s2 + 10.0f * d20.s3 + 4.0f * d01.s0 - 5.0f * d21.s0;
+    out5 += 16.0f * d00.s1 - 20.0f * d00.s3 - 20.0f * d20.s1 + 4.0f * d01.s1 + 25.0f * d20.s3 - 5.0f * d21.s1;
+
+    *(dst_addr) = out0;
+    dst_addr += dst_plane_stride;
+    *(dst_addr) = out1;
+    dst_addr += dst_plane_stride;
+    *(dst_addr) = out2;
+    dst_addr += dst_plane_stride;
+    *(dst_addr) = out3;
+    dst_addr += dst_plane_stride;
+    *(dst_addr) = out4;
+    dst_addr += dst_plane_stride;
+    *(dst_addr) = out5;
+    dst_addr += dst_plane_stride;
+
+    // Row1
+    float4 d10 = vload4(0, (__global float *)(src_addr + 1 * src_stride_y));
+    float2 d11 = vload2(2, (__global float *)(src_addr + 1 * src_stride_y));
+
+    // Row3
+    float4 d30 = vload4(0, (__global float *)(src_addr + 3 * src_stride_y));
+    float2 d31 = vload2(2, (__global float *)(src_addr + 3 * src_stride_y));
+
+    // Compute common parts for the channels between [6, 29]
+    // Channels [6, 11]:  [out10, out11, out12, out13, out14, out15]
+    // Channels [12, 17]: [out20, out21, out22, out23, out24, out25]
+    float part0  = -16.0f * d20.s0 + 20.0f * d20.s2 - 4.0f * d21.s0;
+    float part1  = 16.0f * d10.s0 - 20.0f * d10.s2 + 4.0f * d11.s0 - 4.0f * d30.s0 + 5.0f * d30.s2 - d31.s0;
+    float part2  = 16.0f * d20.s2 - 4.0f * d21.s0;
+    float part3  = 16.0f * d20.s1 - 4.0f * d20.s3;
+    float part4  = 16.0f * d10.s2 - 4.0f * d11.s0 - 4.0f * d30.s2 + d31.s0;
+    float part5  = 16.0f * d10.s1 - 4.0f * d10.s3 - 4.0f * d30.s1 + d30.s3;
+    float part6  = 4.0f * d20.s2 - 4.0f * d21.s0;
+    float part7  = 8.0f * d10.s1 - 8.0f * d10.s3 - 2.0f * d30.s1 + 2.0f * d30.s3;
+    float part8  = 4.0f * d10.s2 - 4.0f * d11.s0 - d30.s2 + d31.s0;
+    float part9  = 8.0f * d20.s1 - 8.0f * d20.s3;
+    float part10 = -16.0f * d20.s1 + 20.0f * d20.s3 - 4.0f * d21.s1;
+    float part11 = -16.0f * d10.s1 + 20.0f * d10.s3 - 4.0f * d11.s1 + 4.0f * d30.s1 - 5.0f * d30.s3 + d31.s1;
+
+    // Channels [18, 23]: [out30, out31, out32, out33, out34, out35]
+    // Channels [24, 29]: [out40, out41, out42, out43, out44, out45]
+    float part12 = 8.0f * d10.s0 - 10.0f * d10.s2 + 2.0f * d11.s0 - 8.0f * d30.s0 + 10.0f * d30.s2 - 2.0f * d31.s0;
+    float part13 = part0 * 0.25f; // -4.0f * d20.s0 + 5.0f * d20.s2 - d21.s0
+    float part14 = part2 * 0.25f; // 4.0f * d20.s2 - d21.s0
+    float part15 = 8.0f * d10.s1 - 2.0f * d10.s3 - 8.0f * d30.s1 + 2.0f * d30.s3;
+    float part16 = 8.0f * d10.s2 - 2.0f * d11.s0 - 8.0f * d30.s2 + 2.0f * d31.s0;
+    float part17 = part3 * 0.25f; // 4.0f * d20.s1 - d20.s3
+    float part18 = part6 * 0.25f; // d20.s2 - d21.s0
+    float part19 = 4.0f * d10.s1 - 4.0f * d10.s3 - 4.0f * d30.s1 + 4.0f * d30.s3;
+    float part20 = 2.0f * d10.s2 - 2.0f * d11.s0 - 2.0f * d30.s2 + 2.0f * d31.s0;
+    float part21 = part9 * 0.25f;                                                 // 2.0f * (d20.s1 - d20.s3)
+    float part22 = part10 * 0.25f;                                                // - 4.0f * d20.s1 + 5.0f * d20.s3 - d21.s1
+    float part23 = part11 * 0.5f + 6.0f * d30.s1 - 7.5f * d30.s3 + 1.5f * d31.s1; // - 8.0f * d10.s1 + 10.0f * d10.s3 - 2.0f * d11.s1 + 8.0f * d30.s1 - 10.0f * d30.s3 + 2.0f * d31.s1;
+
+    out6 += part0 - part1;
+    out12 += part0 + part1;
+    out7 += part2 + part3 + part4 + part5;
+    out8 += part2 - part3 + part4 - part5;
+    out13 += part2 + part3 - part4 - part5;
+    out14 += part2 - part3 - part4 + part5;
+    out9 += part6 + part7 + part8 + part9;
+    out10 += part6 - part7 + part8 - part9;
+    out15 += part6 - part7 - part8 + part9;
+    out16 += part6 + part7 - part8 - part9;
+    out11 += part10 + part11;
+    out17 += part10 - part11;
+
+    out18 += part13 - part12;
+    out24 += part13 + part12;
+    out19 += part14 + part15 + part16 + part17;
+    out20 += part14 - part15 + part16 - part17;
+    out25 += part14 - part15 - part16 + part17;
+    out26 += part14 + part15 - part16 - part17;
+    out21 += part18 + part19 + part20 + part21;
+    out22 += part18 - part19 + part20 - part21;
+    out27 += part18 - part19 - part20 + part21;
+    out28 += part18 + part19 - part20 - part21;
+    out23 += part22 + part23;
+    out29 += part22 - part23;
+
+    *(dst_addr) = out6;
+    dst_addr += dst_plane_stride;
+    *(dst_addr) = out7;
+    dst_addr += dst_plane_stride;
+    *(dst_addr) = out8;
+    dst_addr += dst_plane_stride;
+    *(dst_addr) = out9;
+    dst_addr += dst_plane_stride;
+    *(dst_addr) = out10;
+    dst_addr += dst_plane_stride;
+    *(dst_addr) = out11;
+    dst_addr += dst_plane_stride;
+    *(dst_addr) = out12;
+    dst_addr += dst_plane_stride;
+    *(dst_addr) = out13;
+    dst_addr += dst_plane_stride;
+    *(dst_addr) = out14;
+    dst_addr += dst_plane_stride;
+    *(dst_addr) = out15;
+    dst_addr += dst_plane_stride;
+    *(dst_addr) = out16;
+    dst_addr += dst_plane_stride;
+    *(dst_addr) = out17;
+    dst_addr += dst_plane_stride;
+
+    *(dst_addr) = out18;
+    dst_addr += dst_plane_stride;
+    *(dst_addr) = out19;
+    dst_addr += dst_plane_stride;
+    *(dst_addr) = out20;
+    dst_addr += dst_plane_stride;
+    *(dst_addr) = out21;
+    dst_addr += dst_plane_stride;
+    *(dst_addr) = out22;
+    dst_addr += dst_plane_stride;
+    *(dst_addr) = out23;
+    dst_addr += dst_plane_stride;
+    *(dst_addr) = out24;
+    dst_addr += dst_plane_stride;
+    *(dst_addr) = out25;
+    dst_addr += dst_plane_stride;
+    *(dst_addr) = out26;
+    dst_addr += dst_plane_stride;
+    *(dst_addr) = out27;
+    dst_addr += dst_plane_stride;
+    *(dst_addr) = out28;
+    dst_addr += dst_plane_stride;
+    *(dst_addr) = out29;
+    dst_addr += dst_plane_stride;
+
+    // Row5
+    float4 d50 = vload4(0, (__global float *)(src_addr + 5 * src_stride_y));
+    float2 d51 = vload2(2, (__global float *)(src_addr + 5 * src_stride_y));
+
+    // Channels [30, 35]
+    out0 = 16.0f * d10.s0 - 20.0f * d10.s2 - 20.0f * d30.s0 + 25.0f * d30.s2 + 4.0f * d50.s0 - 5.0f * d50.s2 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
+    out1 = -16.0f * d10.s1 - 16.0f * d10.s2 + 4.0f * d10.s3 + 20.0f * d30.s1 + 20.0f * d30.s2 - 5.0f * d30.s3 - 4.0f * d50.s1 - 4.0f * d50.s2 + d50.s3 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
+    out2 = 16.0f * d10.s1 - 16.0f * d10.s2 - 4.0f * d10.s3 - 20.0f * d30.s1 + 20.0f * d30.s2 + 5.0f * d30.s3 + 4.0f * d50.s1 - 4.0f * d50.s2 - d50.s3 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
+    out3 = -8.0f * d10.s1 - 4.0f * d10.s2 + 8.0f * d10.s3 + 10.0f * d30.s1 - 10.0f * d30.s3 + 5.0f * d30.s2 - 2.0f * d50.s1 + 2.0f * d50.s3 - d50.s2 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
+    out4 = 8.0f * d10.s1 - 4.0f * d10.s2 - 8.0f * d10.s3 - 10.0f * d30.s1 + 5.0f * d30.s2 + 10.0f * d30.s3 + 2.0f * d50.s1 - 2.0f * d50.s3 - d50.s2 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
+    out5 = 16.0f * d10.s1 - 20.0f * d10.s3 + 4.0f * d11.s1 - 20.0f * d30.s1 + 25.0f * d30.s3 - 5.0f * d31.s1 + 4.0f * d50.s1 - 5.0f * d50.s3 + d51.s1;
+
+    *(dst_addr) = out0;
+    dst_addr += dst_plane_stride;
+    *(dst_addr) = out1;
+    dst_addr += dst_plane_stride;
+    *(dst_addr) = out2;
+    dst_addr += dst_plane_stride;
+    *(dst_addr) = out3;
+    dst_addr += dst_plane_stride;
+    *(dst_addr) = out4;
+    dst_addr += dst_plane_stride;
+    *(dst_addr) = out5;
+    dst_addr += dst_plane_stride;
+}
+
 #define OUTPUT_ROW_4x4_5x5(out, tmp, comm_fact)                     \
     ({                                                              \
         comm_fact.s0 = tmp.s2 - 4.25f * tmp.s4 + tmp.s6;            \
@@ -981,6 +1240,183 @@
     vstore2((float2)(out10, out11), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
 }
 
+/** This OpenCL kernel performs Winograd output transform when the output tile is 4x4, the filter size 3x3 and the data format is NCHW
+ *
+ * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16
+ *
+ * @param[in]  src_ptr                           Pointer to the source tensor. Supported data types: F32
+ * @param[in]  src_stride_x                      Stride of the source tensor in X dimension (in bytes)
+ * @param[in]  src_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in]  src_stride_y                      Stride of the source tensor in Y dimension (in bytes)
+ * @param[in]  src_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in]  src_stride_z                      Stride of the source tensor in Z dimension (in bytes)
+ * @param[in]  src_step_z                        src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in]  src_offset_first_element_in_bytes The offset of the first element in the source tensor
+ * @param[out] dst_ptr                           Pointer to the destination tensor. Supported data types: same as @p src_ptr
+ * @param[in]  dst_stride_x                      Stride of the destination tensor in X dimension (in bytes)
+ * @param[in]  dst_step_x                        dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in]  dst_stride_y                      Stride of the destination tensor in Y dimension (in bytes)
+ * @param[in]  dst_step_y                        dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in]  src_stride_z                      Stride of the source tensor in Z dimension (in bytes)
+ * @param[in]  src_step_z                        src_stride_z * number of elements along Z processed per workitem(in bytes)
+ * @param[in]  dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
+ */
+__kernel void winograd_output_transform_4x4_3x3_nchw(
+    TENSOR3D_DECLARATION(src),
+    TENSOR3D_DECLARATION(dst)
+#if defined(HAS_BIAS)
+    ,
+    VECTOR_DECLARATION(bias)
+#endif // defined(HAS_BIAS)
+)
+{
+    // Each thread stores a 4x4 tile
+    Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
+
+    const __global uchar *src_addr = tensor3D_offset(&src, 0, 0, 0);
+
+    // Load the values across the 36 channels to compose the 6x6 tile
+    float d00 = *((__global float *)(src_addr + 0 * src_stride_z));
+    float d01 = *((__global float *)(src_addr + 1 * src_stride_z));
+    float d02 = *((__global float *)(src_addr + 2 * src_stride_z));
+    float d03 = *((__global float *)(src_addr + 3 * src_stride_z));
+    float d04 = *((__global float *)(src_addr + 4 * src_stride_z));
+    float d05 = *((__global float *)(src_addr + 5 * src_stride_z));
+
+    float d10 = *((__global float *)(src_addr + 6 * src_stride_z));
+    float d11 = *((__global float *)(src_addr + 7 * src_stride_z));
+    float d12 = *((__global float *)(src_addr + 8 * src_stride_z));
+    float d13 = *((__global float *)(src_addr + 9 * src_stride_z));
+    float d14 = *((__global float *)(src_addr + 10 * src_stride_z));
+    float d15 = *((__global float *)(src_addr + 11 * src_stride_z));
+
+    float d20 = *((__global float *)(src_addr + 12 * src_stride_z));
+    float d21 = *((__global float *)(src_addr + 13 * src_stride_z));
+    float d22 = *((__global float *)(src_addr + 14 * src_stride_z));
+    float d23 = *((__global float *)(src_addr + 15 * src_stride_z));
+    float d24 = *((__global float *)(src_addr + 16 * src_stride_z));
+    float d25 = *((__global float *)(src_addr + 17 * src_stride_z));
+
+    float d30 = *((__global float *)(src_addr + 18 * src_stride_z));
+    float d31 = *((__global float *)(src_addr + 19 * src_stride_z));
+    float d32 = *((__global float *)(src_addr + 20 * src_stride_z));
+    float d33 = *((__global float *)(src_addr + 21 * src_stride_z));
+    float d34 = *((__global float *)(src_addr + 22 * src_stride_z));
+    float d35 = *((__global float *)(src_addr + 23 * src_stride_z));
+
+    float d40 = *((__global float *)(src_addr + 24 * src_stride_z));
+    float d41 = *((__global float *)(src_addr + 25 * src_stride_z));
+    float d42 = *((__global float *)(src_addr + 26 * src_stride_z));
+    float d43 = *((__global float *)(src_addr + 27 * src_stride_z));
+    float d44 = *((__global float *)(src_addr + 28 * src_stride_z));
+    float d45 = *((__global float *)(src_addr + 29 * src_stride_z));
+
+    float d50 = *((__global float *)(src_addr + 30 * src_stride_z));
+    float d51 = *((__global float *)(src_addr + 31 * src_stride_z));
+    float d52 = *((__global float *)(src_addr + 32 * src_stride_z));
+    float d53 = *((__global float *)(src_addr + 33 * src_stride_z));
+    float d54 = *((__global float *)(src_addr + 34 * src_stride_z));
+    float d55 = *((__global float *)(src_addr + 35 * src_stride_z));
+
+    // Compute out00, out01, out02 and out03
+    float out00 = d01 + d21 + d41 + d11 + d31;
+    float out01 = d01 + d21 + d41 + d11 + d31;
+    float out02 = d01 + d21 + d41 + d11 + d31;
+    float out03 = d01 + d21 + d41 + d11 + d31;
+
+    float k0 = d03 + d04 + d13 + d14 + d23 + d24 + d33 + d34 + d43 + d44;
+    float k1 = 2.0f * d03 - 2.0f * d04 + 2.0f * d13 - 2.0f * d14 + 2.0f * d23 - 2.0f * d24 + 2.0f * d33 - 2.0f * d34 + 2.0f * d43 - 2.0f * d44;
+
+    out00 += k0 + d00 + d02 + d10 + d12 + d20 + d22 + d30 + d32 + d40 + d42;
+    out01 += k1 - d02 - d12 - d22 - d32 - d42;
+    out02 += 4.0f * k0 + d02 + d12 + d22 + d32 + d42;
+    out03 += 4.0f * k1 - d02 - d12 - d22 - d32 - d42 + d05 + d15 + d25 + d35 + d45;
+
+    // Compute out10, out11, out12 and out13
+    float out10 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
+    float out11 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
+    float out12 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
+    float out13 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
+
+    k0 = d13 + d14 - d23 - d24 + 2.0f * d33 + 2.0f * d34 - 2.0f * d43 - 2.0f * d44;
+    k1 = 2.0f * d13 - 2.0f * d14 - 2.0f * d23 + 2.0f * d24 + 4.0f * d33 - 4.0f * d34 - 4.0f * d43 + 4.0f * d44;
+
+    out10 += k0 + d10 + d12 - d20 - d22 + 2.0f * d30 + 2.0f * d32 - 2.0f * d40 - 2.0f * d42;
+    out11 += k1 - d12 + d22 - 2.0f * d32 + 2.0f * d42;
+    out12 += 4.0f * k0 + d12 - d22 + 2.0f * d32 - 2.0f * d42;
+    out13 += 4.0f * k1 - d12 + d15 + d22 - d25 - 2.0f * d32 + 2.0f * d35 + 2.0f * d42 - 2.0f * d45;
+
+    // Compute out20, out21, out22 and out23
+    float out20 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
+    float out21 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
+    float out22 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
+    float out23 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
+
+    k0 = d13 + d14 + d23 + d24 + 4.0f * d33 + 4.0f * d34 + 4.0f * d43 + 4.0f * d44;
+    k1 = 2.0f * d13 - 2.0f * d14 + 2.0f * d23 - 2.0f * d24 + 8.0f * d33 - 8.0f * d34 + 8.0f * d43 - 8.0f * d44;
+
+    out20 += k0 + d10 + d12 + d20 + d22 + 4.0f * d30 + 4.0f * d32 + 4.0f * d40 + 4.0f * d42;
+    out21 += k1 - d12 - d22 - 4.0f * d32 - 4.0f * d42;
+    out22 += 4.0f * k0 + d12 + d22 + 4.0f * d32 + 4.0f * d42;
+    out23 += 4.0f * k1 - d12 + d15 - d22 + d25 - 4.0f * d32 + 4.0f * d35 - 4.0f * d42 + 4.0f * d45;
+
+    // Compute out30, out31, out32 and out33
+    float out30 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
+    float out31 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
+    float out32 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
+    float out33 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
+
+    k0 = d13 + d14 - d23 - d24 + 8.0f * d33 + 8.0f * d34 - 8.0f * d43 - 8.0f * d44 + d53 + d54;
+    k1 = 2.0f * d13 - 2.0f * d14 - 2.0f * d23 + 2.0f * d24 + 16.0f * d33 - 16.0f * d34 - 16.0f * d43 + 16.0f * d44 + 2.0f * d53 - 2.0f * d54;
+
+    out30 += k0 + d10 + d12 - d20 - d22 + 8.0f * d30 + 8.0f * d32 - 8.0f * d40 - 8.0f * d42 + d50 + d52;
+    out31 += k1 - d12 + d22 - 8.0f * d32 + 8.0f * d42 - d52;
+    out32 += 4.0f * k0 + d12 - d22 + 8.0f * d32 - 8.0f * d42 + d52;
+    out33 += 4.0f * k1 - d12 + d15 + d22 - d25 - 8.0f * d32 + 8.0f * d35 + 8.0f * d42 - 8.0f * d45 - d52 + d55;
+
+    int y_in  = get_global_id(1);
+    int x_out = (y_in % NUM_TILES_X) * 4;
+    int y_out = (y_in / NUM_TILES_X) * 4;
+    int z_out = get_global_id(0);
+
+#if defined(HAS_BIAS)
+    // Add bias
+    Vector bias = CONVERT_TO_VECTOR_STRUCT_NO_STEP(bias);
+
+    float b = (float) * ((__global float *)(vector_offset(&bias, z_out)));
+
+    out00 += (float)b;
+    out01 += (float)b;
+    out02 += (float)b;
+    out03 += (float)b;
+
+    out10 += (float)b;
+    out11 += (float)b;
+    out12 += (float)b;
+    out13 += (float)b;
+
+    out20 += (float)b;
+    out21 += (float)b;
+    out22 += (float)b;
+    out23 += (float)b;
+
+    out30 += (float)b;
+    out31 += (float)b;
+    out32 += (float)b;
+    out33 += (float)b;
+
+#endif // defined(HAS_BIAS)
+
+    // Get output address
+    __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x_out * dst_stride_x + y_out * dst_stride_y + z_out * dst_stride_z;
+
+    // Store the 4x4 output tile
+    vstore4((float4)(out00, out01, out02, out03), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
+    vstore4((float4)(out10, out11, out12, out13), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
+    vstore4((float4)(out20, out21, out22, out23), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
+    vstore4((float4)(out30, out31, out32, out33), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
+}
+
 #define COMPUTE_TMP_COL(col, d0, d1, d2, d3, d4, d5, d6, d7, comm_fact)  \
     ({                                                                   \
         comm_fact.s0 = d1 + d2;                                          \
diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
index 0a0de7a..805a594 100644
--- a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
@@ -286,17 +286,23 @@
     else // The input tensors have not been reshaped
     {
         build_opts.add_option("-DCOLS_A=" + support::cpp11::to_string(input0->info()->dimension(0)));
+        build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(data_type));
 
         // Create kernels according to the architecture, data type and input size.
         if(gpu_target_is_in(gpu_target, GPUTarget::G71, GPUTarget::G72, GPUTarget::G51, GPUTarget::G51BIG, GPUTarget::G51LIT, GPUTarget::TNOX) && is_data_type_float(data_type))
         {
-            kernel_name = "gemm_mm_floating_point_" + lower_string(string_from_data_type(data_type)) + "_bifrost";
-            // The first kernel is optimized for the case of 1000 or less output elements (e.g. FC8 of AlexNet and VGG-16, and
-            // FC1 of Inception v3). The second kernel is optimized for the case of greater than 1000 output elements (e.g.
-            // FC6 and FC7 of AlexNet and VGG-16).
-            if(input1->info()->dimension(0) <= 1000 && input0->info()->num_dimensions() == 1 && data_type == DataType::F32)
+            kernel_name = "gemm_mm_floating_point";
+
+            if(input0->info()->num_dimensions() != 1)
             {
-                kernel_name += "_1000";
+                kernel_name += "_" + lower_string(string_from_data_type(data_type)) + "_bifrost";
+            }
+            else if(input1->info()->dimension(0) <= 1000 && data_type == DataType::F32)
+            {
+                // The first kernel is optimized for the case of 1000 or less output elements (e.g. FC8 of AlexNet and VGG-16, and
+                // FC1 of Inception v3). The second kernel is optimized for the case of greater than 1000 output elements (e.g.
+                // FC6 and FC7 of AlexNet and VGG-16).
+                kernel_name += "_" + lower_string(string_from_data_type(data_type)) + "_bifrost_1000";
             }
 
             // The work-group size equal to the Bifrost quad size has been proved to be optimal for these kernels
@@ -309,7 +315,6 @@
         }
         else // (MIDGARD and F32) or (F16)
         {
-            build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(data_type));
             kernel_name = "gemm_mm_floating_point";
         }
         build_opts.add_option("-DNUM_ELEMS_PROCESSED_PER_THREAD_Y=" + support::cpp11::to_string(num_elements_processed.y()));
diff --git a/src/core/CL/kernels/CLWinogradFilterTransformKernel.cpp b/src/core/CL/kernels/CLWinogradFilterTransformKernel.cpp
index d3a33c0..41b3ac5 100644
--- a/src/core/CL/kernels/CLWinogradFilterTransformKernel.cpp
+++ b/src/core/CL/kernels/CLWinogradFilterTransformKernel.cpp
@@ -55,9 +55,11 @@
     const size_t idx_w = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::WIDTH);
     const size_t idx_h = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::HEIGHT);
 
-    ARM_COMPUTE_RETURN_ERROR_ON(kernel_size != Size2D(3U, 3U) && kernel_size != Size2D(5U, 5U));
-    ARM_COMPUTE_RETURN_ERROR_ON(kernel_size == Size2D(3U, 3U) && output_tile_size != Size2D(2U, 2U) && output_tile_size != Size2D(4U, 4U));
-    ARM_COMPUTE_RETURN_ERROR_ON(kernel_size == Size2D(5U, 5U) && output_tile_size != Size2D(4U, 4U));
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(kernel_size != Size2D(3U, 3U) && kernel_size != Size2D(5U, 5U), "Winograd filter transform only supports 3x3 and 5x5 kernels");
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(kernel_size == Size2D(3U, 3U) && output_tile_size != Size2D(2U, 2U)
+                                    && output_tile_size != Size2D(4U, 4U),
+                                    "Winograd filter transform only supports 2x2 or 4x4 output tile for 3x3 kernels");
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(kernel_size == Size2D(5U, 5U) && output_tile_size != Size2D(4U, 4U), "Winograd filter transform only supports 4x4 output tile for 5x5 kernels");
     ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(idx_w) != kernel_size.width || input->dimension(idx_h) != kernel_size.height);
     ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() > 4);
 
diff --git a/src/core/CL/kernels/CLWinogradInputTransformKernel.cpp b/src/core/CL/kernels/CLWinogradInputTransformKernel.cpp
index a47590d..febd22b 100644
--- a/src/core/CL/kernels/CLWinogradInputTransformKernel.cpp
+++ b/src/core/CL/kernels/CLWinogradInputTransformKernel.cpp
@@ -47,7 +47,9 @@
     const Size2D        kernel_size      = winograd_info.kernel_size;
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.stride().first != 1 || conv_info.stride().second != 1, "Winograd input transform only supports unit strides");
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(kernel_size != Size2D(3U, 3U) && kernel_size != Size2D(5U, 5U), "Winograd input transform only supports 3x3 and 5x5 kernels");
-    ARM_COMPUTE_RETURN_ERROR_ON_MSG(kernel_size == Size2D(3U, 3U) && output_tile_size != Size2D(2U, 2U), "Winograd input transform only supports 2x2 output tile for 3x3 kernels");
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(kernel_size == Size2D(3U, 3U) && output_tile_size != Size2D(2U, 2U)
+                                    && output_tile_size != Size2D(4U, 4U),
+                                    "Winograd input transform only supports 2x2 or 4x4 output tile for 3x3 kernels");
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(kernel_size == Size2D(5U, 5U) && output_tile_size != Size2D(4U, 4U), "Winograd input transform only supports 4x4 output tile for 5x5 kernels");
     ARM_COMPUTE_UNUSED(conv_info);
     ARM_COMPUTE_UNUSED(output_tile_size);
@@ -111,7 +113,6 @@
     const int num_elements_y = input->info()->dimension(1) - (kernel_size.height - 1) + conv_info.pad_top() + conv_info.pad_bottom();
 
     // Check if we need to extend the right or bottom border
-    // FIXME: This actually is not needed. Added just for validating the result;
     const unsigned int extra_border_right  = ((num_elements_x % output_tile_size.width) == 0) ? 0u : static_cast<unsigned int>(output_tile_size.width - 1);
     const unsigned int extra_border_bottom = ((num_elements_y % output_tile_size.height) == 0) ? 0u : static_cast<unsigned int>(output_tile_size.height - 1);
 
@@ -137,19 +138,13 @@
     std::string kernel_name = "winograd_input_transform_" + output_tile_size.to_string() + "_" + kernel_size.to_string();
 
     // Check optimized kernel if output_dims == 2x2
-    if(output_tile_size.width == 2 && output_tile_size.height == 2)
+    if(output_tile_size == Size2D(2U, 2U))
     {
-        if((_input->info()->dimension(2) % 2) != 0)
-        {
-            _step_z = 1;
-        }
-        else
-        {
-            _step_z   = 2;
-            _lws_hint = cl::NDRange(1, 1, 8);
-        }
+        _step_z = (_input->info()->dimension(2) % 2) != 0 ? 1 : 2;
     }
 
+    _lws_hint = cl::NDRange(1, 1, 8);
+
     // Append stepz and data layout
     kernel_name += "_stepz";
     kernel_name += support::cpp11::to_string(_step_z);
diff --git a/src/core/CL/kernels/CLWinogradOutputTransformKernel.cpp b/src/core/CL/kernels/CLWinogradOutputTransformKernel.cpp
index 8ee1a82..c5d2528 100644
--- a/src/core/CL/kernels/CLWinogradOutputTransformKernel.cpp
+++ b/src/core/CL/kernels/CLWinogradOutputTransformKernel.cpp
@@ -58,6 +58,7 @@
 
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(kernel_size != Size2D(3U, 3U) && kernel_size != Size2D(5U, 5U), "Only 3x3 and 5x5 kernels are supported");
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(kernel_size == Size2D(3U, 3U) && output_tile_size == Size2D(2U, 2U) && input->dimension(2) != 16, "Wrong number of batches");
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(kernel_size == Size2D(3U, 3U) && output_tile_size == Size2D(4U, 4U) && input->dimension(2) != 36, "Wrong number of batches");
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(kernel_size == Size2D(5U, 5U) && output_tile_size == Size2D(4U, 4U) && input->dimension(2) != 64, "Wrong number of batches");
 
     // Compute number of elements to process in the X and Y direction
@@ -67,7 +68,6 @@
     const int num_tiles_y    = std::ceil(num_elements_y / static_cast<float>(output_tile_size.height));
 
     ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(1) != static_cast<unsigned int>((num_tiles_x * num_tiles_y)));
-    ARM_COMPUTE_UNUSED(output_tile_size);
 
     if(bias != nullptr)
     {
@@ -207,4 +207,4 @@
         enqueue(queue, *this, slice, _lws_hint);
     }
     while(window.slide_window_slice_3D(slice) && window.slide_window_slice_3D(slice_out));
-}
\ No newline at end of file
+}
diff --git a/src/runtime/CL/functions/CLConvolutionLayer.cpp b/src/runtime/CL/functions/CLConvolutionLayer.cpp
index bcb5424..643e24d 100644
--- a/src/runtime/CL/functions/CLConvolutionLayer.cpp
+++ b/src/runtime/CL/functions/CLConvolutionLayer.cpp
@@ -48,9 +48,16 @@
     ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output);
     ARM_COMPUTE_ERROR_THROW_ON(CLConvolutionLayer::validate(input->info(), weights->info(), ((biases != nullptr) ? biases->info() : nullptr), output->info(), conv_info, weights_info, dilation, act_info));
 
-    switch(CLConvolutionLayer::get_convolution_method(input->info(), weights->info(), ((biases != nullptr) ? biases->info() : nullptr), output->info(), conv_info,
+    switch(CLConvolutionLayer::get_convolution_method(input->info(), weights->info(), output->info(), conv_info,
                                                       weights_info, act_info, CLScheduler::get().target(), dilation))
     {
+        case ConvolutionMethod::WINOGRAD:
+        {
+            auto f = arm_compute::support::cpp14::make_unique<CLWinogradConvolutionLayer>();
+            f->configure(input, weights, biases, output, conv_info);
+            _function = std::move(f);
+            break;
+        }
         case ConvolutionMethod::DIRECT:
         {
             auto f = arm_compute::support::cpp14::make_unique<CLDirectConvolutionLayer>();
@@ -79,8 +86,14 @@
     //Configure if the parameters match the direct convolution or the gemm-based
     const GPUTarget gpu_target = CLScheduler::get().target();
 
-    switch(CLConvolutionLayer::get_convolution_method(input, weights, biases, output, conv_info, weights_info, act_info, gpu_target, dilation))
+    switch(CLConvolutionLayer::get_convolution_method(input, weights, output, conv_info, weights_info, act_info, gpu_target, dilation))
     {
+        case ConvolutionMethod::WINOGRAD:
+        {
+            //Validate Winograd
+            CLWinogradConvolutionLayer::validate(input, weights, biases, output, conv_info);
+            break;
+        }
         case ConvolutionMethod::DIRECT:
         {
             // Validate direct convolution layer
@@ -101,19 +114,25 @@
     return Status{};
 }
 
-ConvolutionMethod CLConvolutionLayer::get_convolution_method(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const PadStrideInfo &conv_info,
+ConvolutionMethod CLConvolutionLayer::get_convolution_method(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *output, const PadStrideInfo &conv_info,
                                                              const WeightsInfo &weights_info, const ActivationLayerInfo &act_info, const GPUTarget gpu_target, const Size2D &dilation)
 {
-    ARM_COMPUTE_UNUSED(input);
-    ARM_COMPUTE_UNUSED(weights);
-    ARM_COMPUTE_UNUSED(biases);
+    ARM_COMPUTE_ERROR_ON_NULLPTR(input);
+    ARM_COMPUTE_ERROR_ON_NULLPTR(output);
+    ARM_COMPUTE_ERROR_ON_NULLPTR(weights);
     ARM_COMPUTE_UNUSED(output);
-    ARM_COMPUTE_UNUSED(conv_info);
     ARM_COMPUTE_UNUSED(weights_info);
     ARM_COMPUTE_UNUSED(gpu_target);
-    ARM_COMPUTE_UNUSED(dilation);
-    ARM_COMPUTE_UNUSED(act_info);
 
+    const size_t idx_w = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::WIDTH);
+    const size_t idx_h = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::HEIGHT);
+    const size_t idx_c = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::CHANNEL);
+
+    if((input->data_type() == DataType::F32) && (input->data_layout() == DataLayout::NCHW) && (input->dimension(idx_c) > 3) && (weights->dimension(idx_w) == 3) && (weights->dimension(idx_h) == 3)
+       && (weights->num_dimensions() <= 4) && (conv_info.stride().first == 1) && (conv_info.stride().second == 1) && (dilation == Size2D(1U, 1U)) && (!act_info.enabled()))
+    {
+        return ConvolutionMethod::WINOGRAD;
+    }
     return ConvolutionMethod::GEMM;
 }
 
diff --git a/src/runtime/CL/functions/CLWinogradConvolutionLayer.cpp b/src/runtime/CL/functions/CLWinogradConvolutionLayer.cpp
index 0aa7f8d..86ccdda 100644
--- a/src/runtime/CL/functions/CLWinogradConvolutionLayer.cpp
+++ b/src/runtime/CL/functions/CLWinogradConvolutionLayer.cpp
@@ -44,13 +44,18 @@
     const size_t idx_height = get_data_layout_dimension_index(input->info()->data_layout(), DataLayoutDimension::HEIGHT);
 
     // Input shape
-    const TensorShape input_shape = input->info()->tensor_shape();
+    const TensorShape  input_shape = input->info()->tensor_shape();
+    const unsigned int input_w     = input->info()->tensor_shape()[idx_width];
+    const unsigned int input_h     = input->info()->tensor_shape()[idx_height];
 
     // Kernel size
     const unsigned int kernel_w = weights->info()->tensor_shape()[idx_width];
     const unsigned int kernel_h = weights->info()->tensor_shape()[idx_height];
 
-    const WinogradInfo winograd_info = WinogradInfo(Size2D(2, 2),
+    //Winograd output tile
+    const Size2D output_tile = (Size2D(kernel_w, kernel_h) == Size2D(3U, 3U) && input_w <= 4 && input_h <= 4) ? Size2D(2U, 2U) : Size2D(4U, 4U);
+
+    const WinogradInfo winograd_info = WinogradInfo(output_tile,
                                                     Size2D(kernel_w, kernel_h),
                                                     Size2D(input_shape[idx_width], input_shape[idx_height]),
                                                     conv_info,
@@ -95,13 +100,18 @@
     const size_t idx_height = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::HEIGHT);
 
     // Input shape
-    const TensorShape input_shape = input->tensor_shape();
+    const TensorShape  input_shape = input->tensor_shape();
+    const unsigned int input_w     = input->tensor_shape()[idx_width];
+    const unsigned int input_h     = input->tensor_shape()[idx_height];
 
     // Kernel size
     const unsigned int kernel_w = weights->tensor_shape()[idx_width];
     const unsigned int kernel_h = weights->tensor_shape()[idx_height];
 
-    const WinogradInfo winograd_info = WinogradInfo(Size2D(2, 2),
+    //Winograd output tile
+    const Size2D output_tile = (Size2D(kernel_w, kernel_h) == Size2D(3U, 3U) && input_w <= 4 && input_h <= 4) ? Size2D(2U, 2U) : Size2D(4U, 4U);
+
+    const WinogradInfo winograd_info = WinogradInfo(output_tile,
                                                     Size2D(kernel_w, kernel_h),
                                                     Size2D(input_shape[idx_width], input_shape[idx_height]),
                                                     conv_info,
diff --git a/src/runtime/NEON/functions/NEConvolutionLayer.cpp b/src/runtime/NEON/functions/NEConvolutionLayer.cpp
index afc3545..b0603e9 100644
--- a/src/runtime/NEON/functions/NEConvolutionLayer.cpp
+++ b/src/runtime/NEON/functions/NEConvolutionLayer.cpp
@@ -109,10 +109,12 @@
     ARM_COMPUTE_ERROR_ON_NULLPTR(weights);
     ARM_COMPUTE_UNUSED(output);
     ARM_COMPUTE_UNUSED(weights_info);
-    ARM_COMPUTE_UNUSED(act_info);
 
-    if((input->data_type() == DataType::F32) && (weights->dimension(0) == 3) && (weights->dimension(1) == 3) && (weights->num_dimensions() <= 4) && (conv_info.stride().first == 1)
-       && (conv_info.stride().second == 1) && (dilation == Size2D(1U, 1U)))
+    const size_t idx_w = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::WIDTH);
+    const size_t idx_h = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::HEIGHT);
+
+    if((input->data_type() == DataType::F32) && (input->data_layout() == DataLayout::NCHW) && (weights->dimension(idx_w) == 3) && (weights->dimension(idx_h) == 3) && (weights->num_dimensions() <= 4)
+       && (conv_info.stride().first == 1) && (conv_info.stride().second == 1) && (dilation == Size2D(1U, 1U)) && (!act_info.enabled()))
     {
         //FIXME Until COMPMID-1041 is implemented Winograd is slower than GEMM on A53.
         if(Scheduler::get().cpu_info().get_cpu_model() != CPUModel::A53)