Port NEGEMMConv2d to memory injecting interface

Resolves: COMPMID-4506, COMPMID-4570

Change-Id: I6d37a06da141f1fcfcaa8525322a319cb0234791
Signed-off-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5824
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp b/src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp
index e50099d..c2e9f24 100644
--- a/src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp
+++ b/src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp
@@ -26,10 +26,10 @@
 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
 #include "arm_compute/core/utils/quantization/AsymmHelpers.h"
 #include "arm_compute/runtime/FunctionDescriptors.h"
-#include "arm_compute/runtime/NEON/NEScheduler.h"
-#include "src/runtime/cpu/operators/CpuActivation.h"
-#include "src/runtime/cpu/operators/CpuPermute.h"
-#include "src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h"
+#include "src/core/helpers/MemoryHelpers.h"
+#include "src/runtime/cpu/utils/CpuAuxTensorHandler.h"
+
+#include "support/Cast.h"
 
 #include <set>
 
@@ -37,6 +37,9 @@
 {
 namespace cpu
 {
+using namespace arm_compute::experimental;
+using namespace arm_compute::utils::cast;
+
 namespace
 {
 GEMMLowpOutputStageInfo calculate_output_stage_metadata(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *dst, const ActivationLayerInfo &act)
@@ -87,12 +90,14 @@
 }
 } // namespace
 
-CpuGemmDirectConv2d::CpuGemmDirectConv2d(const std::shared_ptr<IMemoryManager> &memory_manager)
-    : _gemm_asm_func(std::make_unique<CpuGemmAssemblyDispatch>(memory_manager)),
+CpuGemmDirectConv2d::CpuGemmDirectConv2d()
+    : _gemm_asm_func(std::make_unique<CpuGemmAssemblyDispatch>()),
       _activation_func(std::make_unique<CpuActivation>()),
       _weights_permute_func(std::make_unique<CpuPermute>()),
-      _permuted_weights_info(),
-      _permuted_weights(std::make_unique<Tensor>())
+      _aux_mem(AuxTensorIdx::Count),
+      _perm_weights(),
+      _run_activation(false),
+      _is_prepared(false)
 {
 }
 
@@ -106,8 +111,10 @@
                                                              biases != nullptr ? biases : nullptr,
                                                              dst,
                                                              info));
-    _original_weights_info = weights;
-    _weights_permute_func->configure(weights, &_permuted_weights_info, PermutationVector{ 3, 0, 1, 2 });
+    _run_activation = info.act_info.enabled() && !_gemm_asm_func->is_activation_supported(info.act_info);
+    _is_prepared    = false;
+
+    _weights_permute_func->configure(weights, &_perm_weights, PermutationVector{ 3, 0, 1, 2 });
 
     // Configure assembly dispatch
     cpu::AsmGemmInfo asm_info = init_assembly_metadata(info, false);
@@ -115,13 +122,27 @@
     {
         asm_info.output_stage = calculate_output_stage_metadata(src, weights, dst, info.act_info);
     }
-    _gemm_asm_func->configure(src, &_permuted_weights_info, biases, dst, asm_info);
+    _gemm_asm_func->configure(src, &_perm_weights, biases, dst, asm_info);
 
     // Configure activation
-    if(info.act_info.enabled() && !_gemm_asm_func->is_activation_supported(info.act_info))
+    if(_run_activation)
     {
         _activation_func->configure(dst, nullptr, info.act_info);
-        _run_activation = true;
+    }
+
+    // Add auxiliary memory requirements of the assembly dispatch
+    auto asm_mem_req           = _gemm_asm_func->workspace();
+    _aux_mem[AsmGemmWorkspace] = asm_mem_req[AsmGemmWorkspace];
+    _aux_mem[Pretranspose]     = asm_mem_req[Pretranspose];
+
+    if(_aux_mem[Pretranspose].size > 0)
+    {
+        // Release permuted weights at the of prepare as they are further transposed by the assembly dispatch
+        _aux_mem[PermutedWeights] = MemoryInfo(offset_int_vec(PermutedWeights), MemoryLifetime::Prepare, weights->total_size());
+    }
+    else
+    {
+        _aux_mem[PermutedWeights] = MemoryInfo(offset_int_vec(PermutedWeights), MemoryLifetime::Persistent, weights->total_size());
     }
 }
 Status CpuGemmDirectConv2d::validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const Conv2dInfo &info)
@@ -172,35 +193,29 @@
     }
 }
 
-void CpuGemmDirectConv2d::allocate_permuted_weights()
-{
-    // TODO: This function will be removed when memory injection is implemeted.
-    ARM_COMPUTE_ERROR_ON(_permuted_weights == nullptr);
-    _permuted_weights->allocator()->free();
-    _permuted_weights->allocator()->init(_permuted_weights_info);
-    _permuted_weights->allocator()->allocate();
-}
-
 void CpuGemmDirectConv2d::prepare(ITensorPack &tensors)
 {
     if(!_is_prepared)
     {
-        allocate_permuted_weights();
-        ITensorPack permute_tensors
-        {
-            { TensorType::ACL_SRC, tensors.get_const_tensor(TensorType::ACL_SRC_1) },
-            { TensorType::ACL_DST, _permuted_weights.get() },
-        };
+        const ITensor *weights     = tensors.get_const_tensor(ACL_SRC_1);
+        ITensor       *weights_aux = utils::cast::polymorphic_cast<ITensor *>(tensors.get_tensor(offset_int_vec(PermutedWeights)));
+        ARM_COMPUTE_ERROR_ON_NULLPTR(weights, weights_aux);
 
+        CpuAuxTensorHandler permuted_weights(_perm_weights, *weights_aux);
+        ITensorPack         permute_tensors{ { ACL_SRC, weights }, { ACL_DST, permuted_weights.get() } };
         _weights_permute_func->run(permute_tensors);
 
-        tensors.get_const_tensor(TensorType::ACL_SRC_1)->mark_as_unused();
+        tensors.add_const_tensor(ACL_SRC_1, permuted_weights.get());
+        // Call prepare of assembly dispatch
+        _gemm_asm_func->prepare(tensors);
 
-        // switch the original tensor with permuted tensor
-        tensors.add_const_tensor(TensorType::ACL_SRC_1, _permuted_weights.get());
         _is_prepared = true;
     }
 }
 
+experimental::MemoryRequirements CpuGemmDirectConv2d::workspace() const
+{
+    return _aux_mem;
+}
 } // namespace cpu
 } // namespace arm_compute
\ No newline at end of file