COMPMID-1209: Enable memory manager for the GEMM workspace buffer

Change-Id: I125660d412945aa152cb76c78280ca0d52264b86
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/133372
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
diff --git a/arm_compute/runtime/NEON/AssemblyHelper.h b/arm_compute/runtime/NEON/AssemblyHelper.h
index ecaf35a..3aa43ec 100644
--- a/arm_compute/runtime/NEON/AssemblyHelper.h
+++ b/arm_compute/runtime/NEON/AssemblyHelper.h
@@ -51,7 +51,7 @@
     using TypeResult = TypeOutput;
     /** Default constructor. */
     AssemblyKernelGlue()
-        : _gemm_kernel_asm(nullptr), _optimised_kernel(nullptr), _a(nullptr), _b(nullptr), _d(nullptr), _pretranspose(nullptr)
+        : _gemm_kernel_asm(nullptr), _optimised_kernel(nullptr), _a(nullptr), _b(nullptr), _d(nullptr), _workspace(nullptr), _pretranspose(nullptr)
     {
     }
     /** Assembly Gemm */
@@ -72,6 +72,8 @@
     const ITensor *_b;
     /** Output */
     ITensor *_d;
+    /** GEMM workspace */
+    ITensor *_workspace;
     /** Pre-transpose tensor */
     ITensor *_pretranspose;
 
@@ -100,7 +102,16 @@
         const auto in1_ptr = reinterpret_cast<const TypeInput *>(_b->buffer());
         auto       out_ptr = reinterpret_cast<TypeOutput *>(_d->buffer());
 
+        // Set workspace if needed
+        if(_workspace != nullptr)
+        {
+            _gemm_kernel_asm->set_working_space(reinterpret_cast<void *>(_workspace->buffer()));
+        }
+
+        // Set gemm parameters
         _gemm_kernel_asm->set_arrays(in0_ptr, lda, batch_stride_a, multi_stride_a, in1_ptr, ldb, multi_stride_b, out_ptr, ldd, batch_stride_d, multi_stride_d);
+
+        // Pretranspose B if required
         if(_gemm_kernel_asm->B_pretranspose_required())
         {
             // Forcing 128-byte alignment (required by 32-bit kernels)
@@ -113,6 +124,7 @@
             _b->mark_as_unused();
         }
 
+        // Schedule assembly kernel
         NEScheduler::get().schedule(_optimised_kernel.get(), Window::DimX);
     }
 };
@@ -134,9 +146,12 @@
  */
 inline void allocate_workspace(size_t workspace_size, Tensor &workspace, MemoryGroup *memory_group, size_t alignment, unsigned int num_threads)
 {
-    ARM_COMPUTE_UNUSED(memory_group);
     ARM_COMPUTE_ERROR_ON_MSG(workspace_size == 0, "size cannot be 0");
     workspace.allocator()->init(TensorInfo(TensorShape{ (workspace_size + alignment - 1) * num_threads }, 1, DataType::S8));
+    if(memory_group != nullptr)
+    {
+        memory_group->manage(&workspace);
+    }
     workspace.allocator()->allocate();
 }
 
@@ -182,8 +197,7 @@
             // Allocate workspace
             const unsigned int alignment = 4096;
             allocate_workspace(workspace_size, workspace, &memory_group, alignment, num_threads);
-            ARM_COMPUTE_ERROR_ON_NULLPTR(workspace.buffer());
-            asm_gemm->set_working_space(reinterpret_cast<typename T::TypeResult *>(workspace.buffer()));
+            asm_glue._workspace = &workspace;
         }
 
         //if we disable this code below in brackets then ConvLayer deadlocks when threads > 1 and