COMPMID-1406: Refactor gemm_interleaved to use our own types and scheduler

- Ported PrepareB kernel from gemm_interleave
- Ported TransformA feature from gemm_interleave
- Allocate reshaped a and b buffers
- Added memory_manager / memory_group
- MatrixMultiply kernel
- Interleave kernels execution.
- Fixed a few bugs: all nightly Convolution tests passing for threads=1
and threads=4
- Added Doxygen documentations and comments in the code
- Added support for all data types supported

Change-Id: Iffa1c09fda0bb9c61213bb83524d5a48e7ecb03c
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/141281
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
diff --git a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
index f17da7d..8ba620f 100644
--- a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
+++ b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
@@ -24,9 +24,13 @@
 #include "arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h"
 
 #include "arm_compute/core/CPP/Validate.h"
+#include "arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.h"
+#include "arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedPrepareBWrapperKernel.h"
+#include "arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h"
 #include "arm_compute/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.h"
 #include "arm_compute/runtime/NEON/NEScheduler.h"
 #include "arm_compute/runtime/NEON/functions/NESimpleAssemblyFunction.h"
+#include "arm_compute/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.h"
 
 #include <arm_neon.h>
 
@@ -34,31 +38,96 @@
 {
 namespace
 {
-template <typename TypeInput, typename TypeOutput>
-std::unique_ptr<IFunction> create_function(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint)
+std::unique_ptr<IFunction> create_function_all_types(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint,
+                                                     std::shared_ptr<IMemoryManager> memory_manager)
+
 {
-    ARM_COMPUTE_UNUSED(method);
-    ARM_COMPUTE_UNUSED(a);
-    ARM_COMPUTE_UNUSED(b);
-    ARM_COMPUTE_UNUSED(d);
-    ARM_COMPUTE_UNUSED(alpha);
-    ARM_COMPUTE_UNUSED(beta);
-    ARM_COMPUTE_UNUSED(pretranspose_hint);
-    return nullptr;
-}
-template <>
-std::unique_ptr<IFunction> create_function<float, float>(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint)
-{
-    ARM_COMPUTE_UNUSED(method);
-    ARM_COMPUTE_UNUSED(a);
-    ARM_COMPUTE_UNUSED(b);
-    ARM_COMPUTE_UNUSED(d);
-    ARM_COMPUTE_UNUSED(alpha);
-    ARM_COMPUTE_UNUSED(beta);
-    ARM_COMPUTE_UNUSED(pretranspose_hint);
+    //Note: It's safe to not check for FP16 support because this was already checked in NEGEMMAssemblyDispatch::configure()
     switch(method)
     {
+        case arm_gemm::GemmMethod::GEMM_INTERLEAVED:
+        {
+            if(!pretranspose_hint)
+            {
+                return nullptr;
+            }
+            auto function = support::cpp14::make_unique<NEGEMMInterleavedWrapper>(memory_manager);
+            function->configure(a, b, d, alpha, beta, pretranspose_hint);
+            return std::move(function);
+        }
+        default:
+            return nullptr;
+    }
+}
+
+template <typename TypeInput, typename TypeOutput>
+std::unique_ptr<IFunction> create_function(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint,
+                                           std::shared_ptr<IMemoryManager> memory_manager)
+{
+    ARM_COMPUTE_UNUSED(method);
+    ARM_COMPUTE_UNUSED(a);
+    ARM_COMPUTE_UNUSED(b);
+    ARM_COMPUTE_UNUSED(d);
+    ARM_COMPUTE_UNUSED(alpha);
+    ARM_COMPUTE_UNUSED(beta);
+    ARM_COMPUTE_UNUSED(pretranspose_hint);
+    ARM_COMPUTE_UNUSED(memory_manager);
+    return nullptr;
+}
+
 #ifdef __aarch64__
+template <>
+std::unique_ptr<IFunction> create_function<int8_t, int32_t>(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint,
+                                                            std::shared_ptr<IMemoryManager> memory_manager)
+{
+    switch(method)
+    {
+        case arm_gemm::GemmMethod::GEMM_INTERLEAVED_DOT:
+        {
+            if(!pretranspose_hint)
+            {
+                return nullptr;
+            }
+            auto function = support::cpp14::make_unique<NEGEMMInterleavedWrapper>(memory_manager);
+            function->configure(a, b, d, alpha, beta, pretranspose_hint, true /* use_dot */);
+            return std::move(function);
+        }
+        default:
+            return nullptr;
+    }
+    return nullptr;
+}
+
+template <>
+std::unique_ptr<IFunction> create_function<uint8_t, uint32_t>(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint,
+                                                              std::shared_ptr<IMemoryManager> memory_manager)
+{
+    switch(method)
+    {
+        case arm_gemm::GemmMethod::GEMM_INTERLEAVED_DOT:
+        {
+            if(!pretranspose_hint)
+            {
+                return nullptr;
+            }
+            auto function = support::cpp14::make_unique<NEGEMMInterleavedWrapper>(memory_manager);
+            function->configure(a, b, d, alpha, beta, pretranspose_hint, true /* use_dot */);
+            return std::move(function);
+        }
+        default:
+            return nullptr;
+    }
+    return nullptr;
+}
+
+template <>
+std::unique_ptr<IFunction> create_function<float, float>(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint,
+                                                         std::shared_ptr<IMemoryManager> memory_manager)
+{
+    ARM_COMPUTE_UNUSED(pretranspose_hint);
+    ARM_COMPUTE_UNUSED(memory_manager);
+    switch(method)
+    {
         case arm_gemm::GemmMethod::GEMM_NATIVE:
         {
             auto kernel = support::cpp14::make_unique<NEGEMMNativeWrapperKernel<float, float>>();
@@ -67,11 +136,11 @@
             function->configure(std::move(kernel));
             return std::move(function);
         }
-#endif /* __aarch64__ */
         default:
             return nullptr;
     }
 }
+#endif /* __aarch64__ */
 
 /** Fallback in case ACL doesn't have a function */
 template <typename TypeInput, typename TypeOutput>
@@ -173,11 +242,11 @@
         // Pretranspose B if required
         if(_gemm_kernel_asm->B_pretranspose_required())
         {
+            ARM_COMPUTE_ERROR_ON(_pretranspose.buffer() == nullptr);
             const int  ldb            = _b->info()->strides_in_bytes().y() / sizeof(TypeInput);
             const auto in1_ptr        = reinterpret_cast<const TypeInput *>(_b->buffer() + _b->info()->offset_first_element_in_bytes());
             const int  multi_stride_b = _b->info()->strides_in_bytes().z() / sizeof(TypeInput);
 
-            ARM_COMPUTE_ERROR_ON(_pretranspose.buffer() == nullptr);
             _gemm_kernel_asm->pretranspose_B_array(_pretranspose.buffer(), in1_ptr, ldb, multi_stride_b);
             _b->mark_as_unused();
         }
@@ -260,7 +329,7 @@
 
 template <typename TypeInput, typename TypeOutput>
 void create_function_or_arm_gemm(std::unique_ptr<IFunction> &acl_function, std::unique_ptr<NEGEMMAssemblyDispatch::IFallback> &arm_gemm, MemoryGroup &memory_group, const ITensor *a, const ITensor *b,
-                                 ITensor *d, float alpha, float beta, bool pretranspose_hint)
+                                 ITensor *d, float alpha, float beta, bool pretranspose_hint, std::shared_ptr<IMemoryManager> memory_manager)
 {
     INEGEMMWrapperKernel::Params p           = INEGEMMWrapperKernel::extract_parameters(a, b, d);
     const CPUInfo               &ci          = NEScheduler::get().cpu_info();
@@ -269,7 +338,13 @@
     arm_gemm::GemmArgs<TypeOutput> args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, alpha, beta, num_threads, pretranspose_hint);
 
     //Try to create an ACL function:
-    acl_function = create_function<TypeInput, TypeOutput>(arm_gemm::get_gemm_method<TypeInput, TypeOutput>(args), a, b, d, alpha, beta, pretranspose_hint);
+    acl_function = create_function_all_types(arm_gemm::get_gemm_method<TypeInput, TypeOutput>(args), a, b, d, alpha, beta, pretranspose_hint, memory_manager);
+    // If the type agnostic factory failed to create an ACL function, try the specialised one:
+    if(acl_function == nullptr)
+    {
+        acl_function = create_function<TypeInput, TypeOutput>(arm_gemm::get_gemm_method<TypeInput, TypeOutput>(args), a, b, d, alpha, beta, pretranspose_hint, memory_manager);
+    }
+    //If we still don't have an ACL function:
     if(acl_function == nullptr)
     {
         //Fallback onto arm_gemm function if ACL doesn't support this method.
@@ -282,7 +357,7 @@
 } //namespace
 
 NEGEMMAssemblyDispatch::NEGEMMAssemblyDispatch(std::shared_ptr<IMemoryManager> memory_manager)
-    : _function(nullptr), _arm_gemm(nullptr), _memory_group(std::move(memory_manager))
+    : _function(nullptr), _arm_gemm(nullptr), _memory_group(memory_manager), _memory_manager(memory_manager)
 {
 }
 
@@ -321,20 +396,20 @@
     switch(a->info()->data_type())
     {
         case DataType::F32:
-            create_function_or_arm_gemm<float, float>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint);
+            create_function_or_arm_gemm<float, float>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint, _memory_manager);
             break;
 #ifdef __aarch64__
         case DataType::U8:
         case DataType::QASYMM8:
-            create_function_or_arm_gemm<uint8_t, uint32_t>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint);
+            create_function_or_arm_gemm<uint8_t, uint32_t>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint, _memory_manager);
             break;
         case DataType::S8:
-            create_function_or_arm_gemm<int8_t, int32_t>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint);
+            create_function_or_arm_gemm<int8_t, int32_t>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint, _memory_manager);
             break;
 #endif /* __aarch64__ */
 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
         case DataType::F16:
-            create_function_or_arm_gemm<float16_t, float16_t>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint);
+            create_function_or_arm_gemm<float16_t, float16_t>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint, _memory_manager);
             break;
 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
         default: