Re-use auxiliary memory withing CpuWinogradConv2d operators

Input/Output transformation operations are independent and done in
different time-steps of the algorithm, this memory can be re-used
between this transformation stages.

Moreover, reduce the allocation when extracting workspace sizes for
Winograd trasformations. There is a mix return of sizes in bytes and
elements, thus ensure the correct is in place. storage_size() member
functions return elements while working_space() function bytes.

Resolves: COMPMID-4781

Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com>
Change-Id: I705445ba7ca818cead48369db3cacd49684c7192
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/6145
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
diff --git a/src/runtime/cpu/operators/CpuWinogradConv2d.cpp b/src/runtime/cpu/operators/CpuWinogradConv2d.cpp
index ca7b004..253280a 100644
--- a/src/runtime/cpu/operators/CpuWinogradConv2d.cpp
+++ b/src/runtime/cpu/operators/CpuWinogradConv2d.cpp
@@ -494,14 +494,10 @@
     constexpr size_t storage_alignment = 64;
 
     // Kernel Storage
-    const size_t kernel_storage_size = transform_weights_kernel->get_weight_storage_size(out_channels,
-                                                                                         in_channels)
-                                       * data_type_size;
+    const size_t kernel_storage_size = transform_weights_kernel->get_weight_storage_size(out_channels, in_channels) * data_type_size;
 
     // Input storage
-    const size_t input_storage_size = transform_input_kernel->get_input_storage_size(in_shape.n_batches, in_shape.n_channels, in_shape.n_rows, in_shape.n_cols,
-                                                                                     use_same_padding)
-                                      * data_type_size;
+    const size_t input_storage_size = transform_input_kernel->get_input_storage_size(in_shape.n_batches, in_shape.n_channels, in_shape.n_rows, in_shape.n_cols, use_same_padding) * data_type_size;
 
     // Output storage
     const size_t output_storage_size  = transform_output_kernel->get_output_storage_size(in_shape.n_batches, in_shape.n_rows, in_shape.n_cols, out_channels) * data_type_size;
@@ -558,7 +554,6 @@
     if(_data_layout == DataLayout::NCHW)
     {
         _permute_input->configure(src, &_input_nhwc, PermutationVector(2U, 0U, 1U));
-        _aux_mem[PermutedInput]    = MemoryInfo(offset_int_vec(PermutedInput), MemoryLifetime::Temporary, src->total_size());
         input_to_use               = &_input_nhwc;
         weights_permutation_vector = PermutationVector(3U, 2U, 0U, 1U);
     }
@@ -609,7 +604,6 @@
     if(_data_layout == DataLayout::NCHW)
     {
         _permute_output->configure(&_output_nhwc, dst, PermutationVector(1U, 2U, 0U));
-        _aux_mem[PermutedOutput] = MemoryInfo(offset_int_vec(PermutedOutput), MemoryLifetime::Temporary, dst->total_size());
     }
 
     _transform_input_kernel   = std::move(transform_input_kernel);
@@ -630,12 +624,17 @@
     _aux_mem[TransposedRHS]  = asm_mem_req[TransposedRHS];
     _aux_mem[TempResult]     = asm_mem_req[TempResult];
 
-    _aux_mem[InputTransformed]   = MemoryInfo(offset_int_vec(InputTransformed), MemoryLifetime::Temporary, input_storage_size, storage_alignment);
-    _aux_mem[InputWorkspace]     = MemoryInfo(offset_int_vec(InputWorkspace), MemoryLifetime::Temporary, input_workspace_size);
+    // Request temporary memory. Overlap memory needed for Input/Output transformations as they run on different non-overlapping time-steps.
+    _aux_mem[TransformedInput]   = MemoryInfo(offset_int_vec(TransformedInput), MemoryLifetime::Temporary, input_storage_size, storage_alignment);
+    _aux_mem[TransformedOutput]  = MemoryInfo(offset_int_vec(TransformedOutput), MemoryLifetime::Temporary, output_storage_size, storage_alignment);
+    _aux_mem[WorkspaceIO]        = MemoryInfo(offset_int_vec(WorkspaceIO), MemoryLifetime::Temporary, std::max(input_workspace_size, output_workspace_size));
     _aux_mem[PermutedWeights]    = MemoryInfo(offset_int_vec(PermutedWeights), MemoryLifetime::Prepare, _weights_hwio.total_size());
-    _aux_mem[WeightsTransformed] = MemoryInfo(offset_int_vec(WeightsTransformed), MemoryLifetime::Persistent, kernel_storage_size, storage_alignment);
-    _aux_mem[OutputTransformed]  = MemoryInfo(offset_int_vec(OutputTransformed), MemoryLifetime::Temporary, output_storage_size, storage_alignment);
-    _aux_mem[OutputWorkspace]    = MemoryInfo(offset_int_vec(OutputWorkspace), MemoryLifetime::Temporary, output_workspace_size);
+    _aux_mem[TransformedWeights] = MemoryInfo(offset_int_vec(TransformedWeights), MemoryLifetime::Persistent, kernel_storage_size, storage_alignment);
+    if(_data_layout == DataLayout::NCHW)
+    {
+        _aux_mem[PermutedInput].merge(offset_int_vec(PermutedInput), src->total_size());
+        _aux_mem[PermutedOutput].merge(offset_int_vec(PermutedOutput), dst->total_size());
+    }
 }
 
 Status CpuWinogradConv2d::validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst,
@@ -757,9 +756,8 @@
     auto d = tensors.get_tensor(ACL_DST);
 
     CpuAuxTensorHandler input_nhwc(offset_int_vec(PermutedInput), _input_nhwc, tensors, true);
-    CpuAuxTensorHandler output_nhwc(offset_int_vec(PermutedOutput), _output_nhwc, tensors, true);
-    CpuAuxTensorHandler input_transformed(offset_int_vec(InputTransformed), _input_transformed, tensors, true);
-    CpuAuxTensorHandler input_workspace(offset_int_vec(InputWorkspace), _input_workspace, tensors, true);
+    CpuAuxTensorHandler input_transformed(offset_int_vec(TransformedInput), _input_transformed, tensors, true);
+    CpuAuxTensorHandler input_workspace(offset_int_vec(WorkspaceIO), _input_workspace, tensors, true);
 
     const bool is_nchw = _data_layout == DataLayout::NCHW;
     if(is_nchw)
@@ -773,15 +771,20 @@
     ITensorPack transform_input_pack{ { ACL_SRC, is_nchw ? input_nhwc.get() : a }, { ACL_DST, input_transformed.get() }, { ACL_INT, input_workspace.get() } };
     NEScheduler::get().schedule_op(_transform_input_kernel.get(), Window::DimX, _transform_input_kernel->window(), transform_input_pack);
 
-    CpuAuxTensorHandler output_transformed(offset_int_vec(OutputTransformed), _output_transformed, tensors, true);
-    CpuAuxTensorHandler weights_transformed(offset_int_vec(WeightsTransformed), _kernel_storage, tensors, true);
+    CpuAuxTensorHandler output_transformed(offset_int_vec(TransformedOutput), _output_transformed, tensors, true);
+    CpuAuxTensorHandler weights_transformed(offset_int_vec(TransformedWeights), _kernel_storage, tensors, true);
 
     // Run 16 GEMMs in multiple threads, each kernel runs one or more GEMMs
-    ITensorPack gemm_pack{ { ACL_SRC, input_transformed.get() }, { ACL_SRC_1, weights_transformed.get() }, { ACL_DST, output_transformed.get() } };
+    ITensorPack gemm_pack = tensors;
+    gemm_pack.add_const_tensor(ACL_SRC, input_transformed.get());
+    gemm_pack.add_const_tensor(ACL_SRC_1, weights_transformed.get());
+    gemm_pack.add_const_tensor(ACL_BIAS, nullptr);
+    gemm_pack.add_tensor(ACL_DST, output_transformed.get());
     _gemm_function->run(gemm_pack);
 
     // Transform output tensor to the spatial domain
-    CpuAuxTensorHandler output_workspace(offset_int_vec(OutputWorkspace), _output_workspace, tensors, true);
+    CpuAuxTensorHandler output_workspace(offset_int_vec(WorkspaceIO), _output_workspace, tensors, true);
+    CpuAuxTensorHandler output_nhwc(offset_int_vec(PermutedOutput), _output_nhwc, tensors, true);
     ITensorPack         transform_output_pack{ { ACL_SRC_0, c }, { ACL_SRC_1, output_transformed.get() }, { ACL_DST, is_nchw ? output_nhwc.get() : d }, { ACL_INT, output_workspace.get() } };
     NEScheduler::get().schedule_op(_transform_output_kernel.get(), Window::DimX, _transform_output_kernel->window(), transform_output_pack);
 
@@ -813,7 +816,7 @@
         _permute_weights->run(permute_tensors);
 
         // Transform weights
-        ITensor *weights_transf = utils::cast::polymorphic_cast<ITensor *>(tensors.get_tensor(offset_int_vec(WeightsTransformed)));
+        ITensor *weights_transf = utils::cast::polymorphic_cast<ITensor *>(tensors.get_tensor(offset_int_vec(TransformedWeights)));
         ARM_COMPUTE_ERROR_ON_NULLPTR(weights_transf);
 
         CpuAuxTensorHandler transformed_weights(_kernel_storage, *weights_transf);