Integrate improved CPU depthwise convolution kernels

* Replace assembly kernels for depthwise convolution with more optimized
  ones.
* Add int8 assembly kernels.
* Fix implicit padding on optimized kernels

Resolves: COMPMID-3867, COMPMID-4361

Change-Id: I0b0867e05f61be4f368f62190d55e14d0ab3ebf2
Signed-off-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5622
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
diff --git a/src/runtime/cpu/operators/CpuDepthwiseConv2d.cpp b/src/runtime/cpu/operators/CpuDepthwiseConv2d.cpp
index 160a9fd..f577e94 100644
--- a/src/runtime/cpu/operators/CpuDepthwiseConv2d.cpp
+++ b/src/runtime/cpu/operators/CpuDepthwiseConv2d.cpp
@@ -62,8 +62,8 @@
 
     ARM_COMPUTE_RETURN_ON_ERROR(CpuDepthwiseConv2dAssemblyDispatch::validate(src, weights, biases, dst, info));
 
-    //Validate Activation Layer
-    if(info.act_info.enabled())
+    // Validate Activation Layer
+    if(info.act_info.enabled() && !CpuDepthwiseConv2dAssemblyDispatch::is_activation_supported(info.act_info))
     {
         ARM_COMPUTE_RETURN_ON_ERROR(CpuActivation::validate(dst, nullptr, info.act_info));
     }
@@ -95,15 +95,7 @@
     _is_prepared  = false;
 
     // Configure pipeline
-    ActivationLayerInfo act_info_to_use = ActivationLayerInfo();
-    const bool          is_relu         = arm_compute::utils::info_helpers::is_relu(info.act_info);
-    const bool          is_relu6        = arm_compute::utils::info_helpers::is_relu6(info.act_info);
-    _is_activationlayer_enabled         = info.act_info.enabled() && !(is_relu || is_relu6);
-
-    if(!_is_activationlayer_enabled)
-    {
-        act_info_to_use = info.act_info;
-    }
+    _is_activationlayer_enabled = info.act_info.enabled() && !CpuDepthwiseConv2dAssemblyDispatch::is_activation_supported(info.act_info);
 
     _dwc_optimized_func = std::make_unique<CpuDepthwiseConv2dAssemblyDispatch>();
     if(_is_nchw)
@@ -359,7 +351,7 @@
     }
 
     // Validate Activation Layer
-    if(info.act_info.enabled())
+    if(info.act_info.enabled() && !CpuDepthwiseConv2dAssemblyDispatch::is_activation_supported(info.act_info))
     {
         ARM_COMPUTE_RETURN_ON_ERROR(CpuActivation::validate(dst, nullptr, info.act_info));
     }
diff --git a/src/runtime/cpu/operators/CpuDepthwiseConv2d.h b/src/runtime/cpu/operators/CpuDepthwiseConv2d.h
index 049397f..ae9f894 100644
--- a/src/runtime/cpu/operators/CpuDepthwiseConv2d.h
+++ b/src/runtime/cpu/operators/CpuDepthwiseConv2d.h
@@ -92,9 +92,8 @@
     *
     * -# @ref NEFillBorderKernel (if pad_x or pad_y > 0) and no assembly kernel implementation is present
     * -# @ref CpuDepthwiseConv2d3x3Kernel if 3x3 and no assembly kernel implementation is present
-    * -# @ref NEDepthwiseConvolutionAssemblyDispatch if assembly kernel implementation is present
-    * -# @ref NEDirectConvolutionLayerOutputStageKernel if re-quantization of dst is required
-    * -# @ref NEActivationLayer if fused activation is required
+    * -# @ref CpuDepthwiseConv2dAssemblyDispatch if assembly kernel implementation is present
+    * -# @ref CpuActivation if fused activation is required
     *
     */
     class CpuDepthwiseConv2dOptimizedInternal : public ICpuOperator
diff --git a/src/runtime/cpu/operators/CpuDepthwiseConv2dAssemblyDispatch.cpp b/src/runtime/cpu/operators/CpuDepthwiseConv2dAssemblyDispatch.cpp
index a36ee1d..660ac01 100644
--- a/src/runtime/cpu/operators/CpuDepthwiseConv2dAssemblyDispatch.cpp
+++ b/src/runtime/cpu/operators/CpuDepthwiseConv2dAssemblyDispatch.cpp
@@ -24,315 +24,22 @@
 
 #include "src/runtime/cpu/operators/CpuDepthwiseConv2dAssemblyDispatch.h"
 
-#include "arm_compute/core/ITensor.h"
-#include "arm_compute/core/Utils.h"
-#include "arm_compute/core/utils/misc/InfoHelpers.h"
-#include "arm_compute/core/utils/misc/ShapeCalculator.h"
-#include "arm_compute/core/utils/quantization/AsymmHelpers.h"
-#include "src/core/CPP/Validate.h"
-#include "src/core/NEON/kernels/assembly/NEDepthwiseConvolutionAssemblyKernelWrapper.h"
-#include "src/core/NEON/kernels/convolution/depthwise/depthwise_dilated.hpp"
-#include "src/core/NEON/kernels/convolution/depthwise/depthwise_quantized_dilated.hpp"
-#include "src/core/helpers/AutoConfiguration.h"
-
+#include "arm_compute/core/ITensorInfo.h"
 #include "arm_compute/runtime/NEON/NEScheduler.h"
-
-#include <set>
+#include "src/core/CPP/Validate.h"
+#include "src/core/cpu/kernels/internal/CpuDepthwiseConv2dAssemblyWrapperKernel.h"
+#include "src/core/helpers/AutoConfiguration.h"
+#include "src/core/utils/AssemblyUtils.h"
 
 namespace arm_compute
 {
 namespace cpu
 {
-namespace
-{
-std::unique_ptr<depthwise::IDepthwiseConvolution> get_qasymm8_convolver(int kernel_size, int stride_x,
-                                                                        int n_batches, int in_rows, int in_cols, int n_channels,
-                                                                        int dilation_factor, neon_convolution_kernels::ActivationFunction activation,
-                                                                        const qasymm8::QAsymm8Params &wqinfo, const qasymm8::QAsymm8Params &iqinfo, const qasymm8::QAsymm8Params &oqinfo,
-                                                                        const qasymm8::QAsymm8RescaleParams &rescale_params,
-                                                                        int padding_top, int padding_left, int padding_bottom, int padding_right)
-{
-    switch(kernel_size)
-    {
-        case 3:
-        {
-            switch(stride_x)
-            {
-                case 1:
-                    return std::make_unique<depthwise::QAsymm8DilatedDepthwiseConvolution<2, 2, 3, 3, 1, 1>>(
-                               n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, wqinfo, iqinfo, oqinfo, rescale_params, padding_top, padding_left, padding_bottom, padding_right);
-                case 2:
-                    return std::make_unique<depthwise::QAsymm8DilatedDepthwiseConvolution<2, 2, 3, 3, 2, 2>>(
-                               n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, wqinfo, iqinfo, oqinfo, rescale_params, padding_top, padding_left, padding_bottom, padding_right);
-                default:
-                    return nullptr;
-            }
-        }
-        case 5:
-        {
-            switch(stride_x)
-            {
-                case 1:
-                    return std::make_unique<depthwise::QAsymm8DilatedDepthwiseConvolution<2, 2, 5, 5, 1, 1>>(
-                               n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, wqinfo, iqinfo, oqinfo, rescale_params, padding_top, padding_left, padding_bottom, padding_right);
-                case 2:
-                    return std::make_unique<depthwise::QAsymm8DilatedDepthwiseConvolution<2, 2, 5, 5, 2, 2>>(
-                               n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, wqinfo, iqinfo, oqinfo, rescale_params, padding_top, padding_left, padding_bottom, padding_right);
-                default:
-                    return nullptr;
-            }
-        }
-        default:
-            return nullptr;
-    }
-}
-
-std::unique_ptr<depthwise::IDepthwiseConvolution> get_qsymm8_perchannel_convolver(int kernel_size, int stride_x,
-                                                                                  int n_batches, int in_rows, int in_cols, int n_channels,
-                                                                                  neon_convolution_kernels::ActivationFunction activation,
-                                                                                  const qsymm8::QSymm8PerChannelParams &wqinfo, const qasymm8::QAsymm8Params &iqinfo, const qasymm8::QAsymm8Params &oqinfo,
-                                                                                  const qsymm8::QSymm8PerChannelRescaleParams &rescale_params,
-                                                                                  int padding_top, int padding_left, int padding_bottom, int padding_right)
-{
-    switch(kernel_size)
-    {
-        case 3:
-        {
-            switch(stride_x)
-            {
-                case 1:
-                    return std::make_unique<depthwise::QSymm8HybridPerChannelDepthwiseConvolution<2, 2, 3, 3, 1, 1>>(
-                               n_batches, in_rows, in_cols, n_channels, activation, wqinfo, iqinfo, oqinfo, rescale_params, padding_top, padding_left, padding_bottom, padding_right);
-                case 2:
-                    return std::make_unique<depthwise::QSymm8HybridPerChannelDepthwiseConvolution<2, 2, 3, 3, 2, 2>>(
-                               n_batches, in_rows, in_cols, n_channels, activation, wqinfo, iqinfo, oqinfo, rescale_params, padding_top, padding_left, padding_bottom, padding_right);
-                default:
-                    return nullptr;
-            }
-        }
-        case 5:
-        {
-            switch(stride_x)
-            {
-                case 1:
-                    return std::make_unique<depthwise::QSymm8HybridPerChannelDepthwiseConvolution<2, 2, 5, 5, 1, 1>>(
-                               n_batches, in_rows, in_cols, n_channels, activation, wqinfo, iqinfo, oqinfo, rescale_params, padding_top, padding_left, padding_bottom, padding_right);
-                case 2:
-                    return std::make_unique<depthwise::QSymm8HybridPerChannelDepthwiseConvolution<2, 2, 5, 5, 2, 2>>(
-                               n_batches, in_rows, in_cols, n_channels, activation, wqinfo, iqinfo, oqinfo, rescale_params, padding_top, padding_left, padding_bottom, padding_right);
-                default:
-                    return nullptr;
-            }
-        }
-        default:
-            return nullptr;
-    }
-}
-
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-std::unique_ptr<depthwise::IDepthwiseConvolution> get_fp16_convolver(int kernel_size, int stride_x,
-                                                                     int n_batches, int in_rows, int in_cols, int n_channels,
-                                                                     int dilation_factor, neon_convolution_kernels::ActivationFunction activation,
-                                                                     int padding_top, int padding_left, int padding_bottom, int padding_right)
-{
-    switch(kernel_size)
-    {
-        case 3:
-        {
-            switch(stride_x)
-            {
-                case 1:
-                    return std::make_unique<depthwise::DilatedDepthwiseConvolution<3, 3, 3, 3, 1, 1, float16_t, float16_t, float16_t>>(
-                               n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, padding_top, padding_left, padding_bottom, padding_right);
-                case 2:
-                    return std::make_unique<depthwise::DilatedDepthwiseConvolution<3, 3, 3, 3, 2, 2, float16_t, float16_t, float16_t>>(
-                               n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, padding_top, padding_left, padding_bottom, padding_right);
-                default:
-                    return nullptr;
-            }
-        }
-        case 5:
-        {
-            switch(stride_x)
-            {
-                case 1:
-                    return std::make_unique<depthwise::DilatedDepthwiseConvolution<3, 3, 5, 5, 1, 1, float16_t, float16_t, float16_t>>(
-                               n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, padding_top, padding_left, padding_bottom, padding_right);
-                case 2:
-                    return std::make_unique<depthwise::DilatedDepthwiseConvolution<3, 3, 5, 5, 2, 2, float16_t, float16_t, float16_t>>(
-                               n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, padding_top, padding_left, padding_bottom, padding_right);
-                default:
-                    return nullptr;
-            }
-        }
-        default:
-            return nullptr;
-    }
-}
-#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-
-std::unique_ptr<depthwise::IDepthwiseConvolution> get_fp32_convolver(int kernel_size, int stride_x,
-                                                                     int n_batches, int in_rows, int in_cols, int n_channels,
-                                                                     int dilation_factor, neon_convolution_kernels::ActivationFunction activation,
-                                                                     int padding_top, int padding_left, int padding_bottom, int padding_right)
-{
-    switch(kernel_size)
-    {
-        case 3:
-        {
-            switch(stride_x)
-            {
-                case 1:
-                    return std::make_unique<depthwise::DilatedDepthwiseConvolution<4, 4, 3, 3, 1, 1, float, float, float>>(
-                               n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, padding_top, padding_left, padding_bottom, padding_right);
-                case 2:
-                    return std::make_unique<depthwise::DilatedDepthwiseConvolution<3, 3, 3, 3, 2, 2, float, float, float>>(
-                               n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, padding_top, padding_left, padding_bottom, padding_right);
-                default:
-                    return nullptr;
-            }
-        }
-        case 5:
-        {
-            switch(stride_x)
-            {
-                case 1:
-                    return std::make_unique<depthwise::DilatedDepthwiseConvolution<4, 4, 5, 5, 1, 1, float, float, float>>(
-                               n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, padding_top, padding_left, padding_bottom, padding_right);
-                case 2:
-                    return std::make_unique<depthwise::DilatedDepthwiseConvolution<3, 3, 5, 5, 2, 2, float, float, float>>(
-                               n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, padding_top, padding_left, padding_bottom, padding_right);
-                default:
-                    return nullptr;
-            }
-        }
-        default:
-            return nullptr;
-    }
-}
-
-std::unique_ptr<depthwise::IDepthwiseConvolution> create_convolver(const ITensorInfo     *src,
-                                                                   const ITensorInfo     *weights,
-                                                                   ITensorInfo           *output,
-                                                                   const ConvolutionInfo &info)
-{
-    const DataType    data_type = src->data_type();
-    const TensorShape shape     = src->tensor_shape();
-
-    const int n_batches       = shape[3];
-    const int in_rows         = shape.z();
-    const int in_cols         = shape.y();
-    const int n_channels      = shape.x();
-    const int dilation_factor = info.dilation.x();
-    const int padding_top     = info.pad_stride_info.pad_top();
-    const int padding_left    = info.pad_stride_info.pad_left();
-    const int padding_bottom  = info.pad_stride_info.pad_bottom();
-    const int padding_right   = info.pad_stride_info.pad_right();
-
-    const bool is_uniform_quantized    = (data_type == DataType::QASYMM8) && (weights->data_type() == DataType::QASYMM8);
-    const bool is_perchannel_quantized = (data_type == DataType::QASYMM8) && (weights->data_type() == DataType::QSYMM8_PER_CHANNEL);
-
-    const unsigned int stride_x    = info.pad_stride_info.stride().first;
-    const unsigned int kernel_size = weights->tensor_shape().y();
-
-    // Map activation function
-    neon_convolution_kernels::ActivationFunction activation = neon_convolution_kernels::ActivationFunction::None;
-    if(arm_compute::utils::info_helpers::is_relu(info.act_info))
-    {
-        activation = neon_convolution_kernels::ActivationFunction::ReLU;
-    }
-    else if(arm_compute::utils::info_helpers::is_relu6(info.act_info))
-    {
-        activation = neon_convolution_kernels::ActivationFunction::ReLU6;
-    }
-
-    // Create quantized convolver
-    if(is_uniform_quantized)
-    {
-        const UniformQuantizationInfo input_qinfo   = src->quantization_info().uniform();
-        const UniformQuantizationInfo weights_qinfo = weights->quantization_info().uniform();
-        const UniformQuantizationInfo output_qinfo  = output->quantization_info().uniform();
-
-        // Check that quantization info are in the range [0, 255]
-        ARM_COMPUTE_ERROR_ON(input_qinfo.offset < 0 || input_qinfo.offset > 255);
-        ARM_COMPUTE_ERROR_ON(weights_qinfo.offset < 0 || weights_qinfo.offset > 255);
-        ARM_COMPUTE_ERROR_ON(output_qinfo.offset < 0 || output_qinfo.offset > 255);
-        const qasymm8::QAsymm8Params iqinfo{ static_cast<uint8_t>(input_qinfo.offset), input_qinfo.scale };
-        const qasymm8::QAsymm8Params wqinfo{ static_cast<uint8_t>(weights_qinfo.offset), weights_qinfo.scale };
-        const qasymm8::QAsymm8Params oqinfo{ static_cast<uint8_t>(output_qinfo.offset), output_qinfo.scale };
-
-        // Calculate rescale parameters
-        const float fmultipler  = iqinfo.scale * wqinfo.scale / oqinfo.scale;
-        int32_t     qmultiplier = 0;
-        int32_t     qshift      = 0;
-        quantization::calculate_quantized_multiplier_less_than_one(fmultipler, &qmultiplier, &qshift);
-        qasymm8::QAsymm8RescaleParams rescale_params(qshift, qmultiplier, fmultipler);
-
-        return get_qasymm8_convolver(kernel_size, stride_x, n_batches, in_rows, in_cols, n_channels, dilation_factor, activation,
-                                     wqinfo, iqinfo, oqinfo, rescale_params, padding_top, padding_left, padding_bottom, padding_right);
-    }
-    else if(is_perchannel_quantized)
-    {
-        const UniformQuantizationInfo input_qinfo   = src->quantization_info().uniform();
-        const QuantizationInfo        weights_qinfo = weights->quantization_info();
-        const UniformQuantizationInfo output_qinfo  = output->quantization_info().uniform();
-
-        // Check that quantization info are in the range [0, 255]
-        ARM_COMPUTE_ERROR_ON(input_qinfo.offset < 0 || input_qinfo.offset > 255);
-        ARM_COMPUTE_ERROR_ON(output_qinfo.offset < 0 || output_qinfo.offset > 255);
-        const qasymm8::QAsymm8Params         iqinfo{ static_cast<uint8_t>(input_qinfo.offset), input_qinfo.scale };
-        const qsymm8::QSymm8PerChannelParams wqinfo{ weights_qinfo.scale() };
-        const qasymm8::QAsymm8Params         oqinfo{ static_cast<uint8_t>(output_qinfo.offset), output_qinfo.scale };
-
-        // Calculate rescale parameters
-        std::vector<float>   fmultipliers;
-        std::vector<int32_t> qmultipliers;
-        std::vector<int32_t> qshifts;
-
-        for(auto const s : wqinfo.scales)
-        {
-            const float fmultipler  = iqinfo.scale * s / oqinfo.scale;
-            int32_t     qmultiplier = 0;
-            int32_t     qshift      = 0;
-            quantization::calculate_quantized_multiplier_less_than_one(fmultipler, &qmultiplier, &qshift);
-            fmultipliers.push_back(fmultipler);
-            qmultipliers.push_back(qmultiplier);
-            qshifts.push_back(qshift);
-        }
-
-        qsymm8::QSymm8PerChannelRescaleParams rescale_params(qshifts, qmultipliers, fmultipliers);
-
-        return get_qsymm8_perchannel_convolver(kernel_size, stride_x, n_batches, in_rows, in_cols, n_channels, activation,
-                                               wqinfo, iqinfo, oqinfo, rescale_params, padding_top, padding_left, padding_bottom, padding_right);
-    }
-    else
-    {
-        // Create float convolver
-        switch(data_type)
-        {
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-            case DataType::F16:
-            {
-                return get_fp16_convolver(kernel_size, stride_x, n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, padding_top, padding_left, padding_bottom, padding_right);
-            }
-#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-            case DataType::F32:
-            {
-                return get_fp32_convolver(kernel_size, stride_x, n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, padding_top, padding_left, padding_bottom, padding_right);
-            }
-            default:
-                return nullptr;
-        }
-    }
-}
-} // namespace
-
 struct CpuDepthwiseConv2dAssemblyDispatch::LocalImpl
 {
-    std::unique_ptr<depthwise::IDepthwiseConvolution> dwc_assembly_kernel{ nullptr };
-    NEDepthwiseConvolutionAssemblyKernelWrapper       dwc_acl_kernel{};
-    bool                                              is_prepared{ false };
-    experimental::MemoryRequirements                  mem_req{};
+    std::unique_ptr<kernels::CpuDepthwiseConv2dAssemblyWrapperKernel> asm_kernel{ nullptr };
+    bool                                                              is_prepared{ false };
+    experimental::MemoryRequirements                                  mem_req{};
 };
 
 #ifndef DOXYGEN_SKIP_THIS
@@ -350,40 +57,30 @@
                                                    ITensorInfo           *dst,
                                                    const ConvolutionInfo &info)
 {
-    ARM_COMPUTE_ERROR_ON_NULLPTR(src, weights, dst);
-    ARM_COMPUTE_UNUSED(bias);
-    ARM_COMPUTE_ERROR_THROW_ON(CpuDepthwiseConv2dAssemblyDispatch::validate(src,
-                                                                            weights,
-                                                                            bias != nullptr ? bias : nullptr,
-                                                                            dst,
-                                                                            info));
+    const CPUInfo     &ci          = NEScheduler::get().cpu_info();
+    const unsigned int num_threads = NEScheduler::get().num_threads();
+    _pImpl->is_prepared            = false;
 
-    // Output auto inizialitation if not yet initialized
-    const TensorShape dst_shape = misc::shape_calculator::compute_depthwise_convolution_shape(*src, *weights, info);
-    auto_init_if_empty(*dst, src->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(dst_shape).set_quantization_info(dst->quantization_info()));
+    // If we don't support a combination of data types, silently return: it is the caller's responsibility to check if configure() was successful via is_configured()
+    if(!CpuDepthwiseConv2dAssemblyDispatch::validate(src, weights, bias, dst, info))
+    {
+        return;
+    }
 
-    _pImpl->is_prepared = false;
+    auto dwc_wrapper = std::make_unique<kernels::CpuDepthwiseConv2dAssemblyWrapperKernel>();
+    ARM_COMPUTE_ERROR_ON(dwc_wrapper == nullptr);
+    dwc_wrapper->configure(src, weights, bias, dst, info, ci);
 
-    // Create convolver
-    _pImpl->dwc_assembly_kernel = create_convolver(src, weights, dst, info);
-    ARM_COMPUTE_ERROR_ON(_pImpl->dwc_assembly_kernel == nullptr);
+    // Compute memory requirements for assembly kernels
+    constexpr size_t alignment = 4096;
+    _pImpl->mem_req.push_back({ TensorType::ACL_INT_0, dwc_wrapper->get_working_size(num_threads, src->dimension(0)), alignment });
+    _pImpl->mem_req.push_back({ TensorType::ACL_INT_1, dwc_wrapper->get_storage_size(), alignment });
+    _pImpl->asm_kernel = std::move(dwc_wrapper);
+}
 
-    // Create assembly kernel wrapper
-    _pImpl->dwc_acl_kernel.configure(_pImpl->dwc_assembly_kernel.get());
-
-    constexpr size_t alignment = 128;
-
-    // Create workspace
-    const unsigned int num_threads    = NEScheduler::get().num_threads();
-    const size_t       workspace_size = _pImpl->dwc_assembly_kernel->get_working_space_size(num_threads);
-    ARM_COMPUTE_ERROR_ON_MSG(workspace_size == 0, "Workspace size cannot be 0 !");
-    _pImpl->mem_req.push_back({ TensorType::ACL_INT_0, workspace_size, alignment });
-
-    // Create packing tensor
-    const size_t pack_tensor_size = _pImpl->dwc_assembly_kernel->get_packed_params_size();
-    ARM_COMPUTE_ERROR_ON_MSG(pack_tensor_size == 0, "Pack tensor size cannot be 0 !");
-
-    _pImpl->mem_req.push_back({ TensorType::ACL_INT_1, pack_tensor_size, alignment });
+Status CpuDepthwiseConv2dAssemblyDispatch::validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *bias, const ITensorInfo *dst, const ConvolutionInfo &info)
+{
+    return kernels::CpuDepthwiseConv2dAssemblyWrapperKernel::validate(src, weights, bias, dst, info);
 }
 
 experimental::MemoryRequirements CpuDepthwiseConv2dAssemblyDispatch::workspace() const
@@ -391,165 +88,40 @@
     return _pImpl->mem_req;
 }
 
-Status CpuDepthwiseConv2dAssemblyDispatch::validate(const ITensorInfo     *src,
-                                                    const ITensorInfo     *weights,
-                                                    const ITensorInfo     *bias,
-                                                    const ITensorInfo     *dst,
-                                                    const ConvolutionInfo &info)
+bool CpuDepthwiseConv2dAssemblyDispatch::is_activation_supported(const ActivationLayerInfo &activation)
 {
-    ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(src);
-    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
-    if(weights->data_type() != DataType::QSYMM8_PER_CHANNEL)
-    {
-        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, weights);
-    }
-    ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(src, weights);
-
-    // Validate convolver
-    ARM_COMPUTE_RETURN_ERROR_ON(!is_optimized_supported(src, weights, info));
-
-    // Validate activation
-    const bool is_relu  = arm_compute::utils::info_helpers::is_relu(info.act_info);
-    const bool is_relu6 = arm_compute::utils::info_helpers::is_relu6(info.act_info);
-    ARM_COMPUTE_RETURN_ERROR_ON(info.act_info.enabled() && !(is_relu || is_relu6));
-
-    // Check bias
-    if(bias != nullptr)
-    {
-        unsigned int channel_idx = get_data_layout_dimension_index(src->data_layout(), DataLayoutDimension::CHANNEL);
-        ARM_COMPUTE_RETURN_ERROR_ON(bias->num_dimensions() > 1);
-        ARM_COMPUTE_RETURN_ERROR_ON(bias->dimension(0) != weights->dimension(channel_idx));
-    }
-
-    // Check output
-    if(dst->total_size() != 0)
-    {
-        const TensorShape dst_shape = misc::shape_calculator::compute_depthwise_convolution_shape(*src, *weights, info);
-        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(dst->tensor_shape(), dst_shape);
-        ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, dst);
-    }
-
-    // The uniform quantization case will only have 1 scale value in the weights quantization info
-    const UniformQuantizationInfo src_qinfo     = src->quantization_info().uniform();
-    const QuantizationInfo        weights_qinfo = weights->quantization_info();
-    const UniformQuantizationInfo dst_qinfo     = dst->quantization_info().uniform();
-    for(auto const s : weights_qinfo.scale())
-    {
-        const float fmultipler = src_qinfo.scale * s / dst_qinfo.scale;
-        ARM_COMPUTE_RETURN_ERROR_ON(fmultipler > 1.f);
-    }
-
-    return Status{};
-}
-
-bool CpuDepthwiseConv2dAssemblyDispatch::is_optimized_supported(const ITensorInfo     *src,
-                                                                const ITensorInfo     *weights,
-                                                                const ConvolutionInfo &info)
-{
-    ARM_COMPUTE_ERROR_ON_NULLPTR(src, weights);
-
-    // Reshape input shape if in NHWC format
-    const DataLayout data_layout = src->data_layout();
-    TensorShape      in_shape{ src->tensor_shape() };
-    if(data_layout == DataLayout::NHWC)
-    {
-        in_shape.set(Window::DimX, src->tensor_shape().y());
-        in_shape.set(Window::DimY, src->tensor_shape().z());
-        in_shape.set(Window::DimZ, src->tensor_shape().x());
-    }
-
-    // Check data type
-    const DataType input_type            = src->data_type();
-    const bool     is_input_type_valid   = is_data_type_float(input_type) || input_type == DataType::QASYMM8;
-    const DataType weights_type          = weights->data_type();
-    const bool     is_weights_type_valid = is_data_type_float(weights_type) || weights_type == DataType::QASYMM8 || weights_type == DataType::QASYMM8_SIGNED
-                                           || weights_type == DataType::QSYMM8_PER_CHANNEL;
-
-    // Check weighs size
-    std::set<unsigned int> supported_kernel_sizes = { 3, 5 };
-    const unsigned int     width_idx              = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
-    const unsigned int     height_idx             = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
-    const unsigned int     kernel_w               = weights->dimension(width_idx);
-    const unsigned int     kernel_h               = weights->dimension(height_idx);
-    bool                   weights_supported      = (kernel_w == kernel_h) && (supported_kernel_sizes.count(kernel_w) != 0);
-
-    // Check for supported strides
-    const auto &strides           = info.pad_stride_info.stride();
-    bool        supported_strides = (strides.first == strides.second) && ((strides.first == 1) || (strides.first == 2));
-
-    // Check for supported padding
-    const auto    pad_top           = info.pad_stride_info.pad_top();
-    const auto    pad_right         = info.pad_stride_info.pad_right();
-    const auto    pad_bottom        = info.pad_stride_info.pad_bottom();
-    const auto    pad_left          = info.pad_stride_info.pad_left();
-    PadStrideInfo same_pad          = calculate_same_pad(in_shape, TensorShape(kernel_w, kernel_h), info.pad_stride_info, DataLayout::NCHW, info.dilation);
-    bool          is_same_padding   = (pad_top == same_pad.pad_top()) && (pad_right == same_pad.pad_right()) && (pad_bottom == same_pad.pad_bottom()) && (pad_left == same_pad.pad_left());
-    bool          is_valid_padding  = (pad_top == 0) && (pad_right == 0) && (pad_bottom == 0) && (pad_left == 0);
-    bool          supported_padding = is_same_padding || is_valid_padding;
-    // TODO(COMPMID-2464): Enable once dilated conv with stride 2 is supported
-    bool is_dilation_supported = ((info.dilation == Size2D(1U, 1U)) || ((info.dilation.x() == info.dilation.y()) && strides.first == 1));
-
-    if(weights_type == DataType::QSYMM8_PER_CHANNEL)
-    {
-        is_dilation_supported = is_dilation_supported && (info.dilation == Size2D(1U, 1U));
-    }
-
-    return is_input_type_valid && is_weights_type_valid && weights_supported && supported_strides && supported_padding && (info.depth_multiplier == 1) && is_dilation_supported;
+    arm_gemm::Activation act = assembly_utils::map_to_arm_gemm_activation(activation);
+    return act.type != arm_gemm::Activation::Type::None;
 }
 
 void CpuDepthwiseConv2dAssemblyDispatch::run(ITensorPack &tensors)
 {
-    // Prepare assembly kernel
+    ARM_COMPUTE_ERROR_ON_MSG(tensors.empty(), "No inputs provided");
+
     prepare(tensors);
 
-    auto src       = tensors.get_tensor(TensorType::ACL_SRC_0);
-    auto workspace = tensors.get_tensor(TensorType::ACL_INT_0);
-    auto dst       = tensors.get_tensor(TensorType::ACL_DST);
-
-    // Setup inputs/outputs
-    ARM_COMPUTE_ERROR_ON(workspace == nullptr && workspace->buffer() == nullptr);
-    _pImpl->dwc_assembly_kernel->set_working_space(static_cast<void *>(workspace->buffer()));
-
-    ARM_COMPUTE_ERROR_ON(workspace->buffer() == nullptr);
-    const int   input_element_size = src->info()->element_size();
-    const int   input_batch_stride = src->info()->strides_in_bytes()[3] / input_element_size;
-    const int   input_row_stride   = src->info()->strides_in_bytes().z() / input_element_size;
-    const int   input_col_stride   = src->info()->strides_in_bytes().y() / input_element_size;
-    const void *input_ptr          = src->buffer() + src->info()->offset_first_element_in_bytes();
-    _pImpl->dwc_assembly_kernel->set_input(input_ptr, input_batch_stride, input_row_stride, input_col_stride);
-
-    ARM_COMPUTE_ERROR_ON(dst->buffer() == nullptr);
-    const int output_element_size = dst->info()->element_size();
-    const int output_batch_stride = dst->info()->strides_in_bytes()[3] / output_element_size;
-    const int output_row_stride   = dst->info()->strides_in_bytes().z() / output_element_size;
-    const int output_col_stride   = dst->info()->strides_in_bytes().y() / output_element_size;
-    void     *output_ptr          = dst->buffer() + dst->info()->offset_first_element_in_bytes();
-    _pImpl->dwc_assembly_kernel->set_output(output_ptr, output_batch_stride, output_row_stride, output_col_stride);
-
-    // Schedule assembly kernel
-    NEScheduler::get().schedule(&_pImpl->dwc_acl_kernel, Window::DimX);
+    NEScheduler::get().schedule_op(_pImpl->asm_kernel.get(), Window::DimY, _pImpl->asm_kernel->window(), tensors);
 }
 
 void CpuDepthwiseConv2dAssemblyDispatch::prepare(ITensorPack &tensors)
 {
     if(!_pImpl->is_prepared)
     {
-        auto weights        = tensors.get_const_tensor(TensorType::ACL_SRC_1);
-        auto bias           = tensors.get_const_tensor(TensorType::ACL_SRC_2);
-        auto packed_weights = tensors.get_tensor(TensorType::ACL_INT_1);
-
-        ARM_COMPUTE_ERROR_ON(packed_weights->buffer() == nullptr);
-
         // Pack weights and bias
-        const int weights_element_size = weights->info()->element_size();
-        const int weights_row_stride   = weights->info()->strides_in_bytes().z() / weights_element_size;
-        const int weights_col_stride   = weights->info()->strides_in_bytes().y() / weights_element_size;
-        _pImpl->dwc_assembly_kernel->pack_params(packed_weights->buffer(),
-                                                 weights->buffer() + weights->info()->offset_first_element_in_bytes(),
-                                                 weights_row_stride,
-                                                 weights_col_stride,
-                                                 (bias != nullptr) ? bias->buffer() : nullptr);
-        _pImpl->dwc_assembly_kernel->set_packed_params_buffer(packed_weights->buffer());
+        const ITensor *weights = tensors.get_const_tensor(TensorType::ACL_SRC_1);
+        const ITensor *bias    = tensors.get_const_tensor(TensorType::ACL_SRC_2);
+        ITensor       *storage = tensors.get_tensor(TensorType::ACL_INT_1);
+
+        const auto weights_ptr    = weights->buffer() + weights->info()->offset_first_element_in_bytes();
+        const auto bias_ptr       = (bias) ? bias->buffer() + bias->info()->offset_first_element_in_bytes() : nullptr;
+        auto       parameters_ptr = storage->buffer() + storage->info()->offset_first_element_in_bytes();
+
+        const auto weights_shape   = weights->info()->tensor_shape();
+        const auto weights_padding = weights->info()->padding();
+
+        const size_t ld_weights_col = weights_shape[0] + weights_padding.left + weights_padding.right;
+        const size_t ld_weights_row = ld_weights_col * (weights_shape[1] + weights_padding.top + weights_padding.bottom);
+        _pImpl->asm_kernel->pack_parameters(parameters_ptr, bias_ptr, weights_ptr, ld_weights_col, ld_weights_row);
 
         weights->mark_as_unused();
         if(bias != nullptr)
diff --git a/src/runtime/cpu/operators/CpuDepthwiseConv2dAssemblyDispatch.h b/src/runtime/cpu/operators/CpuDepthwiseConv2dAssemblyDispatch.h
index 195942b..7084516 100644
--- a/src/runtime/cpu/operators/CpuDepthwiseConv2dAssemblyDispatch.h
+++ b/src/runtime/cpu/operators/CpuDepthwiseConv2dAssemblyDispatch.h
@@ -21,8 +21,8 @@
  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  * SOFTWARE.
  */
-#ifndef ARM_COMPUTE_CPU_DEPTHWISECONV2DASSEMBLYDISPATCH_H
-#define ARM_COMPUTE_CPU_DEPTHWISECONV2DASSEMBLYDISPATCH_H
+#ifndef ARM_COMPUTE_CPU_DEPTHWISE_CONV2D_ASSEMBLY_DISPATCH_H
+#define ARM_COMPUTE_CPU_DEPTHWISE_CONV2D_ASSEMBLY_DISPATCH_H
 
 #include "src/core/common/Macros.h"
 #include "src/runtime/cpu/ICpuOperator.h"
@@ -40,15 +40,15 @@
     ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(CpuDepthwiseConv2dAssemblyDispatch);
     /** Default destructor */
     ~CpuDepthwiseConv2dAssemblyDispatch();
-
     /** Initialize the function's source, destination, kernels and border_size.
      *
      * @note Supports only NHWC format
      *
-     * @param[in]  src     Source tensor info. Data type supported: QASYMM8/F16/F32. (Written to only for border filling).
-     * @param[in]  weights Weights tensor info. These are 3D tensors with shape [W, H, IFM]. Data type supported: Same as @p src.
+     * @param[in]  src     Source tensor info. Data type supported: QASYMM8/QASYMM8_SIGNED/F16/F32.
+     * @param[in]  weights Weights tensor info. These are 3D tensors with shape [W, H, IFM].
+     *                     Data type supported: same as @p src or QASYMM8/QASYMM8_SIGNED/QSYMM8_PER_CHANNEL when @p src is QASYMM8/QASYMM8_SIGNED.
      * @param[in]  bias    (Optional) Biases tensor info. A 1D tensor with shape [IFM]. Must be nullptr if not needed.
-     *                     Data type supported: Same as @p src.
+     *                     Data type supported: same as @p src or S32 if @p src is quantized.
      * @param[out] dst     Destination tensor info. Data type supported: same as @p src.
      * @param[in]  info    Depthwise convolution meta-data.
      */
@@ -60,18 +60,13 @@
      * @return a status
      */
     static Status validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *bias, const ITensorInfo *dst, const ConvolutionInfo &info);
-    /** Check if the optimized kernel can be used for the given kernel sizes and strides
+    /** Checks if activation is supported by the assembly kernels
      *
-     * @warning Even if this return true the inputs and outputs might need to get permuted as the only layout supported is NHWC
+     * @param[in] activation Activation to check
      *
-     * @param[in] src     Input tensor info.
-     * @param[in] weights Weights tensor info.
-     * @param[in] info    Depthwise convolution meta-data.
-     *
-     * @return True if the assembly kernel could be used else false. Note that transformations of input/output could be needed.
+     * @return True if activation is supported else false
      */
-    static bool is_optimized_supported(const ITensorInfo *src, const ITensorInfo *weights, const ConvolutionInfo &info);
-
+    static bool is_activation_supported(const ActivationLayerInfo &activation);
     // Inherited methods overridden:
     void run(ITensorPack &tensors) override;
     void prepare(ITensorPack &tensors) override;
@@ -83,4 +78,4 @@
 };
 } // namespace cpu
 } // namespace arm_compute
-#endif /* ARM_COMPUTE_CPU_DEPTHWISECONV2DASSEMBLYDISPATCH_H */
+#endif /* ARM_COMPUTE_CPU_DEPTHWISE_CONV2D_ASSEMBLY_DISPATCH_H */
diff --git a/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
index ea3742f..1101e05 100644
--- a/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
+++ b/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
@@ -27,6 +27,7 @@
 #include "src/core/CPP/Validate.h"
 #include "src/core/cpu/kernels/assembly/CpuGemmAssemblyWrapperKernel.h"
 #include "src/core/cpu/kernels/assembly/arm_gemm.hpp"
+#include "src/core/utils/AssemblyUtils.h"
 
 #include <arm_neon.h>
 #include <cstdlib>
@@ -89,38 +90,6 @@
     return p;
 }
 
-arm_gemm::Activation map_to_arm_gemm_activation(const ActivationLayerInfo &act)
-{
-    arm_gemm::Activation gemm_act;
-
-    // Early exit in case lower bound is other than 0, as it's not yet supported
-    if(act.b() != 0.f)
-    {
-        return gemm_act;
-    }
-
-    switch(act.activation())
-    {
-        case ActivationLayerInfo::ActivationFunction::RELU:
-            gemm_act.type = arm_gemm::Activation::Type::ReLU;
-            break;
-        case ActivationLayerInfo::ActivationFunction::BOUNDED_RELU:
-            gemm_act.type   = arm_gemm::Activation::Type::BoundedReLU;
-            gemm_act.param1 = act.a();
-            gemm_act.param2 = 0.f;
-            break;
-        case ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU:
-            gemm_act.type   = arm_gemm::Activation::Type::BoundedReLU;
-            gemm_act.param1 = act.a();
-            gemm_act.param2 = act.b();
-            break;
-        default:
-            gemm_act.type = arm_gemm::Activation::Type::None;
-    }
-
-    return gemm_act;
-}
-
 IScheduler::Hints scheduling_hint_heuristic(arm_gemm::GemmMethod method, DataType data_type)
 {
     // Schedule assembly kernel
@@ -788,14 +757,14 @@
 
 bool CpuGemmAssemblyDispatch::is_activation_supported(const ActivationLayerInfo &activation)
 {
-    arm_gemm::Activation act = map_to_arm_gemm_activation(activation);
+    arm_gemm::Activation act = assembly_utils::map_to_arm_gemm_activation(activation);
     return act.type != arm_gemm::Activation::Type::None;
 }
 
 void CpuGemmAssemblyDispatch::configure(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d, const AsmGemmInfo &info)
 {
     ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d);
-    arm_gemm::Activation act = map_to_arm_gemm_activation(info.activation_info);
+    arm_gemm::Activation act = assembly_utils::map_to_arm_gemm_activation(info.activation_info);
 
     //If we don't support a combination of data types, silently return: it is the caller's responsibility to check if configure() was successful via is_configured()
     if(!CpuGemmAssemblyDispatch::validate(a, b, c, d, info))