COMPMID-959: Perform pretranspose if allowed on NEON assembly

Change-Id: I281699ce7270aec1317c47b5a13799954cf6c9e8
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/130010
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Pablo Tello <pablo.tello@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
diff --git a/arm_compute/runtime/NEON/AssemblyHelper.h b/arm_compute/runtime/NEON/AssemblyHelper.h
index 2b4f35f..ee09ef5 100644
--- a/arm_compute/runtime/NEON/AssemblyHelper.h
+++ b/arm_compute/runtime/NEON/AssemblyHelper.h
@@ -51,7 +51,7 @@
     using TypeResult = TypeOutput;
     /** Default constructor. */
     AssemblyKernelGlue()
-        : _gemm_kernel_asm(nullptr), _optimised_kernel(nullptr), _a(nullptr), _b(nullptr), _d(nullptr)
+        : _gemm_kernel_asm(nullptr), _optimised_kernel(nullptr), _a(nullptr), _b(nullptr), _d(nullptr), _pretranspose(nullptr)
     {
     }
     /** Assembly Gemm */
@@ -72,6 +72,8 @@
     const ITensor *_b;
     /** Output */
     ITensor *_d;
+    /** Pre-transpose tensor */
+    ITensor *_pretranspose;
 
     /** Configures the arrays pointers and strides in the assembly kernel and executes the assembly kernel.
      *  The call to set_arrays is needed to deal with the input sizes containing batches (dims > 2)
@@ -94,6 +96,12 @@
         auto       out_ptr = reinterpret_cast<TypeOutput *>(_d->buffer());
 
         _gemm_kernel_asm->set_arrays(in0_ptr, lda, batch_stride_a, multi_stride_a, in1_ptr, ldb, multi_stride_b, out_ptr, ldd, batch_stride_d, multi_stride_d);
+        if(_gemm_kernel_asm->B_pretranspose_required())
+        {
+            ARM_COMPUTE_ERROR_ON(_pretranspose == nullptr || _pretranspose->buffer() == nullptr);
+            _gemm_kernel_asm->pretranspose_B_array(reinterpret_cast<void *>(_pretranspose->buffer()), in1_ptr, ldb, multi_stride_b);
+        }
+
         NEScheduler::get().schedule(_optimised_kernel.get(), Window::DimX);
     }
 };
@@ -113,8 +121,9 @@
  * @param[in]  alignment      Workspace memory alignment.
  * @param[in]  num_threads    Number of workspace threads.
  */
-inline void allocate_workspace(size_t workspace_size, Tensor &workspace, MemoryGroup &memory_group, size_t alignment, unsigned int num_threads)
+inline void allocate_workspace(size_t workspace_size, Tensor &workspace, MemoryGroup *memory_group, size_t alignment, unsigned int num_threads)
 {
+    ARM_COMPUTE_UNUSED(memory_group);
     ARM_COMPUTE_ERROR_ON_MSG(workspace_size == 0, "size cannot be 0");
     workspace.allocator()->init(TensorInfo(TensorShape{ (workspace_size + alignment - 1) * num_threads }, 1, DataType::S8));
     workspace.allocator()->allocate();
@@ -122,20 +131,22 @@
 
 /** Create a wrapper kernel.
  *
- * @param[in]  a            Input tensor A.
- * @param[in]  b            Input tensor B.
- * @param[out] d            Output tensor.
- * @param[in]  alpha        Alpha value.
- * @param[in]  beta         Beta value.
- * @param[out] workspace    Workspace tensor
- * @param[in]  memory_group Tensor memory group.
- * @param[out] asm_glue     Assembly glue kernel.
+ * @param[in]  a                 Input tensor A.
+ * @param[in]  b                 Input tensor B.
+ * @param[out] d                 Output tensor.
+ * @param[in]  alpha             Alpha value.
+ * @param[in]  beta              Beta value.
+ * @param[in]  pretranspose_hint Pre-transpose hint in case matrix b should be pre-transposed
+ * @param[out] workspace         Workspace tensor
+ * @param[out] B_pretranspose    Tensor to hold the pre-transposed B
+ * @param[in]  memory_group      Tensor memory group.
+ * @param[out] asm_glue          Assembly glue kernel.
  *
  * @return the wrapper kernel.
  */
 template <typename T>
-inline bool setup_assembly_kernel(const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta,
-                                  Tensor &workspace, MemoryGroup &memory_group, T &asm_glue)
+inline bool setup_assembly_kernel(const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint,
+                                  Tensor &workspace, Tensor &B_pretranspose, MemoryGroup &memory_group, T &asm_glue)
 {
     const CPUInfo &ci          = NEScheduler::get().cpu_info();
     const int      M           = d->info()->tensor_shape().y();
@@ -147,7 +158,7 @@
 
     // unique_ptr to a Gemm object
     std::unique_ptr<typename T::AssemblyGemm>
-    asm_gemm(arm_gemm::gemm<typename T::TypeOperator, typename T::TypeResult>(ci, M, N, K, batches, multis, false, false, alpha, beta, num_threads, false));
+    asm_gemm(arm_gemm::gemm<typename T::TypeOperator, typename T::TypeResult>(ci, M, N, K, batches, multis, false, false, alpha, beta, num_threads, pretranspose_hint));
     // arm_compute wrapper for the Gemm object (see above)
     std::unique_ptr<NEGEMMAssemblyWrapper<typename T::AssemblyGemm>>
                                                                   acl_gemm_wrapper = support::cpp14::make_unique<NEGEMMAssemblyWrapper<typename T::AssemblyGemm>>();
@@ -159,7 +170,7 @@
         {
             // Allocate workspace
             const unsigned int alignment = 4096;
-            allocate_workspace(workspace_size, workspace, memory_group, alignment, num_threads);
+            allocate_workspace(workspace_size, workspace, &memory_group, alignment, num_threads);
             ARM_COMPUTE_ERROR_ON_NULLPTR(workspace.buffer());
             asm_gemm->set_working_space(reinterpret_cast<typename T::TypeResult *>(workspace.buffer()));
         }
@@ -175,6 +186,15 @@
             }
         }
 
+        // Check for pre-transposed support
+        if(asm_gemm->B_pretranspose_required())
+        {
+            const size_t B_pretranspose_size = asm_gemm->get_B_pretransposed_array_size();
+            allocate_workspace(B_pretranspose_size, B_pretranspose, nullptr, 1, 1);
+            ARM_COMPUTE_ERROR_ON_NULLPTR(B_pretranspose.buffer());
+            asm_glue._pretranspose = &B_pretranspose;
+        }
+
         asm_glue._gemm_kernel_asm  = std::move(asm_gemm);
         asm_glue._optimised_kernel = std::move(acl_gemm_wrapper);
         // We need to setup the ptrs in the run() method