Integrate multi-threaded pretranspose_B_array
This is required for the case where rhs (B) is dynamic and needs to be
pretransposed in every run.
In a multi-threaded setting, this means the previously single-threaded
pretranspose_B_array would become the bottleneck
Resolves COMPMID-5896
Signed-off-by: SiCong Li <sicong.li@arm.com>
Change-Id: Id508c46992188a0f76a505152931d4955d04c16d
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9455
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Viet-Hoa Do <viet-hoa.do@arm.com>
Reviewed-by: Jakub Sujak <jakub.sujak@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
index 9af98be..9c85631 100644
--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
@@ -38,6 +38,46 @@
{
namespace cpu
{
+namespace
+{
+/** Run pretranspose_B_array in parallel (1D static scheduling)
+ *
+ * @tparam TypeInput
+ * @tparam TypeOutput
+ *
+ * @param[in] gemm_asm GemmCommon kernel to run
+ * @param[in] dst Pretransposed B array
+ * @param[in] src B array to be pretransposed
+ * @param[in] src_ld Stride in y
+ * @param[in] src_multi_stride Stride in z ("multi")
+ * @param[in] num_threads Number of threads to run this method. Must be >= 1
+ */
+template <typename TypeInput, typename TypeOutput>
+void run_parallel_pretranspose_B_array(arm_gemm::GemmCommon<TypeInput, TypeOutput> *gemm_asm, ITensor *dst, const TypeInput *src, int src_ld, int src_multi_stride, unsigned int num_threads)
+{
+ ARM_COMPUTE_ERROR_ON(gemm_asm == nullptr);
+ ARM_COMPUTE_ERROR_ON(num_threads == 0);
+ // The window size is also the total workload size
+ const unsigned int wsize = gemm_asm->get_B_pretranspose_window_size();
+
+ std::vector<IScheduler::Workload> workloads(num_threads);
+ for(unsigned int t = 0; t < num_threads; ++t)
+ {
+ workloads[t] = [ = ](const ThreadInfo & info)
+ {
+ const unsigned int start = (info.thread_id * wsize) / num_threads;
+ const unsigned int end = ((info.thread_id + 1) * wsize) / num_threads;
+
+ if(start < end)
+ {
+ gemm_asm->pretranspose_B_array_part(dst->buffer(), src, src_ld, src_multi_stride, start, end);
+ }
+ };
+ }
+ NEScheduler::get().run_tagged_workloads(workloads, "CpuGemmAssemblyDispatch/pretranspose_B_array");
+}
+} // namespace
+
using namespace arm_compute::experimental;
namespace
@@ -436,7 +476,7 @@
CpuAuxTensorHandler pretranspose(offset_int_vec(Pretranspose), _pretranspose_info, tensors, false);
ARM_COMPUTE_ERROR_ON(pretranspose.get()->buffer() == nullptr);
- _gemm_kernel_asm->pretranspose_B_array(pretranspose.get()->buffer(), in1_ptr, ldb, multi_stride_b);
+ run_parallel_pretranspose_B_array<TypeInput, TypeOutput>(_gemm_kernel_asm.get(), pretranspose.get(), in1_ptr, ldb, multi_stride_b, NEScheduler::get().num_threads());
b->mark_as_unused();
}
@@ -493,9 +533,9 @@
// Check if B is pre-tranposed and de-reference if not
if(!_gemm_kernel_asm->B_is_pretransposed())
{
- ldb = b->info()->strides_in_bytes().y() / b->info()->element_size();
- multi_stride_b = b->info()->strides_in_bytes().z() / b->info()->element_size();
- in1_ptr = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes());
+ ldb = b->info()->strides_in_bytes().y() / b->info()->element_size();
+ multi_stride_b = b->info()->strides_in_bytes().z() / b->info()->element_size();
+ in1_ptr = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes());
}
// If necessary, run pretranspose every time if either weights or biases are non-constant
@@ -522,7 +562,7 @@
}
else
{
- _gemm_kernel_asm->pretranspose_B_array(pretranspose.get()->buffer(), b_ptr, ldb, multi_stride_b);
+ run_parallel_pretranspose_B_array<TypeInput, TypeOutput>(_gemm_kernel_asm.get(), pretranspose.get(), b_ptr, ldb, multi_stride_b, NEScheduler::get().num_threads());
}
}
}