COMPMID-477 - Optimizing CLDirectConvolution 3x3 on OpenCL and added the auto configuration

Change-Id: I3c8384dcbc9d7786943134bb658dafb35356d90d
Reviewed-on: http://mpd-gerrit.cambridge.arm.com/83253
Reviewed-by: Steven Niu <steven.niu@arm.com>
Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com>
diff --git a/src/core/CL/cl_kernels/direct_convolution1x1.cl b/src/core/CL/cl_kernels/direct_convolution1x1.cl
index d161f80..ec0551b 100644
--- a/src/core/CL/cl_kernels/direct_convolution1x1.cl
+++ b/src/core/CL/cl_kernels/direct_convolution1x1.cl
@@ -113,10 +113,11 @@
  *
  * @note The data type must be passed at compile time using -DDATA_TYPE: e.g. -DDATA_TYPE=float
  * @note The data size must be passed at compile time using -DDATA_SIZE e.g. -DDATA_SIZE=32
- * @note The convolution stride x and stride y must be passed at compile time using -DSTRIDE_X and -DSTRIDE_Y: e.g. -DSTRIDE_X=1, _DSTRIDE_Y=1
+ * @note The convolution stride x must be passed at compile time using -DSTRIDE_X e.g. -DSTRIDE_X=1
+ * @note The third dimensions of the weights tensors must be passed at compile time using -DWEIGHTS_DEPTH
  * @note In case biases will be added to the convolution -DHAS_BIAS has to be passed to append the final matrix with 1 in each row.
  *
- * @param[in]  src_ptr                               Pointer to the source tensor. Supported data types: QS8/F16/F32
+ * @param[in]  src_ptr                               Pointer to the source tensor. Supported data types: F16/F32
  * @param[in]  src_stride_x                          Stride of the source tensor in X dimension (in bytes)
  * @param[in]  src_step_x                            src_stride_x * number of elements along X processed per workitem(in bytes)
  * @param[in]  src_stride_y                          Stride of the source tensor in Y dimension (in bytes)
@@ -144,9 +145,9 @@
  * @param[in]  biases_stride_x                       Stride of the biases tensor in X dimension (in bytes)
  * @param[in]  biases_step_x                         biases_stride_x * number of elements along X processed per workitem(in bytes)
  * @param[in]  biases_offset_first_element_in_bytes  The offset of the first element in the biases tensor
- * @param[in]  weights_stride_w                      Stride of the weights tensor in W dimension
- * @param[in]  filter_depth                          The depth size of the filter
+ * @param[in]  weights_stride_w                      Stride of the weights tensor in the 4th dimension
  */
+#if defined(DATA_TYPE) && defined(DATA_SIZE) && defined(STRIDE_X) && defined(WEIGHTS_DEPTH)
 __kernel void direct_convolution1x1(
     TENSOR3D_DECLARATION(src),
     TENSOR3D_DECLARATION(dst),
@@ -154,8 +155,7 @@
 #ifdef HAS_BIAS
     VECTOR_DECLARATION(biases),
 #endif /* defined(HAS_BIAS) */
-    unsigned int weights_stride_w,
-    unsigned int filter_depth)
+    unsigned int weights_stride_w)
 {
     Image    src     = CONVERT_TO_IMAGE_STRUCT(src);
     Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(weights);
@@ -172,7 +172,7 @@
 
     weights.ptr += z_index * weights_stride_w;
 
-    for(int d = 0; d < filter_depth; ++d)
+    for(int d = 0; d < WEIGHTS_DEPTH; ++d)
     {
         DATA_TYPE weight = *(__global DATA_TYPE *)weights.ptr;
         VEC_DATA_TYPE(DATA_TYPE, 8)
@@ -188,3 +188,4 @@
 
     vstore8(pixels, 0, (__global DATA_TYPE *)dst.ptr);
 }
+#endif // defined(DATA_TYPE) && defined(DATA_SIZE) && defined(STRIDE_X) && defined(WEIGHTS_DEPTH)
\ No newline at end of file
diff --git a/src/core/CL/cl_kernels/direct_convolution3x3.cl b/src/core/CL/cl_kernels/direct_convolution3x3.cl
index b5524e1..51886ef 100644
--- a/src/core/CL/cl_kernels/direct_convolution3x3.cl
+++ b/src/core/CL/cl_kernels/direct_convolution3x3.cl
@@ -23,124 +23,48 @@
  */
 #include "helpers.h"
 
-#if STRIDE_X == 2
-#define CONVOLVE1x3(left_pixel_position, left_coeff, middle_coeff, right_coeff) convolution1x3_stride2(left_pixel_position, left_coeff, middle_coeff, right_coeff)
-#elif STRIDE_X == 1 /* STRIDE_X == 1 */
-#define CONVOLVE1x3(left_pixel_position, left_coeff, middle_coeff, right_coeff) convolution1x3_stride1(left_pixel_position, left_coeff, middle_coeff, right_coeff)
+#if STRIDE_X == 1
+#define CONVOLUTION1x3(acc, src_row_ptr, weights_row_ptr) CONVOLUTION1x3_STRIDE1(acc, src_row_ptr, weights_row_ptr)
+#elif STRIDE_X == 2 /* STRIDE_X == 1 */
+#define CONVOLUTION1x3(acc, src_row_ptr, weights_row_ptr) CONVOLUTION1x3_STRIDE2(acc, src_row_ptr, weights_row_ptr)
 #else /* STRIDE_X not equals 1 or 2 */
 #error "STRIDE_X larger than 2 is not supported"
 #endif /* STRIDE_X == 2 */
 
-/** Compute a 1D horizontal convolution of size 3 with stride as 1.
- *
- * @param[in] left_pixel   Pointer to the left pixel.
- * @param[in] left_coeff   Weight of the left pixel
- * @param[in] middle_coeff Weight of the middle pixel
- * @param[in] right_coeff  Weight of the right pixel
- *
- * @return a convoluted values.
- */
-inline VEC_DATA_TYPE(DATA_TYPE, 8) convolution1x3_stride1(__global const DATA_TYPE *left_pixel,
-                                                          const DATA_TYPE left_coeff,
-                                                          const DATA_TYPE middle_coeff,
-                                                          const DATA_TYPE right_coeff)
-{
-    VEC_DATA_TYPE(DATA_TYPE, 16)
-    temp = vload16(0, left_pixel);
+#define CONVOLUTION1x3_STRIDE1(acc, src_row_ptr, weights_row_ptr)                                                               \
+    ({                                                                                                                          \
+        VEC_DATA_TYPE(DATA_TYPE, 4)                                                                                             \
+        weights_values0 = vload4(0, weights_row_ptr);                                                                           \
+        VEC_DATA_TYPE(DATA_TYPE, 8)                                                                                             \
+        src0 = vload8(0, src_row_ptr);                                                                                          \
+        VEC_DATA_TYPE(DATA_TYPE, 2)                                                                                             \
+        src1 = vload2(0, src_row_ptr + 8);                                                                                      \
+        \
+        acc += src0 * (VEC_DATA_TYPE(DATA_TYPE, 8))weights_values0.s0;                                                          \
+        acc += (VEC_DATA_TYPE(DATA_TYPE, 8))(src0.s1234, src0.s567, src1.s0) * (VEC_DATA_TYPE(DATA_TYPE, 8))weights_values0.s1; \
+        acc += (VEC_DATA_TYPE(DATA_TYPE, 8))(src0.s234, src0.s567, src1.s01) * (VEC_DATA_TYPE(DATA_TYPE, 8))weights_values0.s2; \
+    })
 
-    VEC_DATA_TYPE(DATA_TYPE, 8)
-    left = temp.s01234567;
-    VEC_DATA_TYPE(DATA_TYPE, 8)
-    middle = temp.s12345678;
-    VEC_DATA_TYPE(DATA_TYPE, 8)
-    right = temp.s23456789;
-
-    return left * (VEC_DATA_TYPE(DATA_TYPE, 8))left_coeff + middle * (VEC_DATA_TYPE(DATA_TYPE, 8))middle_coeff + right * (VEC_DATA_TYPE(DATA_TYPE, 8))right_coeff;
-}
-
-/** Compute a 1D horizontal convolution of size 3 with stride as 2.
- *
- * @param[in] left_pixel   Pointer to the left pixel.
- * @param[in] left_coeff   Weight of the left pixel
- * @param[in] middle_coeff Weight of the middle pixel
- * @param[in] right_coeff  Weight of the right pixel
- *
- * @return a convoluted values.
- */
-inline VEC_DATA_TYPE(DATA_TYPE, 8) convolution1x3_stride2(__global const DATA_TYPE *left_pixel,
-                                                          const DATA_TYPE left_coeff,
-                                                          const DATA_TYPE middle_coeff,
-                                                          const DATA_TYPE right_coeff)
-{
-    const int stride_size = 2;
-
-    VEC_DATA_TYPE(DATA_TYPE, 16)
-    temp1 = vload16(0, left_pixel);
-
-    VEC_DATA_TYPE(DATA_TYPE, 16)
-    temp2 = vload16(0, left_pixel + 8);
-
-    VEC_DATA_TYPE(DATA_TYPE, 8)
-    left = (VEC_DATA_TYPE(DATA_TYPE, 8))(temp1.s0246, temp2.s0246);
-
-    VEC_DATA_TYPE(DATA_TYPE, 8)
-    middle = (VEC_DATA_TYPE(DATA_TYPE, 8))(temp1.s1357, temp2.s1357);
-
-    VEC_DATA_TYPE(DATA_TYPE, 8)
-    right = (VEC_DATA_TYPE(DATA_TYPE, 8))(temp1.s2468, temp2.s2468);
-
-    return left * (VEC_DATA_TYPE(DATA_TYPE, 8))left_coeff + middle * (VEC_DATA_TYPE(DATA_TYPE, 8))middle_coeff + right * (VEC_DATA_TYPE(DATA_TYPE, 8))right_coeff;
-}
-
-/** Apply a 3x3 2D convolution matrix on the input and return the result.
- *
- * Convolution matrix layout:
- *
- * [ mat0, mat1, mat2 ]\n
- * [ mat3, mat4, mat5 ]\n
- * [ mat6, mat7, mat8 ]\n
- *
- * @param[in] src  A pointer to source Image structure
- * @param[in] mat0 Coefficient from the convolution matrix
- * @param[in] mat1 Coefficient from the convolution matrix
- * @param[in] mat2 Coefficient from the convolution matrix
- * @param[in] mat3 Coefficient from the convolution matrix
- * @param[in] mat4 Coefficient from the convolution matrix
- * @param[in] mat5 Coefficient from the convolution matrix
- * @param[in] mat6 Coefficient from the convolution matrix
- * @param[in] mat0 Coefficient from the convolution matrix
- * @param[in] mat7 Coefficient from the convolution matrix
- * @param[in] mat8 Coefficient from the convolution matrix
- *
- * @return convoluted values.
- */
-inline VEC_DATA_TYPE(DATA_TYPE, 8) convolution3x3(
-    Image          *src,
-    const DATA_TYPE mat0, const DATA_TYPE mat1, const DATA_TYPE mat2,
-    const DATA_TYPE mat3, const DATA_TYPE mat4, const DATA_TYPE mat5,
-    const DATA_TYPE mat6, const DATA_TYPE mat7, const DATA_TYPE mat8)
-{
-    // Output pixels
-    VEC_DATA_TYPE(DATA_TYPE, 8)
-    pixels;
-
-    // Row 0
-    pixels = CONVOLVE1x3((__global DATA_TYPE *)offset(src, 0, 0), mat0, mat1, mat2);
-    // Row
-    pixels += CONVOLVE1x3((__global DATA_TYPE *)offset(src, 0, 1), mat3, mat4, mat5);
-    // Row 2
-    pixels += CONVOLVE1x3((__global DATA_TYPE *)offset(src, 0, 2), mat6, mat7, mat8);
-
-    return pixels;
-}
+#define CONVOLUTION1x3_STRIDE2(acc, src_row_ptr, weights_row_ptr)                                                            \
+    ({                                                                                                                       \
+        VEC_DATA_TYPE(DATA_TYPE, 4)                                                                                          \
+        weights_values0 = vload4(0, weights_row_ptr);                                                                        \
+        VEC_DATA_TYPE(DATA_TYPE, 16)                                                                                         \
+        src0           = vload16(0, src_row_ptr);                                                                            \
+        DATA_TYPE src1 = *(src_row_ptr + 16);                                                                                \
+        \
+        acc += src0.even * (VEC_DATA_TYPE(DATA_TYPE, 8))weights_values0.s0;                                                  \
+        acc += (VEC_DATA_TYPE(DATA_TYPE, 8))(src0.s1357, src0.s9BDF) * (VEC_DATA_TYPE(DATA_TYPE, 8))weights_values0.s1;      \
+        acc += (VEC_DATA_TYPE(DATA_TYPE, 8))(src0.s2468, src0.sACE, src1) * (VEC_DATA_TYPE(DATA_TYPE, 8))weights_values0.s2; \
+    })
 
 /** This kernel performs a direct convolution to convolve the low three dimensions.
  *
  * @note The data type must be passed at compile time using -DDATA_TYPE: e.g. -DDATA_TYPE=float
- * @note The convolution stride x and stride y must be passed at compile time using -DSTRIDE_X and -DSTRIDE_Y: e.g. -DSTRIDE_X=1, _DSTRIDE_Y=1
+ * @note The third dimensions of the weights tensors must be passed at compile time using -DWEIGHTS_DEPTH
  * @note In case biases will be added to the convolution -DHAS_BIAS has to be passed to append the final matrix with 1 in each row.
  *
- * @param[in]  src_ptr                               Pointer to the source tensor. Supported data types: QS8/F16/F32
+ * @param[in]  src_ptr                               Pointer to the source tensor. Supported data types: F16/F32
  * @param[in]  src_stride_x                          Stride of the source tensor in X dimension (in bytes)
  * @param[in]  src_step_x                            src_stride_x * number of elements along X processed per workitem(in bytes)
  * @param[in]  src_stride_y                          Stride of the source tensor in Y dimension (in bytes)
@@ -168,9 +92,9 @@
  * @param[in]  biases_stride_x                       Stride of the biases tensor in X dimension (in bytes)
  * @param[in]  biases_step_x                         biases_stride_x * number of elements along X processed per workitem(in bytes)
  * @param[in]  biases_offset_first_element_in_bytes  The offset of the first element in the biases tensor
- * @param[in]  weights_stride_w                      Stride of the weights tensor in W dimension
- * @param[in]  filter_depth                          The depth size of the filter
+ * @param[in]  weights_stride_w                      Stride of the weights tensor in the 4th dimension
  */
+#if defined(DATA_TYPE) && defined(STRIDE_X) && defined(WEIGHTS_DEPTH)
 __kernel void direct_convolution3x3(
     TENSOR3D_DECLARATION(src),
     TENSOR3D_DECLARATION(dst),
@@ -178,50 +102,37 @@
 #ifdef HAS_BIAS
     VECTOR_DECLARATION(biases),
 #endif /* defined(HAS_BIAS) */
-    unsigned int weights_stride_w,
-    unsigned int filter_depth)
+    unsigned int weights_stride_w)
 {
     Image    src     = CONVERT_TO_IMAGE_STRUCT(src);
     Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(weights);
     Tensor3D dst     = CONVERT_TO_TENSOR3D_STRUCT(dst);
 
-#ifdef HAS_BIAS
-    Vector biases = CONVERT_TO_VECTOR_STRUCT_NO_STEP(biases);
-#endif /* defined(HAS_BIAS) */
-
     VEC_DATA_TYPE(DATA_TYPE, 8)
-    pixels = 0;
+    pixels0 = 0;
 
-    const uint z_index = get_global_id(2);
+    __global uchar *weights_addr = (__global uchar *)tensor3D_offset(&weights, 0, 0, 0);
+    __global uchar *src_addr     = (__global uchar *)offset(&src, 0, 0);
 
-    weights.ptr += z_index * weights_stride_w;
+    const int kernel_index = get_global_id(2);
+    weights_addr += kernel_index * weights_stride_w;
 
-    for(int d = 0; d < filter_depth; ++d)
+    for(int d = 0; d < WEIGHTS_DEPTH; ++d)
     {
-        VEC_DATA_TYPE(DATA_TYPE, 4)
-        weights_row1 = vload4(0, (__global DATA_TYPE *)tensor3D_offset(&weights, 0, 0, 0));
-        VEC_DATA_TYPE(DATA_TYPE, 4)
-        weights_row2 = vload4(0, (__global DATA_TYPE *)tensor3D_offset(&weights, 0, 1, 0));
-        VEC_DATA_TYPE(DATA_TYPE, 4)
-        weights_row3 = vload4(0, (__global DATA_TYPE *)tensor3D_offset(&weights, 0, 2, 0));
+        CONVOLUTION1x3(pixels0, (__global DATA_TYPE *)(src_addr + 0 * src_stride_y), (__global DATA_TYPE *)(weights_addr + 0 * weights_stride_y));
+        CONVOLUTION1x3(pixels0, (__global DATA_TYPE *)(src_addr + 1 * src_stride_y), (__global DATA_TYPE *)(weights_addr + 1 * weights_stride_y));
+        CONVOLUTION1x3(pixels0, (__global DATA_TYPE *)(src_addr + 2 * src_stride_y), (__global DATA_TYPE *)(weights_addr + 2 * weights_stride_y));
 
-        pixels += convolution3x3(&src, weights_row1.s0,
-                                 weights_row1.s1,
-                                 weights_row1.s2,
-                                 weights_row2.s0,
-                                 weights_row2.s1,
-                                 weights_row2.s2,
-                                 weights_row3.s0,
-                                 weights_row3.s1,
-                                 weights_row3.s2);
-
-        src.ptr += src_stride_z;
-        weights.ptr += weights_stride_z;
+        src_addr += src_stride_z;
+        weights_addr += weights_stride_z;
     }
 
 #ifdef HAS_BIAS
-    pixels += (VEC_DATA_TYPE(DATA_TYPE, 8)) * ((__global DATA_TYPE *)(vector_offset(&biases, z_index)));
+    Vector biases = CONVERT_TO_VECTOR_STRUCT_NO_STEP(biases);
+
+    pixels0 += (VEC_DATA_TYPE(DATA_TYPE, 8)) * ((__global DATA_TYPE *)(vector_offset(&biases, kernel_index)));
 #endif /* defined(HAS_BIAS) */
 
-    vstore8(pixels, 0, (__global DATA_TYPE *)dst.ptr);
+    vstore8(pixels0, 0, (__global DATA_TYPE *)dst.ptr);
 }
+#endif // defined(DATA_TYPE) && defined(STRIDE_X) && defined(WEIGHTS_DEPTH)
\ No newline at end of file
diff --git a/src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp b/src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp
index 1f481de..5f14d16 100644
--- a/src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp
+++ b/src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp
@@ -32,6 +32,7 @@
 #include "arm_compute/core/IAccessWindow.h"
 #include "arm_compute/core/ITensor.h"
 #include "arm_compute/core/Types.h"
+#include "arm_compute/core/Utils.h"
 #include "arm_compute/core/Validate.h"
 #include "support/ToolchainSupport.h"
 
@@ -49,20 +50,17 @@
 
 void CLDirectConvolutionLayerKernel::configure(const ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output, const PadStrideInfo &conv_info)
 {
-    const unsigned int kernel_size = weights->info()->dimension(0);
-    ARM_COMPUTE_ERROR_ON_MSG(kernel_size != 1 && kernel_size != 3,
-                             "Kernel sizes other than 1x1 or 3x3 are not supported");
     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
-    ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights, output);
+    ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights);
+    ARM_COMPUTE_ERROR_ON_MSG(weights->info()->dimension(0) != weights->info()->dimension(1),
+                             "Only kernel sizes 1x1 and 3x3 are supported");
+    ARM_COMPUTE_ERROR_ON_MSG(weights->info()->dimension(0) != 1 && weights->info()->dimension(0) != 3,
+                             "Only kernel sizes 1x1 and 3x3 are supported");
     ARM_COMPUTE_ERROR_ON(weights->info()->dimension(2) != input->info()->dimension(2));
     ARM_COMPUTE_ERROR_ON(weights->info()->dimension(0) != weights->info()->dimension(1));
     ARM_COMPUTE_ERROR_ON(weights->info()->num_dimensions() > 4);
-    ARM_COMPUTE_ERROR_ON_MSG(weights->info()->dimension(0) == 1 && (std::get<0>(conv_info.pad()) || std::get<1>(conv_info.pad())),
-                             "Pad > 0 not supported for 1x1 weights");
-    ARM_COMPUTE_ERROR_ON_MSG(weights->info()->dimension(0) == 3 && (std::get<0>(conv_info.pad()) > 1 || std::get<1>(conv_info.pad()) > 1),
-                             "Pad > 1 not supported for 3x3 weights");
-    ARM_COMPUTE_ERROR_ON_MSG(std::get<0>(conv_info.stride()) > 3, "Strides larger than 3 not supported.");
-    ARM_COMPUTE_ERROR_ON_MSG((kernel_size == 3 && std::get<0>(conv_info.stride()) > 2), "Strides larger than 2 not supported in 3x3 direct convolution!");
+    ARM_COMPUTE_ERROR_ON_MSG((weights->info()->dimension(0) == 1) && std::get<0>(conv_info.stride()) > 3, "Strides larger than 3 not supported for 1x1 convolution.");
+    ARM_COMPUTE_ERROR_ON_MSG((weights->info()->dimension(0) == 3) && std::get<0>(conv_info.stride()) > 2, "Strides larger than 2 not supported for 3x3 convolution.");
 
     if(biases != nullptr)
     {
@@ -71,10 +69,29 @@
         ARM_COMPUTE_ERROR_ON(biases->info()->num_dimensions() > 1);
     }
 
+    const unsigned int kernel_size = weights->info()->dimension(0);
+
+    // Get convolved dimensions
+    unsigned int output_width  = 0;
+    unsigned int output_height = 0;
+    std::tie(output_width, output_height) = scaled_dimensions(input->info()->dimension(0), input->info()->dimension(1), kernel_size, kernel_size, conv_info);
+
+    TensorShape output_shape = input->info()->tensor_shape();
+    output_shape.set(0, output_width);
+    output_shape.set(1, output_height);
+    output_shape.set(2, weights->info()->dimension(3));
+
+    // Output auto inizialitation if not yet initialized
+    auto_init_if_empty(*output->info(), output_shape, 1, input->info()->data_type(), input->info()->fixed_point_position());
+
+    ARM_COMPUTE_ERROR_ON_MISMATCHING_DIMENSIONS(output->info()->tensor_shape(), output_shape);
+    ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
+    ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input, output);
+
     _conv_stride_x = std::get<0>(conv_info.stride());
     _conv_stride_y = std::get<1>(conv_info.stride());
-    _conv_pad_x    = std::get<0>(conv_info.pad());
-    _conv_pad_y    = std::get<1>(conv_info.pad());
+    _conv_pad_x    = std::min(std::get<0>(conv_info.pad()), kernel_size / 2);
+    _conv_pad_y    = std::min(std::get<1>(conv_info.pad()), kernel_size / 2);
 
     _input       = input;
     _weights     = weights;
@@ -86,9 +103,9 @@
     std::set<std::string> options;
     kernel_name << "direct_convolution" << kernel_size << "x" << kernel_size;
 
-    options.insert("-DDATA_TYPE=" + get_cl_type_from_data_type(input->info()->data_type()));
-    options.insert("-DDATA_SIZE=" + get_data_size_from_data_type(input->info()->data_type()));
-
+    options.emplace("-DDATA_TYPE=" + get_cl_type_from_data_type(input->info()->data_type()));
+    options.emplace("-DDATA_SIZE=" + get_data_size_from_data_type(input->info()->data_type()));
+    options.emplace("-DWEIGHTS_DEPTH=" + support::cpp11::to_string(_weights->info()->dimension(2)));
     options.emplace("-DSTRIDE_X=" + support::cpp11::to_string(_conv_stride_x));
 
     if(_biases != nullptr)
@@ -98,33 +115,27 @@
 
     _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel(kernel_name.str(), options));
 
-    unsigned int idx = (_biases == nullptr) ? 3 * num_arguments_per_3D_tensor() : (num_arguments_per_1D_tensor() + 3 * num_arguments_per_3D_tensor());
-    _kernel.setArg<cl_uint>(idx++, _weights->info()->strides_in_bytes()[3]); // weights_stride_w
-    _kernel.setArg<cl_uint>(idx++, _weights->info()->dimension(2));          // filter depth
-
-    // Using this local workgroup size gives better performance over others that have been tried.
-    _lws_hint = cl::NDRange(4, 1, 8);
-
     // Configure kernel window
     Window win = calculate_max_window(*output->info());
 
-    unsigned int num_elems_read_per_iteration    = 16 * _conv_stride_x;
-    unsigned int num_elems_written_per_iteration = 8;
+    bool is_kernel3x3_stride2 = ((kernel_size == 3) && (_conv_stride_x == 2));
+
+    const unsigned int num_elems_read_per_iteration_x    = 8 + 2 * (kernel_size / 2) + (is_kernel3x3_stride2 ? 7 : 0);
+    const unsigned int num_elems_read_per_iteration_y    = kernel_size;
+    const unsigned int num_elems_written_per_iteration_x = 8;
+    const unsigned int num_elems_written_per_iteration_y = 1;
 
     // Calculate right and bottom border
-    const int input_width    = input->info()->dimension(0);
-    const int input_height   = input->info()->dimension(1);
-    const int upper_bound_w  = ceil_to_multiple(((output->info()->dimension(0) - 1) * _conv_stride_x + kernel_size), num_elems_read_per_iteration) - _conv_pad_x - input_width;
-    const int upper_bound_h  = ((output->info()->dimension(1) - 1) * _conv_stride_y - _conv_pad_y + kernel_size) - input_height;
-    const int padding_right  = std::max(upper_bound_w, static_cast<int>(kernel_size));
-    const int padding_bottom = std::max(upper_bound_h, static_cast<int>(kernel_size));
+    const int input_width  = input->info()->dimension(0) - kernel_size / 2 + _conv_pad_x;
+    const int input_height = input->info()->dimension(1) - kernel_size / 2 + _conv_pad_y;
 
     // Create window and update padding
-    win = calculate_max_window(*output->info(), Steps(num_elems_written_per_iteration));
-    AccessWindowStatic input_access(input->info(), -_conv_pad_x, -_conv_pad_y, input_width + padding_right, input_height + padding_bottom);
+    win = calculate_max_window(*output->info(), Steps(num_elems_written_per_iteration_x, num_elems_written_per_iteration_y));
 
-    AccessWindowStatic     weights_access(weights->info(), 0, 0, kernel_size, kernel_size);
-    AccessWindowHorizontal output_access(output->info(), 0, num_elems_written_per_iteration);
+    AccessWindowStatic    input_access(input->info(), -_conv_pad_x, -_conv_pad_y, input_width + num_elems_read_per_iteration_x, input_height + num_elems_read_per_iteration_y);
+    AccessWindowStatic    weights_access(weights->info(), 0, 0, kernel_size, kernel_size);
+    AccessWindowRectangle output_access(output->info(), 0, 0, num_elems_written_per_iteration_x, num_elems_written_per_iteration_y);
+
     update_window_and_padding(win, input_access, weights_access, output_access);
 
     output_access.set_valid_region(win, ValidRegion(Coordinates(), output->info()->tensor_shape()));
@@ -158,6 +169,8 @@
         add_1D_tensor_argument(idx1, _biases, slice_biases);
     }
 
+    _kernel.setArg(idx1++, static_cast<unsigned int>(_weights->info()->strides_in_bytes()[3]));
+
     do
     {
         unsigned int idx = 0;
diff --git a/src/core/NEON/kernels/NEDirectConvolutionLayerKernel.cpp b/src/core/NEON/kernels/NEDirectConvolutionLayerKernel.cpp
index 43292d1..3a102ed 100644
--- a/src/core/NEON/kernels/NEDirectConvolutionLayerKernel.cpp
+++ b/src/core/NEON/kernels/NEDirectConvolutionLayerKernel.cpp
@@ -30,6 +30,7 @@
 #include "arm_compute/core/ITensor.h"
 #include "arm_compute/core/NEON/NEFixedPoint.h"
 #include "arm_compute/core/Types.h"
+#include "arm_compute/core/Utils.h"
 #include "arm_compute/core/Validate.h"
 
 #include <algorithm>
@@ -952,13 +953,15 @@
 void NEDirectConvolutionLayerKernel::configure(const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
 {
     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::F16, DataType::QS16, DataType::F32);
-    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(weights, 1, DataType::QS8, DataType::F16, DataType::QS16, DataType::F32);
-    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QS16, DataType::F16, DataType::QS32, DataType::F32);
+    ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights);
     ARM_COMPUTE_ERROR_ON_MSG(weights->info()->dimension(0) == 1 && (std::get<0>(conv_info.pad()) || std::get<1>(conv_info.pad())),
                              "Pad > 0 not supported for 1x1 weights");
     ARM_COMPUTE_ERROR_ON_MSG(weights->info()->dimension(0) == 3 && (std::get<0>(conv_info.pad()) > 1 || std::get<1>(conv_info.pad()) > 1),
                              "Pad > 1 not supported for 3x3 weights");
     ARM_COMPUTE_ERROR_ON_MSG(std::get<0>(conv_info.stride()) > 3, "Strides larger than 3 not supported.");
+    ARM_COMPUTE_ERROR_ON(weights->info()->dimension(2) != input->info()->dimension(2));
+    ARM_COMPUTE_ERROR_ON(weights->info()->dimension(0) != weights->info()->dimension(1));
+    ARM_COMPUTE_ERROR_ON(weights->info()->num_dimensions() > 4);
 
     const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
     const unsigned int conv_pad_x    = std::get<0>(conv_info.pad());
@@ -971,6 +974,32 @@
     _kernel_size = weights->info()->dimension(0);
     _border_size = BorderSize(conv_pad_y, conv_pad_x);
 
+    const unsigned int kernel_size = weights->info()->dimension(0);
+
+    // Get convolved dimensions
+    unsigned int output_width  = 0;
+    unsigned int output_height = 0;
+    std::tie(output_width, output_height) = scaled_dimensions(input->info()->dimension(0), input->info()->dimension(1), kernel_size, kernel_size, conv_info);
+
+    TensorShape output_shape = input->info()->tensor_shape();
+    output_shape.set(0, output_width);
+    output_shape.set(1, output_height);
+    output_shape.set(2, weights->info()->dimension(3));
+
+    DataType data_type = input->info()->data_type();
+
+    if(is_data_type_fixed_point(data_type))
+    {
+        // Promote data type in case of fixed point
+        data_type = ((data_type == DataType::QS8) ? DataType::QS16 : DataType::QS32);
+    }
+
+    // Output auto inizialitation if not yet initialized
+    auto_init_if_empty(*output->info(), output_shape, 1, data_type, input->info()->fixed_point_position());
+
+    ARM_COMPUTE_ERROR_ON_MISMATCHING_DIMENSIONS(output->info()->tensor_shape(), output_shape);
+    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, output->info()->data_type());
+
     Window win = calculate_max_window(*output->info());
 
     switch(_kernel_size)
diff --git a/src/runtime/NEON/functions/NEDirectConvolutionLayer.cpp b/src/runtime/NEON/functions/NEDirectConvolutionLayer.cpp
index 0380e8c..2e3a683 100644
--- a/src/runtime/NEON/functions/NEDirectConvolutionLayer.cpp
+++ b/src/runtime/NEON/functions/NEDirectConvolutionLayer.cpp
@@ -40,8 +40,6 @@
 
 void NEDirectConvolutionLayer::configure(ITensor *input, const ITensor *weights, const ITensor *bias, ITensor *output, const PadStrideInfo &conv_info)
 {
-    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
-
     // Free accumulator
     if(_accumulator.buffer() != nullptr)
     {