Fix incorrect memory handling in ported functions

Details of the functions:
- ClSoftmax
- CpuSoftmax
- CpuPool2d

Change-Id: Icd2c14d5df010c3b2301e2693ce6f414d7c61916
Resolves: COMPMID-4404
Signed-off-by: Manuel Bottini <manuel.bottini@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5797
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/core/helpers/MemoryHelpers.h b/src/core/helpers/MemoryHelpers.h
index 6756a90..e751e60 100644
--- a/src/core/helpers/MemoryHelpers.h
+++ b/src/core/helpers/MemoryHelpers.h
@@ -46,6 +46,15 @@
 template <typename TensorType>
 WorkspaceData<TensorType> manage_workspace(const experimental::MemoryRequirements &mem_reqs,
                                            MemoryGroup                            &mgroup,
+                                           ITensorPack                            &run_pack)
+{
+    ITensorPack dummy_pack = ITensorPack();
+    return manage_workspace<TensorType>(mem_reqs, mgroup, run_pack, dummy_pack);
+}
+
+template <typename TensorType>
+WorkspaceData<TensorType> manage_workspace(const experimental::MemoryRequirements &mem_reqs,
+                                           MemoryGroup                            &mgroup,
                                            ITensorPack &run_pack, ITensorPack &prep_pack)
 {
     WorkspaceData<TensorType> workspace_memory;
diff --git a/src/runtime/CL/functions/CLSoftmaxLayer.cpp b/src/runtime/CL/functions/CLSoftmaxLayer.cpp
index fe45f65..de58bf1 100644
--- a/src/runtime/CL/functions/CLSoftmaxLayer.cpp
+++ b/src/runtime/CL/functions/CLSoftmaxLayer.cpp
@@ -29,6 +29,7 @@
 #include "arm_compute/core/Types.h"
 #include "arm_compute/core/Utils.h"
 #include "src/core/gpu/cl/kernels/ClSoftmaxKernel.h"
+#include "src/core/helpers/MemoryHelpers.h"
 #include "src/runtime/gpu/cl/operators/ClPermute.h"
 #include "src/runtime/gpu/cl/operators/ClSoftmax.h"
 
@@ -43,7 +44,8 @@
     ICLTensor                    *dst{ nullptr };
     std::unique_ptr<OperatorType> op{ nullptr };
     MemoryGroup                   memory_group{};
-    std::vector<std::pair<int, std::unique_ptr<CLTensor>>> workspace_tensors{};
+    ITensorPack                   run_pack{};
+    WorkspaceData<CLTensor>       workspace_tensors{};
 };
 
 template <bool IS_LOG>
@@ -71,7 +73,9 @@
 
     SoftmaxKernelInfo softmax_info{ beta, IS_LOG, input->info()->data_type(), axis };
     _impl->op->configure(compile_context, *input->info(), *output->info(), softmax_info);
-    allocate_workspace();
+
+    _impl->run_pack          = { { TensorType::ACL_SRC, _impl->src }, { TensorType::ACL_DST, _impl->dst } };
+    _impl->workspace_tensors = manage_workspace<CLTensor>(_impl->op->workspace(), _impl->memory_group, _impl->run_pack);
 }
 
 template <bool IS_LOG>
@@ -82,46 +86,12 @@
 }
 
 template <bool IS_LOG>
-void           CLSoftmaxLayerGeneric<IS_LOG>::allocate_workspace()
-{
-    const auto memory_requirements = _impl->op->workspace();
-    std::for_each(memory_requirements.begin(), memory_requirements.end(), [this](const experimental::MemoryInfo & memory_info)
-    {
-        auto tensor_info = TensorInfo{ TensorShape(memory_info.size), 1, DataType::U8 };
-        _impl->workspace_tensors.emplace_back(memory_info.slot, std::make_unique<CLTensor>());
-        auto tensor = _impl->workspace_tensors.back().second.get();
-        ARM_COMPUTE_ERROR_ON_NULLPTR(tensor);
-        tensor->allocator()->init(tensor_info);
-        _impl->memory_group.manage(tensor);
-    });
-
-    std::for_each(_impl->workspace_tensors.begin(), _impl->workspace_tensors.end(), [](std::pair<int, std::unique_ptr<CLTensor>> &wt)
-    {
-        auto tensor = wt.second.get();
-        tensor->allocator()->allocate();
-    });
-}
-
-template <bool IS_LOG>
 void           CLSoftmaxLayerGeneric<IS_LOG>::run()
 {
     // Acquire all the temporaries
     MemoryGroupResourceScope scope_mg(_impl->memory_group);
-
     ARM_COMPUTE_ERROR_ON_NULLPTR(_impl->src, _impl->dst);
-
-    ITensorPack pack;
-    pack.add_tensor(TensorType::ACL_SRC, _impl->src);
-    pack.add_tensor(TensorType::ACL_DST, _impl->dst);
-
-    std::for_each(_impl->workspace_tensors.begin(), _impl->workspace_tensors.end(), [&pack](std::pair<int, std::unique_ptr<CLTensor>> &wt)
-    {
-        auto tensor = wt.second.get();
-        ARM_COMPUTE_ERROR_ON_NULLPTR(tensor);
-        pack.add_tensor(wt.first, tensor);
-    });
-
-    _impl->op->run(pack);
+    _impl->op->run(_impl->run_pack);
 }
 
 template class CLSoftmaxLayerGeneric<false>;
diff --git a/src/runtime/NEON/functions/NEPoolingLayer.cpp b/src/runtime/NEON/functions/NEPoolingLayer.cpp
index bbf3e7c..8d267a3 100644
--- a/src/runtime/NEON/functions/NEPoolingLayer.cpp
+++ b/src/runtime/NEON/functions/NEPoolingLayer.cpp
@@ -26,6 +26,7 @@
 #include "arm_compute/core/TensorInfo.h"
 #include "arm_compute/core/Validate.h"
 #include "arm_compute/runtime/Tensor.h"
+#include "src/core/helpers/MemoryHelpers.h"
 #include "src/runtime/cpu/operators/CpuPool2d.h"
 
 namespace arm_compute
@@ -35,15 +36,18 @@
     ITensor                        *src{ nullptr };
     ITensor                        *dst{ nullptr };
     ITensor                        *indices{ nullptr };
-    Tensor                          workspace{ nullptr };
     std::unique_ptr<cpu::CpuPool2d> op{ nullptr };
+    MemoryGroup                     memory_group{};
+    ITensorPack                     run_pack{};
+    WorkspaceData<Tensor>           workspace_tensors{};
 };
 
 NEPoolingLayer::~NEPoolingLayer() = default;
 
 NEPoolingLayer::NEPoolingLayer(std::shared_ptr<IMemoryManager> memory_manager)
-    : _memory_group(memory_manager), _impl(std::make_unique<Impl>())
+    : _impl(std::make_unique<Impl>())
 {
+    _impl->memory_group = MemoryGroup(std::move(memory_manager));
 }
 
 void NEPoolingLayer::configure(ITensor *input, ITensor *output, const PoolingLayerInfo &pool_info, ITensor *indices)
@@ -54,14 +58,8 @@
     _impl->op      = std::make_unique<cpu::CpuPool2d>();
     _impl->op->configure(input->info(), output->info(), pool_info, (indices) ? indices->info() : nullptr);
 
-    // Allocate workspace based on kernel's memory requirements
-    const experimental::MemoryRequirements mem_req = _impl->op->workspace();
-    if(!mem_req.empty())
-    {
-        _impl->workspace.allocator()->init(TensorInfo(TensorShape{ (mem_req[0].size + mem_req[0].alignment) }, 1, DataType::S8), mem_req[0].alignment);
-        _memory_group.manage(&_impl->workspace);
-        _impl->workspace.allocator()->allocate();
-    }
+    _impl->run_pack          = { { TensorType::ACL_SRC, _impl->src }, { TensorType::ACL_DST_0, _impl->dst }, { TensorType::ACL_DST_1, _impl->indices } };
+    _impl->workspace_tensors = manage_workspace<Tensor>(_impl->op->workspace(), _impl->memory_group, _impl->run_pack);
 }
 
 Status NEPoolingLayer::validate(const ITensorInfo *input, const ITensorInfo *output, const PoolingLayerInfo &pool_info, const ITensorInfo *indices)
@@ -71,11 +69,8 @@
 
 void NEPoolingLayer::run()
 {
-    ITensorPack pack;
-    pack.add_tensor(TensorType::ACL_SRC, _impl->src);
-    pack.add_tensor(TensorType::ACL_DST_0, _impl->dst);
-    pack.add_tensor(TensorType::ACL_DST_1, _impl->indices);
-    pack.add_tensor(TensorType::ACL_INT_0, &_impl->workspace);
-    _impl->op->run(pack);
+    MemoryGroupResourceScope scope_mg(_impl->memory_group);
+    ARM_COMPUTE_ERROR_ON_NULLPTR(_impl->src, _impl->dst);
+    _impl->op->run(_impl->run_pack);
 }
 } // namespace arm_compute
diff --git a/src/runtime/NEON/functions/NESoftmaxLayer.cpp b/src/runtime/NEON/functions/NESoftmaxLayer.cpp
index 3f1e43a..af8546d 100644
--- a/src/runtime/NEON/functions/NESoftmaxLayer.cpp
+++ b/src/runtime/NEON/functions/NESoftmaxLayer.cpp
@@ -23,6 +23,7 @@
  */
 #include "arm_compute/runtime/NEON/functions/NESoftmaxLayer.h"
 #include "arm_compute/core/Validate.h"
+#include "arm_compute/runtime/MemoryGroup.h"
 #include "arm_compute/runtime/Tensor.h"
 #include "src/core/cpu/kernels/CpuSoftmaxKernel.h"
 #include "src/core/helpers/SoftmaxHelpers.h"
@@ -36,16 +37,17 @@
     const ITensor                                  *src{ nullptr };
     ITensor                                        *dst{ nullptr };
     Tensor                                          max{ nullptr };
-    Tensor                                          tmp{ nullptr };
-    Tensor                                          input_permuted{ nullptr };
-    Tensor                                          output_permuted{ nullptr };
     std::unique_ptr<cpu::CpuSoftmaxGeneric<IS_LOG>> op{ nullptr };
+    MemoryGroup                                     memory_group{};
+    ITensorPack                                     run_pack{};
+    WorkspaceData<Tensor>                           workspace_tensors{};
 };
 
 template <bool IS_LOG>
 NESoftmaxLayerGeneric<IS_LOG>::NESoftmaxLayerGeneric(std::shared_ptr<IMemoryManager> memory_manager)
-    : _memory_group(std::move(memory_manager)), _impl(std::make_unique<Impl>())
+    : _impl(std::make_unique<Impl>())
 {
+    _impl->memory_group = MemoryGroup(std::move(memory_manager));
 }
 
 template <bool IS_LOG>
@@ -65,64 +67,8 @@
     _impl->op  = std::make_unique<cpu::CpuSoftmaxGeneric<IS_LOG>>();
     _impl->op->configure(input->info(), output->info(), beta, axis);
 
-    const unsigned int actual_axis   = static_cast<unsigned int>(wrap_around(axis, static_cast<int32_t>(input->info()->num_dimensions())));
-    const bool         needs_permute = actual_axis > 0;
-    if(needs_permute)
-    {
-        // Add to the memory manager _input_permuted
-        auto permute_input = std::make_unique<cpu::CpuPermute>();
-        _memory_group.manage(&_impl->input_permuted);
-        permute_input->configure(input->info(), _impl->input_permuted.info(), softmax_helpers::get_permutation_vector_from_softmax_axis(actual_axis));
-    }
-
-    // We want to deal with a 2D input. Either it is the permuted version of the original input (4D case)
-    // or it is the original input case (2D case)
-    ITensor *tmp_input = (needs_permute ? &_impl->input_permuted : input);
-
-    // Create intermediate tensors shapes
-    const TensorInfo input_info    = tmp_input->info()->clone()->reset_padding().set_is_resizable(true);
-    DataType         tmp_data_type = is_data_type_quantized_asymmetric(tmp_input->info()->data_type()) ? DataType::F32 : tmp_input->info()->data_type();
-    TensorInfo       tensor_info_tmp(input_info.clone()->set_data_type(tmp_data_type));
-
-    // Init intermediate tensors
-    TensorShape max_sum_shape = tmp_input->info()->tensor_shape();
-    max_sum_shape.set(0, 1);
-    _impl->max.allocator()->init(input_info.clone()->set_tensor_shape(max_sum_shape));
-    _impl->tmp.allocator()->init(tensor_info_tmp);
-
-    // Manage intermediate buffers
-    _memory_group.manage(&_impl->max);
-    _memory_group.manage(&_impl->tmp);
-
-    // Configure kernels
-    auto max_kernel     = std::make_unique<cpu::kernels::CpuLogits1DMaxKernel>();
-    auto softmax_kernel = std::make_unique<cpu::kernels::CpuLogits1DSoftmaxKernel<IS_LOG>>();
-    max_kernel->configure(tmp_input->info(), _impl->max.info());
-
-    if(needs_permute)
-    {
-        auto permute_output = std::make_unique<cpu::CpuPermute>();
-        // Add to the memory manager _output_permuted
-        _memory_group.manage(&_impl->output_permuted);
-
-        // The normalization kernel stores the result in a permuted output tensor
-        softmax_kernel->configure(tmp_input->info(), _impl->max.info(), _impl->output_permuted.info(), beta, _impl->tmp.info());
-        _impl->input_permuted.allocator()->allocate();
-
-        // Re-permute the permuted output into the requested (4D) output
-        permute_output->configure(_impl->output_permuted.info(), output->info(), softmax_helpers::get_permutation_vector_from_softmax_axis(actual_axis));
-
-        // Allocate the intermediate permuted tensors
-        _impl->output_permuted.allocator()->allocate();
-    }
-    else
-    {
-        softmax_kernel->configure(tmp_input->info(), _impl->max.info(), output->info(), beta, _impl->tmp.info());
-    }
-
-    // Allocate intermediate buffers
-    _impl->max.allocator()->allocate();
-    _impl->tmp.allocator()->allocate();
+    _impl->run_pack          = { { TensorType::ACL_SRC, _impl->src }, { TensorType::ACL_DST, _impl->dst } };
+    _impl->workspace_tensors = manage_workspace<Tensor>(_impl->op->workspace(), _impl->memory_group, _impl->run_pack);
 }
 
 template <bool IS_LOG>
@@ -136,15 +82,10 @@
 template <bool IS_LOG>
 void           NESoftmaxLayerGeneric<IS_LOG>::run()
 {
-    MemoryGroupResourceScope scope_mg(_memory_group);
-    ITensorPack              pack;
-    pack.add_tensor(TensorType::ACL_SRC, _impl->src);
-    pack.add_tensor(TensorType::ACL_DST, _impl->dst);
-    pack.add_tensor(TensorType::ACL_INT_0, &_impl->tmp);
-    pack.add_tensor(TensorType::ACL_INT_1, &_impl->max);
-    pack.add_tensor(TensorType::ACL_INT_2, &_impl->input_permuted);
-    pack.add_tensor(TensorType::ACL_INT_3, &_impl->output_permuted);
-    _impl->op->run(pack);
+    // Acquire all the temporaries
+    MemoryGroupResourceScope scope_mg(_impl->memory_group);
+    ARM_COMPUTE_ERROR_ON_NULLPTR(_impl->src, _impl->dst);
+    _impl->op->run(_impl->run_pack);
 }
 
 template class NESoftmaxLayerGeneric<false>;
diff --git a/src/runtime/cpu/operators/CpuPool2d.cpp b/src/runtime/cpu/operators/CpuPool2d.cpp
index b225199..e746c8f 100644
--- a/src/runtime/cpu/operators/CpuPool2d.cpp
+++ b/src/runtime/cpu/operators/CpuPool2d.cpp
@@ -30,6 +30,8 @@
 #include "src/core/cpu/kernels/CpuPool2dKernel.h"
 #include "src/core/cpu/kernels/internal/CpuPool2dAssemblyWrapperKernel.h"
 
+using namespace arm_compute::experimental;
+
 namespace arm_compute
 {
 namespace cpu
@@ -40,7 +42,7 @@
       _asm_glue(),
       _is_global_pooling_layer(false),
       _data_layout(DataLayout::NCHW),
-      _mem_req()
+      _aux_mem(1)
 {
 }
 
@@ -71,7 +73,7 @@
         // Get kernel's memory requirements
         constexpr size_t alignment      = 4096;
         const size_t     workspace_size = pooling_wrapper->get_working_size(num_threads);
-        _mem_req.push_back({ TensorType::ACL_INT_0, workspace_size, alignment });
+        _aux_mem[0]                     = MemoryInfo(TensorType::ACL_INT_0, MemoryLifetime::Temporary, workspace_size, alignment);
 
         _asm_glue = std::move(pooling_wrapper);
     }
@@ -150,7 +152,7 @@
 
 experimental::MemoryRequirements CpuPool2d::workspace() const
 {
-    return _mem_req;
+    return _aux_mem;
 }
 } // namespace cpu
 } // namespace arm_compute
diff --git a/src/runtime/cpu/operators/CpuPool2d.h b/src/runtime/cpu/operators/CpuPool2d.h
index ae3d115..68416b5 100644
--- a/src/runtime/cpu/operators/CpuPool2d.h
+++ b/src/runtime/cpu/operators/CpuPool2d.h
@@ -80,7 +80,7 @@
 
     bool                             _is_global_pooling_layer;
     DataLayout                       _data_layout;
-    experimental::MemoryRequirements _mem_req;
+    experimental::MemoryRequirements _aux_mem{};
 };
 } // namespace cpu
 } // namespace arm_compute
diff --git a/src/runtime/cpu/operators/CpuSoftmax.cpp b/src/runtime/cpu/operators/CpuSoftmax.cpp
index 0e1bcd5..e17925e 100644
--- a/src/runtime/cpu/operators/CpuSoftmax.cpp
+++ b/src/runtime/cpu/operators/CpuSoftmax.cpp
@@ -29,7 +29,11 @@
 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
 #include "arm_compute/runtime/NEON/NEScheduler.h"
 #include "src/core/cpu/kernels/CpuSoftmaxKernel.h"
+#include "src/core/helpers/MemoryHelpers.h"
 #include "src/core/helpers/SoftmaxHelpers.h"
+#include "src/runtime/cpu/utils/CpuAuxTensorHandler.h"
+
+using namespace arm_compute::experimental;
 
 namespace arm_compute
 {
@@ -37,7 +41,16 @@
 {
 template <bool IS_LOG>
 CpuSoftmaxGeneric<IS_LOG>::CpuSoftmaxGeneric()
-    : _permute_input(), _permute_output(), _max_kernel(), _softmax_kernel(), _max(nullptr), _tmp(nullptr), _input_permuted(nullptr), _output_permuted(nullptr), _needs_permute(false)
+    : _permute_input(),
+      _permute_output(),
+      _max_kernel(),
+      _softmax_kernel(),
+      _max(),
+      _tmp(),
+      _input_permuted(),
+      _output_permuted(),
+      _needs_permute(false),
+      _aux_mem(InternalTensorIdx::COUNT)
 {
 }
 
@@ -54,13 +67,12 @@
 
     if(_needs_permute)
     {
-        _input_permuted = std::make_unique<TensorInfo>();
-        _permute_input.configure(src, _input_permuted.get(), softmax_helpers::get_permutation_vector_from_softmax_axis(actual_axis));
+        _permute_input.configure(src, &_input_permuted, softmax_helpers::get_permutation_vector_from_softmax_axis(actual_axis));
     }
 
     // We want to deal with a 2D input. Either it is the permuted version of the original input (4D case)
     // or it is the original input case (2D case)
-    const ITensorInfo *tmp_input = (_needs_permute ? _input_permuted.get() : src);
+    const ITensorInfo *tmp_input = (_needs_permute ? &_input_permuted : src);
 
     // Create intermediate tensors shapes
     TensorShape max_sum_shape = tmp_input->tensor_shape();
@@ -71,31 +83,35 @@
     TensorInfo       max_info(tmp_input->clone()->set_tensor_shape(max_sum_shape));
 
     // Init intermediate tensors
-    _max = std::make_unique<TensorInfo>(max_info);
-    _tmp = std::make_unique<TensorInfo>(tensor_info_tmp);
+    _max = TensorInfo(max_info);
+    _tmp = TensorInfo(tensor_info_tmp);
 
     // Configure kernels
     auto mk = std::make_unique<kernels::CpuLogits1DMaxKernel>();
-    mk->configure(tmp_input, _max.get());
+    mk->configure(tmp_input, &_max);
     _max_kernel = std::move(mk);
 
     auto sm = std::make_unique<kernels::CpuLogits1DSoftmaxKernel<IS_LOG>>();
     if(_needs_permute)
     {
-        _output_permuted = std::make_unique<TensorInfo>();
-
         // The normalization kernel stores the result in a permuted output tensor
-        sm->configure(tmp_input, _max.get(), _output_permuted.get(), beta, _tmp.get());
+        sm->configure(tmp_input, &_max, &_output_permuted, beta, &_tmp);
 
         // Re-permute the permuted output into the requested (4D) output
-        _permute_output.configure(_output_permuted.get(), dst, softmax_helpers::get_permutation_vector_from_softmax_axis(actual_axis));
+        _permute_output.configure(&_output_permuted, dst, softmax_helpers::get_permutation_vector_from_softmax_axis(actual_axis));
     }
     else
     {
         // Softmax 2D case
-        sm->configure(tmp_input, _max.get(), dst, beta, _tmp.get());
+        sm->configure(tmp_input, &_max, dst, beta, &_tmp);
     }
     _softmax_kernel = std::move(sm);
+
+    _aux_mem[InternalTensorIdx::MAX] = MemoryInfo(offset_int_vec(InternalTensorIdx::MAX), MemoryLifetime::Temporary, _max.total_size());
+    _aux_mem[InternalTensorIdx::TMP] = MemoryInfo(offset_int_vec(InternalTensorIdx::TMP), MemoryLifetime::Temporary, _tmp.total_size());
+
+    _aux_mem[InternalTensorIdx::PERMUTED_SRC] = MemoryInfo(offset_int_vec(InternalTensorIdx::PERMUTED_SRC), MemoryLifetime::Temporary, _input_permuted.total_size());
+    _aux_mem[InternalTensorIdx::PERMUTED_DST] = MemoryInfo(offset_int_vec(InternalTensorIdx::PERMUTED_DST), MemoryLifetime::Temporary, _output_permuted.total_size());
 }
 
 template <bool IS_LOG>
@@ -141,33 +157,45 @@
 {
     ARM_COMPUTE_ERROR_ON_MSG(tensors.empty(), "No inputs provided");
 
+    auto src = tensors.get_const_tensor(TensorType::ACL_SRC);
+    auto dst = tensors.get_tensor(TensorType::ACL_DST);
+
+    CpuAuxTensorHandler tmp(offset_int_vec(InternalTensorIdx::TMP), _tmp, tensors, false);
+    CpuAuxTensorHandler max(offset_int_vec(InternalTensorIdx::MAX), _max, tensors, false);
+
+    CpuAuxTensorHandler input_permuted(offset_int_vec(InternalTensorIdx::PERMUTED_SRC), _input_permuted, tensors, false);
+    CpuAuxTensorHandler output_permuted(offset_int_vec(InternalTensorIdx::PERMUTED_DST), _output_permuted, tensors, false);
+
     ITensorPack max_pack;
     ITensorPack softmax_pack;
 
     if(_needs_permute)
     {
-        ITensorPack permute_in_pack;
-        permute_in_pack.add_tensor(TensorType::ACL_SRC, tensors.get_const_tensor(ACL_SRC));
-        permute_in_pack.add_tensor(TensorType::ACL_DST, tensors.get_tensor(ACL_INT_2));
+        ITensorPack permute_in_pack = { { TensorType::ACL_SRC, src }, { TensorType::ACL_DST, input_permuted.get() } };
         _permute_input.run(permute_in_pack);
 
-        max_pack.add_tensor(TensorType::ACL_SRC, tensors.get_tensor(ACL_INT_2));
+        max_pack = { { TensorType::ACL_SRC, input_permuted.get() }, { TensorType::ACL_DST, max.get() } };
 
-        softmax_pack.add_tensor(TensorType::ACL_SRC_0, tensors.get_tensor(ACL_INT_2));
-        softmax_pack.add_tensor(TensorType::ACL_SRC_1, tensors.get_tensor(ACL_INT_1));
-        softmax_pack.add_tensor(TensorType::ACL_DST_0, tensors.get_tensor(ACL_INT_3));
-        softmax_pack.add_tensor(TensorType::ACL_DST_1, tensors.get_tensor(ACL_INT_0));
+        softmax_pack =
+        {
+            { TensorType::ACL_SRC_0, input_permuted.get() },
+            { TensorType::ACL_SRC_1, max.get() },
+            { TensorType::ACL_DST_0, output_permuted.get() },
+            { TensorType::ACL_DST_1, tmp.get() }
+        };
     }
     else
     {
-        max_pack.add_tensor(TensorType::ACL_SRC, tensors.get_const_tensor(ACL_SRC));
-        softmax_pack.add_tensor(TensorType::ACL_SRC_0, tensors.get_const_tensor(ACL_SRC));
-        softmax_pack.add_tensor(TensorType::ACL_SRC_1, tensors.get_tensor(ACL_INT_1));
-        softmax_pack.add_tensor(TensorType::ACL_DST_0, tensors.get_tensor(ACL_DST));
-        softmax_pack.add_tensor(TensorType::ACL_DST_1, tensors.get_tensor(ACL_INT_0));
-    }
+        max_pack = { { TensorType::ACL_SRC, src }, { TensorType::ACL_DST, max.get() } };
 
-    max_pack.add_tensor(TensorType::ACL_DST, tensors.get_tensor(ACL_INT_1));
+        softmax_pack =
+        {
+            { TensorType::ACL_SRC_0, src },
+            { TensorType::ACL_SRC_1, max.get() },
+            { TensorType::ACL_DST_0, dst },
+            { TensorType::ACL_DST_1, tmp.get() }
+        };
+    }
 
     NEScheduler::get().schedule_op(_max_kernel.get(), Window::DimY, _max_kernel->window(), max_pack);
     NEScheduler::get().schedule_op(_softmax_kernel.get(), Window::DimY, _softmax_kernel->window(), softmax_pack);
@@ -175,8 +203,8 @@
     if(_needs_permute)
     {
         ITensorPack permute_out_pack;
-        permute_out_pack.add_tensor(TensorType::ACL_SRC, tensors.get_tensor(ACL_INT_3));
-        permute_out_pack.add_tensor(TensorType::ACL_DST, tensors.get_tensor(ACL_DST));
+        permute_out_pack.add_tensor(TensorType::ACL_SRC, output_permuted.get());
+        permute_out_pack.add_tensor(TensorType::ACL_DST, dst);
         _permute_output.run(permute_out_pack);
     }
 }
@@ -184,18 +212,7 @@
 template <bool                   IS_LOG>
 experimental::MemoryRequirements CpuSoftmaxGeneric<IS_LOG>::workspace() const
 {
-    experimental::MemoryRequirements req{};
-
-    req.push_back({ TensorType::ACL_INT_0, _tmp->total_size(), 0 });
-    req.push_back({ TensorType::ACL_INT_1, _max->total_size(), 0 });
-
-    if(_needs_permute)
-    {
-        req.push_back({ TensorType::ACL_INT_2, _input_permuted->total_size(), 0 });
-        req.push_back({ TensorType::ACL_INT_3, _output_permuted->total_size(), 0 });
-    }
-
-    return req;
+    return _aux_mem;
 }
 
 template class CpuSoftmaxGeneric<false>;
diff --git a/src/runtime/cpu/operators/CpuSoftmax.h b/src/runtime/cpu/operators/CpuSoftmax.h
index 9f18e0e..3881797 100644
--- a/src/runtime/cpu/operators/CpuSoftmax.h
+++ b/src/runtime/cpu/operators/CpuSoftmax.h
@@ -24,7 +24,7 @@
 #ifndef ARM_COMPUTE_CPU_SOFTMAX_H
 #define ARM_COMPUTE_CPU_SOFTMAX_H
 
-#include "arm_compute/core/ITensorInfo.h"
+#include "arm_compute/core/TensorInfo.h"
 #include "arm_compute/core/experimental/Types.h"
 #include "src/core/cpu/ICpuKernel.h"
 #include "src/runtime/cpu/ICpuOperator.h"
@@ -87,15 +87,27 @@
     experimental::MemoryRequirements workspace() const override;
 
 private:
-    CpuPermute                   _permute_input;
-    CpuPermute                   _permute_output;
-    std::unique_ptr<ICpuKernel>  _max_kernel;
-    std::unique_ptr<ICpuKernel>  _softmax_kernel;
-    std::unique_ptr<ITensorInfo> _max;
-    std::unique_ptr<ITensorInfo> _tmp;
-    std::unique_ptr<ITensorInfo> _input_permuted;
-    std::unique_ptr<ITensorInfo> _output_permuted;
-    bool                         _needs_permute;
+    enum InternalTensorIdx
+    {
+        MAX = 0,
+        TMP,
+        PERMUTED_SRC,
+        PERMUTED_DST,
+        COUNT
+    };
+
+    CpuPermute                  _permute_input;
+    CpuPermute                  _permute_output;
+    std::unique_ptr<ICpuKernel> _max_kernel;
+    std::unique_ptr<ICpuKernel> _softmax_kernel;
+
+    TensorInfo _max;
+    TensorInfo _tmp;
+    TensorInfo _input_permuted;
+    TensorInfo _output_permuted;
+
+    bool                             _needs_permute;
+    experimental::MemoryRequirements _aux_mem{};
 };
 using CpuSoftmax    = CpuSoftmaxGeneric<false>;
 using CpuLogSoftmax = CpuSoftmaxGeneric<true>;
diff --git a/src/runtime/cpu/utils/CpuAuxTensorHandler.h b/src/runtime/cpu/utils/CpuAuxTensorHandler.h
new file mode 100644
index 0000000..644018a
--- /dev/null
+++ b/src/runtime/cpu/utils/CpuAuxTensorHandler.h
@@ -0,0 +1,101 @@
+/*
+ * Copyright (c) 2021 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef ARM_COMPUTE_CPU_UTILS_CPU_AUX_TENSOR_HANDLER_H
+#define ARM_COMPUTE_CPU_UTILS_CPU_AUX_TENSOR_HANDLER_H
+
+#include "arm_compute/core/ITensorPack.h"
+#include "arm_compute/core/TensorInfo.h"
+#include "arm_compute/runtime/Tensor.h"
+
+#include "support/Cast.h"
+
+namespace arm_compute
+{
+namespace cpu
+{
+/* Tensor handler to wrap and handle tensor allocations on workspace buffers */
+class CpuAuxTensorHandler
+{
+public:
+    CpuAuxTensorHandler(int slot_id, TensorInfo &info, ITensorPack &pack, bool pack_inject = false)
+        : _tensor()
+    {
+        _tensor.allocator()->soft_init(info);
+
+        ITensor *packed_tensor = utils::cast::polymorphic_downcast<ITensor *>(pack.get_tensor(slot_id));
+        if((packed_tensor == nullptr) || (info.total_size() > packed_tensor->info()->total_size()))
+        {
+            _tensor.allocator()->allocate();
+            if(pack_inject)
+            {
+                pack.add_tensor(slot_id, &_tensor);
+                _injected_tensor_pack = &pack;
+                _injected_slot_id     = slot_id;
+            }
+        }
+        else
+        {
+            _tensor.allocator()->import_memory(packed_tensor->buffer());
+        }
+    }
+
+    CpuAuxTensorHandler(TensorInfo &info, ITensor &tensor)
+        : _tensor()
+    {
+        _tensor.allocator()->soft_init(info);
+        if(info.total_size() <= tensor.info()->total_size())
+        {
+            _tensor.allocator()->import_memory(tensor.buffer());
+        }
+    }
+
+    CpuAuxTensorHandler(const CpuAuxTensorHandler &) = delete;
+    CpuAuxTensorHandler &operator=(const CpuAuxTensorHandler) = delete;
+
+    ~CpuAuxTensorHandler()
+    {
+        if(_injected_tensor_pack)
+        {
+            _injected_tensor_pack->remove_tensor(_injected_slot_id);
+        }
+    }
+
+    ITensor *get()
+    {
+        return &_tensor;
+    }
+
+    ITensor *operator()()
+    {
+        return &_tensor;
+    }
+
+private:
+    Tensor       _tensor{};
+    ITensorPack *_injected_tensor_pack{ nullptr };
+    int          _injected_slot_id{ TensorType::ACL_UNKNOWN };
+};
+} // namespace cpu
+} // namespace arm_compute
+#endif /* ARM_COMPUTE_CPU_UTILS_CPU_AUX_TENSOR_HANDLER_H */
\ No newline at end of file
diff --git a/src/runtime/gpu/cl/operators/ClSoftmax.cpp b/src/runtime/gpu/cl/operators/ClSoftmax.cpp
index c3ec7cc..975bb0b 100644
--- a/src/runtime/gpu/cl/operators/ClSoftmax.cpp
+++ b/src/runtime/gpu/cl/operators/ClSoftmax.cpp
@@ -24,84 +24,32 @@
 #include "src/runtime/gpu/cl/operators/ClSoftmax.h"
 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
 #include "src/core/gpu/cl/kernels/ClSoftmaxKernel.h"
+#include "src/core/helpers/MemoryHelpers.h"
 #include "src/core/helpers/SoftmaxHelpers.h"
 #include "src/runtime/gpu/cl/operators/ClPermute.h"
+#include "src/runtime/gpu/cl/utils/ClAuxTensorHandler.h"
 #include "support/Cast.h"
 
+using namespace arm_compute::experimental;
+
 namespace arm_compute
 {
 namespace opencl
 {
-namespace
-{
-void run_permute(ClPermute *op, const ITensor *src, ITensor *dst)
-{
-    ARM_COMPUTE_ERROR_ON_NULLPTR(src, dst, op);
-    ITensorPack pack;
-    pack.add_const_tensor(TensorType::ACL_SRC, src);
-    pack.add_tensor(TensorType::ACL_DST, dst);
-    op->run(pack);
-}
-} // namespace
-
 ClSoftmax::ClSoftmax()
     : _permute_input(std::make_unique<ClPermute>()),
       _permute_output(std::make_unique<ClPermute>()),
       _max_shift_exp_sum_kernel(std::make_unique<kernels::ClLogits1DMaxShiftExpSumKernel>()),
       _norm_kernel(std::make_unique<kernels::ClLogits1DNormKernel>()),
-      _max_info(_internal_info[static_cast<uint32_t>(InternalTensorIdx::MAX)]),
-      _sum_info(_internal_info[static_cast<uint32_t>(InternalTensorIdx::SUM)]),
-      _tmp_info(_internal_info[static_cast<uint32_t>(InternalTensorIdx::TMP)]),
-      _permuted_src_info(_internal_info[static_cast<uint32_t>(InternalTensorIdx::PERMUTED_SRC)]),
-      _permuted_dst_info(_internal_info[static_cast<uint32_t>(InternalTensorIdx::PERMUTED_DST)])
+      _max_info(),
+      _sum_info(),
+      _tmp_info(),
+      _permuted_src_info(),
+      _permuted_dst_info(),
+      _aux_mem(InternalTensorIdx::COUNT)
 {
 }
 
-TensorType ClSoftmax::convert_internal_idx_to_tensor_type(InternalTensorIdx idx) const
-{
-    switch(idx)
-    {
-        case InternalTensorIdx::MAX:
-            return TensorType::ACL_INT_0;
-        case InternalTensorIdx::SUM:
-            return TensorType::ACL_INT_1;
-        case InternalTensorIdx::TMP:
-            return TensorType::ACL_INT_2;
-        case InternalTensorIdx::PERMUTED_SRC:
-            return TensorType::ACL_INT_3;
-        case InternalTensorIdx::PERMUTED_DST:
-            return TensorType::ACL_INT_4;
-        default:
-            ARM_COMPUTE_ERROR("invalid internal tensor index is given.");
-            break;
-    };
-    return TensorType::ACL_UNKNOWN;
-}
-
-void ClSoftmax::create_internal_tensor(TensorInfo &info, InternalTensorIdx idx)
-{
-    const auto tensor_idx = static_cast<uint32_t>(idx);
-    if(!_internal_tensor[tensor_idx])
-    {
-        _internal_tensor[tensor_idx] = std::make_unique<CLTensor>();
-    }
-    _internal_tensor[tensor_idx]->allocator()->init(info);
-}
-
-void ClSoftmax::create_internal_tensor()
-{
-    for(uint32_t i = 0; i < static_cast<uint32_t>(InternalTensorIdx::COUNT); i++)
-    {
-        const auto tensor_idx = static_cast<InternalTensorIdx>(i);
-
-        if(!_needs_permute && (tensor_idx == InternalTensorIdx::PERMUTED_DST || tensor_idx == InternalTensorIdx::PERMUTED_SRC))
-        {
-            continue;
-        }
-        create_internal_tensor(_internal_info[i], static_cast<InternalTensorIdx>(i));
-    }
-}
-
 void ClSoftmax::configure(const CLCompileContext &compile_context, const ITensorInfo &src, ITensorInfo &dst, const SoftmaxKernelInfo &info)
 {
     ARM_COMPUTE_ERROR_THROW_ON(validate(src, dst, info));
@@ -137,6 +85,13 @@
         const auto perm_info = softmax_helpers::get_permutation_vector_from_softmax_axis(actual_axis);
         _permute_output->configure(compile_context, &_permuted_dst_info, &dst, perm_info);
     }
+
+    _aux_mem[InternalTensorIdx::SUM] = MemoryInfo(offset_int_vec(InternalTensorIdx::SUM), MemoryLifetime::Temporary, _sum_info.total_size());
+    _aux_mem[InternalTensorIdx::TMP] = MemoryInfo(offset_int_vec(InternalTensorIdx::TMP), MemoryLifetime::Temporary, _tmp_info.total_size());
+    _aux_mem[InternalTensorIdx::MAX] = MemoryInfo(offset_int_vec(InternalTensorIdx::MAX), MemoryLifetime::Temporary, _max_info.total_size());
+
+    _aux_mem[InternalTensorIdx::PERMUTED_SRC] = MemoryInfo(offset_int_vec(InternalTensorIdx::PERMUTED_SRC), MemoryLifetime::Temporary, _permuted_src_info.total_size());
+    _aux_mem[InternalTensorIdx::PERMUTED_DST] = MemoryInfo(offset_int_vec(InternalTensorIdx::PERMUTED_DST), MemoryLifetime::Temporary, _permuted_dst_info.total_size());
 }
 
 Status ClSoftmax::validate(const ITensorInfo &src, const ITensorInfo &dst, const SoftmaxKernelInfo &info)
@@ -172,105 +127,60 @@
     return Status{};
 }
 
-void ClSoftmax::import_workspace_memory(ITensorPack &tensors)
-{
-    auto import_workspace_memory = [this, &tensors](InternalTensorIdx idx)
-    {
-        const auto workspace_idx   = convert_internal_idx_to_tensor_type(idx);
-        auto       imported_tensor = tensors.get_tensor(workspace_idx);
-        if(imported_tensor)
-        {
-            auto imported_memory = utils::cast::polymorphic_downcast<ICLTensor *>(imported_tensor)->cl_buffer();
-            _internal_tensor[static_cast<uint32_t>(idx)].get()->allocator()->import_memory(imported_memory);
-        }
-    };
-
-    import_workspace_memory(InternalTensorIdx::PERMUTED_SRC);
-    import_workspace_memory(InternalTensorIdx::PERMUTED_DST);
-    import_workspace_memory(InternalTensorIdx::MAX);
-    import_workspace_memory(InternalTensorIdx::SUM);
-    import_workspace_memory(InternalTensorIdx::TMP);
-}
-
-void ClSoftmax::run_source_permute(const ITensor *src)
-{
-    if(_needs_permute)
-    {
-        auto permuted_src = _internal_tensor[static_cast<uint32_t>(InternalTensorIdx::PERMUTED_SRC)].get();
-        run_permute(_permute_input.get(), src, permuted_src);
-    }
-}
-
-void ClSoftmax::run_destination_permute(ITensor *dst)
-{
-    if(_needs_permute)
-    {
-        auto permuted_dst = _internal_tensor[static_cast<uint32_t>(InternalTensorIdx::PERMUTED_DST)].get();
-        run_permute(_permute_output.get(), permuted_dst, dst);
-    }
-}
-
-void ClSoftmax::run_max_sum(const ITensor *src)
-{
-    auto max = _internal_tensor[static_cast<uint32_t>(InternalTensorIdx::MAX)].get();
-    auto sum = _internal_tensor[static_cast<uint32_t>(InternalTensorIdx::SUM)].get();
-    auto tmp = _internal_tensor[static_cast<uint32_t>(InternalTensorIdx::TMP)].get();
-
-    ARM_COMPUTE_ERROR_ON_NULLPTR(src, tmp, max, sum);
-
-    ITensorPack sum_pack;
-    sum_pack.add_const_tensor(TensorType::ACL_SRC, src);
-    sum_pack.add_tensor(TensorType::ACL_DST, tmp);
-    sum_pack.add_tensor(TensorType::ACL_INT_0, max);
-    sum_pack.add_tensor(TensorType::ACL_INT_1, sum);
-
-    CLScheduler::get().enqueue_op(*_max_shift_exp_sum_kernel.get(), sum_pack, false);
-}
-
-void ClSoftmax::run_norm(ITensor *dst)
-{
-    auto sum = _internal_tensor[static_cast<uint32_t>(InternalTensorIdx::SUM)].get();
-    auto tmp = _internal_tensor[static_cast<uint32_t>(InternalTensorIdx::TMP)].get();
-
-    ARM_COMPUTE_ERROR_ON_NULLPTR(tmp, sum, dst);
-
-    ITensorPack norm_pack;
-    norm_pack.add_const_tensor(TensorType::ACL_SRC, tmp);
-    norm_pack.add_tensor(TensorType::ACL_DST, dst);
-    norm_pack.add_tensor(TensorType::ACL_INT_0, sum);
-
-    CLScheduler::get().enqueue_op(*_norm_kernel.get(), norm_pack, false);
-}
-
 void ClSoftmax::run(ITensorPack &tensors)
 {
-    create_internal_tensor();
-
     auto src = tensors.get_const_tensor(TensorType::ACL_SRC);
     auto dst = tensors.get_tensor(TensorType::ACL_DST);
 
-    import_workspace_memory(tensors);
-    run_source_permute(src);
-    run_max_sum(!_needs_permute ? src : _internal_tensor[static_cast<uint32_t>(InternalTensorIdx::PERMUTED_SRC)].get());
-    run_norm(!_needs_permute ? dst : _internal_tensor[static_cast<uint32_t>(InternalTensorIdx::PERMUTED_DST)].get());
-    run_destination_permute(dst);
+    CLAuxTensorHandler sum(offset_int_vec(InternalTensorIdx::SUM), _sum_info, tensors, false);
+    CLAuxTensorHandler tmp(offset_int_vec(InternalTensorIdx::TMP), _tmp_info, tensors, false);
+    CLAuxTensorHandler max(offset_int_vec(InternalTensorIdx::MAX), _max_info, tensors, false);
+
+    CLAuxTensorHandler permuted_src(offset_int_vec(InternalTensorIdx::PERMUTED_SRC), _permuted_src_info, tensors, false);
+    CLAuxTensorHandler permuted_dst(offset_int_vec(InternalTensorIdx::PERMUTED_DST), _permuted_dst_info, tensors, false);
+
+    if(_needs_permute)
+    {
+        ITensorPack pack;
+        pack.add_const_tensor(TensorType::ACL_SRC, src);
+        pack.add_tensor(TensorType::ACL_DST, permuted_src.get());
+        _permute_input.get()->run(pack);
+    }
+
+    ITensorPack sum_pack;
+    ITensorPack norm_pack;
+    if(_needs_permute)
+    {
+        sum_pack.add_const_tensor(TensorType::ACL_SRC, permuted_src.get());
+        norm_pack.add_tensor(TensorType::ACL_DST, permuted_dst.get());
+    }
+    else
+    {
+        sum_pack.add_const_tensor(TensorType::ACL_SRC, src);
+        norm_pack.add_tensor(TensorType::ACL_DST, dst);
+    }
+    sum_pack.add_tensor(TensorType::ACL_DST, tmp.get());
+    sum_pack.add_tensor(TensorType::ACL_INT_0, max.get());
+    sum_pack.add_tensor(TensorType::ACL_INT_1, sum.get());
+
+    norm_pack.add_const_tensor(TensorType::ACL_SRC, tmp.get());
+    norm_pack.add_tensor(TensorType::ACL_INT_0, sum.get());
+
+    CLScheduler::get().enqueue_op(*_max_shift_exp_sum_kernel.get(), sum_pack, false);
+    CLScheduler::get().enqueue_op(*_norm_kernel.get(), norm_pack, false);
+
+    if(_needs_permute)
+    {
+        ITensorPack pack;
+        pack.add_const_tensor(TensorType::ACL_SRC, permuted_dst.get());
+        pack.add_tensor(TensorType::ACL_DST, dst);
+        _permute_output.get()->run(pack);
+    }
 }
 
 experimental::MemoryRequirements ClSoftmax::workspace() const
 {
-    experimental::MemoryRequirements req{};
-
-    req.emplace_back(convert_internal_idx_to_tensor_type(InternalTensorIdx::SUM), _sum_info.total_size(), 0);
-    req.emplace_back(convert_internal_idx_to_tensor_type(InternalTensorIdx::TMP), _tmp_info.total_size(), 0);
-    req.emplace_back(convert_internal_idx_to_tensor_type(InternalTensorIdx::MAX), _max_info.total_size(), 0);
-
-    if(_needs_permute)
-    {
-        req.emplace_back(convert_internal_idx_to_tensor_type(InternalTensorIdx::PERMUTED_SRC), _permuted_src_info.total_size(), 0);
-        req.emplace_back(convert_internal_idx_to_tensor_type(InternalTensorIdx::PERMUTED_DST), _permuted_dst_info.total_size(), 0);
-    }
-
-    return req;
+    return _aux_mem;
 }
 } // namespace opencl
 } // namespace arm_compute
\ No newline at end of file
diff --git a/src/runtime/gpu/cl/operators/ClSoftmax.h b/src/runtime/gpu/cl/operators/ClSoftmax.h
index e38b7c5..f19a51f 100644
--- a/src/runtime/gpu/cl/operators/ClSoftmax.h
+++ b/src/runtime/gpu/cl/operators/ClSoftmax.h
@@ -67,7 +67,7 @@
     experimental::MemoryRequirements workspace() const override;
 
 private:
-    enum class InternalTensorIdx
+    enum InternalTensorIdx
     {
         MAX = 0,
         SUM,
@@ -77,41 +77,19 @@
         COUNT
     };
 
-    /** Create a single internal tensor
-     *
-     * @param[in] info The information used to create a tensor
-     * @param[in] idx  The index within the internal array the created tensor will be held
-     */
-    void create_internal_tensor(TensorInfo &info, InternalTensorIdx idx);
-    /** Create all required internal tensors */
-    void create_internal_tensor();
-    /** Function to convert from internal tensor index to @ref TensorType used externally */
-    TensorType convert_internal_idx_to_tensor_type(InternalTensorIdx idx) const;
-    /** Function to import workspace memory allocated by the caller into internal tensor instances */
-    void import_workspace_memory(ITensorPack &tensors);
-    /** Function to permute the given source tensor when permutation is required */
-    void run_source_permute(const ITensor *src);
-    /** Function to permute the intemediate tensor to the final destination tensor when permutation is required */
-    void run_destination_permute(ITensor *dst);
-    /** Function to run @ref arm_compute::opencl::kernels::ClLogits1DMaxShiftExpSumKernel */
-    void run_max_sum(const ITensor *src);
-    /** Function to run @ref kernels::ClLogits1DNormKernel */
-    void run_norm(ITensor *dst);
-
     std::unique_ptr<ClPermute>                               _permute_input;
     std::unique_ptr<ClPermute>                               _permute_output;
     std::unique_ptr<kernels::ClLogits1DMaxShiftExpSumKernel> _max_shift_exp_sum_kernel;
     std::unique_ptr<kernels::ClLogits1DNormKernel>           _norm_kernel;
     bool                                                     _needs_permute{ false };
 
-    std::array<TensorInfo, static_cast<uint32_t>(InternalTensorIdx::COUNT)>                _internal_info{};
-    std::array<std::unique_ptr<CLTensor>, static_cast<uint32_t>(InternalTensorIdx::COUNT)> _internal_tensor{};
+    TensorInfo _max_info;
+    TensorInfo _sum_info;
+    TensorInfo _tmp_info;
+    TensorInfo _permuted_src_info;
+    TensorInfo _permuted_dst_info;
 
-    TensorInfo &_max_info;
-    TensorInfo &_sum_info;
-    TensorInfo &_tmp_info;
-    TensorInfo &_permuted_src_info;
-    TensorInfo &_permuted_dst_info;
+    experimental::MemoryRequirements _aux_mem{};
 };
 
 } // opencl