COMPMID-808 Add NHWC data format support for NEON direct convolution

Change-Id: I5d4cc3d5b0d25f3fe4ed998c0f15b1b8e260a43a
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/125697
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h
index 757e423..9543d98 100644
--- a/arm_compute/core/utils/misc/ShapeCalculator.h
+++ b/arm_compute/core/utils/misc/ShapeCalculator.h
@@ -300,19 +300,19 @@
     const size_t idx_height  = get_data_layout_dimension_index(input.data_layout(), DataLayoutDimension::HEIGHT);
     const size_t idx_channel = get_data_layout_dimension_index(input.data_layout(), DataLayoutDimension::CHANNEL);
 
-    const unsigned int input_width     = input_shape[idx_width];
-    const unsigned int input_height    = input_shape[idx_height];
-    const unsigned int weights_width   = weights_shape[idx_width];
-    const unsigned int weights_height  = weights_shape[idx_height];
-    const unsigned int weights_channel = weights_shape[idx_channel];
-    unsigned int       output_width    = 0;
-    unsigned int       output_height   = 0;
+    const unsigned int input_width         = input_shape[idx_width];
+    const unsigned int input_height        = input_shape[idx_height];
+    const unsigned int weights_width       = weights_shape[idx_width];
+    const unsigned int weights_height      = weights_shape[idx_height];
+    const unsigned int weights_out_channel = weights_shape[3];
+    unsigned int       output_width        = 0;
+    unsigned int       output_height       = 0;
     std::tie(output_width, output_height) = scaled_dimensions(input_width, input_height, weights_width, weights_height, conv_info);
 
     TensorShape output_shape{ input_shape };
     output_shape.set(idx_width, output_width);
     output_shape.set(idx_height, output_height);
-    output_shape.set(idx_channel, weights_channel);
+    output_shape.set(idx_channel, weights_out_channel);
 
     return output_shape;
 }
diff --git a/arm_compute/runtime/NEON/functions/NEDirectConvolutionLayer.h b/arm_compute/runtime/NEON/functions/NEDirectConvolutionLayer.h
index 1eaad5c..ae384ff 100644
--- a/arm_compute/runtime/NEON/functions/NEDirectConvolutionLayer.h
+++ b/arm_compute/runtime/NEON/functions/NEDirectConvolutionLayer.h
@@ -106,6 +106,7 @@
     bool                                      _has_bias;
     bool                                      _is_fixed_point;
     bool                                      _is_activationlayer_enabled;
+    unsigned int                              _dim_split;
 };
 }
 #endif /* __ARM_COMPUTE_NEDIRECTCONVOLUTIONLAYER_H__ */
diff --git a/src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp b/src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp
index 13ee9a1..e1fa650 100644
--- a/src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp
+++ b/src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp
@@ -34,6 +34,7 @@
 #include "arm_compute/core/Types.h"
 #include "arm_compute/core/Utils.h"
 #include "arm_compute/core/Validate.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
 #include "arm_compute/core/utils/quantization/AsymmHelpers.h"
 #include "support/ToolchainSupport.h"
 
@@ -41,26 +42,6 @@
 
 namespace
 {
-/** Calculates expected output shape dimension
- *
- * @param[in] Input shape
- *
- * @return Expected output shape
- */
-TensorShape get_output_shape(TensorShape input_shape, TensorShape weights_shape, PadStrideInfo conv_info)
-{
-    unsigned int output_width  = 0;
-    unsigned int output_height = 0;
-    std::tie(output_width, output_height) = scaled_dimensions(input_shape.x(), input_shape.y(), weights_shape.x(), weights_shape.y(), conv_info);
-
-    TensorShape output_shape = input_shape;
-    output_shape.set(0, output_width);
-    output_shape.set(1, output_height);
-    output_shape.set(2, weights_shape[3]);
-
-    return output_shape;
-}
-
 Status validate_arguments(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const PadStrideInfo &conv_info)
 {
     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QASYMM8, DataType::QS16, DataType::F16, DataType::F32);
@@ -100,7 +81,7 @@
     if(output->total_size() != 0)
     {
         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(),
-                                                           get_output_shape(input->tensor_shape(), weights->tensor_shape(), conv_info));
+                                                           misc::shape_calculator::compute_deep_convolution_shape(*input, *weights, conv_info));
         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, output);
     }
@@ -114,7 +95,7 @@
     const DataType     data_type   = input->data_type();
 
     // Get convolved dimensions
-    TensorShape output_shape = get_output_shape(input->tensor_shape(), weights->tensor_shape(), conv_info);
+    TensorShape output_shape = misc::shape_calculator::compute_deep_convolution_shape(*input, *weights, conv_info);
 
     // Output auto inizialitation if not yet initialized
     // FIXME: input->clone()->set_tensor_shape(output_shape) doesn't work with subtensors for grouped direct convolutions (AlexNet).
@@ -134,7 +115,8 @@
     unsigned int num_elems_written_per_iteration_x = 0;
     unsigned int num_elems_written_per_iteration_y = 0;
 
-    if(gpu_target_is_in(target, GPUTarget::G71, GPUTarget::G72, GPUTarget::G51, GPUTarget::G51BIG, GPUTarget::G51LIT, GPUTarget::TNOX) && (kernel_size <= 5) && (conv_stride_x == 1) && (conv_stride_y == 1) && (data_type == DataType::F32))
+    if(gpu_target_is_in(target, GPUTarget::G71, GPUTarget::G72, GPUTarget::G51, GPUTarget::G51BIG, GPUTarget::G51LIT, GPUTarget::TNOX) && (kernel_size <= 5) && (conv_stride_x == 1)
+       && (conv_stride_y == 1) && (data_type == DataType::F32))
     {
         // Configure kernel window
 
@@ -274,7 +256,7 @@
     const DataType     data_type   = input->info()->data_type();
 
     // Get convolved dimensions
-    TensorShape output_shape = get_output_shape(input->info()->tensor_shape(), weights->info()->tensor_shape(), conv_info);
+    TensorShape output_shape = misc::shape_calculator::compute_deep_convolution_shape(*input->info(), *weights->info(), conv_info);
 
     // Output auto inizialitation if not yet initialized
     // FIXME: input->clone()->set_tensor_shape(output_shape) doesn't work with subtensors for grouped direct convolutions (AlexNet).
@@ -309,7 +291,8 @@
     CLBuildOptions build_options;
     build_options.add_option_if(_biases != nullptr, std::string("-DHAS_BIAS"));
 
-    if(gpu_target_is_in(gpu_target, GPUTarget::G71, GPUTarget::G72, GPUTarget::G51, GPUTarget::G51BIG, GPUTarget::G51LIT, GPUTarget::TNOX) && (kernel_size <= 5) && (_conv_stride_x == 1) && (_conv_stride_y == 1) && (data_type == DataType::F32))
+    if(gpu_target_is_in(gpu_target, GPUTarget::G71, GPUTarget::G72, GPUTarget::G51, GPUTarget::G51BIG, GPUTarget::G51LIT, GPUTarget::TNOX) && (kernel_size <= 5) && (_conv_stride_x == 1)
+       && (_conv_stride_y == 1) && (data_type == DataType::F32))
     {
         build_options.add_option(std::string("-DWEIGHTS_DEPTH=" + support::cpp11::to_string(_weights->info()->dimension(2))));
 
diff --git a/src/core/NEON/kernels/NEDirectConvolutionLayerKernel.cpp b/src/core/NEON/kernels/NEDirectConvolutionLayerKernel.cpp
index 285ec2d..5eafdf0 100644
--- a/src/core/NEON/kernels/NEDirectConvolutionLayerKernel.cpp
+++ b/src/core/NEON/kernels/NEDirectConvolutionLayerKernel.cpp
@@ -33,6 +33,7 @@
 #include "arm_compute/core/Types.h"
 #include "arm_compute/core/Utils.h"
 #include "arm_compute/core/Validate.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
 
 #include <algorithm>
 #include <arm_neon.h>
@@ -663,6 +664,118 @@
     vst1_qs16(buffer, vqadd_qs16(vld1_qs16(buffer), vget_low_s16(values.val[0])));
 }
 
+template <typename T1>
+class convolver_nhwc
+{
+public:
+    static void convolve(const Window &window, int kernel_size, unsigned int num_elems_read_per_iteration,
+                         const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
+    {
+        const int          input_width     = input->info()->dimension(0);
+        const int          input_depth     = input->info()->dimension(2);
+        const int          input_stride_x  = input->info()->strides_in_bytes().x();
+        const int          input_stride_y  = input->info()->strides_in_bytes().y();
+        const int          input_stride_z  = input->info()->strides_in_bytes().z();
+        const int          output_stride_x = output->info()->strides_in_bytes().x();
+        const int          kernel_stride_x = weights->info()->strides_in_bytes().x();
+        const int          kernel_stride_y = weights->info()->strides_in_bytes().y();
+        const int          kernel_stride_z = weights->info()->strides_in_bytes().z();
+        const int          conv_pad_top    = conv_info.pad_top();
+        const unsigned int conv_stride_x   = std::get<0>(conv_info.stride());
+        const unsigned int conv_stride_y   = std::get<1>(conv_info.stride());
+        const T1           zero            = 0;
+
+        // Setup input window for the input iterator
+        Window window_in = window;
+        window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
+        window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
+        window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
+
+        // Setup input window for the output iterator
+        Window window_out = window;
+        window_out.set(Window::DimX, Window::Dimension(0, 1, 1));
+
+        // Setup input window for the weights iterator
+        Window window_k = calculate_max_window(*weights->info(), Steps());
+        window_k.set(Window::DimX, Window::Dimension(0, 1, 1));
+        window_k.set(Window::DimY, Window::Dimension(0, 1, 1));
+        window_k.set(Window::DimZ, Window::Dimension(0, 1, 1));
+        window_k.set(3, Window::Dimension(0, weights->info()->dimension(3), 1));
+
+        Iterator in(input, window_in);
+        Iterator out(output, window_out);
+        Iterator k(weights, window_k);
+
+        execute_window_loop(window_k, [&](const Coordinates & id_k)
+        {
+            execute_window_loop(window_out, [&](const Coordinates & id)
+            {
+                const auto in_y = static_cast<int>(id.y() * conv_stride_x - conv_info.pad_left());
+                const auto in_z = static_cast<int>(id.z() * conv_stride_y - conv_pad_top);
+
+                const uint8_t *in_ptr  = in.ptr() + in_y * input_stride_y + in_z * input_stride_z;
+                uint8_t       *out_ptr = out.ptr() + id_k[3] * output_stride_x;
+
+                T1 out_val = 0;
+
+                auto in_addr_base0 = in_ptr;
+                auto we_addr_base0 = k.ptr();
+
+                for(int z = 0; z < kernel_size; ++z, in_addr_base0 += input_stride_z, we_addr_base0 += kernel_stride_z)
+                {
+                    const int in_z = id.z() * conv_stride_y + z - conv_pad_top;
+
+                    if(in_z >= 0 && in_z < input_depth) // If false, pad top/bottom
+                    {
+                        auto in_addr_base1 = in_addr_base0;
+                        auto we_addr_base1 = we_addr_base0;
+
+                        for(int y = 0; y < kernel_size; ++y, in_addr_base1 += input_stride_y, we_addr_base1 += kernel_stride_y)
+                        {
+                            auto out_values = internal_vdupq_n(zero);
+
+                            int x           = 0;
+                            int no_leftover = input_width - num_elems_read_per_iteration;
+
+                            for(; x < no_leftover; x += num_elems_read_per_iteration)
+                            {
+                                const auto in_addr   = reinterpret_cast<const T1 *>(in_addr_base1 + x * input_stride_x);
+                                const auto in_values = internal_vld1q<1>(in_addr);
+
+                                const auto we_addr   = reinterpret_cast<const T1 *>(we_addr_base1 + x * kernel_stride_x);
+                                const auto we_values = internal_vld1q<1>(we_addr);
+
+                                out_values = internal_vmlal(out_values, in_values, we_values, 0);
+                            }
+
+                            out_val += out_values[0];
+                            out_val += out_values[1];
+                            out_val += out_values[2];
+                            out_val += out_values[3];
+
+                            // Leftover
+                            for(; x < input_width; ++x)
+                            {
+                                const auto in_addr  = reinterpret_cast<const T1 *>(in_addr_base1 + x * input_stride_x);
+                                const auto in_value = *(in_addr);
+
+                                const auto we_addr  = reinterpret_cast<const T1 *>(we_addr_base1 + x * kernel_stride_x);
+                                const auto we_value = *(we_addr);
+
+                                out_val += in_value * we_value;
+                            }
+                        }
+                    }
+                }
+
+                *(reinterpret_cast<T1 *>(out_ptr)) = out_val;
+            },
+            in, out);
+        },
+        k);
+    }
+};
+
 template <typename T1, typename T2, unsigned int stridex>
 class convolver_3x3
 {
@@ -1003,35 +1116,28 @@
     }
 }
 
-inline TensorShape get_convolved_dimensions(const ITensorInfo *input, const ITensorInfo *weights, const int kernel_size, const PadStrideInfo &conv_info)
-{
-    unsigned int output_width  = 0;
-    unsigned int output_height = 0;
-    std::tie(output_width, output_height) = scaled_dimensions(input->dimension(0), input->dimension(1), kernel_size, kernel_size, conv_info);
-
-    TensorShape output_shape = input->tensor_shape();
-    output_shape.set(0, output_width);
-    output_shape.set(1, output_height);
-    output_shape.set(2, weights->dimension(3));
-
-    return output_shape;
-}
-
 Status validate_arguments(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *output, const PadStrideInfo &conv_info)
 {
     ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, output);
+    ARM_COMPUTE_RETURN_ERROR_ON(input->data_layout() == DataLayout::UNKNOWN);
     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights);
 
+    const DataLayout data_layout = input->data_layout();
+    const int        width_idx   = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
+    const int        height_idx  = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
+    const int        channel_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL);
+
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(std::get<0>(conv_info.stride()) > 3, "Strides larger than 3 not supported.");
-    ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(2) != input->dimension(2));
-    ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(0) != weights->dimension(1));
+    ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(channel_idx) != input->dimension(channel_idx));
+    ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(width_idx) != weights->dimension(height_idx));
     ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 4);
+    ARM_COMPUTE_RETURN_ERROR_ON(data_layout == DataLayout::NHWC && input->data_type() != DataType::F32);
 
     // Checks performed when output is configured
     if(output->total_size() != 0)
     {
-        TensorShape output_shape = get_convolved_dimensions(input, weights, weights->dimension(0), conv_info);
+        TensorShape output_shape = misc::shape_calculator::compute_deep_convolution_shape(*input, *weights, conv_info);
 
         DataType data_type = input->data_type();
         if(is_data_type_fixed_point(data_type))
@@ -1050,101 +1156,127 @@
 std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *weights, ITensorInfo *output, const PadStrideInfo &conv_info, unsigned int &num_weight_elems_read_per_row,
                                                         unsigned int &num_elems_read_per_iteration, unsigned int &num_elems_written_per_iteration, BorderSize &border_size)
 {
+    ARM_COMPUTE_ERROR_ON(input->data_layout() == DataLayout::UNKNOWN);
+
+    const DataLayout data_layout = input->data_layout();
+    const int        width_idx   = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
+
     // Calculate right and bottom border
-    unsigned int kernel_size   = weights->dimension(0);
+    unsigned int kernel_size   = weights->dimension(width_idx);
     const int    conv_stride_x = std::get<0>(conv_info.stride());
     const int    conv_stride_y = std::get<1>(conv_info.stride());
-    const int    input_width   = input->dimension(0);
+    const int    input_width   = input->dimension(width_idx);
 
-    switch(kernel_size)
+    Window win{};
+    bool   window_changed = false;
+
+    if(data_layout == DataLayout::NCHW)
     {
-        case 1:
+        switch(kernel_size)
         {
-            switch(input->data_type())
+            case 1:
             {
+                switch(input->data_type())
+                {
 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-                case DataType::F16:
+                    case DataType::F16:
 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
-                case DataType::QS8:
-                case DataType::QS16:
-                    num_elems_written_per_iteration = 8;
-                    break;
-                case DataType::F32:
-                    if(run_optim_small_tensor_info(input))
-                    {
+                    case DataType::QS8:
+                    case DataType::QS16:
                         num_elems_written_per_iteration = 8;
-                    }
-                    else
-                    {
-                        num_elems_written_per_iteration = 4;
-                    }
-                    break;
-                default:
-                    ARM_COMPUTE_ERROR("Data type not supported.");
-                    break;
+                        break;
+                    case DataType::F32:
+                        if(run_optim_small_tensor_info(input))
+                        {
+                            num_elems_written_per_iteration = 8;
+                        }
+                        else
+                        {
+                            num_elems_written_per_iteration = 4;
+                        }
+                        break;
+                    default:
+                        ARM_COMPUTE_ERROR("Data type not supported.");
+                        break;
+                }
+                num_weight_elems_read_per_row = kernel_size;
+                num_elems_read_per_iteration  = conv_stride_x * num_elems_written_per_iteration;
+                break;
             }
-            num_weight_elems_read_per_row = kernel_size;
-            num_elems_read_per_iteration  = conv_stride_x * num_elems_written_per_iteration;
-            break;
-        }
-        case 3:
-        case 5:
-        {
-            switch(input->data_type())
+            case 3:
+            case 5:
             {
-                case DataType::F32:
-                    num_weight_elems_read_per_row   = 4 + kernel_size - 1;
-                    num_elems_read_per_iteration    = 12;
-                    num_elems_written_per_iteration = 16 >> conv_stride_x;
-                    break;
+                switch(input->data_type())
+                {
+                    case DataType::F32:
+                        num_weight_elems_read_per_row   = 4 + kernel_size - 1;
+                        num_elems_read_per_iteration    = 12;
+                        num_elems_written_per_iteration = 16 >> conv_stride_x;
+                        break;
 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-                case DataType::F16:
+                    case DataType::F16:
 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
-                case DataType::QS8:
-                case DataType::QS16:
-                    num_weight_elems_read_per_row   = 8 + kernel_size - 1;
-                    num_elems_read_per_iteration    = 24;
-                    num_elems_written_per_iteration = 32 >> conv_stride_x;
-                    break;
-                default:
-                    ARM_COMPUTE_ERROR("Data type not supported.");
-                    break;
+                    case DataType::QS8:
+                    case DataType::QS16:
+                        num_weight_elems_read_per_row   = 8 + kernel_size - 1;
+                        num_elems_read_per_iteration    = 24;
+                        num_elems_written_per_iteration = 32 >> conv_stride_x;
+                        break;
+                    default:
+                        ARM_COMPUTE_ERROR("Data type not supported.");
+                        break;
+                }
+            }
+            break;
+            default:
+            {
+                ARM_COMPUTE_ERROR("Not implemented");
+                break;
             }
         }
-        break;
-        default:
-        {
-            ARM_COMPUTE_ERROR("Not implemented");
-            break;
-        }
+
+        // Calculate right pad
+        int start_x       = kernel_size / 2 - static_cast<int>(conv_info.pad_left());
+        int end_x         = ceil_to_multiple(static_cast<int>(output->dimension(0)), num_elems_written_per_iteration) * conv_stride_x;
+        int upper_bound_w = ceil_to_multiple(start_x + end_x, num_elems_read_per_iteration) - input_width;
+
+        // Calculate border
+        const unsigned int conv_pad_left   = conv_info.pad_left();
+        const unsigned int conv_pad_top    = conv_info.pad_top();
+        const unsigned int conv_pad_right  = std::max(upper_bound_w, 0);
+        const unsigned int conv_pad_bottom = conv_info.pad_bottom();
+
+        border_size.left   = conv_pad_left;
+        border_size.top    = conv_pad_top;
+        border_size.right  = conv_pad_right;
+        border_size.bottom = conv_pad_bottom;
+
+        // Configure window
+        win = calculate_max_window(*output, Steps(num_elems_written_per_iteration));
+
+        AccessWindowRectangle input_access(input, -conv_pad_left, -conv_pad_top,
+                                           num_elems_read_per_iteration, kernel_size,
+                                           conv_stride_x, conv_stride_y);
+        AccessWindowStatic     weights_access(weights, 0, 0, num_weight_elems_read_per_row, kernel_size);
+        AccessWindowHorizontal output_access(output, 0, num_elems_written_per_iteration);
+        window_changed = update_window_and_padding(win, input_access, weights_access, output_access);
+        output_access.set_valid_region(win, ValidRegion(Coordinates(), output->tensor_shape()));
     }
+    else
+    {
+        border_size.left   = 0;
+        border_size.top    = conv_info.pad_left();
+        border_size.right  = 0;
+        border_size.bottom = conv_info.pad_right();
 
-    // Calculate right pad
-    int start_x       = kernel_size / 2 - static_cast<int>(conv_info.pad_left());
-    int end_x         = ceil_to_multiple(static_cast<int>(output->dimension(0)), num_elems_written_per_iteration) * conv_stride_x;
-    int upper_bound_w = ceil_to_multiple(start_x + end_x, num_elems_read_per_iteration) - input_width;
+        num_elems_read_per_iteration = 16 / element_size_from_data_type(input->data_type());
 
-    // Calculate border
-    const unsigned int conv_pad_left   = conv_info.pad_left();
-    const unsigned int conv_pad_top    = conv_info.pad_top();
-    const unsigned int conv_pad_right  = std::max(upper_bound_w, 0);
-    const unsigned int conv_pad_bottom = conv_info.pad_bottom();
+        win = calculate_max_window(*output, Steps());
 
-    border_size.left   = conv_pad_left;
-    border_size.top    = conv_pad_top;
-    border_size.right  = conv_pad_right;
-    border_size.bottom = conv_pad_bottom;
-
-    // Configure window
-    Window win = calculate_max_window(*output, Steps(num_elems_written_per_iteration));
-
-    AccessWindowRectangle input_access(input, -conv_pad_left, -conv_pad_top,
-                                       num_elems_read_per_iteration, kernel_size,
-                                       conv_stride_x, conv_stride_y);
-    AccessWindowStatic     weights_access(weights, 0, 0, num_weight_elems_read_per_row, kernel_size);
-    AccessWindowHorizontal output_access(output, 0, num_elems_written_per_iteration);
-    bool                   window_changed = update_window_and_padding(win, input_access, weights_access, output_access);
-    output_access.set_valid_region(win, ValidRegion(Coordinates(), output->tensor_shape()));
+        AccessWindowRectangle input_access(input, 0, -border_size.top, num_elems_read_per_iteration, kernel_size, 1.f, conv_stride_x);
+        AccessWindowRectangle weights_access(weights, 0, 0, num_elems_read_per_iteration, kernel_size);
+        window_changed = update_window_and_padding(win, input_access, weights_access);
+    }
 
     Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
     return std::make_pair(err, win);
@@ -1170,7 +1302,7 @@
     _weights     = weights;
     _output      = output;
     _conv_info   = conv_info;
-    _kernel_size = weights->info()->dimension(0);
+    _kernel_size = weights->info()->dimension(get_data_layout_dimension_index(weights->info()->data_layout(), DataLayoutDimension::WIDTH));
 
     const unsigned int conv_pad_left   = conv_info.pad_left();
     const unsigned int conv_pad_top    = conv_info.pad_top();
@@ -1179,7 +1311,7 @@
     _border_size                       = BorderSize(conv_pad_top, conv_pad_right, conv_pad_bottom, conv_pad_left);
 
     // Get convolved dimensions
-    TensorShape output_shape = get_convolved_dimensions(input->info(), weights->info(), _kernel_size, conv_info);
+    TensorShape output_shape = misc::shape_calculator::compute_deep_convolution_shape(*input->info(), *weights->info(), conv_info);
 
     DataType data_type = input->info()->data_type();
 
@@ -1229,73 +1361,88 @@
     ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
     ARM_COMPUTE_ERROR_ON(_input->buffer() == nullptr);
 
-    const int kernel_size = _weights->info()->dimension(0);
+    const int kernel_size = _weights->info()->dimension(get_data_layout_dimension_index(_weights->info()->data_layout(), DataLayoutDimension::WIDTH));
 
-    switch(kernel_size)
+    if(_input->info()->data_layout() == DataLayout::NCHW)
     {
-        case 1:
+        switch(kernel_size)
         {
-            switch(_input->info()->data_type())
+            case 1:
             {
-                case DataType::QS8:
-                    convolve_1x1<qint8_t, qint16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
-                    break;
-                case DataType::QS16:
-                    convolve_1x1<qint16_t, qint32_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
-                    break;
-                case DataType::F32:
-                    convolve_1x1<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
-                    break;
+                switch(_input->info()->data_type())
+                {
+                    case DataType::QS8:
+                        convolve_1x1<qint8_t, qint16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
+                        break;
+                    case DataType::QS16:
+                        convolve_1x1<qint16_t, qint32_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
+                        break;
+                    case DataType::F32:
+                        convolve_1x1<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
+                        break;
 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-                case DataType::F16:
-                    convolve_1x1<float16_t, float16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
-                    break;
+                    case DataType::F16:
+                        convolve_1x1<float16_t, float16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
+                        break;
 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
-                default:
-                    ARM_COMPUTE_ERROR("Data type not supported");
-                    break;
+                    default:
+                        ARM_COMPUTE_ERROR("Data type not supported");
+                        break;
+                }
+                break;
             }
-            break;
-        }
-        case 3:
-        {
-            switch(_input->info()->data_type())
+            case 3:
             {
-                case DataType::QS8:
-                    convolve_3x3<qint8_t, qint16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
-                    break;
-                case DataType::F32:
-                    convolve_3x3<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
-                    break;
+                switch(_input->info()->data_type())
+                {
+                    case DataType::QS8:
+                        convolve_3x3<qint8_t, qint16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
+                        break;
+                    case DataType::F32:
+                        convolve_3x3<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
+                        break;
 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-                case DataType::F16:
-                    convolve_3x3<float16_t, float16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
-                    break;
+                    case DataType::F16:
+                        convolve_3x3<float16_t, float16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
+                        break;
 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
-                default:
-                    ARM_COMPUTE_ERROR("Data type not supported");
-                    break;
+                    default:
+                        ARM_COMPUTE_ERROR("Data type not supported");
+                        break;
+                }
+                break;
             }
-            break;
-        }
-        case 5:
-        {
-            switch(_input->info()->data_type())
+            case 5:
             {
-                case DataType::F32:
-                    convolve_5x5<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
-                    break;
-                default:
-                    ARM_COMPUTE_ERROR("Data type not supported");
-                    break;
+                switch(_input->info()->data_type())
+                {
+                    case DataType::F32:
+                        convolve_5x5<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
+                        break;
+                    default:
+                        ARM_COMPUTE_ERROR("Data type not supported");
+                        break;
+                }
+                break;
             }
-            break;
-        }
 
-        default:
+            default:
+            {
+                ARM_COMPUTE_ERROR("Only kernel sizes 1x1, 3x3 and 5x5 are supported.");
+                break;
+            }
+        }
+    }
+    else
+    {
+        switch(_input->info()->data_type())
         {
-            ARM_COMPUTE_ERROR("Only kernel sizes 1x1, 3x3 and 5x5 are supported.");
-            break;
+            case DataType::F32:
+                convolver_nhwc<float>::convolve(window, kernel_size, _num_elems_read_per_iteration, _input, _weights, _output, _conv_info);
+                break;
+            default:
+                ARM_COMPUTE_ERROR("Data type not supported");
+                break;
         }
     }
 }
diff --git a/src/runtime/NEON/functions/NEDirectConvolutionLayer.cpp b/src/runtime/NEON/functions/NEDirectConvolutionLayer.cpp
index 00776d7..445864c 100644
--- a/src/runtime/NEON/functions/NEDirectConvolutionLayer.cpp
+++ b/src/runtime/NEON/functions/NEDirectConvolutionLayer.cpp
@@ -35,18 +35,22 @@
 
 NEDirectConvolutionLayer::NEDirectConvolutionLayer(std::shared_ptr<IMemoryManager> memory_manager)
     : _memory_group(std::move(memory_manager)), _output_stage_kernel(), _conv_kernel(), _input_border_handler(), _activationlayer_function(), _accumulator(), _has_bias(false), _is_fixed_point(false),
-      _is_activationlayer_enabled(false)
+      _is_activationlayer_enabled(false), _dim_split(Window::DimZ)
 {
 }
 
 void NEDirectConvolutionLayer::configure(ITensor *input, const ITensor *weights, const ITensor *bias, ITensor *output, const PadStrideInfo &conv_info, const ActivationLayerInfo &act_info)
 {
+    ARM_COMPUTE_ERROR_ON(input->info()->data_layout() == DataLayout::UNKNOWN);
+
     // Free accumulator
     if(_accumulator.buffer() != nullptr)
     {
         _accumulator.allocator()->free();
     }
 
+    _dim_split = input->info()->data_layout() == DataLayout::NCHW ? Window::DimZ : Window::DimY;
+
     // Check if bias should be added in the convolution result
     _has_bias = (bias != nullptr);
 
@@ -124,7 +128,7 @@
 
     _memory_group.acquire();
 
-    NEScheduler::get().schedule(&_conv_kernel, Window::DimZ);
+    NEScheduler::get().schedule(&_conv_kernel, _dim_split);
     if(_has_bias || _is_fixed_point)
     {
         NEScheduler::get().schedule(&_output_stage_kernel, Window::DimY);
diff --git a/tests/SimpleTensor.h b/tests/SimpleTensor.h
index 5a55a95..cfd1383 100644
--- a/tests/SimpleTensor.h
+++ b/tests/SimpleTensor.h
@@ -426,6 +426,8 @@
     swap(tensor1._format, tensor2._format);
     swap(tensor1._data_type, tensor2._data_type);
     swap(tensor1._num_channels, tensor2._num_channels);
+    swap(tensor1._fixed_point_position, tensor2._fixed_point_position);
+    swap(tensor1._quantization_info, tensor2._quantization_info);
     swap(tensor1._buffer, tensor2._buffer);
 }
 } // namespace test
diff --git a/tests/validation/CL/DirectConvolutionLayer.cpp b/tests/validation/CL/DirectConvolutionLayer.cpp
index 4564c64..19d914e 100644
--- a/tests/validation/CL/DirectConvolutionLayer.cpp
+++ b/tests/validation/CL/DirectConvolutionLayer.cpp
@@ -171,7 +171,9 @@
 
 TEST_SUITE(Float)
 TEST_SUITE(FP16)
-FIXTURE_DATA_TEST_CASE(Run, CLDirectConvolutionLayerFixture<half>, framework::DatasetMode::ALL, combine(combine(data, framework::dataset::make("DataType", DataType::F16)), ActivationFunctionsDataset))
+FIXTURE_DATA_TEST_CASE(Run, CLDirectConvolutionLayerFixture<half>, framework::DatasetMode::ALL, combine(combine(combine(data, framework::dataset::make("DataType", DataType::F16)),
+                                                                                                                ActivationFunctionsDataset),
+                                                                                                        framework::dataset::make("DataLayout", DataLayout::NCHW)))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_fp16, tolerance_num);
@@ -179,8 +181,9 @@
 TEST_SUITE_END()
 
 TEST_SUITE(FP32)
-FIXTURE_DATA_TEST_CASE(Run, CLDirectConvolutionLayerFixture<float>, framework::DatasetMode::ALL, combine(combine(data, framework::dataset::make("DataType", DataType::F32)),
-                                                                                                         ActivationFunctionsDataset))
+FIXTURE_DATA_TEST_CASE(Run, CLDirectConvolutionLayerFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(data, framework::dataset::make("DataType", DataType::F32)),
+                                                                                                                 ActivationFunctionsDataset),
+                                                                                                         framework::dataset::make("DataLayout", DataLayout::NCHW)))
 {
     // Validate output
     validate(CLAccessor(_target), _reference, tolerance_fp32);
diff --git a/tests/validation/GLES_COMPUTE/DirectConvolutionLayer.cpp b/tests/validation/GLES_COMPUTE/DirectConvolutionLayer.cpp
index 47c35f9..0942b07 100644
--- a/tests/validation/GLES_COMPUTE/DirectConvolutionLayer.cpp
+++ b/tests/validation/GLES_COMPUTE/DirectConvolutionLayer.cpp
@@ -87,8 +87,9 @@
 
 TEST_SUITE(Float)
 TEST_SUITE(FP16)
-FIXTURE_DATA_TEST_CASE(Run, GCDirectConvolutionLayerFixture<half_float::half>, framework::DatasetMode::ALL, combine(combine(data, framework::dataset::make("DataType", DataType::F16)),
-                                                                                                                    ActivationFunctionsDataset))
+FIXTURE_DATA_TEST_CASE(Run, GCDirectConvolutionLayerFixture<half_float::half>, framework::DatasetMode::ALL, combine(combine(combine(data, framework::dataset::make("DataType", DataType::F16)),
+                                                                                                                    ActivationFunctionsDataset),
+                                                                                                                    framework::dataset::make("DataLayout", DataLayout::NCHW)))
 {
     // Validate output
     validate(GCAccessor(_target), _reference, tolerance_fp16, tolerance_num);
@@ -96,8 +97,9 @@
 TEST_SUITE_END()
 
 TEST_SUITE(FP32)
-FIXTURE_DATA_TEST_CASE(Run, GCDirectConvolutionLayerFixture<float>, framework::DatasetMode::ALL, combine(combine(data, framework::dataset::make("DataType", DataType::F32)),
-                                                                                                         ActivationFunctionsDataset))
+FIXTURE_DATA_TEST_CASE(Run, GCDirectConvolutionLayerFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(data, framework::dataset::make("DataType", DataType::F32)),
+                                                                                                                 ActivationFunctionsDataset),
+                                                                                                         framework::dataset::make("DataLayout", DataLayout::NCHW)))
 {
     // Validate output
     validate(GCAccessor(_target), _reference, tolerance_fp32);
diff --git a/tests/validation/NEON/DirectConvolutionLayer.cpp b/tests/validation/NEON/DirectConvolutionLayer.cpp
index b6f9f62..f700758 100644
--- a/tests/validation/NEON/DirectConvolutionLayer.cpp
+++ b/tests/validation/NEON/DirectConvolutionLayer.cpp
@@ -49,39 +49,39 @@
 constexpr AbsoluteTolerance<float> tolerance_fp32(0.001f); /**< Tolerance for floating point tests */
 
 /** Direct convolution data set. */
-const auto data_pad_f32 = concat(concat(combine(framework::dataset::make("PadX", 0, 1),
-                                                combine(framework::dataset::make("PadY", 0, 1),
-                                                        framework::dataset::make("KernelSize", 1))),
-                                        combine(framework::dataset::make("PadX", 0, 2),
-                                                combine(framework::dataset::make("PadY", 0, 2),
+const auto data_pad_f32 = concat(concat(combine(framework::dataset::make("PadX", { 0, 1 }),
+                                                combine(framework::dataset::make("PadY", { 0, 1 }),
+                                                        framework::dataset::make("KernelSize", 3))),
+                                        combine(framework::dataset::make("PadX", { 0, 2 }),
+                                                combine(framework::dataset::make("PadY", { 0, 2 }),
                                                         framework::dataset::make("KernelSize", 3)))),
-                                 combine(framework::dataset::make("PadX", 0, 3),
-                                         combine(framework::dataset::make("PadY", 0, 3),
+                                 combine(framework::dataset::make("PadX", { 0, 3 }),
+                                         combine(framework::dataset::make("PadY", { 0, 3 }),
                                                  framework::dataset::make("KernelSize", 5))));
 
 const auto data_pad_qs8 = concat(combine(framework::dataset::make("PadX", 0),
                                          combine(framework::dataset::make("PadY", 0),
                                                  framework::dataset::make("KernelSize", 1))),
-                                 combine(framework::dataset::make("PadX", 0, 2),
-                                         combine(framework::dataset::make("PadY", 0, 2),
+                                 combine(framework::dataset::make("PadX", { 0, 2 }),
+                                         combine(framework::dataset::make("PadY", { 0, 2 }),
                                                  framework::dataset::make("KernelSize", 3))));
 
 const auto data_f32 = combine(datasets::SmallDirectConvolutionShapes(),
-                              combine(framework::dataset::make("StrideX", 1, 3),
-                                      combine(framework::dataset::make("StrideY", 1, 3),
+                              combine(framework::dataset::make("StrideX", { 1, 3 }),
+                                      combine(framework::dataset::make("StrideY", { 1, 3 }),
                                               combine(data_pad_f32,
                                                       framework::dataset::make("NumKernels", { 1, 4, 8, 16 })))));
 
 const auto data_qs8 = combine(datasets::TinyDirectConvolutionShapes(),
-                              combine(framework::dataset::make("StrideX", 1, 3),
-                                      combine(framework::dataset::make("StrideY", 1, 3),
+                              combine(framework::dataset::make("StrideX", { 1, 3 }),
+                                      combine(framework::dataset::make("StrideY", { 1, 3 }),
                                               combine(data_pad_qs8,
                                                       framework::dataset::make("NumKernels", { 1, 4, 8, 16 })))));
 
 /** Direct convolution QS16 data set. */
 const auto data_qs16 = combine(datasets::TinyDirectConvolutionShapes(),
-                               combine(framework::dataset::make("StrideX", 1, 3),
-                                       combine(framework::dataset::make("StrideY", 1, 3),
+                               combine(framework::dataset::make("StrideX", { 1, 3 }),
+                                       combine(framework::dataset::make("StrideY", { 1, 3 }),
                                                combine(framework::dataset::make("PadX", 0),
                                                        combine(framework::dataset::make("PadY", 0),
                                                                combine(framework::dataset::make("KernelSize", 1),
@@ -174,8 +174,9 @@
 TEST_SUITE(Float)
 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
 TEST_SUITE(FP16)
-FIXTURE_DATA_TEST_CASE(Run, NEDirectConvolutionLayerFixture<half>, framework::DatasetMode::ALL, combine(combine(data_f32, framework::dataset::make("DataType", DataType::F16)),
-                                                                                                        ActivationFunctionsDataset))
+FIXTURE_DATA_TEST_CASE(Run, NEDirectConvolutionLayerFixture<half>, framework::DatasetMode::ALL, combine(combine(combine(data_f32, framework::dataset::make("DataType", DataType::F16)),
+                                                                                                                ActivationFunctionsDataset),
+                                                                                                        framework::dataset::make("DataLayout", DataLayout::NCHW)))
 {
     // Validate output
     validate(Accessor(_target), _reference, tolerance_fp16);
@@ -184,8 +185,9 @@
 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
 
 TEST_SUITE(FP32)
-FIXTURE_DATA_TEST_CASE(Run, NEDirectConvolutionLayerFixture<float>, framework::DatasetMode::ALL, combine(combine(data_f32, framework::dataset::make("DataType", DataType::F32)),
-                                                                                                         ActivationFunctionsDataset))
+FIXTURE_DATA_TEST_CASE(Run, NEDirectConvolutionLayerFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(data_f32, framework::dataset::make("DataType", DataType::F32)),
+                                                                                                                 ActivationFunctionsDataset),
+                                                                                                         framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })))
 {
     // Validate output
     validate(Accessor(_target), _reference, tolerance_fp32);
diff --git a/tests/validation/fixtures/DirectConvolutionLayerFixture.h b/tests/validation/fixtures/DirectConvolutionLayerFixture.h
index ef7721d..9ea4061 100644
--- a/tests/validation/fixtures/DirectConvolutionLayerFixture.h
+++ b/tests/validation/fixtures/DirectConvolutionLayerFixture.h
@@ -21,8 +21,10 @@
  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  * SOFTWARE.
  */
+#include "arm_compute/core/Helpers.h"
 #include "arm_compute/core/TensorShape.h"
 #include "arm_compute/core/Types.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
 #include "tests/AssetsLibrary.h"
 #include "tests/Globals.h"
 #include "tests/IAccessor.h"
@@ -31,6 +33,7 @@
 #include "tests/validation/Helpers.h"
 #include "tests/validation/fixtures/ConvolutionLayerFixture.h"
 #include "tests/validation/reference/ConvolutionLayer.h"
+#include "tests/validation/reference/Permute.h"
 
 #include <random>
 
@@ -40,6 +43,8 @@
 {
 namespace validation
 {
+using namespace arm_compute::misc::shape_calculator;
+
 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
 class DirectConvolutionValidationGenericFixture : public framework::Fixture
 {
@@ -49,26 +54,42 @@
 public:
     template <typename...>
     void setup(TensorShape input_shape, int stride_x, int stride_y, int pad_x, int pad_y, unsigned int kernel_size, unsigned int num_kernels,
-               DataType data_type, int fractional_bits, QuantizationInfo quantization_info, ActivationLayerInfo act_info)
+               DataType data_type, int fractional_bits, QuantizationInfo quantization_info, ActivationLayerInfo act_info, DataLayout data_layout)
     {
+        ARM_COMPUTE_ERROR_ON(data_layout == DataLayout::UNKNOWN);
+
         _fractional_bits   = fractional_bits;
         _quantization_info = quantization_info;
         _data_type         = data_type;
 
-        const TensorShape   weights_shape(kernel_size, kernel_size, input_shape.z(), num_kernels);
+        TensorShape         weights_shape(kernel_size, kernel_size, input_shape.z(), num_kernels);
         const TensorShape   bias_shape(num_kernels);
         const PadStrideInfo info(stride_x, stride_y, pad_x, pad_y, DimensionRoundingType::FLOOR);
-        const TensorShape   output_shape   = get_output_shape(input_shape, weights_shape, info);
         const DataType      bias_data_type = is_data_type_quantized_asymmetric(data_type) ? DataType::S32 : data_type;
 
-        _target    = compute_target(input_shape, weights_shape, bias_shape, output_shape, info, data_type, bias_data_type, fractional_bits, quantization_info, act_info);
-        _reference = compute_reference(input_shape, weights_shape, bias_shape, output_shape, info, data_type, bias_data_type, fractional_bits, quantization_info, act_info);
+        if(data_layout == DataLayout::NHWC)
+        {
+            permute(input_shape, PermutationVector(2U, 0U, 1U));
+            permute(weights_shape, PermutationVector(2U, 0U, 1U));
+        }
+
+        TensorInfo input_info   = TensorInfo(input_shape, 1, data_type, _fractional_bits);
+        TensorInfo weights_info = TensorInfo(weights_shape, 1, data_type, _fractional_bits);
+
+        input_info.set_data_layout(data_layout);
+        weights_info.set_data_layout(data_layout);
+
+        const TensorShape output_shape = compute_deep_convolution_shape(input_info, weights_info, info);
+
+        _target    = compute_target(input_shape, weights_shape, bias_shape, output_shape, info, data_type, bias_data_type, fractional_bits, quantization_info, act_info, data_layout);
+        _reference = compute_reference(input_shape, weights_shape, bias_shape, output_shape, info, data_type, bias_data_type, fractional_bits, quantization_info, act_info, data_layout);
     }
 
     template <typename...>
     void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, PadStrideInfo info, Size2D dilation,
-               DataType data_type, int fractional_bits, QuantizationInfo quantization_info, ActivationLayerInfo act_info)
+               DataType data_type, int fractional_bits, QuantizationInfo quantization_info, ActivationLayerInfo act_info, DataLayout data_layout)
     {
+        ARM_COMPUTE_ERROR_ON(data_layout == DataLayout::UNKNOWN);
         ARM_COMPUTE_UNUSED(dilation);
 
         _fractional_bits   = fractional_bits;
@@ -77,8 +98,15 @@
 
         const DataType bias_data_type = is_data_type_quantized_asymmetric(data_type) ? DataType::S32 : data_type;
 
-        _target    = compute_target(input_shape, weights_shape, bias_shape, output_shape, info, data_type, bias_data_type, fractional_bits, quantization_info, act_info);
-        _reference = compute_reference(input_shape, weights_shape, bias_shape, output_shape, info, data_type, bias_data_type, fractional_bits, quantization_info, act_info);
+        if(data_layout == DataLayout::NHWC)
+        {
+            permute(input_shape, PermutationVector(2U, 0U, 1U));
+            permute(weights_shape, PermutationVector(2U, 0U, 1U));
+            permute(output_shape, PermutationVector(2U, 0U, 1U));
+        }
+
+        _target    = compute_target(input_shape, weights_shape, bias_shape, output_shape, info, data_type, bias_data_type, fractional_bits, quantization_info, act_info, data_layout);
+        _reference = compute_reference(input_shape, weights_shape, bias_shape, output_shape, info, data_type, bias_data_type, fractional_bits, quantization_info, act_info, data_layout);
     }
 
 protected:
@@ -112,13 +140,13 @@
     }
 
     TensorType compute_target(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape, const PadStrideInfo &info,
-                              DataType data_type, DataType bias_data_type, int fixed_point_position, QuantizationInfo quantization_info, ActivationLayerInfo act_info)
+                              DataType data_type, DataType bias_data_type, int fixed_point_position, QuantizationInfo quantization_info, ActivationLayerInfo act_info, const DataLayout &data_layout)
     {
         // Create tensors
-        TensorType src     = create_tensor<TensorType>(input_shape, data_type, 1, fixed_point_position, quantization_info);
-        TensorType weights = create_tensor<TensorType>(weights_shape, data_type, 1, fixed_point_position, quantization_info);
+        TensorType src     = create_tensor<TensorType>(input_shape, data_type, 1, fixed_point_position, quantization_info, data_layout);
+        TensorType weights = create_tensor<TensorType>(weights_shape, data_type, 1, fixed_point_position, quantization_info, data_layout);
         TensorType bias    = create_tensor<TensorType>(bias_shape, bias_data_type, 1, fixed_point_position, quantization_info);
-        TensorType dst     = create_tensor<TensorType>(output_shape, data_type, 1, fixed_point_position, quantization_info);
+        TensorType dst     = create_tensor<TensorType>(output_shape, data_type, 1, fixed_point_position, quantization_info, data_layout);
 
         // Create and configure function
         FunctionType conv;
@@ -152,11 +180,13 @@
     }
 
     SimpleTensor<T> compute_reference(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape, const PadStrideInfo &info,
-                                      DataType data_type, DataType bias_data_type, int fixed_point_position, QuantizationInfo quantization_info, ActivationLayerInfo act_info)
+                                      DataType data_type, DataType bias_data_type, int fixed_point_position, QuantizationInfo quantization_info, ActivationLayerInfo act_info, const DataLayout &data_layout)
     {
+        ARM_COMPUTE_ERROR_ON(data_layout == DataLayout::UNKNOWN);
+
         // Create reference
-        SimpleTensor<T>     src{ input_shape, data_type, 1, fixed_point_position, quantization_info };
-        SimpleTensor<T>     weights{ weights_shape, data_type, 1, fixed_point_position, quantization_info };
+        SimpleTensor<T>     src{ input_shape, data_type, 1, fixed_point_position, quantization_info, data_layout };
+        SimpleTensor<T>     weights{ weights_shape, data_type, 1, fixed_point_position, quantization_info, data_layout };
         SimpleTensor<TBias> bias{ bias_shape, bias_data_type, 1, fixed_point_position, quantization_info };
 
         // Fill reference
@@ -164,9 +194,25 @@
         fill(weights, 1);
         fill(bias, 2);
 
-        return (act_info.enabled()) ? reference::activation_layer<T>(reference::convolution_layer<T>(src, weights, bias, output_shape, info),
-                                                                     act_info) :
-               reference::convolution_layer<T>(src, weights, bias, output_shape, info);
+        SimpleTensor<T> dst;
+
+        // FIXME: move to reference once all functions that call reference::convolution_layer<>() support NHWC
+        if(src.data_layout() == DataLayout::NHWC)
+        {
+            SimpleTensor<T> src_nchw     = reference::permute<T>(src, PermutationVector(1U, 2U, 0U));
+            SimpleTensor<T> weights_nchw = reference::permute<T>(weights, PermutationVector(1U, 2U, 0U));
+
+            TensorShape output_shape_nchw{ output_shape };
+            permute(output_shape_nchw, PermutationVector(1U, 2U, 0U));
+
+            dst = reference::permute<T>(reference::convolution_layer<T>(src_nchw, weights_nchw, bias, output_shape_nchw, info), PermutationVector(2U, 0U, 1U));
+        }
+        else
+        {
+            dst = reference::convolution_layer<T>(src, weights, bias, output_shape, info);
+        }
+
+        return (act_info.enabled()) ? reference::activation_layer<T>(dst, act_info) : dst;
     }
 
     TensorType       _target{};
@@ -174,21 +220,6 @@
     int              _fractional_bits{};
     QuantizationInfo _quantization_info{};
     DataType         _data_type{};
-
-private:
-    TensorShape get_output_shape(TensorShape in_shape, TensorShape kernel_shape, const PadStrideInfo &info)
-    {
-        TensorShape out_shape(in_shape);
-        const std::pair<unsigned int, unsigned int> scaled_dims = scaled_dimensions(in_shape.x(),
-                                                                                    in_shape.y(),
-                                                                                    kernel_shape.x(),
-                                                                                    kernel_shape.y(),
-                                                                                    info);
-        out_shape.set(0, scaled_dims.first);
-        out_shape.set(1, scaled_dims.second);
-        out_shape.set(2, kernel_shape[3]);
-        return out_shape;
-    }
 };
 
 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
@@ -196,10 +227,11 @@
 {
 public:
     template <typename...>
-    void setup(TensorShape input_shape, int stride_x, int stride_y, int pad_x, int pad_y, unsigned int kernel_size, unsigned int num_kernels, DataType data_type, ActivationLayerInfo act_info)
+    void setup(TensorShape input_shape, int stride_x, int stride_y, int pad_x, int pad_y, unsigned int kernel_size, unsigned int num_kernels, DataType data_type, ActivationLayerInfo act_info,
+               DataLayout data_layout)
     {
         DirectConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, stride_x, stride_y, pad_x, pad_y, kernel_size, num_kernels, data_type, 0, QuantizationInfo(),
-                                                                                                    act_info);
+                                                                                                    act_info, data_layout);
     }
 };
 
@@ -212,7 +244,7 @@
                ActivationLayerInfo act_info)
     {
         DirectConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, stride_x, stride_y, pad_x, pad_y, kernel_size, num_kernels, data_type, fractional_bits,
-                                                                                                    QuantizationInfo(), act_info);
+                                                                                                    QuantizationInfo(), act_info, DataLayout::NCHW);
     }
 };
 
@@ -225,7 +257,7 @@
                ActivationLayerInfo act_info)
     {
         DirectConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, stride_x, stride_y, pad_x, pad_y, kernel_size, num_kernels, data_type, 0, quantization_info,
-                                                                                                    act_info);
+                                                                                                    act_info, DataLayout::NCHW);
     }
 };
 
@@ -238,7 +270,7 @@
                DataType data_type, QuantizationInfo quantization_info, ActivationLayerInfo act_info)
     {
         DirectConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, weights_shape, bias_shape, output_shape, info, dilation, data_type, 0, quantization_info,
-                                                                                                    act_info);
+                                                                                                    act_info, DataLayout::NCHW);
     }
 };
 
@@ -251,7 +283,7 @@
                DataType data_type, ActivationLayerInfo act_info)
     {
         DirectConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, weights_shape, bias_shape, output_shape, info, dilation, data_type, 0, QuantizationInfo(),
-                                                                                                    act_info);
+                                                                                                    act_info, DataLayout::NCHW);
     }
 };
 
diff --git a/tests/validation/reference/Permute.cpp b/tests/validation/reference/Permute.cpp
index db347e5..c670c3e 100644
--- a/tests/validation/reference/Permute.cpp
+++ b/tests/validation/reference/Permute.cpp
@@ -60,6 +60,8 @@
 template SimpleTensor<uint8_t> permute(const SimpleTensor<uint8_t> &src, PermutationVector perm);
 template SimpleTensor<uint16_t> permute(const SimpleTensor<uint16_t> &src, PermutationVector perm);
 template SimpleTensor<uint32_t> permute(const SimpleTensor<uint32_t> &src, PermutationVector perm);
+template SimpleTensor<int8_t> permute(const SimpleTensor<int8_t> &src, PermutationVector perm);
+template SimpleTensor<int16_t> permute(const SimpleTensor<int16_t> &src, PermutationVector perm);
 template SimpleTensor<float> permute(const SimpleTensor<float> &src, PermutationVector perm);
 template SimpleTensor<half> permute(const SimpleTensor<half> &src, PermutationVector perm);
 } // namespace reference