Revert "Implement memory injection in CpuDirectGemmConv2d"

This reverts commit b3be45759bdd0749ae3a16fe470820f0d9830ea9.

Resolves: COMPMID-4548

Change-Id: I46e0d8c67ddf988af3ce38f83177cda412db916c
Signed-off-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5775
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Sheri Zhang <sheri.zhang@arm.com>
diff --git a/src/core/helpers/MemoryHelpers.h b/src/core/helpers/MemoryHelpers.h
index dfa8e60..6756a90 100644
--- a/src/core/helpers/MemoryHelpers.h
+++ b/src/core/helpers/MemoryHelpers.h
@@ -56,13 +56,12 @@
             continue;
         }
 
-        const auto alignment = req.alignment;
-        const auto aux_info  = TensorInfo{ TensorShape(req.size + alignment), 1, DataType::U8 };
+        const auto aux_info = TensorInfo{ TensorShape(req.size), 1, DataType::U8 };
         workspace_memory.emplace_back(req.slot, std::make_unique<TensorType>());
 
         auto aux_tensor = workspace_memory.back().second.get();
         ARM_COMPUTE_ERROR_ON_NULLPTR(aux_tensor);
-        aux_tensor->allocator()->init(aux_info, alignment);
+        aux_tensor->allocator()->init(aux_info);
 
         if(req.lifetime == experimental::MemoryLifetime::Temporary)
         {
@@ -83,14 +82,5 @@
 
     return workspace_memory;
 }
-
-template <typename TensorType>
-WorkspaceData<TensorType> manage_workspace(const experimental::MemoryRequirements &mem_reqs,
-                                           MemoryGroup                            &mgroup,
-                                           ITensorPack                            &run_pack)
-{
-    ITensorPack dummy_prep_pack{};
-    return manage_workspace<TensorType>(mem_reqs, mgroup, run_pack, dummy_prep_pack);
-}
 } // namespace arm_compute
 #endif /* SRC_COMMON_MEMORY_HELPERS_H */
diff --git a/src/runtime/NEON/functions/NEGEMM.cpp b/src/runtime/NEON/functions/NEGEMM.cpp
index b526874..7318c3e 100644
--- a/src/runtime/NEON/functions/NEGEMM.cpp
+++ b/src/runtime/NEON/functions/NEGEMM.cpp
@@ -38,7 +38,6 @@
 #include "src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.h"
 #include "src/core/NEON/kernels/NEGEMMTranspose1xWKernel.h"
 #include "src/core/helpers/AutoConfiguration.h"
-#include "src/core/helpers/MemoryHelpers.h"
 #include "src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h"
 
 #include <cmath>
@@ -47,14 +46,6 @@
 
 namespace arm_compute
 {
-using WorkspaceDataType = WorkspaceData<Tensor>;
-
-struct NEGEMM::AsmGlueTensors
-{
-    ITensorPack       tensors{};
-    WorkspaceDataType ws{};
-};
-
 namespace
 {
 cpu::AsmGemmInfo init_assembly_metadata(const GEMMInfo &info)
@@ -72,7 +63,7 @@
 NEGEMM::NEGEMM(std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager)
     : _memory_group(memory_manager), _weights_manager(weights_manager), _interleave_kernel(), _transpose_kernel(), _mm_kernel(), _asm_glue(std::make_unique<cpu::CpuGemmAssemblyDispatch>()), _ma_kernel(),
       _alpha_scale_func(nullptr), _add_bias(), _activation_func(), _tmp_a(), _tmp_b(), _tmp_d(), _original_b(nullptr), _run_vector_matrix_multiplication(false), _run_alpha_scale(false),
-      _run_addition(false), _run_bias_addition(false), _run_activation(false), _reshape_b_only_on_first_run(false), _is_prepared(false), _asm_glue_tensors(std::make_unique<AsmGlueTensors>())
+      _run_addition(false), _run_bias_addition(false), _run_activation(false), _reshape_b_only_on_first_run(false), _is_prepared(false)
 {
 }
 
@@ -103,7 +94,7 @@
         _asm_glue->configure(a->info(), b->info(), c_info_to_use, d->info(), asm_info);
         ARM_COMPUTE_ERROR_ON(!_asm_glue->is_configured());
 
-        _asm_glue_tensors->tensors =
+        _asm_glue_tensors =
         {
             { ACL_SRC_0, a },
             { ACL_SRC_1, b },
@@ -111,8 +102,6 @@
             { ACL_DST, d },
         };
 
-        _asm_glue_tensors->ws = manage_workspace<Tensor>(_asm_glue->workspace(), _memory_group, _asm_glue_tensors->tensors);
-
         // Scale product by alpha
         if(_run_alpha_scale)
         {
@@ -334,7 +323,7 @@
 
     if(_asm_glue->is_configured())
     {
-        _asm_glue->run(_asm_glue_tensors->tensors);
+        _asm_glue->run(_asm_glue_tensors);
         if(_run_alpha_scale)
         {
             _alpha_scale_func.run();
@@ -388,7 +377,7 @@
                 ARM_COMPUTE_ERROR_ON(!_original_b->is_used());
             }
 
-            _asm_glue->prepare(_asm_glue_tensors->tensors);
+            _asm_glue->prepare(_asm_glue_tensors);
             if(!original_b_managed_by_weights_manager)
             {
                 _original_b->mark_as_unused();
diff --git a/src/runtime/NEON/functions/NEGEMMConv2d.cpp b/src/runtime/NEON/functions/NEGEMMConv2d.cpp
index 790543a..94ceb6d 100644
--- a/src/runtime/NEON/functions/NEGEMMConv2d.cpp
+++ b/src/runtime/NEON/functions/NEGEMMConv2d.cpp
@@ -26,37 +26,24 @@
 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
 #include "arm_compute/core/utils/quantization/AsymmHelpers.h"
 #include "arm_compute/runtime/NEON/NEScheduler.h"
-#include "src/core/helpers/MemoryHelpers.h"
 #include "src/runtime/cpu/operators/CpuGemmDirectConv2d.h"
 
 #include <set>
 
 namespace arm_compute
 {
-using OperatorType      = cpu::CpuGemmDirectConv2d;
-using WorkspaceDataType = WorkspaceData<Tensor>;
+using OperatorType = cpu::CpuGemmDirectConv2d;
 
 struct NEGEMMConv2d::Impl
 {
     ITensorPack                   tensors{};
-    MemoryGroup                   mg{};
     std::unique_ptr<OperatorType> op{ nullptr };
-    WorkspaceDataType             ws{};
-
-    void allocate_and_add_workspace()
-    {
-        if(op)
-        {
-            ws = manage_workspace<Tensor>(op->workspace(), mg, tensors);
-        }
-    }
 };
 
 NEGEMMConv2d::NEGEMMConv2d(const std::shared_ptr<IMemoryManager> &memory_manager)
     : _impl(std::make_unique<Impl>())
 {
-    _impl->op = std::make_unique<OperatorType>();
-    _impl->mg = MemoryGroup(memory_manager);
+    _impl->op = std::make_unique<OperatorType>(memory_manager);
 }
 
 NEGEMMConv2d::~NEGEMMConv2d() = default;
@@ -68,9 +55,7 @@
     _impl->tensors.add_const_tensor(TensorType::ACL_SRC_2, biases);
     _impl->tensors.add_tensor(TensorType::ACL_DST, output);
 
-    _impl->op->configure(input->info(), weights->info(), ((biases) ? biases->info() : nullptr), output->info(), info);
-
-    _impl->allocate_and_add_workspace();
+    _impl->op->configure(input->info(), weights->info(), biases->info(), output->info(), info);
 }
 
 Status NEGEMMConv2d::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const Conv2dInfo &info)
diff --git a/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp b/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp
index d42e656..cc0f20e 100644
--- a/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp
+++ b/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp
@@ -42,17 +42,10 @@
 #include "src/core/NEON/kernels/NEGEMMLowpOffsetContributionOutputStageKernel.h"
 #include "src/core/NEON/kernels/NEGEMMLowpReductionKernel.h"
 #include "src/core/NEON/kernels/NEGEMMTranspose1xWKernel.h"
-#include "src/core/helpers/MemoryHelpers.h"
 #include "src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h"
 
 namespace arm_compute
 {
-using WorkspaceDataType = WorkspaceData<Tensor>;
-struct NEGEMMLowpMatrixMultiplyCore::AsmGlueTensors
-{
-    ITensorPack       tensors{};
-    WorkspaceDataType ws{};
-};
 namespace
 {
 cpu::AsmGemmInfo init_assembly_metadata(const GEMMInfo &info)
@@ -73,11 +66,11 @@
 NEGEMMLowpMatrixMultiplyCore::~NEGEMMLowpMatrixMultiplyCore() = default;
 
 NEGEMMLowpMatrixMultiplyCore::NEGEMMLowpMatrixMultiplyCore(std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager)
-    : _memory_group(memory_manager), _weights_manager(weights_manager), _asm_glue(std::make_unique<cpu::CpuGemmAssemblyDispatch>(weights_manager)), _mm_kernel(), _mtx_a_reshape_kernel(),
+    : _memory_group(memory_manager), _weights_manager(weights_manager), _asm_glue(std::make_unique<cpu::CpuGemmAssemblyDispatch>(memory_manager, weights_manager)), _mm_kernel(), _mtx_a_reshape_kernel(),
       _mtx_b_reshape_kernel(), _mtx_a_reduction_kernel(), _mtx_b_reduction_kernel(), _offset_contribution_kernel(), _offset_contribution_output_stage_kernel(), _activation_func(),
       _convert_to_signed_asymm(), _convert_from_signed_asymm(), _vector_sum_col(), _vector_sum_row(), _tmp_a(), _tmp_b(), _mm_result_s32(), _signed_a(), _signed_output(), _original_b(nullptr), _a_offset(0),
       _b_offset(0), _run_vector_matrix_multiplication(false), _assembly_path(false), _fused_assembly_path(false), _reshape_b_only_on_first_run(false), _is_prepared(false), _fuse_output_stage(false),
-      _run_activation(false), _flip_signedness(false), _asm_glue_tensors(std::make_unique<AsmGlueTensors>())
+      _run_activation(false), _flip_signedness(false)
 {
 }
 
@@ -156,24 +149,18 @@
                 auto c_info_to_use = c == nullptr ? nullptr : c->info();
                 _asm_glue->configure(a_to_use->info(), b->info(), c_info_to_use, output->info(), asm_info);
                 _fused_assembly_path = _asm_glue->is_configured();
-                _asm_glue_tensors->tensors.add_const_tensor(TensorType::ACL_SRC_2, c);
-                _asm_glue_tensors->tensors.add_tensor(TensorType::ACL_DST, output);
+                _asm_glue_tensors.add_const_tensor(TensorType::ACL_SRC_2, c);
+                _asm_glue_tensors.add_tensor(TensorType::ACL_DST, output);
             }
             else
             {
                 auto output_to_use = (_fuse_output_stage ? &_mm_result_s32 : output);
                 _asm_glue->configure(a_to_use->info(), b->info(), nullptr, output_to_use->info(), asm_info);
-                _asm_glue_tensors->tensors.add_tensor(TensorType::ACL_DST, output_to_use);
+                _asm_glue_tensors.add_tensor(TensorType::ACL_DST, output_to_use);
             }
             _assembly_path = _asm_glue->is_configured();
-            _asm_glue_tensors->tensors.add_const_tensor(TensorType::ACL_SRC_0, a_to_use);
-            _asm_glue_tensors->tensors.add_const_tensor(TensorType::ACL_SRC_1, b);
-
-            if(_assembly_path)
-            {
-                _asm_glue_tensors->ws = manage_workspace<Tensor>(_asm_glue->workspace(), _memory_group, _asm_glue_tensors->tensors);
-            }
-
+            _asm_glue_tensors.add_const_tensor(TensorType::ACL_SRC_0, a_to_use);
+            _asm_glue_tensors.add_const_tensor(TensorType::ACL_SRC_1, b);
             break;
         }
         default:
@@ -533,7 +520,7 @@
     // Run GEMM
     if(_asm_glue->is_configured())
     {
-        _asm_glue->run(_asm_glue_tensors->tensors);
+        _asm_glue->run(_asm_glue_tensors);
     }
     else
     {
@@ -603,7 +590,7 @@
                 ARM_COMPUTE_ERROR_ON(!_original_b->is_used());
             }
 
-            _asm_glue->prepare(_asm_glue_tensors->tensors);
+            _asm_glue->prepare(_asm_glue_tensors);
             if(!original_b_managed_by_weights_manager)
             {
                 _original_b->mark_as_unused();
diff --git a/src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp b/src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp
index 7b7b68a..e50099d 100644
--- a/src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp
+++ b/src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp
@@ -53,13 +53,11 @@
                                                                                ActivationLayerInfo::ActivationFunction::BOUNDED_RELU,
                                                                                ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU
                                                                              };
-
     PixelValue type_min{};
     PixelValue type_max{};
-
     std::tie(type_min, type_max) = get_min_max(data_type);
-    int32_t min_activation       = type_min.get<int32_t>();
-    int32_t max_activation       = type_max.get<int32_t>();
+    int32_t min_activation = type_min.get<int32_t>();
+    int32_t max_activation = type_max.get<int32_t>();
     if(supported_acts.count(act.activation()) != 0)
     {
         std::tie(min_activation, max_activation) = get_quantized_activation_min_max(act, data_type, uoqinfo);
@@ -89,8 +87,8 @@
 }
 } // namespace
 
-CpuGemmDirectConv2d::CpuGemmDirectConv2d()
-    : _gemm_asm_func(std::make_unique<CpuGemmAssemblyDispatch>()),
+CpuGemmDirectConv2d::CpuGemmDirectConv2d(const std::shared_ptr<IMemoryManager> &memory_manager)
+    : _gemm_asm_func(std::make_unique<CpuGemmAssemblyDispatch>(memory_manager)),
       _activation_func(std::make_unique<CpuActivation>()),
       _weights_permute_func(std::make_unique<CpuPermute>()),
       _permuted_weights_info(),
@@ -165,8 +163,6 @@
 }
 void CpuGemmDirectConv2d::run(ITensorPack &tensors)
 {
-    import_workspace_memory(tensors);
-
     prepare(tensors);
 
     _gemm_asm_func->run(tensors);
@@ -174,14 +170,22 @@
     {
         _activation_func->run(tensors);
     }
+}
 
-    free_imported_workspace_memory();
+void CpuGemmDirectConv2d::allocate_permuted_weights()
+{
+    // TODO: This function will be removed when memory injection is implemeted.
+    ARM_COMPUTE_ERROR_ON(_permuted_weights == nullptr);
+    _permuted_weights->allocator()->free();
+    _permuted_weights->allocator()->init(_permuted_weights_info);
+    _permuted_weights->allocator()->allocate();
 }
 
 void CpuGemmDirectConv2d::prepare(ITensorPack &tensors)
 {
     if(!_is_prepared)
     {
+        allocate_permuted_weights();
         ITensorPack permute_tensors
         {
             { TensorType::ACL_SRC, tensors.get_const_tensor(TensorType::ACL_SRC_1) },
@@ -198,41 +202,5 @@
     }
 }
 
-experimental::MemoryRequirements CpuGemmDirectConv2d::workspace() const
-{
-    experimental::MemoryRequirements req = _gemm_asm_func->workspace();
-
-    auto index = static_cast<std::underlying_type<TensorType>::type>(TensorType::ACL_INT_0);
-
-    if(req.size() > 0)
-    {
-        index = req.back().slot + 1;
-
-        constexpr auto max_index = static_cast<std::underlying_type<TensorType>::type>(TensorType::ACL_INT_4);
-        ARM_COMPUTE_UNUSED(max_index); // in order to prevent build error with assertion is disabled.
-        ARM_COMPUTE_ERROR_ON(index > max_index);
-    }
-
-    req.emplace_back(index, _permuted_weights_info.total_size(), 0);
-
-    return req;
-}
-
-void CpuGemmDirectConv2d::import_workspace_memory(ITensorPack &tensors)
-{
-    auto imported_tensor = tensors.get_tensor(workspace().back().slot);
-
-    ARM_COMPUTE_ERROR_ON_NULLPTR(imported_tensor);
-
-    auto imported_memory = imported_tensor->buffer();
-    _permuted_weights->allocator()->init(_permuted_weights_info);
-    _permuted_weights->allocator()->import_memory(imported_memory);
-}
-
-void CpuGemmDirectConv2d::free_imported_workspace_memory()
-{
-    _permuted_weights->allocator()->free();
-}
-
 } // namespace cpu
 } // namespace arm_compute
\ No newline at end of file
diff --git a/src/runtime/cpu/operators/CpuGemmDirectConv2d.h b/src/runtime/cpu/operators/CpuGemmDirectConv2d.h
index 305a076..6aa17c2 100644
--- a/src/runtime/cpu/operators/CpuGemmDirectConv2d.h
+++ b/src/runtime/cpu/operators/CpuGemmDirectConv2d.h
@@ -48,7 +48,7 @@
 {
 public:
     /** Constructor */
-    CpuGemmDirectConv2d();
+    CpuGemmDirectConv2d(const std::shared_ptr<IMemoryManager> &memory_manager = nullptr);
     ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(CpuGemmDirectConv2d);
     /** Destructor */
     ~CpuGemmDirectConv2d();
@@ -80,16 +80,15 @@
     void configure(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *dst, const Conv2dInfo &info);
     /** Static function to check if given info will lead to a valid configuration of @ref CpuGemmDirectConv2d
      *
-     * Similar to @ref CpuGemmDirectConv2d::configure()
+     * Similar to CpuGemmDirectConv2d::configure()
      *
      * @return a status
      */
     static Status validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const Conv2dInfo &info);
 
     // Inherited methods overridden:
-    void                             run(ITensorPack &tensors) override;
-    void                             prepare(ITensorPack &constants) override;
-    experimental::MemoryRequirements workspace() const override;
+    void run(ITensorPack &tensors) override;
+    void prepare(ITensorPack &constants) override;
 
 private:
     std::unique_ptr<CpuGemmAssemblyDispatch> _gemm_asm_func;
@@ -101,13 +100,11 @@
     bool                                     _is_prepared{ false };
     bool                                     _run_activation{ false };
 
-    /** Function to import workspace tensors
+    /** Function to allocated a tensor for permuted weights
      *
-     * @param[in] tensors Tensor pack includes workspace tensors
+     * @note This function will be removed when memory injection is properly implemented.
      */
-    void import_workspace_memory(ITensorPack &tensors);
-    /** Function free used workspace tensors */
-    void free_imported_workspace_memory();
+    void allocate_permuted_weights();
 };
 } // namespace cpu
 } // namespace arm_compute
diff --git a/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
index 53d71a3..ea3742f 100644
--- a/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
+++ b/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
@@ -204,11 +204,11 @@
     }
 
 private:
-    Tensor                                                       _output{};
-    int                                                          _ldb{};
-    const TypeInput                                             *_in1_ptr{};
-    int                                                          _multi_stride_b{};
-    size_t                                                       _B_pretranspose_size{};
+    Tensor           _output{};
+    int              _ldb{};
+    const TypeInput *_in1_ptr{};
+    int              _multi_stride_b{};
+    size_t           _B_pretranspose_size{};
     std::shared_ptr<arm_gemm::GemmCommon<TypeInput, TypeOutput>> _gemm_kernel_asm{ nullptr };
 };
 
@@ -240,7 +240,7 @@
      */
     void configure(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d,
                    arm_gemm::GemmArgs args, const AsmGemmInfo &gemm_info,
-                   IWeightsManager *weights_manager, const OutputStage &os = {});
+                   MemoryGroup &memory_group, IWeightsManager *weights_manager, const OutputStage &os = {});
 
     /** Set requantization shifts to be used
      *
@@ -265,42 +265,13 @@
     bool is_configured() const override;
 
 private:
-    static constexpr size_t _workspace_alignment{ 4096 };
-    /** Function to get the memory requirements */
-    experimental::MemoryRequirements get_workspace() const override
-    {
-        experimental::MemoryRequirements req{};
-        const auto                       size = _gemm_kernel_asm->get_working_size();
-        if(size > 0)
-        {
-            req.emplace_back(TensorType::ACL_INT, size, _workspace_alignment);
-        }
-        return req;
-    }
-
-    /** Function to import workspace tensors
+    /** Allocate a workspace tensor.
      *
-     * @param[in] tensors Tensor pack includes workspace tensors
+     * @param[in] workspace_size Size to allocate.
+     * @param[in] memory_group   Tensor memory group.
+     * @param[in] alignment      Workspace memory alignment.
      */
-    void import_workspace(ITensorPack &tensors)
-    {
-        const auto size = _gemm_kernel_asm->get_working_size();
-
-        if(size > 0)
-        {
-            auto imported_tensor = tensors.get_tensor(TensorType::ACL_INT);
-            ARM_COMPUTE_ERROR_ON_NULLPTR(imported_tensor);
-            const size_t workspace_size = _gemm_kernel_asm->get_working_size();
-            _workspace.allocator()->init(TensorInfo(TensorShape{ (workspace_size + _workspace_alignment) }, 1, DataType::S8), _workspace_alignment);
-            _workspace.allocator()->import_memory(imported_tensor->buffer());
-        }
-    }
-    /** Function free used workspace tensors */
-    void free_imported_workspace()
-    {
-        _workspace.allocator()->free();
-    }
-
+    void allocate_workspace(size_t workspace_size, MemoryGroup &memory_group, size_t alignment);
     /** Configure the indirect buffer
      *
      * @param[in]  a    Input tensor containing the Matrix A.
@@ -339,8 +310,8 @@
     /** Indirect buffer */
     std::unique_ptr<const TypeInput *const *, free_delete> _indirect_arg{};
     std::unique_ptr<const TypeInput *, free_delete>        _indirect_buf{};
-    std::vector<TypeInput>                                 _indirect_pad{};
-    arm_gemm::ConvolutionParameters                        _cp{};
+    std::vector<TypeInput>          _indirect_pad{};
+    arm_gemm::ConvolutionParameters _cp{};
 
     bool is_weight_managed()
     {
@@ -363,9 +334,6 @@
 };
 
 template <typename TypeInput, typename TypeOutput, class OutputStage>
-constexpr size_t Fallback<TypeInput, TypeOutput, OutputStage>::_workspace_alignment;
-
-template <typename TypeInput, typename TypeOutput, class OutputStage>
 std::tuple<bool, const int32_t *, const int32_t *, const int32_t *>
 Fallback<TypeInput, TypeOutput, OutputStage>::set_requantize_data(const std::vector<int32_t> &shifts, const std::vector<int32_t> &multipliers)
 {
@@ -502,7 +470,7 @@
 template <typename TypeInput, typename TypeOutput, class OutputStage>
 void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d,
                                                              arm_gemm::GemmArgs args, const AsmGemmInfo &gemm_info,
-                                                             IWeightsManager *weights_manager, const OutputStage &os)
+                                                             MemoryGroup &memory_group, IWeightsManager *weights_manager, const OutputStage &os)
 {
     ARM_COMPUTE_UNUSED(c);
     arm_gemm::GemmConfig gemm_cfg;
@@ -524,6 +492,13 @@
     auto acl_gemm_wrapper = std::make_unique<kernel::CpuGemmAssemblyWrapperKernel<TypeInput, TypeOutput>>();
     ARM_COMPUTE_ERROR_ON(acl_gemm_wrapper == nullptr);
     acl_gemm_wrapper->configure(_gemm_kernel_asm.get(), gemm_cfg.filter);
+    const size_t workspace_size = _gemm_kernel_asm->get_working_size();
+    if(workspace_size > 0)
+    {
+        // Allocate workspace
+        const unsigned int alignment = 4096;
+        allocate_workspace(workspace_size, memory_group, alignment);
+    }
 
     //if we disable this code below in brackets then ConvLayer deadlocks when threads > 1 and
     //the shapes are In=1x1x1024 Weights=1x1x1024x1001 Biases=1001 Out=1x1x1001
@@ -612,6 +587,15 @@
 }
 
 template <typename TypeInput, typename TypeOutput, class OutputStage>
+void Fallback<TypeInput, TypeOutput, OutputStage>::allocate_workspace(size_t workspace_size, MemoryGroup &memory_group, size_t alignment)
+{
+    ARM_COMPUTE_ERROR_ON_MSG(workspace_size == 0, "size cannot be 0");
+    _workspace.allocator()->init(TensorInfo(TensorShape{ (workspace_size + alignment) }, 1, DataType::S8), alignment);
+    memory_group.manage(&_workspace);
+    _workspace.allocator()->allocate();
+}
+
+template <typename TypeInput, typename TypeOutput, class OutputStage>
 bool Fallback<TypeInput, TypeOutput, OutputStage>::is_configured() const
 {
     return _optimised_kernel != nullptr;
@@ -625,10 +609,6 @@
     auto c = tensors.get_const_tensor(TensorType::ACL_SRC_2);
     auto d = tensors.get_tensor(TensorType::ACL_DST);
 
-    ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d);
-
-    import_workspace(tensors);
-
     int       lda = a->info()->strides_in_bytes().y() / sizeof(TypeInput);
     int       ldb = 0;
     const int ldd = d->info()->strides_in_bytes().y() / sizeof(TypeOutput);
@@ -704,11 +684,10 @@
                                  bias, 0);
     // Schedule
     NEScheduler::get().schedule(_optimised_kernel.get(), scheduling_hint);
-    free_imported_workspace();
 }
 
 template <typename TypeInput, typename TypeOutput>
-void create_arm_gemm(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_gemm,
+void create_arm_gemm(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_gemm, MemoryGroup &memory_group,
                      const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d, arm_gemm::Activation activation, const AsmGemmInfo &info,
                      IWeightsManager *weights_manager)
 {
@@ -720,12 +699,12 @@
 
     // Create arm_gemm fallback
     auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput>>();
-    fallback->configure(a, b, c, d, args, info, weights_manager);
+    fallback->configure(a, b, c, d, args, info, memory_group, weights_manager);
     arm_gemm = std::move(fallback);
 }
 
 template <typename TypeInput, typename TypeOutput>
-void create_arm_gemm_quant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_gemm,
+void create_arm_gemm_quant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_gemm, MemoryGroup &memory_group,
                            const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d, arm_gemm::Activation activation, const AsmGemmInfo &info,
                            IWeightsManager *weights_manager)
 {
@@ -765,14 +744,14 @@
     }
 
     // Configure fallback
-    fallback->configure(a, b, c, d, args, info, weights_manager, gemm_requant_info);
+    fallback->configure(a, b, c, d, args, info, memory_group, weights_manager, gemm_requant_info);
     arm_gemm = std::move(fallback);
 }
 
 } //namespace
 
-CpuGemmAssemblyDispatch::CpuGemmAssemblyDispatch(IWeightsManager *weights_manager)
-    : _arm_gemm(nullptr), _weights_manager(weights_manager)
+CpuGemmAssemblyDispatch::CpuGemmAssemblyDispatch(std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager)
+    : _arm_gemm(nullptr), _memory_group(std::move(memory_manager)), _weights_manager(weights_manager)
 {
 }
 
@@ -827,40 +806,40 @@
     switch(a->data_type())
     {
         case DataType::F32:
-            create_arm_gemm<float, float>(_arm_gemm, a, b, c, d, act, info, _weights_manager);
+            create_arm_gemm<float, float>(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager);
             break;
 #ifdef __aarch64__
         case DataType::U8:
         case DataType::QASYMM8:
             if(d->data_type() == DataType::S32)
             {
-                create_arm_gemm<uint8_t, uint32_t>(_arm_gemm, a, b, c, d, act, info, _weights_manager);
+                create_arm_gemm<uint8_t, uint32_t>(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager);
             }
             else
             {
-                create_arm_gemm_quant<uint8_t, uint8_t>(_arm_gemm, a, b, c, d, act, info, _weights_manager);
+                create_arm_gemm_quant<uint8_t, uint8_t>(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager);
             }
             break;
         case DataType::S8:
         case DataType::QASYMM8_SIGNED:
             if(d->data_type() == DataType::S32)
             {
-                create_arm_gemm<int8_t, int32_t>(_arm_gemm, a, b, c, d, act, info, _weights_manager);
+                create_arm_gemm<int8_t, int32_t>(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager);
             }
             else
             {
-                create_arm_gemm_quant<int8_t, int8_t>(_arm_gemm, a, b, c, d, act, info, _weights_manager);
+                create_arm_gemm_quant<int8_t, int8_t>(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager);
             }
             break;
 #endif /* __aarch64__ */
 #if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16)
         case DataType::BFLOAT16:
-            create_arm_gemm<bfloat16, float>(_arm_gemm, a, b, c, d, act, info, _weights_manager);
+            create_arm_gemm<bfloat16, float>(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager);
             break;
 #endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */
 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
         case DataType::F16:
-            create_arm_gemm<float16_t, float16_t>(_arm_gemm, a, b, c, d, act, info, _weights_manager);
+            create_arm_gemm<float16_t, float16_t>(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager);
             break;
 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
         default:
@@ -881,13 +860,10 @@
 
 void CpuGemmAssemblyDispatch::run(ITensorPack &tensors)
 {
+    MemoryGroupResourceScope scope_mg(_memory_group);
+
     ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr);
     _arm_gemm->run(tensors);
 }
-
-experimental::MemoryRequirements CpuGemmAssemblyDispatch::workspace() const
-{
-    return is_configured() ? _arm_gemm->get_workspace() : experimental::MemoryRequirements{};
-}
 } // namespace cpu
 } // namespace arm_compute
diff --git a/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h b/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h
index 154def6..ffc097c 100644
--- a/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h
+++ b/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h
@@ -24,6 +24,7 @@
 #ifndef ARM_COMPUTE_CPU_INTERNAL_CPU_GEMM_ASSEMBLY_DISPATCH_H
 #define ARM_COMPUTE_CPU_INTERNAL_CPU_GEMM_ASSEMBLY_DISPATCH_H
 
+#include "arm_compute/runtime/IMemoryManager.h"
 #include "arm_compute/runtime/IWeightsManager.h"
 #include "arm_compute/runtime/MemoryGroup.h"
 #include "arm_compute/runtime/Tensor.h"
@@ -61,7 +62,7 @@
 {
 public:
     /** Constructor */
-    CpuGemmAssemblyDispatch(IWeightsManager *weights_manager = nullptr);
+    CpuGemmAssemblyDispatch(std::shared_ptr<IMemoryManager> memory_manager = nullptr, IWeightsManager *weights_manager = nullptr);
     /** Defautl destructor */
     ~CpuGemmAssemblyDispatch() = default;
 
@@ -70,11 +71,10 @@
     class IFallback
     {
     public:
-        virtual void run(ITensorPack &tensors)                         = 0;
-        virtual void prepare(ITensorPack &tensors)                     = 0;
-        virtual bool is_configured() const                             = 0;
-        virtual ~IFallback()                                           = default;
-        virtual experimental::MemoryRequirements get_workspace() const = 0;
+        virtual void run(ITensorPack &tensors)     = 0;
+        virtual void prepare(ITensorPack &tensors) = 0;
+        virtual bool is_configured() const         = 0;
+        virtual ~IFallback()                       = default;
     };
 
 public:
@@ -113,12 +113,12 @@
     bool is_configured() const;
 
     // Inherited methods overridden:
-    void                             prepare(ITensorPack &tensors) override;
-    void                             run(ITensorPack &tensors) override;
-    experimental::MemoryRequirements workspace() const override;
+    void prepare(ITensorPack &tensors) override;
+    void run(ITensorPack &tensors) override;
 
 private:
     std::unique_ptr<IFallback> _arm_gemm;        /**< Interface for the arm_gemm fallback */
+    MemoryGroup                _memory_group;    /**< Function memory group */
     IWeightsManager           *_weights_manager; /**< Pointer to the weights manager */
 };
 } // namespace cpu