Integrate multi-threaded pretranspose_B_array

This is required for the case where rhs (B) is dynamic and needs to be
pretransposed in every run.

In a multi-threaded setting, this means the previously single-threaded
pretranspose_B_array would become the bottleneck

Resolves COMPMID-5896

Signed-off-by: SiCong Li <sicong.li@arm.com>
Change-Id: Id508c46992188a0f76a505152931d4955d04c16d
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9455
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Viet-Hoa Do <viet-hoa.do@arm.com>
Reviewed-by: Jakub Sujak <jakub.sujak@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
index 9af98be..9c85631 100644
--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
@@ -38,6 +38,46 @@
 {
 namespace cpu
 {
+namespace
+{
+/** Run pretranspose_B_array in parallel (1D static scheduling)
+ *
+ * @tparam TypeInput
+ * @tparam TypeOutput
+ *
+ * @param[in] gemm_asm         GemmCommon kernel to run
+ * @param[in] dst              Pretransposed B array
+ * @param[in] src              B array to be pretransposed
+ * @param[in] src_ld           Stride in y
+ * @param[in] src_multi_stride Stride in z ("multi")
+ * @param[in] num_threads      Number of threads to run this method. Must be >= 1
+ */
+template <typename TypeInput, typename TypeOutput>
+void run_parallel_pretranspose_B_array(arm_gemm::GemmCommon<TypeInput, TypeOutput> *gemm_asm, ITensor *dst, const TypeInput *src, int src_ld, int src_multi_stride, unsigned int num_threads)
+{
+    ARM_COMPUTE_ERROR_ON(gemm_asm == nullptr);
+    ARM_COMPUTE_ERROR_ON(num_threads == 0);
+    // The window size is also the total workload size
+    const unsigned int wsize = gemm_asm->get_B_pretranspose_window_size();
+
+    std::vector<IScheduler::Workload> workloads(num_threads);
+    for(unsigned int t = 0; t < num_threads; ++t)
+    {
+        workloads[t] = [ = ](const ThreadInfo & info)
+        {
+            const unsigned int start = (info.thread_id * wsize) / num_threads;
+            const unsigned int end   = ((info.thread_id + 1) * wsize) / num_threads;
+
+            if(start < end)
+            {
+                gemm_asm->pretranspose_B_array_part(dst->buffer(), src, src_ld, src_multi_stride, start, end);
+            }
+        };
+    }
+    NEScheduler::get().run_tagged_workloads(workloads, "CpuGemmAssemblyDispatch/pretranspose_B_array");
+}
+} // namespace
+
 using namespace arm_compute::experimental;
 
 namespace
@@ -436,7 +476,7 @@
 
             CpuAuxTensorHandler pretranspose(offset_int_vec(Pretranspose), _pretranspose_info, tensors, false);
             ARM_COMPUTE_ERROR_ON(pretranspose.get()->buffer() == nullptr);
-            _gemm_kernel_asm->pretranspose_B_array(pretranspose.get()->buffer(), in1_ptr, ldb, multi_stride_b);
+            run_parallel_pretranspose_B_array<TypeInput, TypeOutput>(_gemm_kernel_asm.get(), pretranspose.get(), in1_ptr, ldb, multi_stride_b, NEScheduler::get().num_threads());
 
             b->mark_as_unused();
         }
@@ -493,9 +533,9 @@
     // Check if B is pre-tranposed and de-reference if not
     if(!_gemm_kernel_asm->B_is_pretransposed())
     {
-        ldb                                = b->info()->strides_in_bytes().y() / b->info()->element_size();
-        multi_stride_b                     = b->info()->strides_in_bytes().z() / b->info()->element_size();
-        in1_ptr = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes());
+        ldb            = b->info()->strides_in_bytes().y() / b->info()->element_size();
+        multi_stride_b = b->info()->strides_in_bytes().z() / b->info()->element_size();
+        in1_ptr        = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes());
     }
 
     // If necessary, run pretranspose every time if either weights or biases are non-constant
@@ -522,7 +562,7 @@
             }
             else
             {
-                _gemm_kernel_asm->pretranspose_B_array(pretranspose.get()->buffer(), b_ptr, ldb, multi_stride_b);
+                run_parallel_pretranspose_B_array<TypeInput, TypeOutput>(_gemm_kernel_asm.get(), pretranspose.get(), b_ptr, ldb, multi_stride_b, NEScheduler::get().num_threads());
             }
         }
     }