COMPMID-3144: Remove padding from NEDirectConvolutionLayerKernel

Change-Id: I22b907eebfbe037e6e1c7bf604172f4709a9cbed
Signed-off-by: Manuel Bottini <manuel.bottini@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4082
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
diff --git a/src/core/NEON/kernels/NEDirectConvolutionLayerKernel.cpp b/src/core/NEON/kernels/NEDirectConvolutionLayerKernel.cpp
index ac1d6ae..c22fa6a 100644
--- a/src/core/NEON/kernels/NEDirectConvolutionLayerKernel.cpp
+++ b/src/core/NEON/kernels/NEDirectConvolutionLayerKernel.cpp
@@ -40,9 +40,10 @@
 
 #include <algorithm>
 
-using namespace arm_compute;
 using namespace arm_compute::detail;
 
+namespace arm_compute
+{
 namespace
 {
 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
@@ -472,117 +473,6 @@
     return out;
 }
 
-template <typename T1>
-class convolver_nhwc
-{
-public:
-    static void convolve(const Window &window, uint32_t kernel_size, unsigned int num_elems_read_per_iteration,
-                         const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
-    {
-        const int          input_width     = input->info()->dimension(0);
-        const int          input_depth     = input->info()->dimension(2);
-        const int          input_stride_x  = input->info()->strides_in_bytes().x();
-        const int          input_stride_y  = input->info()->strides_in_bytes().y();
-        const int          input_stride_z  = input->info()->strides_in_bytes().z();
-        const int          output_stride_x = output->info()->strides_in_bytes().x();
-        const int          kernel_stride_x = weights->info()->strides_in_bytes().x();
-        const int          kernel_stride_y = weights->info()->strides_in_bytes().y();
-        const int          kernel_stride_z = weights->info()->strides_in_bytes().z();
-        const int          conv_pad_top    = conv_info.pad_top();
-        const unsigned int conv_stride_x   = std::get<0>(conv_info.stride());
-        const unsigned int conv_stride_y   = std::get<1>(conv_info.stride());
-        const T1           zero            = 0;
-
-        // Setup input window for the input iterator
-        Window window_in = window;
-        window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
-        window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
-        window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
-
-        // Setup input window for the output iterator
-        Window window_out = window;
-        window_out.set(Window::DimX, Window::Dimension(0, 1, 1));
-
-        // Setup input window for the weights iterator
-        Window window_k = calculate_max_window(*weights->info(), Steps());
-        window_k.set(Window::DimX, Window::Dimension(0, 1, 1));
-        window_k.set(Window::DimY, Window::Dimension(0, 1, 1));
-        window_k.set(Window::DimZ, Window::Dimension(0, 1, 1));
-        window_k.set(3, Window::Dimension(0, weights->info()->dimension(3), 1));
-
-        Iterator in(input, window_in);
-        Iterator out(output, window_out);
-        Iterator k(weights, window_k);
-
-        execute_window_loop(window_k, [&](const Coordinates & id_k)
-        {
-            execute_window_loop(window_out, [&](const Coordinates & id)
-            {
-                const auto in_y = static_cast<int>(id.y() * conv_stride_x - conv_info.pad_left());
-                const auto in_z = static_cast<int>(id.z() * conv_stride_y - conv_pad_top);
-
-                const uint8_t *in_ptr  = in.ptr() + in_y * input_stride_y + in_z * input_stride_z;
-                uint8_t       *out_ptr = out.ptr() + id_k[3] * output_stride_x;
-
-                T1 out_val = 0;
-
-                auto in_addr_base0 = in_ptr;
-                auto we_addr_base0 = k.ptr();
-
-                for(uint32_t z = 0; z < kernel_size; ++z, in_addr_base0 += input_stride_z, we_addr_base0 += kernel_stride_z)
-                {
-                    const int in_z = id.z() * conv_stride_y + z - conv_pad_top;
-
-                    if(in_z >= 0 && in_z < input_depth) // If false, pad top/bottom
-                    {
-                        auto in_addr_base1 = in_addr_base0;
-                        auto we_addr_base1 = we_addr_base0;
-
-                        for(uint32_t y = 0; y < kernel_size; ++y, in_addr_base1 += input_stride_y, we_addr_base1 += kernel_stride_y)
-                        {
-                            auto out_values = internal_vdupq_n(zero);
-
-                            int x           = 0;
-                            int no_leftover = input_width - num_elems_read_per_iteration;
-
-                            for(; x < no_leftover; x += num_elems_read_per_iteration)
-                            {
-                                const auto in_addr   = reinterpret_cast<const T1 *>(in_addr_base1 + x * input_stride_x);
-                                const auto in_values = internal_vld1q<1>(in_addr);
-
-                                const auto we_addr   = reinterpret_cast<const T1 *>(we_addr_base1 + x * kernel_stride_x);
-                                const auto we_values = internal_vld1q<1>(we_addr);
-
-                                out_values = internal_vmlal(out_values, in_values, we_values);
-                            }
-
-                            auto carry_addition = wrapper::vpadd(wrapper::vgethigh(out_values), wrapper::vgetlow(out_values));
-                            carry_addition      = wrapper::vpadd(carry_addition, carry_addition);
-                            out_val += wrapper::vgetlane(carry_addition, 0);
-
-                            // Leftover
-                            for(; x < input_width; ++x)
-                            {
-                                const auto in_addr  = reinterpret_cast<const T1 *>(in_addr_base1 + x * input_stride_x);
-                                const auto in_value = *(in_addr);
-
-                                const auto we_addr  = reinterpret_cast<const T1 *>(we_addr_base1 + x * kernel_stride_x);
-                                const auto we_value = *(we_addr);
-
-                                out_val += in_value * we_value;
-                            }
-                        }
-                    }
-                }
-
-                *(reinterpret_cast<T1 *>(out_ptr)) = out_val;
-            },
-            in, out);
-        },
-        k);
-    }
-};
-
 template <typename T1, typename T2, unsigned int stridex>
 class convolver_3x3
 {
@@ -815,76 +705,6 @@
     }
 };
 
-inline void convolve_row1x9_nhwc(const float *row_ptr, const float *weights_ptr, size_t src_stride_y, size_t weights_stride_y,
-                                 float32x4_t &acc0, float32x4_t &acc1, float32x4_t &acc2, float32x4_t &acc3)
-{
-    // Load 4 channels for each of the 12 inputs values along the same X spatial dimension
-    const float32x4_t src0  = wrapper::vloadq(row_ptr);
-    const float32x4_t src1  = wrapper::vloadq(row_ptr + 1 * src_stride_y);
-    const float32x4_t src2  = wrapper::vloadq(row_ptr + 2 * src_stride_y);
-    const float32x4_t src3  = wrapper::vloadq(row_ptr + 3 * src_stride_y);
-    const float32x4_t src4  = wrapper::vloadq(row_ptr + 4 * src_stride_y);
-    const float32x4_t src5  = wrapper::vloadq(row_ptr + 5 * src_stride_y);
-    const float32x4_t src6  = wrapper::vloadq(row_ptr + 6 * src_stride_y);
-    const float32x4_t src7  = wrapper::vloadq(row_ptr + 7 * src_stride_y);
-    const float32x4_t src8  = wrapper::vloadq(row_ptr + 8 * src_stride_y);
-    const float32x4_t src9  = wrapper::vloadq(row_ptr + 9 * src_stride_y);
-    const float32x4_t src10 = wrapper::vloadq(row_ptr + 10 * src_stride_y);
-    const float32x4_t src11 = wrapper::vloadq(row_ptr + 11 * src_stride_y);
-
-    // Load 4 channels for each of the 9 weights values along the same X spatial dimension
-    const float32x4_t w0 = wrapper::vloadq(weights_ptr);
-    const float32x4_t w1 = wrapper::vloadq(weights_ptr + 1 * weights_stride_y);
-    const float32x4_t w2 = wrapper::vloadq(weights_ptr + 2 * weights_stride_y);
-    const float32x4_t w3 = wrapper::vloadq(weights_ptr + 3 * weights_stride_y);
-    const float32x4_t w4 = wrapper::vloadq(weights_ptr + 4 * weights_stride_y);
-    const float32x4_t w5 = wrapper::vloadq(weights_ptr + 5 * weights_stride_y);
-    const float32x4_t w6 = wrapper::vloadq(weights_ptr + 6 * weights_stride_y);
-    const float32x4_t w7 = wrapper::vloadq(weights_ptr + 7 * weights_stride_y);
-    const float32x4_t w8 = wrapper::vloadq(weights_ptr + 8 * weights_stride_y);
-
-    // Store 4 channels for each of the 4 output values along the same X spatial dimension
-    acc0 = wrapper::vmla(acc0, w0, src0);
-    acc0 = wrapper::vmla(acc0, w1, src1);
-    acc0 = wrapper::vmla(acc0, w2, src2);
-    acc0 = wrapper::vmla(acc0, w3, src3);
-    acc0 = wrapper::vmla(acc0, w4, src4);
-    acc0 = wrapper::vmla(acc0, w5, src5);
-    acc0 = wrapper::vmla(acc0, w6, src6);
-    acc0 = wrapper::vmla(acc0, w7, src7);
-    acc0 = wrapper::vmla(acc0, w8, src8);
-
-    acc1 = wrapper::vmla(acc1, w0, src1);
-    acc1 = wrapper::vmla(acc1, w1, src2);
-    acc1 = wrapper::vmla(acc1, w2, src3);
-    acc1 = wrapper::vmla(acc1, w3, src4);
-    acc1 = wrapper::vmla(acc1, w4, src5);
-    acc1 = wrapper::vmla(acc1, w5, src6);
-    acc1 = wrapper::vmla(acc1, w6, src7);
-    acc1 = wrapper::vmla(acc1, w7, src8);
-    acc1 = wrapper::vmla(acc1, w8, src9);
-
-    acc2 = wrapper::vmla(acc2, w0, src2);
-    acc2 = wrapper::vmla(acc2, w1, src3);
-    acc2 = wrapper::vmla(acc2, w2, src4);
-    acc2 = wrapper::vmla(acc2, w3, src5);
-    acc2 = wrapper::vmla(acc2, w4, src6);
-    acc2 = wrapper::vmla(acc2, w5, src7);
-    acc2 = wrapper::vmla(acc2, w6, src8);
-    acc2 = wrapper::vmla(acc2, w7, src9);
-    acc2 = wrapper::vmla(acc2, w8, src10);
-
-    acc3 = wrapper::vmla(acc3, w0, src3);
-    acc3 = wrapper::vmla(acc3, w1, src4);
-    acc3 = wrapper::vmla(acc3, w2, src5);
-    acc3 = wrapper::vmla(acc3, w3, src6);
-    acc3 = wrapper::vmla(acc3, w4, src7);
-    acc3 = wrapper::vmla(acc3, w5, src8);
-    acc3 = wrapper::vmla(acc3, w6, src9);
-    acc3 = wrapper::vmla(acc3, w7, src10);
-    acc3 = wrapper::vmla(acc3, w8, src11);
-}
-
 float vreduce(const float32x4_t &v)
 {
     auto v0    = wrapper::vgethigh(v);
@@ -896,175 +716,6 @@
     return a + b;
 }
 
-template <typename V>
-class convolver_9x9_nhwc
-{
-public:
-    static void convolve(const Window &window, unsigned int num_elems_read_per_iteration,
-                         const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
-    {
-        // Declare useful types
-        using vector_type = typename V::type;
-        using scalar_type = typename V::scalar_type;
-        using tag_type    = typename V::tag_type;
-
-        // Scalar quantities
-        const int          element_size    = input->info()->element_size();
-        const int          input_width     = input->info()->dimension(0);
-        const int          input_depth     = input->info()->dimension(2);
-        const int          input_stride_y  = input->info()->strides_in_bytes().y() / element_size;
-        const int          input_stride_z  = input->info()->strides_in_bytes().z() / element_size;
-        const int          input_stride_w  = input->info()->strides_in_bytes()[3];
-        const int          output_stride_x = output->info()->strides_in_bytes().x();
-        const int          output_stride_y = output->info()->strides_in_bytes().y();
-        const int          kernel_stride_y = weights->info()->strides_in_bytes().y() / element_size;
-        const int          kernel_stride_z = weights->info()->strides_in_bytes().z() / element_size;
-        const unsigned int conv_stride_y   = std::get<1>(conv_info.stride());
-        const unsigned int conv_pad_top    = conv_info.pad_top();
-        const unsigned int conv_pad_left   = conv_info.pad_left();
-
-        // Setup input window for the input iterator
-        Window window_in = window;
-        window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
-        window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
-        window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
-
-        // Setup input window for the output iterator
-        Window window_out = window;
-        window_out.set(Window::DimX, Window::Dimension(0, 1, 1));
-
-        // Setup input window for the weights iterator
-        Window window_k = calculate_max_window(*weights->info(), Steps());
-        window_k.set(Window::DimX, Window::Dimension(0, 1, 1));
-        window_k.set(Window::DimY, Window::Dimension(0, 1, 1));
-        window_k.set(Window::DimZ, Window::Dimension(0, 1, 1));
-        window_k.set(3, Window::Dimension(0, weights->info()->dimension(3), 1));
-
-        Iterator in(input, window_in);
-        Iterator out(output, window_out);
-        Iterator k(weights, window_k);
-
-        // Calculate the max_offset.
-        // max_offset is the offset for the last NOT valid value in the Z dimension (spatial dimension Y for NHWC)
-        //  |******************|
-        //  |     pad_top      |
-        //  |******************|
-        //  |                  |
-        //  |      plane0      |
-        //  |      batch0      |
-        //  |__________________|
-        //  |******************|       Batch 0
-        //  |    pad_bottom    |
-        //  |     pad_top      |
-        //  |******************|
-        //  |                  |
-        //  |      plane1      |
-        //  |      batch0      |
-        //  |__________________|-----> max_offset
-        //  |******************|
-        //  |    pad_bottom    |
-        //  |     pad_top      |
-        //  |******************|
-        //  |                  |
-        //  |      plane0      |
-        //  |      batch1      |
-        //  |__________________|
-        //  |******************|       Batch 1
-        //  |    pad_bottom    |
-        //  |     pad_top      |
-        //  |******************|
-        //  |                  |
-        //  |      plane1      |
-        //  |      batch1      |
-        //  |__________________|
-        //  |     pad_bottom   |
-        //  |******************|
-        const int64_t max_offset = input_stride_z * input_depth - (input->info()->padding().bottom + input->info()->padding().top) * input_stride_y;
-        execute_window_loop(window_k, [&](const Coordinates & id_k) // loop on the batch size
-        {
-
-            execute_window_loop(window_out, [&](const Coordinates & id)
-            {
-                const auto y_offset = int(id.y() - conv_pad_left) * input_stride_y;
-
-                // Buffer pointers
-                const scalar_type *in_ptr      = reinterpret_cast<scalar_type *>(input->buffer() + input->info()->offset_first_element_in_bytes() + id[3] * input_stride_w);
-                const scalar_type *weights_ptr = reinterpret_cast<scalar_type *>(k.ptr());
-                uint8_t           *out_ptr     = out.ptr() + id_k[3] * output_stride_x;
-
-                // Output elements
-                vector_type out0 = wrapper::vdup_n(scalar_type(0), tag_type());
-                vector_type out1 = wrapper::vdup_n(scalar_type(0), tag_type());
-                vector_type out2 = wrapper::vdup_n(scalar_type(0), tag_type());
-                vector_type out3 = wrapper::vdup_n(scalar_type(0), tag_type());
-
-                // Reduce along the feature maps
-                for(int x = 0; x < input_width; x += num_elems_read_per_iteration)
-                {
-                    // z == 0
-                    auto in_z   = static_cast<int64_t>(id.z() * conv_stride_y - conv_pad_top);
-                    in_z        = std::min(static_cast<unsigned int>(in_z), static_cast<unsigned int>(input_depth));
-                    auto offset = y_offset + in_z * input_stride_z;
-                    offset      = std::min(offset, max_offset);
-                    convolve_row1x9_nhwc(in_ptr + offset + x, weights_ptr + 0 * kernel_stride_z + x, input_stride_y, kernel_stride_y, out0, out1, out2, out3);
-
-                    // z == 1
-                    in_z   = static_cast<int64_t>(id.z() * conv_stride_y - conv_pad_top + 1);
-                    in_z   = std::min(static_cast<unsigned int>(in_z), static_cast<unsigned int>(input_depth));
-                    offset = y_offset + in_z * input_stride_z;
-                    offset = std::min(offset, max_offset);
-                    convolve_row1x9_nhwc(in_ptr + offset + x, weights_ptr + 1 * kernel_stride_z + x, input_stride_y, kernel_stride_y, out0, out1, out2, out3);
-
-                    // z == 2
-                    in_z   = static_cast<int64_t>(id.z() * conv_stride_y - conv_pad_top + 2);
-                    in_z   = std::min(static_cast<unsigned int>(in_z), static_cast<unsigned int>(input_depth));
-                    offset = y_offset + in_z * input_stride_z;
-                    offset = std::min(offset, max_offset);
-                    convolve_row1x9_nhwc(in_ptr + offset + x, weights_ptr + 2 * kernel_stride_z + x, input_stride_y, kernel_stride_y, out0, out1, out2, out3);
-
-                    // z == 3
-                    in_z   = static_cast<int64_t>(id.z() * conv_stride_y - conv_pad_top + 3);
-                    offset = y_offset + in_z * input_stride_z;
-                    offset = std::min(offset, max_offset);
-                    convolve_row1x9_nhwc(in_ptr + offset + x, weights_ptr + 3 * kernel_stride_z + x, input_stride_y, kernel_stride_y, out0, out1, out2, out3);
-
-                    // z == 4
-                    in_z   = static_cast<int64_t>(id.z() * conv_stride_y - conv_pad_top + 4);
-                    offset = y_offset + in_z * input_stride_z;
-                    convolve_row1x9_nhwc(in_ptr + offset + x, weights_ptr + 4 * kernel_stride_z + x, input_stride_y, kernel_stride_y, out0, out1, out2, out3);
-
-                    // z == 5
-                    offset += input_stride_z;
-                    offset = std::min(offset, max_offset);
-                    convolve_row1x9_nhwc(in_ptr + offset + x, weights_ptr + 5 * kernel_stride_z + x, input_stride_y, kernel_stride_y, out0, out1, out2, out3);
-
-                    // z == 6
-                    offset += input_stride_z;
-                    offset = std::min(offset, max_offset);
-                    convolve_row1x9_nhwc(in_ptr + offset + x, weights_ptr + 6 * kernel_stride_z + x, input_stride_y, kernel_stride_y, out0, out1, out2, out3);
-
-                    // z == 7
-                    offset += input_stride_z;
-                    offset = std::min(offset, max_offset);
-                    convolve_row1x9_nhwc(in_ptr + offset + x, weights_ptr + 7 * kernel_stride_z + x, input_stride_y, kernel_stride_y, out0, out1, out2, out3);
-
-                    // z == 8
-                    offset += input_stride_z;
-                    offset = std::min(offset, max_offset);
-                    convolve_row1x9_nhwc(in_ptr + offset + x, weights_ptr + 8 * kernel_stride_z + x, input_stride_y, kernel_stride_y, out0, out1, out2, out3);
-                }
-
-                *(reinterpret_cast<scalar_type *>(out_ptr + 0 * output_stride_y)) = vreduce(out0);
-                *(reinterpret_cast<scalar_type *>(out_ptr + 1 * output_stride_y)) = vreduce(out1);
-                *(reinterpret_cast<scalar_type *>(out_ptr + 2 * output_stride_y)) = vreduce(out2);
-                *(reinterpret_cast<scalar_type *>(out_ptr + 3 * output_stride_y)) = vreduce(out3);
-            },
-            in, out);
-        },
-        k);
-    }
-};
-
 template <typename T1, typename T2>
 inline void convolve_1x1(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
                          const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
@@ -1169,21 +820,6 @@
     }
 }
 
-template <typename V>
-inline void convolve_9x9_nhwc(const Window &window, unsigned int num_elems_read_per_iteration,
-                              const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
-{
-    const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
-    switch(conv_stride_x)
-    {
-        case 1:
-            convolver_9x9_nhwc<V>::convolve(window, num_elems_read_per_iteration, input, weights, output, conv_info);
-            break;
-        default:
-            ARM_COMPUTE_ERROR("Not implemented");
-    }
-}
-
 Status validate_arguments(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *output, const PadStrideInfo &conv_info)
 {
     ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, output);
@@ -1337,69 +973,249 @@
     }
     else
     {
-        if(kernel_size == 9)
-        {
-            border_size.left = 0;
-            border_size.top  = conv_info.pad_left();
-
-            const int num_elems_read_per_iteration_x    = 4;
-            const int num_elems_written_per_iteration_x = 1;
-            const int num_elems_read_per_iteration_y    = 12;
-            const int num_elems_written_per_iteration_y = 4;
-
-            num_elems_read_per_iteration    = num_elems_read_per_iteration_x;
-            num_elems_written_per_iteration = num_elems_written_per_iteration_x;
-
-            border_size.right = num_elems_read_per_iteration_x;
-            if((conv_info.pad_bottom() != 0) || (conv_info.pad_top() != 0))
-            {
-                // If bottom or top padding are set, we need to read num_elems_read_per_iteration_y rows to zero.
-                // Since num_elems_read_per_iteration_y is always greater than conv_info.pad_right() we can set
-                // the bottom padding to num_elems_read_per_iteration_y
-                border_size.bottom = num_elems_read_per_iteration_y;
-            }
-            else if(conv_info.pad_right() != 0)
-            {
-                // Convetional border padding. Fill the bottom paddings so that we can read in batch of num_elems_read_per_iteration_y
-                border_size.bottom = ceil_to_multiple(input->dimension(1) + conv_info.pad_right(), num_elems_read_per_iteration_y) - input->dimension(1);
-            }
-            else
-            {
-                // No padding
-                border_size.bottom = 0;
-            }
-
-            win = calculate_max_window(*output, Steps(num_elems_written_per_iteration_x, num_elems_written_per_iteration_y));
-
-            AccessWindowStatic input_access(input, 0, -border_size.top,
-                                            ceil_to_multiple(input->dimension(0), num_elems_read_per_iteration_x),
-                                            input->dimension(1) + border_size.bottom);
-
-            AccessWindowStatic    weights_access(weights, 0, 0, ceil_to_multiple(weights->dimension(0), num_elems_read_per_iteration_x), weights->dimension(1));
-            AccessWindowRectangle output_access(output, 0, 0, num_elems_written_per_iteration_x, num_elems_written_per_iteration_y);
-            window_changed = update_window_and_padding(win, input_access, weights_access, output_access);
-            output_access.set_valid_region(win, ValidRegion(Coordinates(), output->tensor_shape()));
-        }
-        else
-        {
-            border_size.left             = 0;
-            border_size.top              = conv_info.pad_left();
-            border_size.right            = 0;
-            border_size.bottom           = conv_info.pad_right();
-            num_elems_read_per_iteration = 16 / element_size_from_data_type(input->data_type());
-            win                          = calculate_max_window(*output, Steps());
-
-            AccessWindowRectangle input_access(input, 0, -border_size.top, num_elems_read_per_iteration, kernel_size, 1.f, conv_stride_x);
-            AccessWindowRectangle weights_access(weights, 0, 0, num_elems_read_per_iteration, kernel_size);
-            window_changed = update_window_and_padding(win, input_access, weights_access);
-        }
+        // Configure window NHWC without any padding
+        win = calculate_max_window(*output, Steps());
+        Coordinates coord;
+        coord.set_num_dimensions(output->num_dimensions());
+        output->set_valid_region(ValidRegion(coord, output->tensor_shape()));
     }
 
     Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
     return std::make_pair(err, win);
 }
+
+bool have_zero_x_internal_padding(ITensorInfo *input, ITensorInfo *weights)
+{
+    return (input->padding().left == 0 && weights->padding().left == 0 && input->padding().right == 0 && weights->padding().right == 0);
+}
+
 } // namespace
 
+template <typename T>
+void NEDirectConvolutionLayerKernel::convolve_nhwc_optimized(const Window &window)
+{
+    // This function assumes that input and weights have not padding in channel
+
+    // Declare useful types
+    using vtype       = wrapper::traits::neon_bitvector<T, wrapper::traits::BitWidth::W128>;
+    using vector_type = typename vtype::type;
+    using tag_type    = typename vtype::tag_type;
+
+    // Scalar quantities
+    const int element_size   = _input->info()->element_size();
+    const int input_stride_w = _input->info()->strides_in_bytes().y() / element_size;
+    const int input_stride_h = _input->info()->strides_in_bytes().z() / element_size;
+    const int input_stride_n = _input->info()->strides_in_bytes()[3] / element_size;
+    const int input_dim_w    = _input->info()->dimension(1);
+    const int input_dim_h    = _input->info()->dimension(2);
+
+    const int output_stride_c = _output->info()->strides_in_bytes().x();
+
+    const unsigned int kernel_stride_w = _weights->info()->strides_in_bytes().y() / element_size;
+    const unsigned int kernel_stride_h = _weights->info()->strides_in_bytes().z() / element_size;
+    const int          kernel_dim_w    = _weights->info()->dimension(1);
+    const int          kernel_dim_h    = _weights->info()->dimension(2);
+
+    const int conv_pad_top  = _conv_info.pad_top();
+    const int conv_pad_left = _conv_info.pad_left();
+    const int conv_stride_w = std::get<0>(_conv_info.stride());
+    const int conv_stride_h = std::get<1>(_conv_info.stride());
+
+    // Setup input window for the output iterator
+    Window window_out = window;
+    window_out.set(Window::DimX, Window::Dimension(0, 1, 1));
+
+    // Setup input window for the weights iterator
+    Window window_w = calculate_max_window(*_weights->info(), Steps());
+    window_w.set(Window::DimX, Window::Dimension(0, 1, 1));
+    window_w.set(Window::DimY, Window::Dimension(0, 1, 1));
+    window_w.set(Window::DimZ, Window::Dimension(0, 1, 1));
+
+    Iterator out(_output, window_out);
+    Iterator wei(_weights, window_w);
+
+    constexpr int num_elems_read_per_iteration = 16 / sizeof(T);
+    /*
+     * This implementation parallelize the full WC plane of input and weights by
+     * treating them as series of elements. So for example, a 3x3 weights and
+     * floating point vector operations of 4 elements per time, the first 3
+     * channel elements of the first row would be taken and additionally the first
+     * element of the second row. The 9 elements in each single WC weight plane
+     * would require 2 4-element vector operations and a last single element operation.
+     *
+     * This works since when we create the input vector to multiply with the weights,
+     * the exact required elements are loaded in the same order. Therefore the
+     * multiplication works on the correct input/weight elements.
+     */
+    execute_window_loop(window_out, [&](const Coordinates & id)
+    {
+        /*
+         * In here we create theoretical indexes which then we validate for both
+         * inputs and weights.
+         * As a reminder, this loop take each output point in NHW, C is treated
+         * in the weights loop.
+         */
+        // We are computing the theoretical starting input starting points
+        const int in_w_start_t = static_cast<int>(id.y()) * conv_stride_w - conv_pad_left;
+        const int in_h_start_t = static_cast<int>(id.z()) * conv_stride_h - conv_pad_top;
+        const int in_w_end_t   = in_w_start_t + kernel_dim_w;
+        const int in_h_end_t   = in_h_start_t + kernel_dim_h;
+
+        // We are computing the valid initial and ending input points by checking the borders
+        const int in_w_start = std::max(in_w_start_t, 0);
+        const int in_h_start = std::max(in_h_start_t, 0);
+        const int in_w_end   = std::min(in_w_end_t, input_dim_w);
+        const int in_h_end   = std::min(in_h_end_t, input_dim_h);
+
+        // We use the input points to select the valid weight points to use
+        const int index_wc_start = (in_w_start - in_w_start_t) * kernel_stride_w;
+        const int index_h_start  = in_h_start - in_h_start_t;
+        const int index_wc_end   = (kernel_dim_w - (in_w_end_t - in_w_end)) * kernel_stride_w;
+        const int index_h_end    = kernel_dim_h - (in_h_end_t - in_h_end);
+
+        execute_window_loop(window_w, [&](const Coordinates & id_w)
+        {
+            /*
+             * This is the loop in the weights, and it goes along N (the batches)
+             * As a reminder, the batches of the weights are translated into the
+             * channels of the output
+             */
+            const T *in_ptr_row = reinterpret_cast<const T *>(_input->buffer() + _input->info()->offset_first_element_in_bytes())
+                                  + id[3] * input_stride_n + in_w_start * input_stride_w + in_h_start * input_stride_h;
+            const T *weights_ptr_row = reinterpret_cast<const T *>(wei.ptr()) + index_h_start * kernel_stride_h;
+            uint8_t *out_ptr         = out.ptr() + id_w[3] * output_stride_c;
+
+            T out_temp = static_cast<T>(0);
+            for(int index_h = index_h_start; index_h < index_h_end; ++index_h, in_ptr_row += input_stride_h, weights_ptr_row += kernel_stride_h)
+            {
+                const T    *in_ptr_mover = in_ptr_row;
+                int         index_wc     = index_wc_start;
+                vector_type out_temp_vec = wrapper::vdup_n(static_cast<T>(0), tag_type());
+                for(; index_wc <= index_wc_end - num_elems_read_per_iteration; index_wc += num_elems_read_per_iteration, in_ptr_mover += num_elems_read_per_iteration)
+                {
+                    const auto src_vec = wrapper::vloadq(in_ptr_mover);
+                    const auto w_vec   = wrapper::vloadq(weights_ptr_row + index_wc);
+                    out_temp_vec       = wrapper::vmla(out_temp_vec, w_vec, src_vec);
+                }
+                out_temp += vreduce(out_temp_vec);
+                for(; index_wc < index_wc_end; ++index_wc, ++in_ptr_mover)
+                {
+                    const auto src_val = *(in_ptr_mover);
+                    const auto w_val   = *(weights_ptr_row + index_wc);
+                    out_temp += src_val * w_val;
+                }
+            }
+            *(reinterpret_cast<T *>(out_ptr)) = out_temp;
+        },
+        wei);
+    },
+    out);
+}
+
+template <typename T>
+void NEDirectConvolutionLayerKernel::convolve_nhwc(const Window &window)
+{
+    // Declare useful types
+    using vtype       = wrapper::traits::neon_bitvector<T, wrapper::traits::BitWidth::W128>;
+    using vector_type = typename vtype::type;
+    using tag_type    = typename vtype::tag_type;
+
+    // Scalar quantities
+    const int element_size   = _input->info()->element_size();
+    const int input_stride_w = _input->info()->strides_in_bytes().y() / element_size;
+    const int input_stride_h = _input->info()->strides_in_bytes().z() / element_size;
+    const int input_stride_n = _input->info()->strides_in_bytes()[3] / element_size;
+    const int input_dim_w    = _input->info()->dimension(1);
+    const int input_dim_h    = _input->info()->dimension(2);
+
+    const int output_stride_c = _output->info()->strides_in_bytes().x();
+
+    const unsigned int kernel_stride_w = _weights->info()->strides_in_bytes().y() / element_size;
+    const unsigned int kernel_stride_h = _weights->info()->strides_in_bytes().z() / element_size;
+    const int          kernel_dim_w    = _weights->info()->dimension(1);
+    const int          kernel_dim_h    = _weights->info()->dimension(2);
+
+    const int conv_pad_top  = _conv_info.pad_top();
+    const int conv_pad_left = _conv_info.pad_left();
+    const int conv_stride_w = std::get<0>(_conv_info.stride());
+    const int conv_stride_h = std::get<1>(_conv_info.stride());
+
+    // Setup input window for the output iterator
+    Window window_out = window;
+    window_out.set(Window::DimX, Window::Dimension(0, 1, 1));
+
+    // Setup input window for the weights iterator
+    Window window_w = calculate_max_window(*_weights->info(), Steps());
+    window_w.set(Window::DimX, Window::Dimension(0, 1, 1));
+    window_w.set(Window::DimY, Window::Dimension(0, 1, 1));
+    window_w.set(Window::DimZ, Window::Dimension(0, 1, 1));
+
+    Iterator out(_output, window_out);
+    Iterator wei(_weights, window_w);
+
+    constexpr int num_elems_read_per_iteration = 16 / sizeof(T);
+
+    execute_window_loop(window_out, [&](const Coordinates & id)
+    {
+        // We are computing the theoretical starting input starting points
+        const int in_w_start_t = static_cast<int>(id.y()) * conv_stride_w - conv_pad_left;
+        const int in_h_start_t = static_cast<int>(id.z()) * conv_stride_h - conv_pad_top;
+        const int in_w_end_t   = in_w_start_t + kernel_dim_w;
+        const int in_h_end_t   = in_h_start_t + kernel_dim_h;
+
+        // We are computing the valid initial and ending input points by checking the borders
+        const int in_w_start = std::max(in_w_start_t, 0);
+        const int in_h_start = std::max(in_h_start_t, 0);
+        const int in_w_end   = std::min(in_w_end_t, input_dim_w);
+        const int in_h_end   = std::min(in_h_end_t, input_dim_h);
+
+        // We use the input points to select the valid weight points to use
+        const int wei_w_start = in_w_start - in_w_start_t;
+        const int wei_h_start = in_h_start - in_h_start_t;
+        const int wei_w_end   = kernel_dim_w - (in_w_end_t - in_w_end);
+        const int wei_h_end   = kernel_dim_h - (in_h_end_t - in_h_end);
+
+        const int      index_c_end  = _weights->info()->dimension(0);
+        const T *const in_ptr_start = reinterpret_cast<const T *>(_input->buffer() + _input->info()->offset_first_element_in_bytes()) + id[3] * input_stride_n;
+
+        execute_window_loop(window_w, [&](const Coordinates & id_w)
+        {
+            const T *const weights_ptr_start = reinterpret_cast<const T *>(wei.ptr());
+            uint8_t       *out_ptr           = out.ptr() + id_w[3] * output_stride_c;
+
+            T out_temp = static_cast<T>(0);
+            for(int index_wei_h = wei_h_start, index_in_h = in_h_start; index_wei_h < wei_h_end; ++index_wei_h, ++index_in_h)
+            {
+                const T *const in_ptr_row      = in_ptr_start + index_in_h * input_stride_h;
+                const T *const weights_ptr_row = weights_ptr_start + index_wei_h * kernel_stride_h;
+                for(int index_wei_w = wei_w_start, index_in_w = in_w_start; index_wei_w < wei_w_end; ++index_wei_w, ++index_in_w)
+                {
+                    const T    *in_ptr_mover      = in_ptr_row + index_in_w * input_stride_w;
+                    const T    *weights_ptr_mover = weights_ptr_row + index_wei_w * kernel_stride_w;
+                    int         index_c           = 0;
+                    vector_type out_temp_vec      = wrapper::vdup_n(static_cast<T>(0), tag_type());
+                    for(; index_c <= index_c_end - num_elems_read_per_iteration; index_c += num_elems_read_per_iteration, in_ptr_mover += num_elems_read_per_iteration, weights_ptr_mover += num_elems_read_per_iteration)
+                    {
+                        const auto src_vec = wrapper::vloadq(in_ptr_mover);
+                        const auto w_vec   = wrapper::vloadq(weights_ptr_mover);
+                        out_temp_vec       = wrapper::vmla(out_temp_vec, w_vec, src_vec);
+                    }
+                    out_temp += vreduce(out_temp_vec);
+                    for(; index_c < index_c_end; ++index_c, ++in_ptr_mover, ++weights_ptr_mover)
+                    {
+                        const auto src_val = *(in_ptr_mover);
+                        const auto w_val   = *(weights_ptr_mover);
+                        out_temp += src_val * w_val;
+                    }
+                }
+            }
+            *(reinterpret_cast<T *>(out_ptr)) = out_temp;
+        },
+        wei);
+    },
+    out);
+}
+
 NEDirectConvolutionLayerKernel::NEDirectConvolutionLayerKernel()
     : _input(nullptr), _weights(nullptr), _output(nullptr), _conv_info(), _border_size(0), _kernel_size(0), _num_weight_elems_read_per_row(0), _num_elems_read_per_iteration(0),
       _num_elems_written_per_iteration(0)
@@ -1425,7 +1241,14 @@
     const unsigned int conv_pad_top    = conv_info.pad_top();
     const unsigned int conv_pad_right  = conv_info.pad_right();
     const unsigned int conv_pad_bottom = conv_info.pad_bottom();
-    _border_size                       = BorderSize(conv_pad_top, conv_pad_right, conv_pad_bottom, conv_pad_left);
+    if(_input->info()->data_layout() == DataLayout::NCHW)
+    {
+        _border_size = BorderSize(conv_pad_top, conv_pad_right, conv_pad_bottom, conv_pad_left);
+    }
+    else
+    {
+        _border_size = BorderSize(0);
+    }
 
     // Get convolved dimensions
     TensorShape output_shape = misc::shape_calculator::compute_deep_convolution_shape(*input->info(), *weights->info(), conv_info);
@@ -1536,22 +1359,17 @@
     }
     else
     {
-        const int kernel_size = _weights->info()->dimension(get_data_layout_dimension_index(_weights->info()->data_layout(), DataLayoutDimension::WIDTH));
-        const int stride_x    = std::get<0>(_conv_info.stride());
-        const int stride_y    = std::get<1>(_conv_info.stride());
-
         switch(_input->info()->data_type())
         {
             case DataType::F32:
             {
-                if(kernel_size == 9 && stride_x == 1 && stride_y == 1)
+                if(have_zero_x_internal_padding(_input->info(), _weights->info()))
                 {
-                    using vtype = wrapper::traits::neon_vector<float, 4>;
-                    convolve_9x9_nhwc<vtype>(window, _num_elems_read_per_iteration, _input, _weights, _output, _conv_info);
+                    convolve_nhwc_optimized<float>(window);
                 }
                 else
                 {
-                    convolver_nhwc<float>::convolve(window, kernel_size, _num_elems_read_per_iteration, _input, _weights, _output, _conv_info);
+                    convolve_nhwc<float>(window);
                 }
                 break;
             }
@@ -1561,3 +1379,4 @@
         }
     }
 }
+} // namespace arm_compute
\ No newline at end of file
diff --git a/src/runtime/NEON/functions/NEDirectConvolutionLayer.cpp b/src/runtime/NEON/functions/NEDirectConvolutionLayer.cpp
index da7e771..fe54590 100644
--- a/src/runtime/NEON/functions/NEDirectConvolutionLayer.cpp
+++ b/src/runtime/NEON/functions/NEDirectConvolutionLayer.cpp
@@ -28,14 +28,11 @@
 #include "arm_compute/core/Validate.h"
 #include "arm_compute/runtime/NEON/NEScheduler.h"
 
-#include <cmath>
-#include <tuple>
-
 namespace arm_compute
 {
 NEDirectConvolutionLayer::NEDirectConvolutionLayer(std::shared_ptr<IMemoryManager> memory_manager)
     : _memory_group(std::move(memory_manager)), _output_stage_kernel(), _conv_kernel(), _input_border_handler(), _activationlayer_function(), _accumulator(), _has_bias(false),
-      _is_activationlayer_enabled(false), _dim_split(Window::DimZ)
+      _is_activationlayer_enabled(false), _dim_split(Window::DimZ), _is_padding_required()
 {
 }
 
@@ -59,9 +56,13 @@
     {
         _output_stage_kernel.configure(output, bias);
     }
+    _is_padding_required = !_conv_kernel.border_size().empty();
 
-    // Add zero padding XY
-    _input_border_handler.configure(input, _conv_kernel.border_size(), BorderMode::CONSTANT, PixelValue(static_cast<float>(0.f)));
+    if(_is_padding_required)
+    {
+        // Add zero padding XY
+        _input_border_handler.configure(input, _conv_kernel.border_size(), BorderMode::CONSTANT, PixelValue(static_cast<float>(0.f)));
+    }
 
     //Configure Activation Layer
     _is_activationlayer_enabled = act_info.enabled();
@@ -104,10 +105,12 @@
 
 void NEDirectConvolutionLayer::run()
 {
-    NEScheduler::get().schedule(&_input_border_handler, Window::DimZ);
-
     MemoryGroupResourceScope scope_mg(_memory_group);
 
+    if(_is_padding_required)
+    {
+        NEScheduler::get().schedule(&_input_border_handler, Window::DimZ);
+    }
     NEScheduler::get().schedule(&_conv_kernel, _dim_split);
     if(_has_bias)
     {