COMPMID-1406: Refactor gemm_interleaved to use our own types and scheduler
- Ported PrepareB kernel from gemm_interleave
- Ported TransformA feature from gemm_interleave
- Allocate reshaped a and b buffers
- Added memory_manager / memory_group
- MatrixMultiply kernel
- Interleave kernels execution.
- Fixed a few bugs: all nightly Convolution tests passing for threads=1
and threads=4
- Added Doxygen documentations and comments in the code
- Added support for all data types supported
Change-Id: Iffa1c09fda0bb9c61213bb83524d5a48e7ecb03c
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/141281
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
diff --git a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
index f17da7d..8ba620f 100644
--- a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
+++ b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
@@ -24,9 +24,13 @@
#include "arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h"
#include "arm_compute/core/CPP/Validate.h"
+#include "arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.h"
+#include "arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedPrepareBWrapperKernel.h"
+#include "arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h"
#include "arm_compute/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.h"
#include "arm_compute/runtime/NEON/NEScheduler.h"
#include "arm_compute/runtime/NEON/functions/NESimpleAssemblyFunction.h"
+#include "arm_compute/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.h"
#include <arm_neon.h>
@@ -34,31 +38,96 @@
{
namespace
{
-template <typename TypeInput, typename TypeOutput>
-std::unique_ptr<IFunction> create_function(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint)
+std::unique_ptr<IFunction> create_function_all_types(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint,
+ std::shared_ptr<IMemoryManager> memory_manager)
+
{
- ARM_COMPUTE_UNUSED(method);
- ARM_COMPUTE_UNUSED(a);
- ARM_COMPUTE_UNUSED(b);
- ARM_COMPUTE_UNUSED(d);
- ARM_COMPUTE_UNUSED(alpha);
- ARM_COMPUTE_UNUSED(beta);
- ARM_COMPUTE_UNUSED(pretranspose_hint);
- return nullptr;
-}
-template <>
-std::unique_ptr<IFunction> create_function<float, float>(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint)
-{
- ARM_COMPUTE_UNUSED(method);
- ARM_COMPUTE_UNUSED(a);
- ARM_COMPUTE_UNUSED(b);
- ARM_COMPUTE_UNUSED(d);
- ARM_COMPUTE_UNUSED(alpha);
- ARM_COMPUTE_UNUSED(beta);
- ARM_COMPUTE_UNUSED(pretranspose_hint);
+ //Note: It's safe to not check for FP16 support because this was already checked in NEGEMMAssemblyDispatch::configure()
switch(method)
{
+ case arm_gemm::GemmMethod::GEMM_INTERLEAVED:
+ {
+ if(!pretranspose_hint)
+ {
+ return nullptr;
+ }
+ auto function = support::cpp14::make_unique<NEGEMMInterleavedWrapper>(memory_manager);
+ function->configure(a, b, d, alpha, beta, pretranspose_hint);
+ return std::move(function);
+ }
+ default:
+ return nullptr;
+ }
+}
+
+template <typename TypeInput, typename TypeOutput>
+std::unique_ptr<IFunction> create_function(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint,
+ std::shared_ptr<IMemoryManager> memory_manager)
+{
+ ARM_COMPUTE_UNUSED(method);
+ ARM_COMPUTE_UNUSED(a);
+ ARM_COMPUTE_UNUSED(b);
+ ARM_COMPUTE_UNUSED(d);
+ ARM_COMPUTE_UNUSED(alpha);
+ ARM_COMPUTE_UNUSED(beta);
+ ARM_COMPUTE_UNUSED(pretranspose_hint);
+ ARM_COMPUTE_UNUSED(memory_manager);
+ return nullptr;
+}
+
#ifdef __aarch64__
+template <>
+std::unique_ptr<IFunction> create_function<int8_t, int32_t>(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint,
+ std::shared_ptr<IMemoryManager> memory_manager)
+{
+ switch(method)
+ {
+ case arm_gemm::GemmMethod::GEMM_INTERLEAVED_DOT:
+ {
+ if(!pretranspose_hint)
+ {
+ return nullptr;
+ }
+ auto function = support::cpp14::make_unique<NEGEMMInterleavedWrapper>(memory_manager);
+ function->configure(a, b, d, alpha, beta, pretranspose_hint, true /* use_dot */);
+ return std::move(function);
+ }
+ default:
+ return nullptr;
+ }
+ return nullptr;
+}
+
+template <>
+std::unique_ptr<IFunction> create_function<uint8_t, uint32_t>(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint,
+ std::shared_ptr<IMemoryManager> memory_manager)
+{
+ switch(method)
+ {
+ case arm_gemm::GemmMethod::GEMM_INTERLEAVED_DOT:
+ {
+ if(!pretranspose_hint)
+ {
+ return nullptr;
+ }
+ auto function = support::cpp14::make_unique<NEGEMMInterleavedWrapper>(memory_manager);
+ function->configure(a, b, d, alpha, beta, pretranspose_hint, true /* use_dot */);
+ return std::move(function);
+ }
+ default:
+ return nullptr;
+ }
+ return nullptr;
+}
+
+template <>
+std::unique_ptr<IFunction> create_function<float, float>(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint,
+ std::shared_ptr<IMemoryManager> memory_manager)
+{
+ ARM_COMPUTE_UNUSED(pretranspose_hint);
+ ARM_COMPUTE_UNUSED(memory_manager);
+ switch(method)
+ {
case arm_gemm::GemmMethod::GEMM_NATIVE:
{
auto kernel = support::cpp14::make_unique<NEGEMMNativeWrapperKernel<float, float>>();
@@ -67,11 +136,11 @@
function->configure(std::move(kernel));
return std::move(function);
}
-#endif /* __aarch64__ */
default:
return nullptr;
}
}
+#endif /* __aarch64__ */
/** Fallback in case ACL doesn't have a function */
template <typename TypeInput, typename TypeOutput>
@@ -173,11 +242,11 @@
// Pretranspose B if required
if(_gemm_kernel_asm->B_pretranspose_required())
{
+ ARM_COMPUTE_ERROR_ON(_pretranspose.buffer() == nullptr);
const int ldb = _b->info()->strides_in_bytes().y() / sizeof(TypeInput);
const auto in1_ptr = reinterpret_cast<const TypeInput *>(_b->buffer() + _b->info()->offset_first_element_in_bytes());
const int multi_stride_b = _b->info()->strides_in_bytes().z() / sizeof(TypeInput);
- ARM_COMPUTE_ERROR_ON(_pretranspose.buffer() == nullptr);
_gemm_kernel_asm->pretranspose_B_array(_pretranspose.buffer(), in1_ptr, ldb, multi_stride_b);
_b->mark_as_unused();
}
@@ -260,7 +329,7 @@
template <typename TypeInput, typename TypeOutput>
void create_function_or_arm_gemm(std::unique_ptr<IFunction> &acl_function, std::unique_ptr<NEGEMMAssemblyDispatch::IFallback> &arm_gemm, MemoryGroup &memory_group, const ITensor *a, const ITensor *b,
- ITensor *d, float alpha, float beta, bool pretranspose_hint)
+ ITensor *d, float alpha, float beta, bool pretranspose_hint, std::shared_ptr<IMemoryManager> memory_manager)
{
INEGEMMWrapperKernel::Params p = INEGEMMWrapperKernel::extract_parameters(a, b, d);
const CPUInfo &ci = NEScheduler::get().cpu_info();
@@ -269,7 +338,13 @@
arm_gemm::GemmArgs<TypeOutput> args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, alpha, beta, num_threads, pretranspose_hint);
//Try to create an ACL function:
- acl_function = create_function<TypeInput, TypeOutput>(arm_gemm::get_gemm_method<TypeInput, TypeOutput>(args), a, b, d, alpha, beta, pretranspose_hint);
+ acl_function = create_function_all_types(arm_gemm::get_gemm_method<TypeInput, TypeOutput>(args), a, b, d, alpha, beta, pretranspose_hint, memory_manager);
+ // If the type agnostic factory failed to create an ACL function, try the specialised one:
+ if(acl_function == nullptr)
+ {
+ acl_function = create_function<TypeInput, TypeOutput>(arm_gemm::get_gemm_method<TypeInput, TypeOutput>(args), a, b, d, alpha, beta, pretranspose_hint, memory_manager);
+ }
+ //If we still don't have an ACL function:
if(acl_function == nullptr)
{
//Fallback onto arm_gemm function if ACL doesn't support this method.
@@ -282,7 +357,7 @@
} //namespace
NEGEMMAssemblyDispatch::NEGEMMAssemblyDispatch(std::shared_ptr<IMemoryManager> memory_manager)
- : _function(nullptr), _arm_gemm(nullptr), _memory_group(std::move(memory_manager))
+ : _function(nullptr), _arm_gemm(nullptr), _memory_group(memory_manager), _memory_manager(memory_manager)
{
}
@@ -321,20 +396,20 @@
switch(a->info()->data_type())
{
case DataType::F32:
- create_function_or_arm_gemm<float, float>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint);
+ create_function_or_arm_gemm<float, float>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint, _memory_manager);
break;
#ifdef __aarch64__
case DataType::U8:
case DataType::QASYMM8:
- create_function_or_arm_gemm<uint8_t, uint32_t>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint);
+ create_function_or_arm_gemm<uint8_t, uint32_t>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint, _memory_manager);
break;
case DataType::S8:
- create_function_or_arm_gemm<int8_t, int32_t>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint);
+ create_function_or_arm_gemm<int8_t, int32_t>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint, _memory_manager);
break;
#endif /* __aarch64__ */
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::F16:
- create_function_or_arm_gemm<float16_t, float16_t>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint);
+ create_function_or_arm_gemm<float16_t, float16_t>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint, _memory_manager);
break;
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
default: