COMPMID-3776: Indirect GEMM

Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com>
Change-Id: I51a1b0f098bc3a8c408c50c92221e4df3061e12c
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4343
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Sang-Hoon Park <sang-hoon.park@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/runtime/NEON/functions/NEConvolutionLayer.cpp b/src/runtime/NEON/functions/NEConvolutionLayer.cpp
index 901b1e8..cc5f160 100644
--- a/src/runtime/NEON/functions/NEConvolutionLayer.cpp
+++ b/src/runtime/NEON/functions/NEConvolutionLayer.cpp
@@ -27,27 +27,12 @@
 #include "arm_compute/core/Utils.h"
 #include "arm_compute/core/Validate.h"
 #include "arm_compute/runtime/NEON/NEScheduler.h"
-#include "src/core/NEON/kernels/NECol2ImKernel.h"
-#include "src/core/NEON/kernels/NEConvertQuantizedSignednessKernel.h"
-#include "src/core/NEON/kernels/NECopyKernel.h"
-#include "src/core/NEON/kernels/NEDirectConvolutionLayerKernel.h"
-#include "src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.h"
-#include "src/core/NEON/kernels/NEFFTDigitReverseKernel.h"
-#include "src/core/NEON/kernels/NEFFTRadixStageKernel.h"
-#include "src/core/NEON/kernels/NEFFTScaleKernel.h"
-#include "src/core/NEON/kernels/NEFillBorderKernel.h"
-#include "src/core/NEON/kernels/NEGEMMInterleave4x4Kernel.h"
-#include "src/core/NEON/kernels/NEGEMMLowpMatrixMultiplyKernel.h"
-#include "src/core/NEON/kernels/NEGEMMLowpOffsetContributionKernel.h"
-#include "src/core/NEON/kernels/NEGEMMLowpOffsetContributionOutputStageKernel.h"
-#include "src/core/NEON/kernels/NEGEMMLowpReductionKernel.h"
-#include "src/core/NEON/kernels/NEGEMMMatrixAdditionKernel.h"
-#include "src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.h"
-#include "src/core/NEON/kernels/NEGEMMTranspose1xWKernel.h"
-#include "src/core/NEON/kernels/NEIm2ColKernel.h"
-#include "src/core/NEON/kernels/NEPadLayerKernel.h"
-#include "src/core/NEON/kernels/NEReductionOperationKernel.h"
-#include "src/core/NEON/kernels/NEWeightsReshapeKernel.h"
+#include "arm_compute/runtime/NEON/functions/NEDirectConvolutionLayer.h"
+#include "arm_compute/runtime/NEON/functions/NEFFTConvolutionLayer.h"
+#include "arm_compute/runtime/NEON/functions/NEGEMMConv2d.h"
+#include "arm_compute/runtime/NEON/functions/NEGEMMConvolutionLayer.h"
+#include "arm_compute/runtime/NEON/functions/NEWinogradConvolutionLayer.h"
+
 #include "support/MemorySupport.h"
 
 #include <cmath>
@@ -71,6 +56,7 @@
     ARM_COMPUTE_ERROR_THROW_ON(NEConvolutionLayer::validate(input->info(), weights->info(), ((biases != nullptr) ? biases->info() : nullptr), output->info(), conv_info, weights_info, dilation, act_info,
                                                             enable_fast_math, num_groups));
 
+    const Conv2dInfo info(conv_info, dilation, act_info, enable_fast_math, num_groups);
     switch(NEConvolutionLayer::get_convolution_method(input->info(), weights->info(), output->info(), conv_info, weights_info, dilation, act_info, enable_fast_math))
     {
         case ConvolutionMethod::WINOGRAD:
@@ -87,6 +73,13 @@
             _function = std::move(f);
             break;
         }
+        case ConvolutionMethod::GEMM_CONV2D:
+        {
+            auto f = arm_compute::support::cpp14::make_unique<NEGEMMConv2d>(_memory_manager);
+            f->configure(input, weights, biases, output, info);
+            _function = std::move(f);
+            break;
+        }
         case ConvolutionMethod::DIRECT:
         {
             auto f = arm_compute::support::cpp14::make_unique<NEDirectConvolutionLayer>(_memory_manager);
@@ -112,22 +105,22 @@
 {
     ARM_COMPUTE_RETURN_ERROR_ON_MSG((num_groups != 1), "Grouping (num_groups != 1) is not supported on NEON");
 
+    const Conv2dInfo info(conv_info, dilation, act_info, enable_fast_math, num_groups);
     switch(NEConvolutionLayer::get_convolution_method(input, weights, output, conv_info, weights_info, dilation, act_info, enable_fast_math))
     {
         case ConvolutionMethod::WINOGRAD:
-            //Validate Winograd
             ARM_COMPUTE_RETURN_ON_ERROR(NEWinogradConvolutionLayer::validate(input, weights, biases, output, conv_info, act_info, enable_fast_math));
             break;
         case ConvolutionMethod::GEMM:
-            //Validate Gemm-based Convolution
             ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMConvolutionLayer::validate(input, weights, biases, output, conv_info, weights_info, dilation, act_info));
             break;
+        case ConvolutionMethod::GEMM_CONV2D:
+            ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMConv2d::validate(input, weights, biases, output, info));
+            break;
         case ConvolutionMethod::DIRECT:
-            //Validate Direct Convolution
             ARM_COMPUTE_RETURN_ON_ERROR(NEDirectConvolutionLayer::validate(input, weights, biases, output, conv_info, act_info));
             break;
         case ConvolutionMethod::FFT:
-            // Validate FFT-based convolution layer
             ARM_COMPUTE_RETURN_ON_ERROR(NEFFTConvolutionLayer::validate(input, weights, nullptr, output, conv_info, act_info));
             break;
         default:
@@ -149,6 +142,8 @@
     const size_t idx_h = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::HEIGHT);
     const size_t idx_c = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::CHANNEL);
 
+    const Conv2dInfo info(conv_info, dilation, act_info, enable_fast_math, 1);
+
     /* Input spatial dims, kernel size, IFM/OFM, conv info*/
     using ConvolutionConfiguration = std::tuple<Size2D, Size2D, Size2D, PadStrideInfo>;
     using ConfigurationMethod      = std::pair<ConvolutionConfiguration, ConvolutionMethod>;
@@ -235,7 +230,21 @@
             }
         }
 #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-        return bool(NEWinogradConvolutionLayer::validate(input, weights, nullptr, output, conv_info, act_info, enable_fast_math)) ? ConvolutionMethod::WINOGRAD : ConvolutionMethod::GEMM;
+        // For 1x1 convolutions run the default GEMM
+        if(weights->dimension(idx_w) == 1 && weights->dimension(idx_h) == 1)
+        {
+            return ConvolutionMethod::GEMM;
+        }
+
+        if(bool(NEWinogradConvolutionLayer::validate(input, weights, nullptr, output, conv_info, act_info, enable_fast_math)))
+        {
+            return ConvolutionMethod::WINOGRAD;
+        }
+        if(bool(NEGEMMConv2d::validate(input, weights, nullptr, output, info)))
+        {
+            return ConvolutionMethod::GEMM_CONV2D;
+        }
+        return ConvolutionMethod::GEMM;
     }
 }
 
diff --git a/src/runtime/NEON/functions/NEGEMM.cpp b/src/runtime/NEON/functions/NEGEMM.cpp
index 0215098..9f52e45 100644
--- a/src/runtime/NEON/functions/NEGEMM.cpp
+++ b/src/runtime/NEON/functions/NEGEMM.cpp
@@ -47,7 +47,19 @@
 
 namespace arm_compute
 {
-NEGEMM::~NEGEMM() = default;
+namespace
+{
+AsmGemmInfo init_assembly_metadata(const GEMMInfo &info)
+{
+    AsmGemmInfo asm_info;
+    asm_info.method                  = AsmConvMethod::Im2Col;
+    asm_info.reinterpret_input_as_3d = info.reinterpret_input_as_3d();
+    asm_info.depth_output_gemm3d     = info.depth_output_gemm3d();
+    asm_info.activation_info         = info.activation_info();
+
+    return asm_info;
+}
+} // namespace
 
 NEGEMM::NEGEMM(std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager)
     : _memory_group(memory_manager), _weights_manager(weights_manager), _interleave_kernel(), _transpose_kernel(), _mm_kernel(), _asm_glue(memory_manager, weights_manager), _ma_kernel(),
@@ -56,12 +68,15 @@
 {
 }
 
+NEGEMM::~NEGEMM() = default;
+
 void NEGEMM::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, float alpha, float beta, const GEMMInfo &gemm_info)
 {
     ARM_COMPUTE_ERROR_THROW_ON(NEGEMM::validate(a->info(), b->info(), (c != nullptr) ? c->info() : nullptr, d->info(), alpha, beta, gemm_info));
 
-    const bool is_c_bias     = gemm_info.reshape_b_only_on_first_run();
-    bool       run_optimised = bool(NEGEMMAssemblyDispatch::validate(a->info(), b->info(), (is_c_bias && c != nullptr) ? c->info() : nullptr, d->info(), gemm_info));
+    const AsmGemmInfo asm_info      = init_assembly_metadata(gemm_info);
+    const bool        is_c_bias     = gemm_info.reshape_b_only_on_first_run();
+    bool              run_optimised = bool(NEGEMMAssemblyDispatch::validate(a->info(), b->info(), (is_c_bias && c != nullptr) ? c->info() : nullptr, d->info(), asm_info));
 
     // Check if we need to reshape the matrix B only on the first run
     _is_prepared                      = false;
@@ -76,7 +91,7 @@
     if(run_optimised)
     {
         const ITensor *c_to_use = is_c_bias ? c : nullptr;
-        _asm_glue.configure(a, b, c_to_use, d, gemm_info);
+        _asm_glue.configure(a, b, c_to_use, d, asm_info);
         ARM_COMPUTE_ERROR_ON(!_asm_glue.is_configured());
 
         // Scale product by alpha
@@ -221,7 +236,8 @@
     }
 
     // Check if we need to run the optimized assembly kernel
-    const bool run_optimised = bool(NEGEMMAssemblyDispatch::validate(a, b, is_c_bias ? c : nullptr, output, gemm_info));
+    AsmGemmInfo asm_info      = init_assembly_metadata(gemm_info);
+    const bool  run_optimised = bool(NEGEMMAssemblyDispatch::validate(a, b, is_c_bias ? c : nullptr, output, asm_info));
 
     if(!run_optimised)
     {
diff --git a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
index 5b08483..400fa64 100644
--- a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
+++ b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
@@ -25,18 +25,70 @@
 
 #include "arm_compute/runtime/NEON/NEScheduler.h"
 #include "src/core/CPP/Validate.h"
-#include "src/core/NEON/kernels/assembly/INEGEMMWrapperKernel.h"
 #include "src/core/NEON/kernels/assembly/NEGEMMAssemblyWrapperKernel.h"
 #include "src/core/NEON/kernels/assembly/arm_gemm.hpp"
 
 #include "support/MemorySupport.h"
 
 #include <arm_neon.h>
+#include <cstdlib>
 
 namespace arm_compute
 {
 namespace
 {
+struct free_delete
+{
+    void operator()(void *x)
+    {
+        free(x);
+    }
+};
+
+struct Params
+{
+    unsigned int M;
+    unsigned int N;
+    unsigned int K;
+    unsigned int batches;
+    unsigned int multis;
+    unsigned int sections;
+    bool         indirect;
+};
+
+Params extract_parameters(const ITensor *a, const ITensor *b, const ITensor *d, const AsmGemmInfo &info)
+{
+    ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d);
+
+    Params p;
+    p.K        = a->info()->tensor_shape().x();
+    p.N        = d->info()->tensor_shape().x();
+    p.multis   = 1;
+    p.indirect = false;
+    p.sections = 1;
+
+    if(info.method == AsmConvMethod::Conv || info.method == AsmConvMethod::Indirect)
+    {
+        p.indirect = true;
+        p.sections = b->info()->tensor_shape()[2] * b->info()->tensor_shape()[3];
+    }
+    else
+    {
+        p.M       = d->info()->tensor_shape().y();
+        p.multis  = b->info()->tensor_shape().z();
+        p.batches = d->info()->tensor_shape().total_size_upper(2) / p.multis; //COMPMID-1423: Agree on and document the layout of gemm inputs/outputs
+    }
+
+    // Update M in case of GEMM3D for output
+    if(info.depth_output_gemm3d != 0)
+    {
+        p.M       = d->info()->tensor_shape().y() * d->info()->tensor_shape().z();
+        p.batches = d->info()->tensor_shape().total_size_upper(3) / p.multis;
+    }
+
+    return p;
+}
+
 arm_gemm::Activation map_to_arm_gemm_activation(const ActivationLayerInfo &act)
 {
     arm_gemm::Activation gemm_act;
@@ -69,6 +121,29 @@
     return gemm_act;
 }
 
+IScheduler::Hints scheduling_hint_heuristic(arm_gemm::GemmMethod method, DataType data_type)
+{
+    // Schedule assembly kernel
+    const int         granule_threshold = 200;
+    IScheduler::Hints scheduling_hint   = IScheduler::Hints(Window::DimX);
+    if(method == arm_gemm::GemmMethod::GEMM_INTERLEAVED && data_type == DataType::F32)
+    {
+        scheduling_hint = IScheduler::Hints(Window::DimX, IScheduler::StrategyHint::DYNAMIC, granule_threshold);
+    }
+    else if(method == arm_gemm::GemmMethod::GEMM_INTERLEAVED_2D && (data_type == DataType::F32 || data_type == DataType::F16 || data_type == DataType::U8 || data_type == DataType::S8))
+    {
+        //GEMM_INTERLEAVED supports 2D parallelism, IScheduler::split_dimensions_all signals to parallelise over all window dimensions
+        scheduling_hint = IScheduler::Hints(IScheduler::split_dimensions_all, IScheduler::StrategyHint::STATIC, granule_threshold);
+    }
+    else if(method == arm_gemm::GemmMethod::QUANTIZE_WRAPPER_2D && (data_type == DataType::QASYMM8 || data_type == DataType::QASYMM8_SIGNED))
+    {
+        //special case for QASYMM8 to support 2D parallelism, scheduler here may be tweaked differently compared to FP32 case
+        scheduling_hint = IScheduler::Hints(IScheduler::split_dimensions_all, IScheduler::StrategyHint::STATIC, granule_threshold);
+    }
+
+    return scheduling_hint;
+}
+
 template <typename TypeInput, typename TypeOutput>
 class FallbackTransform : public ITransformWeights
 {
@@ -165,7 +240,7 @@
      * @param[in]  os              Output stage meta-data.
      */
     void configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d,
-                   arm_gemm::GemmArgs args, const GEMMInfo &gemm_info,
+                   arm_gemm::GemmArgs args, const AsmGemmInfo &gemm_info,
                    MemoryGroup &memory_group, IWeightsManager *weights_manager, const OutputStage &os = {});
 
     /** Set requantization shifts to be used
@@ -198,6 +273,16 @@
      * @param[in] alignment      Workspace memory alignment.
      */
     void allocate_workspace(size_t workspace_size, MemoryGroup &memory_group, size_t alignment);
+    /** Configure the indirect buffer
+     *
+     * @param[in]  a    Input tensor containing the Matrix A.
+     * @param[in]  b    Input tensor containing the Matrix B.
+     * @param[out] d    Output tensor to store the result of matrix multiplication.
+     * @param[in]  info GEMM meta-data
+     */
+    void configure_indirect(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *d, const AsmGemmInfo &info);
+    /** Prepare the indirect buffer */
+    void prepare_indirect_buffer();
 
     /** Assembly Gemm kernel */
     std::shared_ptr<arm_gemm::GemmCommon<TypeInput, TypeOutput>> _gemm_kernel_asm{ nullptr };
@@ -226,7 +311,7 @@
     /** Prepared flag */
     bool _is_prepared{ false };
     /** GEMM meta-data */
-    GEMMInfo _gemm_info{};
+    AsmGemmInfo _gemm_info{};
     /** Weights manager */
     IWeightsManager *_weights_manager{ nullptr };
     /** Weights transform object */
@@ -239,11 +324,16 @@
     std::vector<int32_t> left_shifts{};
     /** Per channel quantization multipliers */
     std::vector<int32_t> _multipliers{};
+    /** Indirect buffer */
+    std::unique_ptr<const TypeInput *const *, free_delete> _indirect_arg{};
+    std::unique_ptr<const TypeInput *, free_delete>        _indirect_buf{};
+    std::vector<TypeInput>          _indirect_pad{};
+    arm_gemm::ConvolutionParameters _cp{};
 };
 
 template <typename TypeInput, typename TypeOutput, class OutputStage>
-std::tuple<bool, const int32_t *, const int32_t *, const int32_t *> Fallback<TypeInput, TypeOutput, OutputStage>::set_requantize_data(const std::vector<int32_t> &shifts,
-        const std::vector<int32_t> &multipliers)
+std::tuple<bool, const int32_t *, const int32_t *, const int32_t *>
+Fallback<TypeInput, TypeOutput, OutputStage>::set_requantize_data(const std::vector<int32_t> &shifts, const std::vector<int32_t> &multipliers)
 {
     _multipliers   = multipliers;
     _shifts        = shifts;
@@ -261,8 +351,122 @@
 }
 
 template <typename TypeInput, typename TypeOutput, class OutputStage>
+void Fallback<TypeInput, TypeOutput, OutputStage>::prepare_indirect_buffer()
+{
+    const TypeInput *A_ptr          = reinterpret_cast<TypeInput *>(_a->buffer());
+    const int        multis         = 1;
+    const int        batches        = _a->info()->tensor_shape().total_size_upper(3);
+    const size_t     stride_A       = _a->info()->strides_in_bytes().y() / sizeof(TypeInput);
+    const size_t     batch_stride_A = _a->info()->strides_in_bytes()[3] / sizeof(TypeInput);
+    const size_t     multi_stride_A = _a->info()->strides_in_bytes()[4] / sizeof(TypeInput);
+
+    const size_t output_hw    = _cp.output_height * _cp.output_width;
+    const int    batch_size   = _cp.kernel_height * _cp.kernel_width * output_hw * sizeof(TypeInput);
+    const size_t batch_stride = batch_size / sizeof(TypeInput);
+    const int    multi_size   = batch_size * batches;
+    const size_t multi_stride = multi_size / sizeof(TypeInput);
+
+    for(int64_t m = 0; m < multis; m++)
+    {
+        for(int64_t b = 0; b < batches; b++)
+        {
+            for(int64_t output_y = 0; output_y < _cp.output_height; output_y++)
+            {
+                for(int64_t output_x = 0; output_x < _cp.output_width; output_x++)
+                {
+                    int64_t output_xy = (output_y * _cp.output_width) + output_x;
+
+                    for(int64_t kernel_y = 0; kernel_y < _cp.kernel_height; kernel_y++)
+                    {
+                        for(int64_t kernel_x = 0; kernel_x < _cp.kernel_width; kernel_x++)
+                        {
+                            int64_t input_x   = (output_x * _cp.output_stride_w) + kernel_x - _cp.padding_left;
+                            int64_t input_y   = (output_y * _cp.output_stride_h) + kernel_y - _cp.padding_top;
+                            int64_t kernel_xy = (kernel_y * _cp.kernel_width) + kernel_x;
+                            int64_t input_xy  = (input_y * _cp.input_width) + input_x;
+
+                            if(input_x < 0 || input_x >= _cp.input_width || input_y < 0 || input_y >= _cp.input_height)
+                            {
+                                _indirect_buf.get()[m * multi_stride + b * batch_stride + kernel_xy * output_hw + output_xy] = _indirect_pad.data();
+                            }
+                            else
+                            {
+                                _indirect_buf.get()[m * multi_stride + b * batch_stride + kernel_xy * output_hw + output_xy] =
+                                    A_ptr + (m * multi_stride_A + b * batch_stride_A + input_xy * stride_A);
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    }
+}
+
+template <typename TypeInput, typename TypeOutput, class OutputStage>
+void Fallback<TypeInput, TypeOutput, OutputStage>::configure_indirect(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *d, const AsmGemmInfo &info)
+{
+    ARM_COMPUTE_ERROR_ON(!(info.method == AsmConvMethod::Conv || info.method == AsmConvMethod::Indirect));
+
+    float zeropad = 0.f;
+    if(is_data_type_quantized(a->data_type()))
+    {
+        zeropad = a->quantization_info().uniform().offset;
+    }
+
+    const int64_t input_width    = static_cast<int64_t>(a->tensor_shape()[1]);
+    const int64_t input_height   = static_cast<int64_t>(a->tensor_shape()[2]);
+    const int64_t input_channels = static_cast<int64_t>(a->tensor_shape()[0]);
+    const int64_t kernel_width   = static_cast<int64_t>(b->tensor_shape()[2]);
+    const int64_t kernel_height  = static_cast<int64_t>(b->tensor_shape()[3]);
+    const int64_t output_width   = static_cast<int64_t>(d->tensor_shape()[1]);
+    const int64_t output_height  = static_cast<int64_t>(d->tensor_shape()[2]);
+
+    _cp = { input_width, input_height, input_channels, kernel_width, kernel_height, output_width, output_height,
+            info.ps_info.stride().first, info.ps_info.stride().second, info.padding_top, info.padding_left, zeropad
+          };
+
+    if(info.method == AsmConvMethod::Conv)
+    {
+        _gemm_kernel_asm->set_convolution_parameters(_cp);
+    }
+
+    if(info.method == AsmConvMethod::Indirect)
+    {
+        const unsigned int multis    = 1;
+        const unsigned int batches   = a->tensor_shape().total_size_upper(3);
+        const unsigned int kernel_hw = _cp.kernel_width * _cp.kernel_height;
+        const unsigned int output_hw = _cp.output_width * _cp.output_height;
+
+        using TypeInputPtr        = TypeInput *;
+        const int    batch_size   = kernel_hw * output_hw * sizeof(TypeInputPtr);
+        const size_t batch_stride = batch_size / sizeof(TypeInputPtr);
+        const int    multi_size   = batch_size * batches;
+        const size_t multi_stride = multi_size / sizeof(TypeInputPtr);
+
+        _indirect_buf = std::unique_ptr<const TypeInput *, free_delete>(reinterpret_cast<const TypeInput **>(malloc(multi_size * multis)));
+        _indirect_arg = std::unique_ptr<const TypeInput *const *, free_delete>(reinterpret_cast<const TypeInput *const **>(malloc(sizeof(TypeInput **) * kernel_hw * multis * batches)));
+        _indirect_pad = std::vector<TypeInput>(_cp.input_channels, zeropad);
+
+        // Set indirect argument
+        int64_t pos = 0;
+        for(int64_t m = 0; m < multis; m++)
+        {
+            for(int64_t b = 0; b < batches; b++)
+            {
+                for(int64_t kernel_xy = 0; kernel_xy < kernel_hw; kernel_xy++)
+                {
+                    (_indirect_arg.get())[pos++] = _indirect_buf.get() + m * multi_stride + b * batch_stride + kernel_xy * output_hw;
+                }
+            }
+        }
+
+        _gemm_kernel_asm->set_indirect_parameters(a->tensor_shape()[0], _indirect_arg.get());
+    }
+}
+
+template <typename TypeInput, typename TypeOutput, class OutputStage>
 void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d,
-                                                             arm_gemm::GemmArgs args, const GEMMInfo &gemm_info,
+                                                             arm_gemm::GemmArgs args, const AsmGemmInfo &gemm_info,
                                                              MemoryGroup &memory_group, IWeightsManager *weights_manager, const OutputStage &os)
 {
     arm_gemm::GemmConfig gemm_cfg;
@@ -325,6 +529,12 @@
             static_cast<Tensor *>(_pretranspose)->allocator()->init(TensorInfo(TensorShape{ (B_pretranspose_size + alignment /* FIXME: remove alignment after COMPMID-1088 */) }, 1, DataType::S8), alignment);
         }
     }
+
+    // Handle indirect GEMM convolution
+    if(gemm_info.method == AsmConvMethod::Conv || gemm_info.method == AsmConvMethod::Indirect)
+    {
+        configure_indirect(a->info(), b->info(), d->info(), gemm_info);
+    }
 }
 
 template <typename TypeInput, typename TypeOutput, class OutputStage>
@@ -365,6 +575,11 @@
             }
         }
 
+        if(_gemm_info.method == AsmConvMethod::Indirect)
+        {
+            prepare_indirect_buffer();
+        }
+
         _is_prepared = true;
     }
 }
@@ -387,23 +602,23 @@
 template <typename TypeInput, typename TypeOutput, class OutputStage>
 void Fallback<TypeInput, TypeOutput, OutputStage>::run()
 {
-    const int lda = _a->info()->strides_in_bytes().y() / sizeof(TypeInput);
+    int       lda = _a->info()->strides_in_bytes().y() / sizeof(TypeInput);
     int       ldb = 0;
     const int ldd = _d->info()->strides_in_bytes().y() / sizeof(TypeOutput);
 
-    const size_t a_batch_idx = _gemm_info.reinterpret_input_as_3d() != 0 ? 3 : 2;
+    const size_t a_batch_idx = _gemm_info.reinterpret_input_as_3d != 0 ? 3 : 2;
     const size_t a_multi_idx = a_batch_idx + 1;
-    const size_t d_batch_idx = _gemm_info.depth_output_gemm3d() != 0 ? 3 : 2;
+    const size_t d_batch_idx = _gemm_info.depth_output_gemm3d != 0 ? 3 : 2;
     const size_t d_multi_idx = d_batch_idx + 1;
 
-    const int batch_stride_a = _a->info()->strides_in_bytes()[a_batch_idx] / sizeof(TypeInput);
+    int       batch_stride_a = _a->info()->strides_in_bytes()[a_batch_idx] / sizeof(TypeInput);
     const int batch_stride_d = _d->info()->strides_in_bytes()[d_batch_idx] / sizeof(TypeOutput);
 
-    const int multi_stride_a = _a->info()->strides_in_bytes()[a_multi_idx] / sizeof(TypeInput);
+    int       multi_stride_a = _a->info()->strides_in_bytes()[a_multi_idx] / sizeof(TypeInput);
     int       multi_stride_b = 0;
     const int multi_stride_d = _d->info()->strides_in_bytes()[d_multi_idx] / sizeof(TypeOutput);
 
-    const auto       in0_ptr = reinterpret_cast<const TypeInput *>(_a->buffer() + _a->info()->offset_first_element_in_bytes());
+    auto             in0_ptr = reinterpret_cast<const TypeInput *>(_a->buffer() + _a->info()->offset_first_element_in_bytes());
     const TypeInput *in1_ptr = nullptr;
     auto             out_ptr = reinterpret_cast<TypeOutput *>(_d->buffer() + _d->info()->offset_first_element_in_bytes());
 
@@ -415,25 +630,7 @@
         in1_ptr        = reinterpret_cast<const TypeInput *>(_b->buffer() + _b->info()->offset_first_element_in_bytes());
     }
 
-    IScheduler::Hints scheduling_hint = IScheduler::Hints(Window::DimX);
-    if(_kernel_info.method == arm_gemm::GemmMethod::GEMM_INTERLEAVED && _d->info()->data_type() == DataType::F32)
-    {
-        const int granule_threshold = 200;
-        scheduling_hint             = IScheduler::Hints(Window::DimX, IScheduler::StrategyHint::DYNAMIC, granule_threshold);
-    }
-    else if(_kernel_info.method == arm_gemm::GemmMethod::GEMM_INTERLEAVED_2D && (_d->info()->data_type() == DataType::F32 || _d->info()->data_type() == DataType::F16
-                                                                                 || _d->info()->data_type() == DataType::U8 || _d->info()->data_type() == DataType::S8))
-    {
-        //GEMM_INTERLEAVED supports 2D parallelism, IScheduler::split_dimensions_all signals to parallelise over all window dimensions
-        const int granule_threshold = 200;
-        scheduling_hint             = IScheduler::Hints(IScheduler::split_dimensions_all, IScheduler::StrategyHint::STATIC, granule_threshold);
-    }
-    else if(_kernel_info.method == arm_gemm::GemmMethod::QUANTIZE_WRAPPER_2D && (_d->info()->data_type() == DataType::QASYMM8 || _d->info()->data_type() == DataType::QASYMM8_SIGNED))
-    {
-        //special case for QASYMM8 to support 2D parallelism, scheduler here may be tweaked differently compared to FP32 case
-        const int granule_threshold = 200;
-        scheduling_hint             = IScheduler::Hints(IScheduler::split_dimensions_all, IScheduler::StrategyHint::STATIC, granule_threshold);
-    }
+    const auto scheduling_hint = scheduling_hint_heuristic(_kernel_info.method, _d->info()->data_type());
 
     // Set workspace if needed and reset number of threads as buffer manager gets re-created with max_threads
     if(_workspace.buffer() != nullptr)
@@ -458,57 +655,67 @@
     // Prepare assembly kernel
     prepare();
 
-    TypeOutput *bias = nullptr;
     // Setup up matrix bias in the assembly kernel, it's just a pointer to matrix C.
+    TypeOutput *bias = nullptr;
     if(_c && _c->info()->data_type() != DataType::S32)
     {
         bias = reinterpret_cast<TypeOutput *>(_c->buffer() + _c->info()->offset_first_element_in_bytes());
     }
+
+    if(_gemm_info.method == AsmConvMethod::Indirect)
+    {
+        in0_ptr        = nullptr;
+        lda            = 0;
+        batch_stride_a = 0;
+        multi_stride_a = 0;
+    }
+
     // Set gemm parameters
     _gemm_kernel_asm->set_arrays(in0_ptr, lda, batch_stride_a, multi_stride_a,
                                  in1_ptr, ldb, multi_stride_b,
                                  out_ptr, ldd, batch_stride_d, multi_stride_d,
                                  bias, 0);
-    // Schedule assembly kernel
+    // Schedule
     NEScheduler::get().schedule(_optimised_kernel.get(), scheduling_hint);
 }
 
 template <typename TypeInput, typename TypeOutput>
 void create_arm_gemm(std::unique_ptr<NEGEMMAssemblyDispatch::IFallback> &arm_gemm, MemoryGroup &memory_group,
-                     const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, arm_gemm::Activation activation, const GEMMInfo &gemm_info,
+                     const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, arm_gemm::Activation activation, const AsmGemmInfo &info,
                      IWeightsManager *weights_manager)
 {
-    INEGEMMWrapperKernel::Params p           = INEGEMMWrapperKernel::extract_parameters(a, b, d, gemm_info);
-    const CPUInfo               &ci          = NEScheduler::get().cpu_info();
-    unsigned int                 num_threads = NEScheduler::get().num_threads();
+    Params         p           = extract_parameters(a, b, d, info);
+    const CPUInfo &ci          = NEScheduler::get().cpu_info();
+    unsigned int   num_threads = NEScheduler::get().num_threads();
 
-    arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.batches, p.multis, activation, num_threads);
+    arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads);
 
     // Create arm_gemm fallback
     auto fallback = support::cpp14::make_unique<Fallback<TypeInput, TypeOutput>>();
-    fallback->configure(a, b, c, d, args, gemm_info, memory_group, weights_manager);
+    fallback->configure(a, b, c, d, args, info, memory_group, weights_manager);
     arm_gemm = std::move(fallback);
 }
 
 template <typename TypeInput, typename TypeOutput>
 void create_arm_gemm_quant(std::unique_ptr<NEGEMMAssemblyDispatch::IFallback> &arm_gemm, MemoryGroup &memory_group,
-                           const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, arm_gemm::Activation activation, const GEMMInfo &gemm_info,
+                           const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, arm_gemm::Activation activation, const AsmGemmInfo &info,
                            IWeightsManager *weights_manager)
 {
     ARM_COMPUTE_UNUSED(activation);
-    INEGEMMWrapperKernel::Params p           = INEGEMMWrapperKernel::extract_parameters(a, b, d, gemm_info);
-    const CPUInfo               &ci          = NEScheduler::get().cpu_info();
-    unsigned int                 num_threads = NEScheduler::get().num_threads();
+    Params         p           = extract_parameters(a, b, d, info);
+    const CPUInfo &ci          = NEScheduler::get().cpu_info();
+    unsigned int   num_threads = NEScheduler::get().num_threads();
 
-    arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.batches, p.multis, activation, num_threads);
+    arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads);
 
     // Create arm_gemm fallback
     auto fallback = support::cpp14::make_unique<Fallback<TypeInput, TypeOutput, arm_gemm::Requantize32>>();
 
     // Configure requantization info
-    const int32_t                 a_offset = -a->info()->quantization_info().uniform().offset;
-    const int32_t                 b_offset = -b->info()->quantization_info().uniform().offset;
-    const GEMMLowpOutputStageInfo os_info  = gemm_info.gemmlowp_output_stage();
+    const int32_t                 negation = info.negated_offsets ? 1 : -1;
+    const int32_t                 a_offset = -a->info()->quantization_info().uniform().offset * negation;
+    const int32_t                 b_offset = -b->info()->quantization_info().uniform().offset * negation;
+    const GEMMLowpOutputStageInfo os_info  = info.output_stage;
 
     arm_gemm::Requantize32 gemm_requant_info{};
     if(os_info.gemmlowp_shifts.size() > 1)
@@ -530,7 +737,7 @@
     }
 
     // Configure fallback
-    fallback->configure(a, b, c, d, args, gemm_info, memory_group, weights_manager, gemm_requant_info);
+    fallback->configure(a, b, c, d, args, info, memory_group, weights_manager, gemm_requant_info);
     arm_gemm = std::move(fallback);
 }
 
@@ -541,14 +748,13 @@
 {
 }
 
-Status NEGEMMAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const GEMMInfo &gemm_info)
+Status NEGEMMAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info)
 {
-    ARM_COMPUTE_UNUSED(c);
+    ARM_COMPUTE_UNUSED(c, info);
     ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(a, b, d);
     ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(a);
     ARM_COMPUTE_RETURN_ERROR_ON_CPU_BF16_UNSUPPORTED(a);
 
-    ARM_COMPUTE_RETURN_ERROR_ON(!gemm_info.pretranpose_B());
 #ifndef __aarch64__
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->element_size() == 1, "8bit integer types only supported for aarch64");
 #endif /* __aarch64__ */
@@ -579,13 +785,13 @@
     return act.type != arm_gemm::Activation::Type::None;
 }
 
-void NEGEMMAssemblyDispatch::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, const GEMMInfo &gemm_info)
+void NEGEMMAssemblyDispatch::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, const AsmGemmInfo &info)
 {
     ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d);
-    arm_gemm::Activation act = map_to_arm_gemm_activation(gemm_info.activation_info());
+    arm_gemm::Activation act = map_to_arm_gemm_activation(info.activation_info);
 
     //If we don't support a combination of data types, silently return: it is the caller's responsibility to check if configure() was successful via is_configured()
-    if(!NEGEMMAssemblyDispatch::validate(a->info(), b->info(), c != nullptr ? c->info() : nullptr, d->info(), gemm_info))
+    if(!NEGEMMAssemblyDispatch::validate(a->info(), b->info(), c != nullptr ? c->info() : nullptr, d->info(), info))
     {
         return;
     }
@@ -593,40 +799,40 @@
     switch(a->info()->data_type())
     {
         case DataType::F32:
-            create_arm_gemm<float, float>(_arm_gemm, _memory_group, a, b, c, d, act, gemm_info, _weights_manager);
+            create_arm_gemm<float, float>(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager);
             break;
 #ifdef __aarch64__
         case DataType::U8:
         case DataType::QASYMM8:
             if(d->info()->data_type() == DataType::S32)
             {
-                create_arm_gemm<uint8_t, uint32_t>(_arm_gemm, _memory_group, a, b, c, d, act, gemm_info, _weights_manager);
+                create_arm_gemm<uint8_t, uint32_t>(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager);
             }
             else
             {
-                create_arm_gemm_quant<uint8_t, uint8_t>(_arm_gemm, _memory_group, a, b, c, d, act, gemm_info, _weights_manager);
+                create_arm_gemm_quant<uint8_t, uint8_t>(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager);
             }
             break;
         case DataType::S8:
         case DataType::QASYMM8_SIGNED:
             if(d->info()->data_type() == DataType::S32)
             {
-                create_arm_gemm<int8_t, int32_t>(_arm_gemm, _memory_group, a, b, c, d, act, gemm_info, _weights_manager);
+                create_arm_gemm<int8_t, int32_t>(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager);
             }
             else
             {
-                create_arm_gemm_quant<int8_t, int8_t>(_arm_gemm, _memory_group, a, b, c, d, act, gemm_info, _weights_manager);
+                create_arm_gemm_quant<int8_t, int8_t>(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager);
             }
             break;
 #endif /* __aarch64__ */
 #if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16)
         case DataType::BFLOAT16:
-            create_arm_gemm<bfloat16, float>(_arm_gemm, _memory_group, a, b, c, d, act, gemm_info, _weights_manager);
+            create_arm_gemm<bfloat16, float>(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager);
             break;
 #endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */
 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
         case DataType::F16:
-            create_arm_gemm<float16_t, float16_t>(_arm_gemm, _memory_group, a, b, c, d, act, gemm_info, _weights_manager);
+            create_arm_gemm<float16_t, float16_t>(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager);
             break;
 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
         default:
diff --git a/src/runtime/NEON/functions/NEGEMMConv2d.cpp b/src/runtime/NEON/functions/NEGEMMConv2d.cpp
new file mode 100644
index 0000000..642b084
--- /dev/null
+++ b/src/runtime/NEON/functions/NEGEMMConv2d.cpp
@@ -0,0 +1,167 @@
+/*
+ * Copyright (c) 2020 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "arm_compute/runtime/NEON/functions/NEGEMMConv2d.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
+#include "arm_compute/core/utils/quantization/AsymmHelpers.h"
+#include "arm_compute/runtime/NEON/NEScheduler.h"
+#include <set>
+namespace arm_compute
+{
+namespace
+{
+GEMMLowpOutputStageInfo calculate_output_stage_metadata(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *output, const ActivationLayerInfo &act)
+{
+    // Since we need negative offsets for computing convolution, we need to change QuantizationInfo()
+    // Extract and negate input and weights offset
+    const QuantizationInfo        iqinfo    = input->quantization_info();
+    const QuantizationInfo        wqinfo    = weights->quantization_info();
+    const QuantizationInfo        oqinfo    = (output->total_size() == 0) ? iqinfo : output->quantization_info();
+    const UniformQuantizationInfo uoqinfo   = oqinfo.uniform();
+    const DataType                data_type = input->data_type();
+    // Merge activation with output stage
+    const std::set<ActivationLayerInfo::ActivationFunction> supported_acts = { ActivationLayerInfo::ActivationFunction::RELU,
+                                                                               ActivationLayerInfo::ActivationFunction::BOUNDED_RELU,
+                                                                               ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU
+                                                                             };
+    PixelValue type_min{};
+    PixelValue type_max{};
+    std::tie(type_min, type_max) = get_min_max(data_type);
+    int32_t min_activation = type_min.get<int32_t>();
+    int32_t max_activation = type_max.get<int32_t>();
+    if(supported_acts.count(act.activation()) != 0)
+    {
+        std::tie(min_activation, max_activation) = get_quantized_activation_min_max(act, data_type, uoqinfo);
+    }
+    GEMMLowpOutputStageInfo os_info;
+    os_info.type                     = GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT;
+    os_info.gemmlowp_offset          = uoqinfo.offset;
+    os_info.gemmlowp_min_bound       = min_activation;
+    os_info.gemmlowp_max_bound       = max_activation;
+    os_info.is_quantized_per_channel = (weights->data_type() == DataType::QSYMM8_PER_CHANNEL);
+    quantization::calculate_quantized_multipliers(iqinfo, wqinfo, oqinfo, os_info);
+    return os_info;
+}
+AsmGemmInfo init_assembly_metadata(const Conv2dInfo &info, bool is_indirect)
+{
+    AsmGemmInfo asm_info;
+    asm_info.method                  = is_indirect ? AsmConvMethod::Indirect : AsmConvMethod::Conv;
+    asm_info.ps_info                 = info.conv_info;
+    asm_info.activation_info         = info.act_info;
+    asm_info.depth_output_gemm3d     = true;
+    asm_info.reinterpret_input_as_3d = true;
+    asm_info.padding_top             = info.conv_info.pad_top();
+    asm_info.padding_left            = info.conv_info.pad_left();
+    asm_info.padding_value           = 0.f;
+    asm_info.negated_offsets         = false;
+    return asm_info;
+}
+} // namespace
+
+NEGEMMConv2d::NEGEMMConv2d(const std::shared_ptr<IMemoryManager> &memory_manager)
+    : _gemm_asm_func(memory_manager), _activation_func(), _weights_permute_func(), _original_weights(nullptr), _permuted_weights(), _is_prepared(false), _run_activation(false)
+{
+}
+void NEGEMMConv2d::configure(ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, const Conv2dInfo &info)
+{
+    ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output);
+    ARM_COMPUTE_ERROR_THROW_ON(NEGEMMConv2d::validate(input->info(),
+                                                      weights->info(),
+                                                      biases != nullptr ? biases->info() : nullptr,
+                                                      output->info(),
+                                                      info));
+    _original_weights = weights;
+    _weights_permute_func.configure(weights, &_permuted_weights, PermutationVector{ 3, 0, 1, 2 });
+
+    // Configure assembly dispatch
+    AsmGemmInfo asm_info = init_assembly_metadata(info, false);
+    if(is_data_type_quantized(input->info()->data_type()))
+    {
+        asm_info.output_stage = calculate_output_stage_metadata(input->info(), weights->info(), output->info(), info.act_info);
+    }
+    _gemm_asm_func.configure(input, &_permuted_weights, biases, output, asm_info);
+
+    // Configure activation
+    if(info.act_info.enabled() && !_gemm_asm_func.is_activation_supported(info.act_info))
+    {
+        _activation_func.configure(output, nullptr, info.act_info);
+        _run_activation = true;
+    }
+}
+Status NEGEMMConv2d::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const Conv2dInfo &info)
+{
+    ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, output);
+    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::BFLOAT16, DataType::F16, DataType::F32);
+    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(weights, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM8_PER_CHANNEL, DataType::BFLOAT16, DataType::F16, DataType::F32);
+    ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(input, weights);
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(info.num_groups > 1, "Grouping (num_groups != 1) is not supported on NEON");
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_layout() != DataLayout::NHWC, "Data layout supported is NHWC");
+    const DataType    data_type = input->data_type();
+    const TensorShape i_shape   = input->tensor_shape();
+    const TensorShape w_shape   = weights->tensor_shape();
+    ARM_COMPUTE_RETURN_ERROR_ON(w_shape[0] != i_shape[0]);
+    ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 4);
+    // Validate biases
+    if(biases != nullptr)
+    {
+        if(is_data_type_quantized_asymmetric(data_type))
+        {
+            ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(biases, 1, DataType::S32);
+        }
+        else if(data_type == DataType::BFLOAT16)
+        {
+            ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(biases, 1, DataType::F32);
+        }
+        else
+        {
+            ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, biases);
+        }
+        ARM_COMPUTE_RETURN_ERROR_ON(biases->dimension(0) != weights->dimension(3));
+        ARM_COMPUTE_RETURN_ERROR_ON(biases->num_dimensions() > 1);
+    }
+
+    AsmGemmInfo asm_info = init_assembly_metadata(info, false);
+    ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMAssemblyDispatch::validate(input, weights, biases, output, asm_info));
+    return Status{};
+}
+void NEGEMMConv2d::run()
+{
+    prepare();
+
+    _gemm_asm_func.run();
+    if(_run_activation)
+    {
+        _activation_func.run();
+    }
+}
+void NEGEMMConv2d::prepare()
+{
+    if(!_is_prepared)
+    {
+        _permuted_weights.allocator()->allocate();
+        _weights_permute_func.run();
+        _original_weights->mark_as_unused();
+        _is_prepared = true;
+    }
+}
+} // namespace arm_compute
diff --git a/src/runtime/NEON/functions/NEGEMMLowpAssemblyMatrixMultiplyCore.cpp b/src/runtime/NEON/functions/NEGEMMLowpAssemblyMatrixMultiplyCore.cpp
deleted file mode 100644
index 09637dd..0000000
--- a/src/runtime/NEON/functions/NEGEMMLowpAssemblyMatrixMultiplyCore.cpp
+++ /dev/null
@@ -1,142 +0,0 @@
-/*
- * Copyright (c) 2017-2020 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-#include "arm_compute/runtime/NEON/functions/NEGEMMLowpAssemblyMatrixMultiplyCore.h"
-
-#include "arm_compute/core/Error.h"
-#include "arm_compute/core/Helpers.h"
-#include "arm_compute/core/ITensor.h"
-#include "arm_compute/core/TensorInfo.h"
-#include "arm_compute/core/Types.h"
-#include "arm_compute/core/Validate.h"
-#include "arm_compute/runtime/NEON/NEScheduler.h"
-#include "arm_compute/runtime/TensorAllocator.h"
-#include "src/core/NEON/kernels/NEGEMMInterleave4x4Kernel.h"
-#include "src/core/NEON/kernels/NEGEMMLowpMatrixMultiplyKernel.h"
-#include "src/core/NEON/kernels/NEGEMMTranspose1xWKernel.h"
-#include "support/MemorySupport.h"
-
-namespace arm_compute
-{
-NEGEMMLowpAssemblyMatrixMultiplyCore::~NEGEMMLowpAssemblyMatrixMultiplyCore() = default;
-
-NEGEMMLowpAssemblyMatrixMultiplyCore::NEGEMMLowpAssemblyMatrixMultiplyCore(std::shared_ptr<IMemoryManager> memory_manager)
-    : _memory_group(memory_manager), _asm_glue(memory_manager), _mm_kernel(nullptr), _mtx_a_reshape_kernel(nullptr), _mtx_b_reshape_kernel(nullptr), _tmp_a(), _tmp_b()
-{
-}
-
-void NEGEMMLowpAssemblyMatrixMultiplyCore::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *output)
-{
-    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::U8, DataType::S8);
-    ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U32, DataType::S32);
-    ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(a, b);
-    ARM_COMPUTE_ERROR_ON_MSG((a)->info()->dimension(0) != (b)->info()->dimension(1), "The product AB is defined only if the number of columns in A is equal to the number of rows in B");
-    ARM_COMPUTE_ERROR_ON_MSG((a)->info()->dimension(1) != (output)->info()->dimension(1), "The output matrix must have the same number of rows as the matrix A");
-    ARM_COMPUTE_ERROR_ON_MSG((b)->info()->dimension(0) != (output)->info()->dimension(0), "The output matrix must have the same number of columns as the matrix B");
-
-    bool run_optimised = false;
-    switch(a->info()->data_type())
-    {
-        case DataType::S8:
-        case DataType::QASYMM8:
-        case DataType::U8:
-        {
-            _asm_glue.configure(a, b, c, output, GEMMInfo(false, false, true));
-            run_optimised = _asm_glue.is_configured();
-            break;
-        }
-        default:
-        {
-            ARM_COMPUTE_ERROR("Datatype not supported");
-            break;
-        }
-    }
-    if(!run_optimised)
-    {
-        // The interleaved output matrix will have the following shape: [ a_height * 4, ceil(a_width / 4.0f) ]
-        TensorShape shape_tmp_a = a->info()->tensor_shape();
-        shape_tmp_a.set(0, a->info()->dimension(0) * 4);
-        shape_tmp_a.set(1, std::ceil(a->info()->dimension(1) / 4.f));
-
-        // The transpose1xW output matrix will have the following shape: [ b_height * 16, ceil(b_width / 16.0f) ]
-        TensorShape shape_tmp_b = b->info()->tensor_shape();
-        shape_tmp_b.set(0, b->info()->dimension(1) * 16);
-        shape_tmp_b.set(1, std::ceil(b->info()->dimension(0) / 16.f));
-
-        TensorInfo info_a(shape_tmp_a, 1, a->info()->data_type());
-        TensorInfo info_b(shape_tmp_b, 1, b->info()->data_type());
-        _tmp_a.allocator()->init(info_a);
-        _tmp_b.allocator()->init(info_b);
-        _memory_group.manage(&_tmp_a);
-        _memory_group.manage(&_tmp_b);
-
-        // Configure interleave kernel
-        {
-            auto k = arm_compute::support::cpp14::make_unique<NEGEMMInterleave4x4Kernel>();
-            k->configure(a, &_tmp_a);
-            _mtx_a_reshape_kernel = std::move(k);
-        }
-
-        // Configure transpose kernel
-        {
-            auto k = arm_compute::support::cpp14::make_unique<NEGEMMTranspose1xWKernel>();
-            k->configure(b, &_tmp_b);
-            _mtx_b_reshape_kernel = std::move(k);
-        }
-
-        // Configure matrix multiply kernel
-        {
-            auto k = arm_compute::support::cpp14::make_unique<NEGEMMLowpMatrixMultiplyKernel>();
-            k->configure(&_tmp_a, &_tmp_b, output);
-            _mm_kernel = std::move(k);
-        }
-
-        // Allocate tensors
-        _tmp_a.allocator()->allocate();
-        _tmp_b.allocator()->allocate();
-    }
-}
-
-void NEGEMMLowpAssemblyMatrixMultiplyCore::run()
-{
-    MemoryGroupResourceScope scope_mg(_memory_group);
-    if(_mtx_a_reshape_kernel)
-    {
-        NEScheduler::get().schedule(_mtx_a_reshape_kernel.get(), Window::DimY);
-    }
-
-    if(_mtx_b_reshape_kernel)
-    {
-        NEScheduler::get().schedule(_mtx_b_reshape_kernel.get(), Window::DimY);
-    }
-
-    if(_asm_glue.is_configured())
-    {
-        _asm_glue.run();
-    }
-    else
-    {
-        NEScheduler::get().schedule(_mm_kernel.get(), Window::DimY);
-    }
-}
-} // namespace arm_compute
\ No newline at end of file
diff --git a/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp b/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp
index 9050427..df8eaac 100644
--- a/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp
+++ b/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp
@@ -47,6 +47,21 @@
 
 namespace arm_compute
 {
+namespace
+{
+AsmGemmInfo init_assembly_metadata(const GEMMInfo &info)
+{
+    AsmGemmInfo asm_info;
+    asm_info.method                  = AsmConvMethod::Im2Col;
+    asm_info.reinterpret_input_as_3d = info.reinterpret_input_as_3d();
+    asm_info.depth_output_gemm3d     = info.depth_output_gemm3d();
+    asm_info.activation_info         = info.activation_info();
+    asm_info.output_stage            = info.gemmlowp_output_stage();
+
+    return asm_info;
+}
+} // namespace
+
 using namespace arm_compute::misc::shape_calculator;
 
 NEGEMMLowpMatrixMultiplyCore::~NEGEMMLowpMatrixMultiplyCore() = default;
@@ -120,6 +135,8 @@
         _mm_result_s32.allocator()->init(info_mm_result_s32);
     }
 
+    // Initialize assembly kernel meta-data
+    const AsmGemmInfo asm_info = init_assembly_metadata(gemm_info);
 #ifdef __aarch64__
     switch(a->info()->data_type())
     {
@@ -130,12 +147,12 @@
         {
             if(is_data_type_quantized_asymmetric(a_to_use->info()->data_type()) && info.gemmlowp_output_stage().type == GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT)
             {
-                _asm_glue.configure(a_to_use, b, c, output, gemm_info);
+                _asm_glue.configure(a_to_use, b, c, output, asm_info);
                 _fused_assembly_path = _asm_glue.is_configured();
             }
             else
             {
-                _asm_glue.configure(a_to_use, b, nullptr, _fuse_output_stage ? &_mm_result_s32 : output, gemm_info);
+                _asm_glue.configure(a_to_use, b, nullptr, _fuse_output_stage ? &_mm_result_s32 : output, asm_info);
             }
             _assembly_path = _asm_glue.is_configured();
             break;
@@ -346,17 +363,20 @@
         matrix_a_info = &signed_a;
     }
 
+    // Initialize assembly kernel meta-data
+    const AsmGemmInfo asm_info = init_assembly_metadata(info);
+
     // Check if we need to run the optimized assembly kernel
     bool run_optimised             = false;
     bool run_optimised_requantized = false;
     if(is_data_type_quantized_asymmetric(a_to_use->data_type()) && info.gemmlowp_output_stage().type == GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT)
     {
-        run_optimised             = bool(NEGEMMAssemblyDispatch::validate(a_to_use, b, c, output, gemm_info));
+        run_optimised             = bool(NEGEMMAssemblyDispatch::validate(a_to_use, b, c, output, asm_info));
         run_optimised_requantized = run_optimised;
     }
     else
     {
-        run_optimised = bool(NEGEMMAssemblyDispatch::validate(a_to_use, b, nullptr, fuse_output_stage ? &mm_result_s32_info : output, gemm_info));
+        run_optimised = bool(NEGEMMAssemblyDispatch::validate(a_to_use, b, nullptr, fuse_output_stage ? &mm_result_s32_info : output, asm_info));
     }
 
     if(run_optimised)
diff --git a/src/runtime/NEON/functions/NESimpleAssemblyFunction.cpp b/src/runtime/NEON/functions/NESimpleAssemblyFunction.cpp
deleted file mode 100644
index d165b22..0000000
--- a/src/runtime/NEON/functions/NESimpleAssemblyFunction.cpp
+++ /dev/null
@@ -1,46 +0,0 @@
-/*
- * Copyright (c) 2018-2020 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-#include "src/runtime/NEON/functions/NESimpleAssemblyFunction.h"
-
-#include "arm_compute/core/Validate.h"
-#include "arm_compute/runtime/NEON/NEScheduler.h"
-
-using namespace arm_compute;
-
-NESimpleAssemblyFunction::NESimpleAssemblyFunction() // NOLINT
-    : _kernel()
-{
-}
-
-void NESimpleAssemblyFunction::run()
-{
-    NEScheduler::get().schedule(_kernel.get(), Window::DimX);
-}
-
-void NESimpleAssemblyFunction::configure(std::unique_ptr<INEGEMMWrapperKernel> kernel)
-{
-    ARM_COMPUTE_ERROR_ON_NULLPTR(kernel.get());
-    _kernel = std::move(kernel);
-    ARM_COMPUTE_ERROR_ON_WINDOW_DIMENSIONS_GTE(_kernel->window(), 1);
-}
diff --git a/src/runtime/NEON/functions/NESimpleAssemblyFunction.h b/src/runtime/NEON/functions/NESimpleAssemblyFunction.h
deleted file mode 100644
index e9be54d..0000000
--- a/src/runtime/NEON/functions/NESimpleAssemblyFunction.h
+++ /dev/null
@@ -1,56 +0,0 @@
-/*
- * Copyright (c) 2018-2020 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-#ifndef ARM_COMPUTE_NESIMPLEASSEMBLYFUNCTION_H
-#define ARM_COMPUTE_NESIMPLEASSEMBLYFUNCTION_H
-
-#include "arm_compute/runtime/IFunction.h"
-#include "src/core/NEON/kernels/assembly/INEGEMMWrapperKernel.h"
-
-#include <memory>
-
-namespace arm_compute
-{
-/** Basic interface for functions which have a single NEON GEMM wrapper kernel to run */
-class NESimpleAssemblyFunction : public IFunction
-{
-public:
-    /** Constructor */
-    NESimpleAssemblyFunction();
-
-    /** Configure the function with the kernel to run
-     *
-     * @param[in] kernel GEMM Wrapper kernel configured and ready to run
-     *
-     * @note The kernel is expected to have a 1D window. The function will multi-thread this window across the X dimension.
-     */
-    void configure(std::unique_ptr<INEGEMMWrapperKernel> kernel);
-
-    // Inherited methods overridden:
-    void run() override final;
-
-protected:
-    std::unique_ptr<INEGEMMWrapperKernel> _kernel; /**< Kernel to run */
-};
-} //namespace arm_compute
-#endif /*ARM_COMPUTE_NESIMPLEASSEMBLYFUNCTION_H */