COMPMID-1277 - Optimizing CLIm2ColKernel for NHWC.

This patch includes:

- Im2Col optimizations for NHWC using a new data layout
- Refactoring of CLIm2ColKernel adding validation method and auto-init
- Removed im2col_reduced from CLIm2ColKernel and created a new kernel CLFlattenLayerKernel

Change-Id: I1620640b6796baa268324b33ae92cdd8de53e27c
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/141241
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Giorgio Arena <giorgio.arena@arm.com>
diff --git a/tests/validation/CL/Im2Col.cpp b/tests/validation/CL/Im2Col.cpp
index 9422fcc..291befa 100644
--- a/tests/validation/CL/Im2Col.cpp
+++ b/tests/validation/CL/Im2Col.cpp
@@ -41,8 +41,18 @@
 {
 namespace
 {
-const auto conv_filter_sizes = framework::dataset::make("KernelDims", { Size2D(3U, 3U), Size2D(3U, 1U), Size2D(1U, 5U), Size2D(5U, 5U), Size2D(7U, 7U) });
-const auto padstrides        = framework::dataset::make("PadStride", { PadStrideInfo(1U, 1U, 0U, 0U), PadStrideInfo(1U, 1U, 1U, 1U), PadStrideInfo(2U, 2U, 0U, 2U) });
+// *INDENT-OFF*
+// clang-format off
+const auto conv_filter_sizes = framework::dataset::make("KernelDims", { Size2D(3U, 3U),
+                                                                        Size2D(5U, 5U),
+                                                                        Size2D(3U, 1U),
+                                                                        Size2D(1U, 3U),
+                                                                        Size2D(5U, 3U),
+                                                                        Size2D(1U, 1U),
+                                                                        Size2D(11U, 11U)} );
+const auto padstrides        = framework::dataset::make("PadStride", { PadStrideInfo(1U, 1U, 0U, 0U),
+                                                                       PadStrideInfo(1U, 1U, 1U, 1U),
+                                                                       PadStrideInfo(2U, 2U, 0U, 2U) });
 const auto conv_args         = combine(combine(combine(conv_filter_sizes, padstrides),
                                                framework::dataset::make("QuantizationInfo", QuantizationInfo(0.5f, 10))),
                                        framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC }));
@@ -53,23 +63,19 @@
 
 using CLIm2Col = CLSynthetizeFunction<CLIm2ColKernel>;
 
-// *INDENT-OFF*
-// clang-format off
 DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(
                framework::dataset::make("InputInfo", { TensorInfo(TensorShape(10U, 12U, 2U), 1, DataType::U8),      // Unsupported data type
                                                        TensorInfo(TensorShape(10U, 12U, 2U), 1, DataType::F32),     // Mismatching data type
                                                        TensorInfo(TensorShape(10U, 12U, 2U), 1, DataType::QASYMM8), // Bias not supported with QASYMM8
-                                                       TensorInfo(TensorShape(10U, 12U, 2U), 1, DataType::QASYMM8), // Mismatching shapes
                                                        TensorInfo(TensorShape(10U, 12U, 2U, 2U), 1, DataType::QASYMM8),
                                                      }),
                framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(3U, 4U, 10U, 2U), 1, DataType::F16),
                                                        TensorInfo(TensorShape(3U, 4U, 10U, 2U), 1, DataType::F16),
                                                        TensorInfo(TensorShape(3U, 3U, 10U, 2U), 1, DataType::QASYMM8),
-                                                       TensorInfo(TensorShape(3U, 4U, 10U, 2U), 1, DataType::QASYMM8),
-                                                       TensorInfo(TensorShape(18U, 80U, 1U, 2U), 1, DataType::QASYMM8),
+                                                       TensorInfo(TensorShape(18U, 80U, 2U, 1U), 1, DataType::QASYMM8),
                                                      })),
-               framework::dataset::make("HasBias", { true, true, true, false, false })),
-               framework::dataset::make("Expected", { false, false, false, true, true })),
+               framework::dataset::make("HasBias", { true, true, true, false })),
+               framework::dataset::make("Expected", { false, false, false, true })),
                input_info, output_info, has_bias, expected)
 {
 
@@ -83,16 +89,18 @@
 using CLIm2ColFixture = Im2ColValidationFixture<CLTensor, CLAccessor, CLIm2Col, T, true>;
 TEST_SUITE(Float)
 TEST_SUITE(FP32)
-FIXTURE_DATA_TEST_CASE(RunSmall, CLIm2ColFixture<float>, framework::DatasetMode::ALL, combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::F32)),
-                                                                                              conv_args))
+FIXTURE_DATA_TEST_CASE(RunSmall, CLIm2ColFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::F32)),
+                                                                                                      conv_args),
+                                                                                              framework::dataset::make("ChannelsFirstOutputNHWC", true)))
 {
     // Validate output
     validate(CLAccessor(_target), _reference);
 }
 TEST_SUITE_END()
 
-FIXTURE_DATA_TEST_CASE(RunLarge, CLIm2ColFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::F32)),
-                                                                                                  conv_args))
+FIXTURE_DATA_TEST_CASE(RunLarge, CLIm2ColFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::F32)),
+                                                                                                          conv_args),
+                                                                                                  framework::dataset::make("ChannelsFirstOutputNHWC", true)))
 {
     // Validate output
     validate(CLAccessor(_target), _reference);
@@ -101,14 +109,16 @@
 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
 
 TEST_SUITE(FP16)
-FIXTURE_DATA_TEST_CASE(RunSmall, CLIm2ColFixture<half>, framework::DatasetMode::ALL, combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::F16)),
-                                                                                             conv_args))
+FIXTURE_DATA_TEST_CASE(RunSmall, CLIm2ColFixture<half>, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::F16)),
+                                                                                                     conv_args),
+                                                                                             framework::dataset::make("ChannelsFirstOutputNHWC", true)))
 {
     // Validate output
     validate(CLAccessor(_target), _reference);
 }
-FIXTURE_DATA_TEST_CASE(RunLarge, CLIm2ColFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::F16)),
-                                                                                                 conv_args))
+FIXTURE_DATA_TEST_CASE(RunLarge, CLIm2ColFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::F16)),
+                                                                                                         conv_args),
+                                                                                                 framework::dataset::make("ChannelsFirstOutputNHWC", true)))
 {
     // Validate output
     validate(CLAccessor(_target), _reference);
@@ -120,14 +130,16 @@
 TEST_SUITE_END()
 
 TEST_SUITE(QASYMM8)
-FIXTURE_DATA_TEST_CASE(RunSmall, CLIm2ColFixture<uint8_t>, framework::DatasetMode::ALL, combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::QASYMM8)),
-                                                                                                conv_args))
+FIXTURE_DATA_TEST_CASE(RunSmall, CLIm2ColFixture<uint8_t>, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::QASYMM8)),
+                                                                                                        conv_args),
+                                                                                                framework::dataset::make("ChannelsFirstOutputNHWC", true)))
 {
     // Validate output
     validate(CLAccessor(_target), _reference);
 }
-FIXTURE_DATA_TEST_CASE(RunLarge, CLIm2ColFixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::QASYMM8)),
-                                                                                                    conv_args))
+FIXTURE_DATA_TEST_CASE(RunLarge, CLIm2ColFixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::QASYMM8)),
+                                                                                                            conv_args),
+                                                                                                    framework::dataset::make("ChannelsFirstOutputNHWC", false)))
 {
     // Validate output
     validate(CLAccessor(_target), _reference);
diff --git a/tests/validation/CL/LocallyConnected.cpp b/tests/validation/CL/LocallyConnected.cpp
index 5381072..dbfe4e2 100644
--- a/tests/validation/CL/LocallyConnected.cpp
+++ b/tests/validation/CL/LocallyConnected.cpp
@@ -59,6 +59,7 @@
                                              TensorInfo(TensorShape(23U, 27U, 5U), 1, DataType::F32), // Mismatching shape input/bias
                                              TensorInfo(TensorShape(23U, 27U, 5U), 1, DataType::F32), // Mismatching shape input/output
                                              TensorInfo(TensorShape(23U, 27U, 5U), 1, DataType::F32), // Asymmetric padding
+                                             TensorInfo(TensorShape(23U, 27U, 5U), 1, DataType::F32), // Padding required
                                              TensorInfo(TensorShape(23U, 27U, 5U), 1, DataType::F32)
                                            }),
     framework::dataset::make("WeightsInfo",{ TensorInfo(TensorShape(3U, 3U, 5U, 21U, 275U), 1, DataType::F16),
@@ -68,7 +69,8 @@
                                              TensorInfo(TensorShape(3U, 3U, 5U, 21U, 275U), 1, DataType::F32),
                                              TensorInfo(TensorShape(3U, 3U, 5U, 21U, 275U), 1, DataType::F32),
                                              TensorInfo(TensorShape(3U, 3U, 5U, 21U, 275U), 1, DataType::F32),
-                                             TensorInfo(TensorShape(3U, 3U, 5U, 21U, 275U), 1, DataType::F32)
+                                             TensorInfo(TensorShape(3U, 3U, 5U, 21U, 275U), 1, DataType::F32),
+                                             TensorInfo(TensorShape(1U, 3U, 5U, 21U, 575U), 1, DataType::F32)
                                            })),
     framework::dataset::make("BiasInfo",   { TensorInfo(TensorShape(21U, 275U), 1, DataType::F32),
                                              TensorInfo(TensorShape(21U, 275U), 1, DataType::F16),
@@ -77,7 +79,8 @@
                                              TensorInfo(TensorShape(21U, 274U), 1, DataType::F32),
                                              TensorInfo(TensorShape(21U, 275U), 1, DataType::F32),
                                              TensorInfo(TensorShape(21U, 275U), 1, DataType::F32),
-                                             TensorInfo(TensorShape(21U, 275U), 1, DataType::F32)
+                                             TensorInfo(TensorShape(21U, 275U), 1, DataType::F32),
+                                             TensorInfo(TensorShape(21U, 575U), 1, DataType::F32)
                                            })),
     framework::dataset::make("OutputInfo", { TensorInfo(TensorShape(11U, 25U, 21U), 1, DataType::F32),
                                              TensorInfo(TensorShape(11U, 25U, 21U), 1, DataType::F32),
@@ -86,7 +89,8 @@
                                              TensorInfo(TensorShape(11U, 25U, 21U), 1, DataType::F32),
                                              TensorInfo(TensorShape(11U, 25U, 22U), 1, DataType::F32),
                                              TensorInfo(TensorShape(11U, 25U, 21U), 1, DataType::F32),
-                                             TensorInfo(TensorShape(11U, 25U, 21U), 1, DataType::F32)
+                                             TensorInfo(TensorShape(11U, 25U, 21U), 1, DataType::F32),
+                                             TensorInfo(TensorShape(23U, 25U, 21U), 1, DataType::F32)
                                            })),
     framework::dataset::make("PadStride",  { PadStrideInfo(2, 1, 0, 0),
                                              PadStrideInfo(2, 1, 0, 0),
@@ -94,10 +98,11 @@
                                              PadStrideInfo(2, 1, 0, 0),
                                              PadStrideInfo(2, 1, 0, 0),
                                              PadStrideInfo(2, 1, 0, 0),
-                                             PadStrideInfo(2, 1, 1, 0, 0, 0, DimensionRoundingType::FLOOR),
-                                             PadStrideInfo(2, 1, 0, 0)
+                                             PadStrideInfo(2, 1, 1, 0),
+                                             PadStrideInfo(2, 1, 0, 0),
+                                             PadStrideInfo(1, 1, 0, 0)
                                            })),
-    framework::dataset::make("Expected", { false, false, false, false, false, false, false, true })),
+    framework::dataset::make("Expected", { false, false, false, false, false, false, false, false, true })),
     input_info, weights_info, bias_info, output_info, conv_info, expected)
 {
     bool is_valid = bool(CLLocallyConnectedLayer::validate(&input_info.clone()->set_is_resizable(false),
diff --git a/tests/validation/NEON/Im2Col.cpp b/tests/validation/NEON/Im2Col.cpp
index bff8634..f011ebe 100644
--- a/tests/validation/NEON/Im2Col.cpp
+++ b/tests/validation/NEON/Im2Col.cpp
@@ -77,14 +77,16 @@
 
 TEST_SUITE(Float)
 TEST_SUITE(FP32)
-FIXTURE_DATA_TEST_CASE(RunSmall, NEIm2ColFixture<float>, framework::DatasetMode::ALL, combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::F32)),
-                                                                                              conv_args))
+FIXTURE_DATA_TEST_CASE(RunSmall, NEIm2ColFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::F32)),
+                                                                                                      conv_args),
+                                                                                              framework::dataset::make("ChannelsFirstOutputNHWC", false)))
 {
     // Validate output
     validate(Accessor(_target), _reference);
 }
-FIXTURE_DATA_TEST_CASE(RunLarge, NEIm2ColFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::F32)),
-                                                                                                  conv_args))
+FIXTURE_DATA_TEST_CASE(RunLarge, NEIm2ColFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::F32)),
+                                                                                                          conv_args),
+                                                                                                  framework::dataset::make("ChannelsFirstOutputNHWC", false)))
 {
     // Validate output
     validate(Accessor(_target), _reference);
@@ -94,14 +96,16 @@
 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
 
 TEST_SUITE(FP16)
-FIXTURE_DATA_TEST_CASE(RunSmall, NEIm2ColFixture<half>, framework::DatasetMode::ALL, combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::F16)),
-                                                                                             conv_args))
+FIXTURE_DATA_TEST_CASE(RunSmall, NEIm2ColFixture<half>, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::F16)),
+                                                                                                     conv_args),
+                                                                                             framework::dataset::make("ChannelsFirstOutputNHWC", false)))
 {
     // Validate output
     validate(Accessor(_target), _reference);
 }
-FIXTURE_DATA_TEST_CASE(RunLarge, NEIm2ColFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::F16)),
-                                                                                                 conv_args))
+FIXTURE_DATA_TEST_CASE(RunLarge, NEIm2ColFixture<half>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::F16)),
+                                                                                                         conv_args),
+                                                                                                 framework::dataset::make("ChannelsFirstOutputNHWC", false)))
 {
     // Validate output
     validate(Accessor(_target), _reference);
@@ -113,14 +117,16 @@
 TEST_SUITE_END()
 
 TEST_SUITE(QASYMM8)
-FIXTURE_DATA_TEST_CASE(RunSmall, NEIm2ColFixture<uint8_t>, framework::DatasetMode::ALL, combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::QASYMM8)),
-                                                                                                conv_args))
+FIXTURE_DATA_TEST_CASE(RunSmall, NEIm2ColFixture<uint8_t>, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::QASYMM8)),
+                                                                                                        conv_args),
+                                                                                                framework::dataset::make("ChannelsFirstOutputNHWC", false)))
 {
     // Validate output
     validate(Accessor(_target), _reference);
 }
-FIXTURE_DATA_TEST_CASE(RunLarge, NEIm2ColFixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::QASYMM8)),
-                                                                                                    conv_args))
+FIXTURE_DATA_TEST_CASE(RunLarge, NEIm2ColFixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::QASYMM8)),
+                                                                                                            conv_args),
+                                                                                                    framework::dataset::make("ChannelsFirstOutputNHWC", false)))
 {
     // Validate output
     validate(Accessor(_target), _reference);
diff --git a/tests/validation/fixtures/FlattenLayerFixture.h b/tests/validation/fixtures/FlattenLayerFixture.h
index f273e93..d170806 100644
--- a/tests/validation/fixtures/FlattenLayerFixture.h
+++ b/tests/validation/fixtures/FlattenLayerFixture.h
@@ -55,7 +55,7 @@
     {
         TensorShape shape_flatten;
         TensorInfo  input_info(shape, 1, data_type);
-        shape_flatten = compute_im2col_flatten_shape(&input_info);
+        shape_flatten = compute_flatten_shape(&input_info);
 
         _target    = compute_target(shape, shape_flatten, data_type);
         _reference = compute_reference(shape, shape_flatten, data_type);
diff --git a/tests/validation/fixtures/Im2ColFixture.h b/tests/validation/fixtures/Im2ColFixture.h
index f72e38f..da2576b 100644
--- a/tests/validation/fixtures/Im2ColFixture.h
+++ b/tests/validation/fixtures/Im2ColFixture.h
@@ -49,7 +49,8 @@
 {
 public:
     template <typename...>
-    void setup(TensorShape input_shape, DataType data_type, const Size2D &kernel_dims, const PadStrideInfo &conv_info, const QuantizationInfo &quant_info, const DataLayout &data_layout)
+    void setup(TensorShape input_shape, DataType data_type, const Size2D &kernel_dims, const PadStrideInfo &conv_info, const QuantizationInfo &quant_info, const DataLayout &data_layout,
+               bool channels_first_output_nhwc)
     {
         _kernel_dims = kernel_dims;
         _conv_info   = conv_info;
@@ -68,7 +69,7 @@
         const TensorShape output_shape = compute_im2col_conv_shape(&input_info, _kernel_dims, _conv_info, _has_bias, Size2D(1U, 1U), batch_size_on_z);
         _target                        = compute_target(input_shape, output_shape, data_type);
 
-        compute_reference(input_shape, output_shape, data_type);
+        compute_reference(input_shape, output_shape, data_type, channels_first_output_nhwc);
     }
 
 protected:
@@ -107,14 +108,16 @@
         return dst;
     }
 
-    void compute_reference(const TensorShape &input_shape, const TensorShape &output_shape, DataType data_type)
+    void compute_reference(const TensorShape &input_shape, const TensorShape &output_shape, DataType data_type, bool channels_first_output_nhwc)
     {
         // Create reference
         SimpleTensor<T> src{ input_shape, data_type, 1, _quant_info, _data_layout };
         _reference = SimpleTensor<T>(output_shape, data_type, 1, _quant_info, DataLayout::NCHW);
+
         // Fill reference
         fill(src);
-        reference::im2col<T>(src, _reference, _kernel_dims, _conv_info, _has_bias);
+
+        reference::im2col<T>(src, _reference, _kernel_dims, _conv_info, _has_bias, channels_first_output_nhwc);
     }
     TensorType       _target{};
     SimpleTensor<T>  _reference{};
diff --git a/tests/validation/reference/Im2Col.cpp b/tests/validation/reference/Im2Col.cpp
index 83ef8b4..2459499 100644
--- a/tests/validation/reference/Im2Col.cpp
+++ b/tests/validation/reference/Im2Col.cpp
@@ -23,8 +23,6 @@
  */
 #include "Im2Col.h"
 
-#include "Permute.h"
-
 #include "arm_compute/core/Types.h"
 #include "tests/validation/Helpers.h"
 #include "tests/validation/reference/Utils.h"
@@ -41,46 +39,45 @@
 void im2col_nchw(const SimpleTensor<T> &src, SimpleTensor<T> &dst, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias)
 {
     ARM_COMPUTE_ERROR_ON(src.data_layout() != DataLayout::NCHW);
-    // Create reference
-    const int pad_x         = conv_info.pad().first;
-    const int pad_y         = conv_info.pad().second;
     const int stride_x      = conv_info.stride().first;
     const int stride_y      = conv_info.stride().second;
     const int kernel_width  = kernel_dims.width;
     const int kernel_height = kernel_dims.height;
+    const int pad_x         = conv_info.pad().first;
+    const int pad_y         = conv_info.pad().second;
     const int src_width     = src.shape().x();
     const int src_height    = src.shape().y();
-    const int src_depth     = src.shape().z();
+    const int src_channels  = src.shape().z();
     const int batches       = src.shape().total_size_upper(3);
+    const int dst_height    = dst.shape().y();
     const int pad_val       = is_data_type_quantized_asymmetric(src.data_type()) ? src.quantization_info().offset : 0;
+    int       dst_idx       = 0;
 
-    int dst_idx = 0;
-    // dst[dst_idx++] will write out of bounds if kernel_height == kernel_width == 1 because lasty will be the bottom padding row
-    // and this is not present in the dst buffer
-    const int lasty = src_height + (kernel_height > 1 ? pad_y : 0) - kernel_height;
-    const int lastx = src_width + (kernel_width > 1 ? pad_x : 0) - kernel_width;
+    // Compute width and height of the convolved tensors
+    std::pair<unsigned int, unsigned int> convolved_dims = scaled_dimensions(src_width, src_height, kernel_dims.width, kernel_dims.height, conv_info);
 
     for(int b = 0; b < batches; ++b)
     {
-        for(int y = -pad_y; y <= lasty; y += stride_y)
+        for(int yo = 0; yo < dst_height; ++yo)
         {
-            for(int x = -pad_x; x <= lastx; x += stride_x)
+            // Compute input spatial coordinates
+            const int xi = (yo % convolved_dims.first) * stride_x;
+            const int yi = (yo / convolved_dims.first) * stride_y;
+
+            for(int ci = 0; ci < src_channels; ++ci)
             {
-                for(int z = 0; z < src_depth; ++z)
+                for(int yk = 0; yk < kernel_height; ++yk)
                 {
-                    for(int patch_y = y; patch_y < (y + kernel_height); ++patch_y)
+                    for(int xk = 0; xk < kernel_width; ++xk)
                     {
-                        for(int patch_x = x; patch_x < (x + kernel_width); ++patch_x)
-                        {
-                            dst[dst_idx++] = tensor_elem_at(src, Coordinates(patch_x, patch_y, z, b), BorderMode::CONSTANT, static_cast<T>(pad_val));
-                        }
+                        dst[dst_idx++] = tensor_elem_at(src, Coordinates(xi + xk - pad_x, yi + yk - pad_y, ci, b), BorderMode::CONSTANT, static_cast<T>(pad_val));
                     }
                 }
+            }
 
-                if(has_bias)
-                {
-                    dst[dst_idx++] = static_cast<T>(1);
-                }
+            if(has_bias)
+            {
+                dst[dst_idx++] = static_cast<T>(1);
             }
         }
     }
@@ -133,7 +130,56 @@
 }
 
 template <typename T>
-void im2col(const SimpleTensor<T> &src, SimpleTensor<T> &dst, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias)
+void im2col_nhwc_channel_first(const SimpleTensor<T> &src, SimpleTensor<T> &dst, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias)
+{
+    ARM_COMPUTE_ERROR_ON(src.data_layout() != DataLayout::NHWC);
+    const int stride_x      = conv_info.stride().first;
+    const int stride_y      = conv_info.stride().second;
+    const int kernel_width  = kernel_dims.width;
+    const int kernel_height = kernel_dims.height;
+    const int pad_x         = conv_info.pad().first;
+    const int pad_y         = conv_info.pad().second;
+    const int src_width     = src.shape().y();
+    const int src_height    = src.shape().z();
+    const int src_channels  = src.shape().x();
+    const int batches       = src.shape().total_size_upper(3);
+    const int dst_width     = has_bias ? dst.shape().x() - 1 : dst.shape().x();
+    const int dst_height    = dst.shape().y();
+    const int pad_val       = is_data_type_quantized_asymmetric(src.data_type()) ? src.quantization_info().offset : 0;
+
+    // Compute width and height of the convolved tensors
+    std::pair<unsigned int, unsigned int> convolved_dims = scaled_dimensions(src_width, src_height, kernel_dims.width, kernel_dims.height, conv_info);
+
+    for(int b = 0; b < batches; ++b)
+    {
+        for(int yo = 0; yo < dst_height; ++yo)
+        {
+            // Compute input spatial coordinates
+            const int xi = (yo % convolved_dims.first) * stride_x;
+            const int yi = (yo / convolved_dims.first) * stride_y;
+
+            for(int ci = 0; ci < src_channels; ++ci)
+            {
+                for(int yk = 0; yk < kernel_height; ++yk)
+                {
+                    for(int xk = 0; xk < kernel_width; ++xk)
+                    {
+                        dst[ci + (xk + yk * kernel_width) * src_channels + yo * dst.shape().x() + b * dst.shape().x() * dst.shape().y()] = tensor_elem_at(src, Coordinates(ci, xi + xk - pad_x, yi + yk - pad_y, b),
+                                                                                                                                           BorderMode::CONSTANT, static_cast<T>(pad_val));
+                    }
+                }
+            }
+
+            if(has_bias)
+            {
+                dst[dst_width + yo * dst.shape().x() + b * dst.shape().x() * dst.shape().y()] = static_cast<T>(1);
+            }
+        }
+    }
+}
+
+template <typename T>
+void im2col(const SimpleTensor<T> &src, SimpleTensor<T> &dst, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, bool channels_first_output_nhwc)
 {
     switch(src.data_layout())
     {
@@ -144,7 +190,14 @@
         }
         case DataLayout::NHWC:
         {
-            im2col_nhwc(src, dst, kernel_dims, conv_info, has_bias);
+            if(channels_first_output_nhwc)
+            {
+                im2col_nhwc_channel_first(src, dst, kernel_dims, conv_info, has_bias);
+            }
+            else
+            {
+                im2col_nhwc(src, dst, kernel_dims, conv_info, has_bias);
+            }
             break;
         }
         default:
@@ -155,9 +208,9 @@
     }
 }
 
-template void im2col(const SimpleTensor<uint8_t> &src, SimpleTensor<uint8_t> &dst, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias);
-template void im2col(const SimpleTensor<half> &src, SimpleTensor<half> &dst, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias);
-template void im2col(const SimpleTensor<float> &src, SimpleTensor<float> &dst, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias);
+template void im2col(const SimpleTensor<uint8_t> &src, SimpleTensor<uint8_t> &dst, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, bool channels_first_output_nhwc);
+template void im2col(const SimpleTensor<half> &src, SimpleTensor<half> &dst, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, bool channels_first_output_nhwc);
+template void im2col(const SimpleTensor<float> &src, SimpleTensor<float> &dst, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, bool channels_first_output_nhwc);
 } // namespace reference
 } // namespace validation
 } // namespace test
diff --git a/tests/validation/reference/Im2Col.h b/tests/validation/reference/Im2Col.h
index 5277171..b1ebaf2 100644
--- a/tests/validation/reference/Im2Col.h
+++ b/tests/validation/reference/Im2Col.h
@@ -35,7 +35,7 @@
 namespace reference
 {
 template <typename T>
-void im2col(const SimpleTensor<T> &src, SimpleTensor<T> &dst, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias);
+void im2col(const SimpleTensor<T> &src, SimpleTensor<T> &dst, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, bool channels_first_output_nhwc = false);
 } // namespace reference
 } // namespace validation
 } // namespace test