Port NEGEMMConv2d to memory injecting interface

Resolves: COMPMID-4506, COMPMID-4570

Change-Id: I6d37a06da141f1fcfcaa8525322a319cb0234791
Signed-off-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5824
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp b/src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp
index e50099d..c2e9f24 100644
--- a/src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp
+++ b/src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp
@@ -26,10 +26,10 @@
 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
 #include "arm_compute/core/utils/quantization/AsymmHelpers.h"
 #include "arm_compute/runtime/FunctionDescriptors.h"
-#include "arm_compute/runtime/NEON/NEScheduler.h"
-#include "src/runtime/cpu/operators/CpuActivation.h"
-#include "src/runtime/cpu/operators/CpuPermute.h"
-#include "src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h"
+#include "src/core/helpers/MemoryHelpers.h"
+#include "src/runtime/cpu/utils/CpuAuxTensorHandler.h"
+
+#include "support/Cast.h"
 
 #include <set>
 
@@ -37,6 +37,9 @@
 {
 namespace cpu
 {
+using namespace arm_compute::experimental;
+using namespace arm_compute::utils::cast;
+
 namespace
 {
 GEMMLowpOutputStageInfo calculate_output_stage_metadata(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *dst, const ActivationLayerInfo &act)
@@ -87,12 +90,14 @@
 }
 } // namespace
 
-CpuGemmDirectConv2d::CpuGemmDirectConv2d(const std::shared_ptr<IMemoryManager> &memory_manager)
-    : _gemm_asm_func(std::make_unique<CpuGemmAssemblyDispatch>(memory_manager)),
+CpuGemmDirectConv2d::CpuGemmDirectConv2d()
+    : _gemm_asm_func(std::make_unique<CpuGemmAssemblyDispatch>()),
       _activation_func(std::make_unique<CpuActivation>()),
       _weights_permute_func(std::make_unique<CpuPermute>()),
-      _permuted_weights_info(),
-      _permuted_weights(std::make_unique<Tensor>())
+      _aux_mem(AuxTensorIdx::Count),
+      _perm_weights(),
+      _run_activation(false),
+      _is_prepared(false)
 {
 }
 
@@ -106,8 +111,10 @@
                                                              biases != nullptr ? biases : nullptr,
                                                              dst,
                                                              info));
-    _original_weights_info = weights;
-    _weights_permute_func->configure(weights, &_permuted_weights_info, PermutationVector{ 3, 0, 1, 2 });
+    _run_activation = info.act_info.enabled() && !_gemm_asm_func->is_activation_supported(info.act_info);
+    _is_prepared    = false;
+
+    _weights_permute_func->configure(weights, &_perm_weights, PermutationVector{ 3, 0, 1, 2 });
 
     // Configure assembly dispatch
     cpu::AsmGemmInfo asm_info = init_assembly_metadata(info, false);
@@ -115,13 +122,27 @@
     {
         asm_info.output_stage = calculate_output_stage_metadata(src, weights, dst, info.act_info);
     }
-    _gemm_asm_func->configure(src, &_permuted_weights_info, biases, dst, asm_info);
+    _gemm_asm_func->configure(src, &_perm_weights, biases, dst, asm_info);
 
     // Configure activation
-    if(info.act_info.enabled() && !_gemm_asm_func->is_activation_supported(info.act_info))
+    if(_run_activation)
     {
         _activation_func->configure(dst, nullptr, info.act_info);
-        _run_activation = true;
+    }
+
+    // Add auxiliary memory requirements of the assembly dispatch
+    auto asm_mem_req           = _gemm_asm_func->workspace();
+    _aux_mem[AsmGemmWorkspace] = asm_mem_req[AsmGemmWorkspace];
+    _aux_mem[Pretranspose]     = asm_mem_req[Pretranspose];
+
+    if(_aux_mem[Pretranspose].size > 0)
+    {
+        // Release permuted weights at the of prepare as they are further transposed by the assembly dispatch
+        _aux_mem[PermutedWeights] = MemoryInfo(offset_int_vec(PermutedWeights), MemoryLifetime::Prepare, weights->total_size());
+    }
+    else
+    {
+        _aux_mem[PermutedWeights] = MemoryInfo(offset_int_vec(PermutedWeights), MemoryLifetime::Persistent, weights->total_size());
     }
 }
 Status CpuGemmDirectConv2d::validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const Conv2dInfo &info)
@@ -172,35 +193,29 @@
     }
 }
 
-void CpuGemmDirectConv2d::allocate_permuted_weights()
-{
-    // TODO: This function will be removed when memory injection is implemeted.
-    ARM_COMPUTE_ERROR_ON(_permuted_weights == nullptr);
-    _permuted_weights->allocator()->free();
-    _permuted_weights->allocator()->init(_permuted_weights_info);
-    _permuted_weights->allocator()->allocate();
-}
-
 void CpuGemmDirectConv2d::prepare(ITensorPack &tensors)
 {
     if(!_is_prepared)
     {
-        allocate_permuted_weights();
-        ITensorPack permute_tensors
-        {
-            { TensorType::ACL_SRC, tensors.get_const_tensor(TensorType::ACL_SRC_1) },
-            { TensorType::ACL_DST, _permuted_weights.get() },
-        };
+        const ITensor *weights     = tensors.get_const_tensor(ACL_SRC_1);
+        ITensor       *weights_aux = utils::cast::polymorphic_cast<ITensor *>(tensors.get_tensor(offset_int_vec(PermutedWeights)));
+        ARM_COMPUTE_ERROR_ON_NULLPTR(weights, weights_aux);
 
+        CpuAuxTensorHandler permuted_weights(_perm_weights, *weights_aux);
+        ITensorPack         permute_tensors{ { ACL_SRC, weights }, { ACL_DST, permuted_weights.get() } };
         _weights_permute_func->run(permute_tensors);
 
-        tensors.get_const_tensor(TensorType::ACL_SRC_1)->mark_as_unused();
+        tensors.add_const_tensor(ACL_SRC_1, permuted_weights.get());
+        // Call prepare of assembly dispatch
+        _gemm_asm_func->prepare(tensors);
 
-        // switch the original tensor with permuted tensor
-        tensors.add_const_tensor(TensorType::ACL_SRC_1, _permuted_weights.get());
         _is_prepared = true;
     }
 }
 
+experimental::MemoryRequirements CpuGemmDirectConv2d::workspace() const
+{
+    return _aux_mem;
+}
 } // namespace cpu
 } // namespace arm_compute
\ No newline at end of file
diff --git a/src/runtime/cpu/operators/CpuGemmDirectConv2d.h b/src/runtime/cpu/operators/CpuGemmDirectConv2d.h
index 6aa17c2..b572f36 100644
--- a/src/runtime/cpu/operators/CpuGemmDirectConv2d.h
+++ b/src/runtime/cpu/operators/CpuGemmDirectConv2d.h
@@ -24,14 +24,12 @@
 #ifndef ARM_COMPUTE_CPU_GEMM_DIRECT_CONV_2D_H
 #define ARM_COMPUTE_CPU_GEMM_DIRECT_CONV_2D_H
 
-#include "arm_compute/core/ITensorInfo.h"
-#include "arm_compute/core/experimental/Types.h"
-#include "arm_compute/runtime/Tensor.h"
+#include "arm_compute/core/TensorInfo.h"
 #include "src/core/common/Macros.h"
-#include "src/core/cpu/ICpuKernel.h"
 #include "src/runtime/cpu/ICpuOperator.h"
-
-#include <memory>
+#include "src/runtime/cpu/operators/CpuActivation.h"
+#include "src/runtime/cpu/operators/CpuPermute.h"
+#include "src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h"
 
 namespace arm_compute
 {
@@ -40,15 +38,11 @@
 struct Conv2dInfo;
 namespace cpu
 {
-class CpuGemmAssemblyDispatch;
-class CpuActivation;
-class CpuPermute;
-
 class CpuGemmDirectConv2d : public ICpuOperator
 {
 public:
     /** Constructor */
-    CpuGemmDirectConv2d(const std::shared_ptr<IMemoryManager> &memory_manager = nullptr);
+    CpuGemmDirectConv2d();
     ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(CpuGemmDirectConv2d);
     /** Destructor */
     ~CpuGemmDirectConv2d();
@@ -89,22 +83,24 @@
     // Inherited methods overridden:
     void run(ITensorPack &tensors) override;
     void prepare(ITensorPack &constants) override;
+    experimental::MemoryRequirements workspace() const override;
 
 private:
+    enum AuxTensorIdx
+    {
+        AsmGemmWorkspace = 0,
+        Pretranspose,
+        PermutedWeights,
+        Count
+    };
+
     std::unique_ptr<CpuGemmAssemblyDispatch> _gemm_asm_func;
     std::unique_ptr<CpuActivation>           _activation_func;
     std::unique_ptr<CpuPermute>              _weights_permute_func;
-    const ITensorInfo                       *_original_weights_info{};
-    TensorInfo                               _permuted_weights_info;
-    std::unique_ptr<Tensor>                  _permuted_weights{ nullptr };
-    bool                                     _is_prepared{ false };
-    bool                                     _run_activation{ false };
-
-    /** Function to allocated a tensor for permuted weights
-     *
-     * @note This function will be removed when memory injection is properly implemented.
-     */
-    void allocate_permuted_weights();
+    experimental::MemoryRequirements         _aux_mem;
+    TensorInfo                               _perm_weights;
+    bool                                     _run_activation;
+    bool                                     _is_prepared;
 };
 } // namespace cpu
 } // namespace arm_compute
diff --git a/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
index 1101e05..79ea1cb 100644
--- a/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
+++ b/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
@@ -27,15 +27,18 @@
 #include "src/core/CPP/Validate.h"
 #include "src/core/cpu/kernels/assembly/CpuGemmAssemblyWrapperKernel.h"
 #include "src/core/cpu/kernels/assembly/arm_gemm.hpp"
+#include "src/core/helpers/MemoryHelpers.h"
 #include "src/core/utils/AssemblyUtils.h"
+#include "src/runtime/cpu/utils/CpuAuxTensorHandler.h"
 
 #include <arm_neon.h>
-#include <cstdlib>
 
 namespace arm_compute
 {
 namespace cpu
 {
+using namespace arm_compute::experimental;
+
 namespace
 {
 struct free_delete
@@ -113,103 +116,27 @@
     return scheduling_hint;
 }
 
-template <typename TypeInput, typename TypeOutput>
-class FallbackTransform : public ITransformWeights
-{
-public:
-    FallbackTransform() noexcept {};
-    /** Prevent instances of this class from being copied (As this class contains pointers) */
-    FallbackTransform(const FallbackTransform &) = delete;
-    /** Default move constructor */
-    FallbackTransform(FallbackTransform &&) = default;
-    /** Prevent instances of this class from being copied (As this class contains pointers) */
-    FallbackTransform &operator=(const FallbackTransform &) = delete;
-    /** Default move assignment operator */
-    FallbackTransform &operator=(FallbackTransform &&) = default;
-    void               run() override
-    {
-        _output.allocator()->allocate();
-        ARM_COMPUTE_ERROR_ON(_output.buffer() == nullptr);
-        _gemm_kernel_asm->pretranspose_B_array(_output.buffer(), _in1_ptr, _ldb, _multi_stride_b);
-        _reshape_run = true;
-    }
-
-    void release() override
-    {
-        _output.allocator()->free();
-    }
-
-    ITensor *get_weights() override
-    {
-        return &_output;
-    }
-
-    uint32_t uid() override
-    {
-        uint32_t id = (_B_pretranspose_size | 0x80000000);
-        return id;
-    }
-
-    void configure(size_t B_pretranspose_size, unsigned int alignment)
-    {
-        _output.allocator()->init(TensorInfo(TensorShape{ (B_pretranspose_size + alignment) }, 1, DataType::S8), alignment);
-        _B_pretranspose_size = B_pretranspose_size;
-    }
-
-    void set_pretranspose(ITensor *tensor)
-    {
-        if(!_reshape_run)
-        {
-            _gemm_kernel_asm->set_pretransposed_B_data(tensor->buffer());
-        }
-    }
-
-    void set_args(const int ldb, const TypeInput *in1_ptr, const int multi_stride_b, std::shared_ptr<arm_gemm::GemmCommon<TypeInput, TypeOutput>> gemm_kernel_asm)
-    {
-        _ldb             = ldb;
-        _in1_ptr         = in1_ptr;
-        _multi_stride_b  = multi_stride_b;
-        _gemm_kernel_asm = gemm_kernel_asm;
-    }
-
-private:
-    Tensor           _output{};
-    int              _ldb{};
-    const TypeInput *_in1_ptr{};
-    int              _multi_stride_b{};
-    size_t           _B_pretranspose_size{};
-    std::shared_ptr<arm_gemm::GemmCommon<TypeInput, TypeOutput>> _gemm_kernel_asm{ nullptr };
-};
-
 /** Fallback in case ACL doesn't have a function */
 template <typename TypeInput, typename TypeOutput, class OutputStage = arm_gemm::Nothing>
 class Fallback : public CpuGemmAssemblyDispatch::IFallback
 {
 public:
     /** Destructor */
-    ~Fallback()
-    {
-        if(_pretranspose && !(is_weight_managed()))
-        {
-            delete _pretranspose;
-        }
-    }
+    ~Fallback() = default;
 
     /** Initialise the functions's input and output.
      *
-     * @param[in]  a               Input tensor containing the Matrix A.
-     * @param[in]  b               Input tensor containing the Matrix B.
-     * @param[in]  c               Input tensor containing the Matrix C.
-     * @param[out] d               Output tensor to store the result of matrix multiplication.
-     * @param[in]  args            Matrix multiplication information.
-     * @param[in]  gemm_info       GEMM meta-data
-     * @param[in]  memory_group    Memory group to be used by the function.
-     * @param[in]  weights_manager Weights manager to be used by the function.
-     * @param[in]  os              Output stage meta-data.
+     * @param[in]  a         Input tensor containing the Matrix A.
+     * @param[in]  b         Input tensor containing the Matrix B.
+     * @param[in]  c         Input tensor containing the Matrix C.
+     * @param[out] d         Output tensor to store the result of matrix multiplication.
+     * @param[in]  args      Matrix multiplication information.
+     * @param[in]  gemm_info GEMM meta-data
+     * @param[in]  os        Output stage meta-data.
      */
     void configure(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d,
                    arm_gemm::GemmArgs args, const AsmGemmInfo &gemm_info,
-                   MemoryGroup &memory_group, IWeightsManager *weights_manager, const OutputStage &os = {});
+                   const OutputStage &os = {});
 
     /** Set requantization shifts to be used
      *
@@ -231,16 +158,17 @@
     // Inherited methods overridden:
     void run(ITensorPack &tensors) override;
     void prepare(ITensorPack &tensors) override;
-    bool is_configured() const override;
+    bool                             is_configured() const override;
+    experimental::MemoryRequirements workspace() const override;
 
 private:
-    /** Allocate a workspace tensor.
-     *
-     * @param[in] workspace_size Size to allocate.
-     * @param[in] memory_group   Tensor memory group.
-     * @param[in] alignment      Workspace memory alignment.
-     */
-    void allocate_workspace(size_t workspace_size, MemoryGroup &memory_group, size_t alignment);
+    enum AuxTensorIdx
+    {
+        AsmGemmWorkspace = 0,
+        Pretranspose,
+        Count
+    };
+
     /** Configure the indirect buffer
      *
      * @param[in]  a    Input tensor containing the Matrix A.
@@ -256,18 +184,14 @@
     std::shared_ptr<arm_gemm::GemmCommon<TypeInput, TypeOutput>> _gemm_kernel_asm{ nullptr };
     /** Optimised Arm® Neon™ kernel */
     std::unique_ptr<INEKernel> _optimised_kernel{ nullptr };
-    /** GEMM workspace */
-    Tensor _workspace{};
-    /** Pre-transpose tensor */
-    ITensor *_pretranspose{ nullptr };
+    /** Assembly GEMM workspace tensor info */
+    TensorInfo _workspace_info{};
+    /** Pre-transpose tensor info */
+    TensorInfo _pretranspose_info{};
     /** Prepared flag */
     bool _is_prepared{ false };
     /** GEMM meta-data */
     AsmGemmInfo _gemm_info{};
-    /** Weights manager */
-    IWeightsManager *_weights_manager{ nullptr };
-    /** Weights transform object */
-    FallbackTransform<TypeInput, TypeOutput> _weights_transform{};
     /** GEMM kernel description */
     arm_gemm::KernelDescription _kernel_info{};
     /** Per channel quantization shifts */
@@ -279,27 +203,9 @@
     /** Indirect buffer */
     std::unique_ptr<const TypeInput *const *, free_delete> _indirect_arg{};
     std::unique_ptr<const TypeInput *, free_delete>        _indirect_buf{};
-    std::vector<TypeInput>          _indirect_pad{};
-    arm_gemm::ConvolutionParameters _cp{};
-
-    bool is_weight_managed()
-    {
-        // TODO (COMPMID-4539): This function should do the following:
-        // _weights_manager && _weights_manager->are_weights_managed(_b)
-        // , where _b is the second Tensor that is used to be given to the configure().
-        // Currently, however, weight manager is disabled to make this class stateless.
-        // This should be revisited in the future.
-        return false;
-    }
-
-    void acquire_managed_weight()
-    {
-        // TODO (COMPMID-4539): This function should do the following:
-        // _pretranspose = _weights_manager->acquire(_b, &_weights_transform);
-        // , where _b is the second Tensor that is used to be given to the configure().
-        // Currently, however, weight manager is disabled to make this class stateless.
-        _pretranspose = nullptr;
-    }
+    std::vector<TypeInput>           _indirect_pad{};
+    arm_gemm::ConvolutionParameters  _cp{};
+    experimental::MemoryRequirements _aux_mem{ Count };
 };
 
 template <typename TypeInput, typename TypeOutput, class OutputStage>
@@ -439,12 +345,11 @@
 template <typename TypeInput, typename TypeOutput, class OutputStage>
 void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d,
                                                              arm_gemm::GemmArgs args, const AsmGemmInfo &gemm_info,
-                                                             MemoryGroup &memory_group, IWeightsManager *weights_manager, const OutputStage &os)
+                                                             const OutputStage &os)
 {
     ARM_COMPUTE_UNUSED(c);
     arm_gemm::GemmConfig gemm_cfg;
-    _kernel_info     = arm_gemm::get_gemm_method<TypeInput, TypeOutput, OutputStage>(args, os);
-    _weights_manager = weights_manager;
+    _kernel_info = arm_gemm::get_gemm_method<TypeInput, TypeOutput, OutputStage>(args, os);
     if(_kernel_info.method != arm_gemm::GemmMethod::GEMV_BATCHED)
     {
         gemm_cfg.filter = _kernel_info.name;
@@ -461,13 +366,10 @@
     auto acl_gemm_wrapper = std::make_unique<kernel::CpuGemmAssemblyWrapperKernel<TypeInput, TypeOutput>>();
     ARM_COMPUTE_ERROR_ON(acl_gemm_wrapper == nullptr);
     acl_gemm_wrapper->configure(_gemm_kernel_asm.get(), gemm_cfg.filter);
-    const size_t workspace_size = _gemm_kernel_asm->get_working_size();
-    if(workspace_size > 0)
-    {
-        // Allocate workspace
-        const unsigned int alignment = 4096;
-        allocate_workspace(workspace_size, memory_group, alignment);
-    }
+    const size_t       workspace_size = _gemm_kernel_asm->get_working_size();
+    const unsigned int alignment      = 4096;
+    _workspace_info                   = TensorInfo(TensorShape(workspace_size), 1, DataType::U8);
+    _aux_mem[AsmGemmWorkspace]        = MemoryInfo(offset_int_vec(AsmGemmWorkspace), MemoryLifetime::Temporary, workspace_size, alignment);
 
     //if we disable this code below in brackets then ConvLayer deadlocks when threads > 1 and
     //the shapes are In=1x1x1024 Weights=1x1x1024x1001 Biases=1001 Out=1x1x1001
@@ -487,16 +389,8 @@
         // Forcing 128-byte alignment (required by 32-bit kernels)
         const unsigned int alignment           = 128;
         const size_t       B_pretranspose_size = _gemm_kernel_asm->get_B_pretransposed_array_size();
-        if(is_weight_managed())
-        {
-            _weights_transform.configure(B_pretranspose_size, alignment);
-            acquire_managed_weight();
-        }
-        else
-        {
-            _pretranspose = new Tensor();
-            static_cast<Tensor *>(_pretranspose)->allocator()->init(TensorInfo(TensorShape{ (B_pretranspose_size + alignment) }, 1, DataType::S8), alignment);
-        }
+        _pretranspose_info                     = TensorInfo(TensorShape(B_pretranspose_size), 1, DataType::U8);
+        _aux_mem[Pretranspose]                 = MemoryInfo(offset_int_vec(Pretranspose), MemoryLifetime::Persistent, B_pretranspose_size, alignment);
     }
 
     // Handle indirect GEMM convolution
@@ -509,10 +403,11 @@
 template <typename TypeInput, typename TypeOutput, class OutputStage>
 void Fallback<TypeInput, TypeOutput, OutputStage>::prepare(ITensorPack &tensors)
 {
-    auto b = tensors.get_const_tensor(TensorType::ACL_SRC_1);
-    auto c = tensors.get_const_tensor(TensorType::ACL_SRC_2);
     if(!_is_prepared)
     {
+        auto b = tensors.get_const_tensor(TensorType::ACL_SRC_1);
+        auto c = tensors.get_const_tensor(TensorType::ACL_SRC_2);
+
         // Setup up matrix bias in the assembly kernel, it's just a pointer to matrix C.
         if(c && c->info()->data_type() == DataType::S32)
         {
@@ -526,24 +421,9 @@
             const auto in1_ptr        = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes());
             const int  multi_stride_b = b->info()->strides_in_bytes().z() / sizeof(TypeInput);
 
-            if(is_weight_managed())
-            {
-                _weights_transform.set_args(ldb, in1_ptr, multi_stride_b, _gemm_kernel_asm);
-                _weights_manager->run(b, &_weights_transform);
-
-                // If we didn't run the reshape function, set the pretransposed buffer
-                if(!_weights_transform.is_reshape_run())
-                {
-                    _weights_transform.set_pretranspose(_pretranspose);
-                }
-            }
-            else
-            {
-                static_cast<Tensor *>(_pretranspose)->allocator()->allocate();
-                ARM_COMPUTE_ERROR_ON(_pretranspose->buffer() == nullptr);
-                _gemm_kernel_asm->pretranspose_B_array(_pretranspose->buffer(), in1_ptr, ldb, multi_stride_b);
-                b->mark_as_unused();
-            }
+            CpuAuxTensorHandler pretranspose(offset_int_vec(Pretranspose), _pretranspose_info, tensors, false);
+            ARM_COMPUTE_ERROR_ON(pretranspose.get()->buffer() == nullptr);
+            _gemm_kernel_asm->pretranspose_B_array(pretranspose.get()->buffer(), in1_ptr, ldb, multi_stride_b);
         }
 
         if(_gemm_info.method == AsmConvMethod::Indirect)
@@ -556,21 +436,18 @@
 }
 
 template <typename TypeInput, typename TypeOutput, class OutputStage>
-void Fallback<TypeInput, TypeOutput, OutputStage>::allocate_workspace(size_t workspace_size, MemoryGroup &memory_group, size_t alignment)
-{
-    ARM_COMPUTE_ERROR_ON_MSG(workspace_size == 0, "size cannot be 0");
-    _workspace.allocator()->init(TensorInfo(TensorShape{ (workspace_size + alignment) }, 1, DataType::S8), alignment);
-    memory_group.manage(&_workspace);
-    _workspace.allocator()->allocate();
-}
-
-template <typename TypeInput, typename TypeOutput, class OutputStage>
 bool Fallback<TypeInput, TypeOutput, OutputStage>::is_configured() const
 {
     return _optimised_kernel != nullptr;
 }
 
 template <typename TypeInput, typename TypeOutput, class OutputStage>
+experimental::MemoryRequirements Fallback<TypeInput, TypeOutput, OutputStage>::workspace() const
+{
+    return _aux_mem;
+}
+
+template <typename TypeInput, typename TypeOutput, class OutputStage>
 void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
 {
     auto a = tensors.get_const_tensor(TensorType::ACL_SRC_0);
@@ -609,9 +486,10 @@
     const auto scheduling_hint = scheduling_hint_heuristic(_kernel_info.method, d->info()->data_type());
 
     // Set workspace if needed and reset number of threads as buffer manager gets re-created with max_threads
-    if(_workspace.buffer() != nullptr)
+    CpuAuxTensorHandler workspace(offset_int_vec(AsmGemmWorkspace), _workspace_info, tensors, false);
+    if(workspace.get()->buffer() != nullptr)
     {
-        _gemm_kernel_asm->set_working_space(reinterpret_cast<void *>(_workspace.buffer()));
+        _gemm_kernel_asm->set_working_space(reinterpret_cast<void *>(workspace.get()->buffer()));
         const unsigned int split_dim   = scheduling_hint.split_dimension();
         const unsigned int window_size = _gemm_kernel_asm->get_window_size().total_size();
         unsigned int       num_threads = NEScheduler::get().num_threads();
@@ -656,9 +534,9 @@
 }
 
 template <typename TypeInput, typename TypeOutput>
-void create_arm_gemm(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_gemm, MemoryGroup &memory_group,
-                     const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d, arm_gemm::Activation activation, const AsmGemmInfo &info,
-                     IWeightsManager *weights_manager)
+void create_arm_gemm(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_gemm,
+                     const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d,
+                     arm_gemm::Activation activation, const AsmGemmInfo &info)
 {
     Params         p           = extract_parameters(a, b, d, info);
     const CPUInfo &ci          = NEScheduler::get().cpu_info();
@@ -668,14 +546,14 @@
 
     // Create arm_gemm fallback
     auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput>>();
-    fallback->configure(a, b, c, d, args, info, memory_group, weights_manager);
+    fallback->configure(a, b, c, d, args, info);
     arm_gemm = std::move(fallback);
 }
 
 template <typename TypeInput, typename TypeOutput>
-void create_arm_gemm_quant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_gemm, MemoryGroup &memory_group,
-                           const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d, arm_gemm::Activation activation, const AsmGemmInfo &info,
-                           IWeightsManager *weights_manager)
+void create_arm_gemm_quant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_gemm,
+                           const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d,
+                           arm_gemm::Activation activation, const AsmGemmInfo &info)
 {
     ARM_COMPUTE_UNUSED(activation);
     Params         p           = extract_parameters(a, b, d, info);
@@ -713,14 +591,13 @@
     }
 
     // Configure fallback
-    fallback->configure(a, b, c, d, args, info, memory_group, weights_manager, gemm_requant_info);
+    fallback->configure(a, b, c, d, args, info, gemm_requant_info);
     arm_gemm = std::move(fallback);
 }
-
 } //namespace
 
-CpuGemmAssemblyDispatch::CpuGemmAssemblyDispatch(std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager)
-    : _arm_gemm(nullptr), _memory_group(std::move(memory_manager)), _weights_manager(weights_manager)
+CpuGemmAssemblyDispatch::CpuGemmAssemblyDispatch()
+    : _arm_gemm(nullptr)
 {
 }
 
@@ -775,40 +652,40 @@
     switch(a->data_type())
     {
         case DataType::F32:
-            create_arm_gemm<float, float>(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager);
+            create_arm_gemm<float, float>(_arm_gemm, a, b, c, d, act, info);
             break;
 #ifdef __aarch64__
         case DataType::U8:
         case DataType::QASYMM8:
             if(d->data_type() == DataType::S32)
             {
-                create_arm_gemm<uint8_t, uint32_t>(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager);
+                create_arm_gemm<uint8_t, uint32_t>(_arm_gemm, a, b, c, d, act, info);
             }
             else
             {
-                create_arm_gemm_quant<uint8_t, uint8_t>(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager);
+                create_arm_gemm_quant<uint8_t, uint8_t>(_arm_gemm, a, b, c, d, act, info);
             }
             break;
         case DataType::S8:
         case DataType::QASYMM8_SIGNED:
             if(d->data_type() == DataType::S32)
             {
-                create_arm_gemm<int8_t, int32_t>(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager);
+                create_arm_gemm<int8_t, int32_t>(_arm_gemm, a, b, c, d, act, info);
             }
             else
             {
-                create_arm_gemm_quant<int8_t, int8_t>(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager);
+                create_arm_gemm_quant<int8_t, int8_t>(_arm_gemm, a, b, c, d, act, info);
             }
             break;
 #endif /* __aarch64__ */
 #if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16)
         case DataType::BFLOAT16:
-            create_arm_gemm<bfloat16, float>(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager);
+            create_arm_gemm<bfloat16, float>(_arm_gemm, a, b, c, d, act, info);
             break;
 #endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */
 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
         case DataType::F16:
-            create_arm_gemm<float16_t, float16_t>(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager);
+            create_arm_gemm<float16_t, float16_t>(_arm_gemm, a, b, c, d, act, info);
             break;
 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
         default:
@@ -829,10 +706,14 @@
 
 void CpuGemmAssemblyDispatch::run(ITensorPack &tensors)
 {
-    MemoryGroupResourceScope scope_mg(_memory_group);
-
     ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr);
     _arm_gemm->run(tensors);
 }
+
+experimental::MemoryRequirements CpuGemmAssemblyDispatch::workspace() const
+{
+    ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr);
+    return _arm_gemm->workspace();
+}
 } // namespace cpu
 } // namespace arm_compute
diff --git a/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h b/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h
index ffc097c..355273a 100644
--- a/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h
+++ b/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h
@@ -24,10 +24,6 @@
 #ifndef ARM_COMPUTE_CPU_INTERNAL_CPU_GEMM_ASSEMBLY_DISPATCH_H
 #define ARM_COMPUTE_CPU_INTERNAL_CPU_GEMM_ASSEMBLY_DISPATCH_H
 
-#include "arm_compute/runtime/IMemoryManager.h"
-#include "arm_compute/runtime/IWeightsManager.h"
-#include "arm_compute/runtime/MemoryGroup.h"
-#include "arm_compute/runtime/Tensor.h"
 #include "src/core/common/Macros.h"
 #include "src/runtime/cpu/ICpuOperator.h"
 
@@ -62,7 +58,7 @@
 {
 public:
     /** Constructor */
-    CpuGemmAssemblyDispatch(std::shared_ptr<IMemoryManager> memory_manager = nullptr, IWeightsManager *weights_manager = nullptr);
+    CpuGemmAssemblyDispatch();
     /** Defautl destructor */
     ~CpuGemmAssemblyDispatch() = default;
 
@@ -71,10 +67,11 @@
     class IFallback
     {
     public:
-        virtual void run(ITensorPack &tensors)     = 0;
-        virtual void prepare(ITensorPack &tensors) = 0;
-        virtual bool is_configured() const         = 0;
-        virtual ~IFallback()                       = default;
+        virtual void run(ITensorPack &tensors)                         = 0;
+        virtual void prepare(ITensorPack &tensors)                     = 0;
+        virtual experimental::MemoryRequirements workspace() const     = 0;
+        virtual bool                             is_configured() const = 0;
+        virtual ~IFallback()                                           = default;
     };
 
 public:
@@ -115,11 +112,10 @@
     // Inherited methods overridden:
     void prepare(ITensorPack &tensors) override;
     void run(ITensorPack &tensors) override;
+    experimental::MemoryRequirements workspace() const override;
 
 private:
-    std::unique_ptr<IFallback> _arm_gemm;        /**< Interface for the arm_gemm fallback */
-    MemoryGroup                _memory_group;    /**< Function memory group */
-    IWeightsManager           *_weights_manager; /**< Pointer to the weights manager */
+    std::unique_ptr<IFallback> _arm_gemm; /**< Interface for the arm_gemm fallback */
 };
 } // namespace cpu
 } // namespace arm_compute