Fix im2col for fast-maths mode with padding.

Following the investigation proposed by ONCPUML-1193, padding
is implemented in im2col when the input channel is not a multiple of
blocks requested by the weight format.

Partially resolves: ONCPUML-1193

Signed-off-by: Renato Arantes <renato.arantes@arm.com>
Change-Id: I350c7a1b2dcae63f8d94f5b6f1f86e948eab1f09
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9508
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Viet-Hoa Do <viet-hoa.do@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h
index 9b1ebf6..f935265 100644
--- a/arm_compute/core/utils/misc/ShapeCalculator.h
+++ b/arm_compute/core/utils/misc/ShapeCalculator.h
@@ -432,8 +432,8 @@
     const int        weights_width_idx   = get_data_layout_dimension_index(weights_data_layout, DataLayoutDimension::WIDTH);
     const int        weights_height_idx  = get_data_layout_dimension_index(weights_data_layout, DataLayoutDimension::HEIGHT);
 
-    unsigned int output_width  = 0;
-    unsigned int output_height = 0;
+    unsigned int output_width             = 0;
+    unsigned int output_height            = 0;
     std::tie(output_width, output_height) = scaled_dimensions(input_shape[width_idx], input_shape[height_idx],
                                                               weights_shape[weights_width_idx], weights_shape[weights_height_idx],
                                                               info.pad_stride_info, info.dilation);
@@ -517,11 +517,12 @@
  * @param[in] dilation        Dilation, in elements, across x and y
  * @param[in] batch_size_on_z True if batch size is on z axis
  * @param[in] num_groups      (Optional)  Number of groups when performing a grouped convolution
+ * @param[in] input_pad_right (Optional) When fast-math is selected, per element padding for the im2col matrix may be necessary
  *
  * @return the calculated shape
  */
 inline TensorShape compute_im2col_conv_shape(const ITensorInfo *input, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias, const Size2D &dilation, bool batch_size_on_z,
-                                             unsigned int num_groups = 1)
+                                             unsigned int num_groups = 1, unsigned int input_pad_right = 0)
 {
     // The output shape will be the 3D shape [ out_channels * kernel_area, num_elems_per_out_channel, batches ]                           if batch_size_on_z == true
     //                       or the 4D shape [ out_channels * kernel_area / num_groups, num_elems_per_out_channel, num_groups, batches ]  if batch_size_on_z == false
@@ -538,7 +539,7 @@
     const int        channel_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL);
 
     std::pair<unsigned int, unsigned int> out_dims = scaled_dimensions(output_shape[width_idx], output_shape[height_idx], kernel_dims.width, kernel_dims.height, conv_info, dilation);
-    output_shape.set(0, (output_shape[channel_idx] / num_groups * kernel_dims.area() + (has_bias ? 1 : 0))); // NOLINT
+    output_shape.set(0, ((output_shape[channel_idx] + input_pad_right) / num_groups * kernel_dims.area() + (has_bias ? 1 : 0))); // NOLINT
     output_shape.set(1, (out_dims.first * out_dims.second));
     if(batch_size_on_z && output_shape.num_dimensions() >= 3)
     {
@@ -682,8 +683,8 @@
     const DataLayout    data_layout      = winograd_info.output_data_layout;
 
     // Compute output shape
-    unsigned int output_width  = 0;
-    unsigned int output_height = 0;
+    unsigned int output_width             = 0;
+    unsigned int output_height            = 0;
     std::tie(output_width, output_height) = scaled_dimensions(input_dimensions.width, input_dimensions.height,
                                                               kernel_size.width, kernel_size.height, conv_info);
 
@@ -723,7 +724,7 @@
     const unsigned int weights_out_channel = weights_shape[3];
     unsigned int       output_width        = 0;
     unsigned int       output_height       = 0;
-    std::tie(output_width, output_height) = scaled_dimensions(input_width, input_height, weights_width, weights_height, conv_info);
+    std::tie(output_width, output_height)  = scaled_dimensions(input_width, input_height, weights_width, weights_height, conv_info);
 
     TensorShape output_shape{ input_shape };
     output_shape.set(idx_width, output_width);
diff --git a/src/cpu/kernels/CpuIm2ColKernel.cpp b/src/cpu/kernels/CpuIm2ColKernel.cpp
index 25ff6c2..9ac2915 100644
--- a/src/cpu/kernels/CpuIm2ColKernel.cpp
+++ b/src/cpu/kernels/CpuIm2ColKernel.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2022 Arm Limited.
+ * Copyright (c) 2017-2023 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -52,7 +52,7 @@
 namespace
 {
 Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info,
-                          bool has_bias, const Size2D &dilation, unsigned int num_groups)
+                          bool has_bias, const Size2D &dilation, unsigned int num_groups, unsigned int input_pad_right)
 {
     ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input);
     ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(output);
@@ -70,7 +70,7 @@
 
     if(output->total_size() > 0)
     {
-        TensorInfo expected_output = output->clone()->set_tensor_shape(compute_im2col_conv_shape(input, kernel_dims, conv_info, has_bias, dilation, false));
+        TensorInfo expected_output = output->clone()->set_tensor_shape(compute_im2col_conv_shape(input, kernel_dims, conv_info, has_bias, dilation, false, num_groups, input_pad_right));
         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&expected_output, output);
         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(input, output);
@@ -246,6 +246,86 @@
         *out_ptr = static_cast<T>(1);
     }
 }
+
+template <typename T, bool has_pads>
+inline void linearize_volume_nhwc(const uint8_t *const in_ptr,
+                                  T                   *out_ptr,
+                                  bool                 has_bias,
+                                  int                  start_x,
+                                  int                  start_y,
+                                  int                  kernel_width,
+                                  int                  kernel_height,
+                                  int                  input_w,
+                                  int                  input_h,
+                                  int                  input_c,
+                                  int                  input_stride_y,
+                                  int                  input_stride_z,
+                                  int                  pad_value,
+                                  int                  dilation_x,
+                                  int                  dilation_y,
+                                  int                  pad_right)
+{
+    const int end_x              = start_x + kernel_width * dilation_x;
+    const int end_y              = start_y + kernel_height * dilation_y;
+    const int pad_quant          = kernel_width * (input_c + pad_right);
+    const int element_size       = static_cast<int>(sizeof(T));
+    const int channel_chunk_size = input_c * element_size;
+
+    if((start_y >= 0) && (end_y < input_h) && (start_x >= 0) && (end_x < input_w) && (dilation_x == 1) && (input_stride_y == channel_chunk_size))
+    {
+        for(int y = start_y; y < end_y; y += dilation_y)
+        {
+            const uint8_t *offset_ptr = in_ptr + (y * input_stride_z + start_x * input_stride_y);
+            for(int e = 0; e < kernel_width; e++)
+            {
+                memcpy(out_ptr, reinterpret_cast<const T *>(offset_ptr + e * channel_chunk_size), channel_chunk_size);
+                out_ptr += input_c + pad_right;
+            }
+        }
+    }
+    else
+    {
+        for(int y = start_y; y < end_y; y += dilation_y)
+        {
+            if(y < 0 || y >= input_h)
+            {
+                memset(static_cast<void *>(out_ptr), pad_value, pad_quant * element_size);
+                out_ptr += pad_quant;
+            }
+            else if(dilation_x > 1 || start_x < 0 || end_x >= input_w || input_stride_y != channel_chunk_size)
+            {
+                for(int x = start_x; x < end_x; x += dilation_x)
+                {
+                    if(x < 0 || x >= input_w)
+                    {
+                        memset(static_cast<void *>(out_ptr), pad_value, (input_c + pad_right) * element_size);
+                        out_ptr += input_c + pad_right;
+                    }
+                    else
+                    {
+                        memcpy(out_ptr, reinterpret_cast<const T *>(in_ptr + (y * input_stride_z + x * input_stride_y)), channel_chunk_size);
+                        out_ptr += input_c + pad_right;
+                    }
+                }
+            }
+            else
+            {
+                const uint8_t *offset_ptr = in_ptr + (y * input_stride_z + start_x * input_stride_y);
+                for(int e = 0; e < kernel_width; e++)
+                {
+                    memcpy(out_ptr, reinterpret_cast<const T *>(offset_ptr + e * channel_chunk_size), channel_chunk_size);
+                    out_ptr += input_c + pad_right;
+                }
+            }
+        }
+    }
+    // Append 1 if the convolution layer has biases
+    if(has_bias)
+    {
+        *out_ptr = static_cast<T>(1);
+    }
+}
+
 } // namespace
 
 template <typename T, bool has_pads, bool is_nchw>
@@ -280,7 +360,8 @@
     Iterator in(src, window_in_out);
     Iterator out(dst, window_in_out);
 
-    execute_window_loop(window, [&](const Coordinates & id)
+    execute_window_loop(
+        window, [&](const Coordinates & id)
     {
         const int start_w = id[width_idx] * stride_x - pad_left;
         const int start_h = id[height_idx] * stride_y - pad_top;
@@ -311,31 +392,53 @@
         }
         else
         {
-            linearize_volume_nhwc<T, has_pads>(input_ptr,
-                                               output_ptr,
-                                               _has_bias,
-                                               start_w,
-                                               start_h,
-                                               _kernel_width,
-                                               _kernel_height,
-                                               input_w,
-                                               input_h,
-                                               input_c,
-                                               input_stride_y,
-                                               input_stride_z,
-                                               pad_value,
-                                               _dilation.x(),
-                                               _dilation.y());
+            if(_input_pad_right > 0)
+            {
+                linearize_volume_nhwc<T, has_pads>(input_ptr,
+                                                   output_ptr,
+                                                   _has_bias,
+                                                   start_w,
+                                                   start_h,
+                                                   _kernel_width,
+                                                   _kernel_height,
+                                                   input_w,
+                                                   input_h,
+                                                   input_c,
+                                                   input_stride_y,
+                                                   input_stride_z,
+                                                   pad_value,
+                                                   _dilation.x(),
+                                                   _dilation.y(),
+                                                   _input_pad_right);
+            }
+            else
+            {
+                linearize_volume_nhwc<T, has_pads>(input_ptr,
+                                                   output_ptr,
+                                                   _has_bias,
+                                                   start_w,
+                                                   start_h,
+                                                   _kernel_width,
+                                                   _kernel_height,
+                                                   input_w,
+                                                   input_h,
+                                                   input_c,
+                                                   input_stride_y,
+                                                   input_stride_z,
+                                                   pad_value,
+                                                   _dilation.x(),
+                                                   _dilation.y());
+            }
         }
     },
     in, out);
 }
 
 void CpuIm2ColKernel::configure(const ITensorInfo *src, ITensorInfo *dst, const Size2D &kernel_dims, const PadStrideInfo &conv_info,
-                                bool has_bias, const Size2D &dilation, unsigned int num_groups)
+                                bool has_bias, const Size2D &dilation, unsigned int num_groups, unsigned int input_pad_right)
 {
     ARM_COMPUTE_ERROR_ON_NULLPTR(src, dst);
-    ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(src, dst, kernel_dims, conv_info, has_bias, dilation, num_groups));
+    ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(src, dst, kernel_dims, conv_info, has_bias, dilation, num_groups, input_pad_right));
     ARM_COMPUTE_UNUSED(num_groups);
 
     _data_layout                   = src->data_layout();
@@ -343,14 +446,15 @@
     const unsigned int height_idx  = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::HEIGHT);
     const unsigned int channel_idx = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::CHANNEL);
 
-    _conv_info      = conv_info;
-    _kernel_width   = kernel_dims.width;
-    _kernel_height  = kernel_dims.height;
-    _dilation       = dilation;
-    _convolved_dims = scaled_dimensions(src->dimension(width_idx), dst->dimension(height_idx),
-                                        _kernel_width, _kernel_height,
-                                        _conv_info, _dilation);
-    _has_bias = has_bias;
+    _conv_info       = conv_info;
+    _kernel_width    = kernel_dims.width;
+    _kernel_height   = kernel_dims.height;
+    _input_pad_right = input_pad_right;
+    _dilation        = dilation;
+    _convolved_dims  = scaled_dimensions(src->dimension(width_idx), dst->dimension(height_idx),
+                                         _kernel_width, _kernel_height,
+                                         _conv_info, _dilation);
+    _has_bias        = has_bias;
 
     if(_data_layout == DataLayout::NCHW)
     {
@@ -408,7 +512,7 @@
     }
 
     // Output tensor auto initialization if not yet initialized
-    auto_init_if_empty(*dst, src->clone()->set_tensor_shape(compute_im2col_conv_shape(src, kernel_dims, conv_info, has_bias, dilation, false)));
+    auto_init_if_empty(*dst, src->clone()->set_tensor_shape(compute_im2col_conv_shape(src, kernel_dims, conv_info, has_bias, dilation, false, num_groups, _input_pad_right)));
 
     std::pair<unsigned int, unsigned int> convolved_dims = scaled_dimensions(src->dimension(width_idx), src->dimension(height_idx),
                                                                              kernel_dims.width, kernel_dims.height,
@@ -423,9 +527,9 @@
 }
 
 Status CpuIm2ColKernel::validate(const ITensorInfo *src, const ITensorInfo *dst, const Size2D &kernel_dims, const PadStrideInfo &conv_info,
-                                 bool has_bias, const Size2D &dilation, unsigned int num_groups)
+                                 bool has_bias, const Size2D &dilation, unsigned int num_groups, unsigned int input_pad_right)
 {
-    ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(src, dst, kernel_dims, conv_info, has_bias, dilation, num_groups));
+    ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(src, dst, kernel_dims, conv_info, has_bias, dilation, num_groups, input_pad_right));
     return Status{};
 }
 
@@ -437,8 +541,10 @@
 
     auto src = tensors.get_const_tensor(TensorType::ACL_SRC);
     auto dst = tensors.get_tensor(TensorType::ACL_DST);
+
     (this->*_func)(src, dst, window);
 }
+
 const char *CpuIm2ColKernel::name() const
 {
     return "CpuIm2ColKernel";
diff --git a/src/cpu/kernels/CpuIm2ColKernel.h b/src/cpu/kernels/CpuIm2ColKernel.h
index 8160310..d133f8d 100644
--- a/src/cpu/kernels/CpuIm2ColKernel.h
+++ b/src/cpu/kernels/CpuIm2ColKernel.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2022 Arm Limited.
+ * Copyright (c) 2017-2023 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -66,19 +66,20 @@
     ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(CpuIm2ColKernel);
     /** Set the input and output of the kernel.
      *
-     * @param[in]  src         The input tensor info to convert. 3 lower dimensions represent a single input [width, height, IFM],
-     *                         while every optional dimension from 4 and above represent a batch of inputs.
-     *                         Data types supported: QASYMM8/QASYMM8_SIGNED/BFLOAT16/F16/F32
-     *                         Note: QASYMM8/QASYMM8_SIGNED works only for has_bias = false
-     * @param[out] dst         The output tensor info. Data types supported: Same as @p input
-     * @param[in]  kernel_dims The kernel dimensions (width and height).
-     * @param[in]  conv_info   Contains padding and stride information described in @ref PadStrideInfo.
-     * @param[in]  has_bias    In case biases are provided expands the matrix with 1.
-     * @param[in]  dilation    (Optional) Dilation, in elements, across x and y. Defaults to (1, 1).
-     * @param[in]  num_groups  (Optional) Number of groups when performing a grouped convolution. num_groups != 1 is not supported
+     * @param[in]  src             The input tensor info to convert. 3 lower dimensions represent a single input [width, height, IFM],
+     *                             while every optional dimension from 4 and above represent a batch of inputs.
+     *                             Data types supported: QASYMM8/QASYMM8_SIGNED/BFLOAT16/F16/F32
+     *                             Note: QASYMM8/QASYMM8_SIGNED works only for has_bias = false
+     * @param[out] dst             The output tensor info. Data types supported: Same as @p input
+     * @param[in]  kernel_dims     The kernel dimensions (width and height).
+     * @param[in]  conv_info       Contains padding and stride information described in @ref PadStrideInfo.
+     * @param[in]  has_bias        In case biases are provided expands the matrix with 1.
+     * @param[in]  dilation        (Optional) Dilation, in elements, across x and y. Defaults to (1, 1).
+     * @param[in]  num_groups      (Optional) Number of groups when performing a grouped convolution. num_groups != 1 is not supported
+     * @param[in]  input_pad_right (Optional) When fast-math is selected, per element padding for the im2col matrix may be necessary
      */
     void configure(const ITensorInfo *src, ITensorInfo *dst, const Size2D &kernel_dims, const PadStrideInfo &conv_info,
-                   bool has_bias, const Size2D &dilation = Size2D(1U, 1U), unsigned int num_groups = 1);
+                   bool has_bias, const Size2D &dilation = Size2D(1U, 1U), unsigned int num_groups = 1, unsigned int input_pad_right = 0);
     /** Static function to check if given info will lead to a valid configuration
      *
      * Similar to CpuIm2ColKernel::configure()
@@ -86,10 +87,10 @@
      * @return a status
      */
     static Status validate(const ITensorInfo *src, const ITensorInfo *dst, const Size2D &kernel_dims, const PadStrideInfo &conv_info,
-                           bool has_bias, const Size2D &dilation = Size2D(1U, 1U), unsigned int num_groups = 1);
+                           bool has_bias, const Size2D &dilation = Size2D(1U, 1U), unsigned int num_groups = 1, unsigned int input_pad_right = 0);
 
     // Inherited methods overridden:
-    void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override;
+    void        run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override;
     const char *name() const override;
     /** Return minimum workload size of the relevant kernel
      *
@@ -116,14 +117,15 @@
      */
     using Im2ColFunctionPtr = void (CpuIm2ColKernel::*)(const ITensor *src, ITensor *dst, const Window &window);
 
-    Im2ColFunctionPtr _func{ nullptr };
+    Im2ColFunctionPtr                     _func{ nullptr };
     std::pair<unsigned int, unsigned int> _convolved_dims{};
-    PadStrideInfo _conv_info{};
-    unsigned int  _kernel_width{ 0 };
-    unsigned int  _kernel_height{ 0 };
-    bool          _has_bias{ false };
-    Size2D        _dilation{ 1U, 1U };
-    DataLayout    _data_layout{ DataLayout::UNKNOWN };
+    PadStrideInfo                         _conv_info{};
+    unsigned int                          _kernel_width{ 0 };
+    unsigned int                          _kernel_height{ 0 };
+    unsigned int                          _input_pad_right{ 0 };
+    bool                                  _has_bias{ false };
+    Size2D                                _dilation{ 1U, 1U };
+    DataLayout                            _data_layout{ DataLayout::UNKNOWN };
 };
 } // namespace kernels
 } // namespace cpu
diff --git a/src/cpu/operators/CpuGemm.cpp b/src/cpu/operators/CpuGemm.cpp
index b9d18c4..7411d76 100644
--- a/src/cpu/operators/CpuGemm.cpp
+++ b/src/cpu/operators/CpuGemm.cpp
@@ -155,14 +155,14 @@
 Status CpuGemm::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, float alpha, float beta, const GEMMInfo &gemm_info)
 {
     ARM_COMPUTE_UNUSED(alpha);
-    const bool is_c_bias = beta == 1 && c != nullptr;
+    const bool is_c_bias    = beta == 1 && c != nullptr;
     const bool run_addition = c != nullptr && beta != 0 && beta != 1;
 
     ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(a);
     ARM_COMPUTE_RETURN_ERROR_ON_CPU_BF16_UNSUPPORTED(a);
     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::BFLOAT16, DataType::F16, DataType::F32);
 
-    if (is_fixed_format_fast_math(gemm_info.weight_format()))
+    if(is_fixed_format_fast_math(gemm_info.weight_format()))
     {
         ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(a, DataType::F32);
         ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(b, DataType::BFLOAT16);
@@ -172,7 +172,24 @@
         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, b);
     }
 
-    ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->dimension(0) != b->dimension(1), "The product AB is defined only if the number of columns in A is equal to the number of rows in B");
+    const int block_by = arm_compute::block_by(gemm_info.weight_format());
+    if(block_by > 1)
+    {
+        // have to verify bias
+        const size_t dim0_sz = a->dimension(0);
+        ARM_COMPUTE_RETURN_ERROR_ON_MSG((dim0_sz % block_by) != 0, ("The matrix A number of columns must be a multiple of block_by=" + std::to_string(block_by)).c_str());
+        // a->dimension(0) = kernel_area * input_channel + kernel_area * input_pad_right
+        // b->dimension(1) = kernel_area * input_channel
+        // a->dimension(0) = b->dimension(1) + kernel_area * input_pad_right
+        const size_t input_pad_right = (dim0_sz - b->dimension(1)) % block_by;
+        const size_t kernel_area     = (dim0_sz - b->dimension(1)) / input_pad_right;
+        ARM_COMPUTE_RETURN_ERROR_ON_MSG((dim0_sz - kernel_area * input_pad_right) != b->dimension(1), "The product AB is defined only if A number of columns and B number of rows are related");
+    }
+    else
+    {
+        ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->dimension(0) != b->dimension(1), "The product AB is defined only if the number of columns in A is equal to the number of rows in B");
+    }
+
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_a_reshaped(), "Matrix A already reshaped is not supported");
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_b_reshaped(), "Matrix B already reshaped is not supported");
     if(a->data_type() != DataType::BFLOAT16)
diff --git a/src/cpu/operators/CpuGemmConv2d.cpp b/src/cpu/operators/CpuGemmConv2d.cpp
index ebf2ebc..7c0e58b 100644
--- a/src/cpu/operators/CpuGemmConv2d.cpp
+++ b/src/cpu/operators/CpuGemmConv2d.cpp
@@ -62,13 +62,13 @@
     const unsigned int kernel_height = weights->dimension(idx_height);
     unsigned int       conv_w        = 0;
     unsigned int       conv_h        = 0;
-    std::tie(conv_w, conv_h) = scaled_dimensions(src->dimension(idx_width),
-                                                 src->dimension(idx_height),
-                                                 kernel_width,
-                                                 kernel_height,
-                                                 conv_info,
-                                                 dilation);
-    const bool skip_im2col = (data_layout == DataLayout::NHWC && kernel_width == 1 && kernel_height == 1 && conv_info.stride().first == 1 && conv_info.stride().second == 1);
+    std::tie(conv_w, conv_h)         = scaled_dimensions(src->dimension(idx_width),
+                                                         src->dimension(idx_height),
+                                                         kernel_width,
+                                                         kernel_height,
+                                                         conv_info,
+                                                         dilation);
+    const bool skip_im2col           = (data_layout == DataLayout::NHWC && kernel_width == 1 && kernel_height == 1 && conv_info.stride().first == 1 && conv_info.stride().second == 1);
 
     if(skip_im2col)
     {
@@ -139,8 +139,8 @@
         PixelValue type_min{};
         PixelValue type_max{};
         std::tie(type_min, type_max) = get_min_max(data_type);
-        int32_t min_activation = type_min.get<int32_t>();
-        int32_t max_activation = type_max.get<int32_t>();
+        int32_t min_activation       = type_min.get<int32_t>();
+        int32_t max_activation       = type_max.get<int32_t>();
 
         if(supported_acts.count(act_info.activation()) != 0)
         {
@@ -203,8 +203,8 @@
         PixelValue type_min{};
         PixelValue type_max{};
         std::tie(type_min, type_max) = get_min_max(data_type);
-        int32_t min_activation = type_min.get<int32_t>();
-        int32_t max_activation = type_max.get<int32_t>();
+        int32_t min_activation       = type_min.get<int32_t>();
+        int32_t max_activation       = type_max.get<int32_t>();
 
         const std::set<ActivationLayerInfo::ActivationFunction> supported_acts = { ActivationLayerInfo::ActivationFunction::RELU,
                                                                                    ActivationLayerInfo::ActivationFunction::BOUNDED_RELU,
@@ -274,6 +274,7 @@
     const DataLayout data_layout = src->data_layout();
     const int        idx_width   = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
     const int        idx_height  = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
+    const int        idx_channel = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL);
     const int        idx_kernels = get_data_layout_dimension_index(data_layout, DataLayoutDimension::BATCHES);
 
     const unsigned int kernel_width  = weights->dimension(idx_width);
@@ -288,8 +289,8 @@
     ITensorInfo       *gemm_output_to_use = dst;
 
     // Get convolved dimensions
-    unsigned int conv_w = 0;
-    unsigned int conv_h = 0;
+    unsigned int conv_w      = 0;
+    unsigned int conv_h      = 0;
     std::tie(conv_w, conv_h) = scaled_dimensions(src->dimension(idx_width),
                                                  src->dimension(idx_height),
                                                  kernel_width,
@@ -306,8 +307,8 @@
     _skip_col2im                            = skip_info.skip_col2im;
 
     // Get parameters from conv_info
-    unsigned int stride_x = 0;
-    unsigned int stride_y = 0;
+    unsigned int stride_x        = 0;
+    unsigned int stride_y        = 0;
     std::tie(stride_x, stride_y) = conv_info.stride();
 
     unsigned int mat_weights_cols = weights->dimension(idx_kernels);
@@ -321,9 +322,15 @@
     // Create tensor to store im2col reshaped inputs
     if(!_skip_im2col)
     {
+        const int    block_by        = arm_compute::block_by(weights_info.weight_format());
+        unsigned int input_pad_right = 0;
+        if(block_by > 1)
+        {
+            input_pad_right = (src->dimension(idx_channel) % block_by) == 0 ? 0 : block_by - (src->dimension(idx_channel) % block_by);
+        }
         // Configure
         _im2col_kernel = std::make_unique<kernels::CpuIm2ColKernel>();
-        _im2col_kernel->configure(src, &_im2col_output, Size2D(kernel_width, kernel_height), conv_info, false, dilation);
+        _im2col_kernel->configure(src, &_im2col_output, Size2D(kernel_width, kernel_height), conv_info, false, dilation, num_groups, input_pad_right);
 
         // Update GEMM input
         gemm_input_to_use = &_im2col_output;
@@ -399,12 +406,12 @@
     const unsigned int kernel_height = weights->dimension(idx_height);
     unsigned int       conv_w        = 0;
     unsigned int       conv_h        = 0;
-    std::tie(conv_w, conv_h) = scaled_dimensions(src->dimension(idx_width),
-                                                 src->dimension(idx_height),
-                                                 kernel_width,
-                                                 kernel_height,
-                                                 conv_info,
-                                                 dilation);
+    std::tie(conv_w, conv_h)         = scaled_dimensions(src->dimension(idx_width),
+                                                         src->dimension(idx_height),
+                                                         kernel_width,
+                                                         kernel_height,
+                                                         conv_info,
+                                                         dilation);
 
     const CpuGemmConv2d::SkipInfo skip_info = CpuGemmConv2d::skip_im_col_info(src, weights, conv_info,
                                                                               dilation, act_info);
@@ -428,7 +435,7 @@
     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::BFLOAT16, DataType::F16, DataType::F32);
     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(weights, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM8_PER_CHANNEL, DataType::BFLOAT16, DataType::F16, DataType::F32);
 
-    if (!is_fixed_format(weights_info.weight_format()))
+    if(!is_fixed_format(weights_info.weight_format()))
     {
         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(src, weights);
     }
@@ -469,9 +476,9 @@
                                                  dilation);
 
     // Check if GEMM3D is supported
-    const CpuGemmConv2d::SkipInfo skip_info = CpuGemmConv2d::skip_im_col_info(src, weights, conv_info,
-                                                                              dilation, act_info);
-    const bool skip_im2col = skip_info.skip_im2col, skip_col2im = skip_info.skip_col2im;
+    const CpuGemmConv2d::SkipInfo skip_info   = CpuGemmConv2d::skip_im_col_info(src, weights, conv_info,
+                                                                                dilation, act_info);
+    const bool                    skip_im2col = skip_info.skip_im2col, skip_col2im = skip_info.skip_col2im;
 
     ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(idx_channel) != src->dimension(idx_channel));
     ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 4);
@@ -504,6 +511,14 @@
 
     if(!skip_im2col)
     {
+        const int block_by        = arm_compute::block_by(weights_info.weight_format());
+        int       input_pad_right = 0;
+        if(block_by > 1)
+        {
+            input_pad_right  = (src->dimension(idx_channel) % block_by) == 0 ? 0 : block_by - (src->dimension(idx_channel) % block_by);
+            mat_weights_rows = weights->dimension(idx_width) * weights->dimension(idx_height) * (weights->dimension(idx_channel) + input_pad_right);
+        }
+
         // Create tensor info for im2col reshaped inputs
         // For CPU, the batch size is on the fourth dimension
         TensorShape shape_im2col = src->tensor_shape();
@@ -513,7 +528,7 @@
 
         im2col_reshaped_info = TensorInfo(shape_im2col, 1, data_type);
         im2col_reshaped_info.set_quantization_info(src->quantization_info());
-        ARM_COMPUTE_RETURN_ON_ERROR(kernels::CpuIm2ColKernel::validate(src, &im2col_reshaped_info, Size2D(kernel_width, kernel_height), conv_info, append_bias, dilation, 1));
+        ARM_COMPUTE_RETURN_ON_ERROR(kernels::CpuIm2ColKernel::validate(src, &im2col_reshaped_info, Size2D(kernel_width, kernel_height), conv_info, append_bias, dilation, num_groups, input_pad_right));
         gemm_input_to_use = &im2col_reshaped_info;
     }
 
@@ -563,7 +578,7 @@
     {
         // Run input reshaping
         unsigned int y_dim = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::HEIGHT);
-        ITensorPack  pack =
+        ITensorPack  pack  =
         {
             { TensorType::ACL_SRC, src },
             { TensorType::ACL_DST, im2col_output.get() }
@@ -657,7 +672,7 @@
         // Run weights reshaping and mark original weights tensor as unused
         CpuAuxTensorHandler weights_reshaped(offset_int_vec(WeightsReshaped), _weights_reshaped, tensors);
         auto                weights = tensors.get_const_tensor(TensorType::ACL_SRC_1);
-        ITensorPack         pack =
+        ITensorPack         pack    =
         {
             { TensorType::ACL_SRC, weights },
             { TensorType::ACL_DST, weights_reshaped.get() }