COMPMID-1405: Create our own gemm_native kernel / function.

Change-Id: Ie0a80bd6b4eb5632cac63ccf54bcb07d4309da19
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/140305
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Giorgio Arena <giorgio.arena@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
diff --git a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
index e796a6a..f4710fa 100644
--- a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
+++ b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
@@ -23,21 +23,74 @@
  */
 #include "arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h"
 
+#include "arm_compute/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.h"
 #include "arm_compute/runtime/NEON/NEScheduler.h"
+#include "arm_compute/runtime/NEON/functions/NESimpleAssemblyFunction.h"
 
-using namespace arm_compute;
-
+namespace arm_compute
+{
 template <typename TypeInput, typename TypeOutput>
 NEGEMMAssemblyDispatch<TypeInput, TypeOutput>::NEGEMMAssemblyDispatch(std::shared_ptr<IMemoryManager> memory_manager)
     : _function(nullptr), _arm_gemm(), _memory_group(std::move(memory_manager))
 {
 }
 
+template <>
+bool NEGEMMAssemblyDispatch<float, float>::create_function(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint)
+{
+    ARM_COMPUTE_UNUSED(method);
+    ARM_COMPUTE_UNUSED(a);
+    ARM_COMPUTE_UNUSED(b);
+    ARM_COMPUTE_UNUSED(d);
+    ARM_COMPUTE_UNUSED(alpha);
+    ARM_COMPUTE_UNUSED(beta);
+    ARM_COMPUTE_UNUSED(pretranspose_hint);
+    switch(method)
+    {
+#ifdef __aarch64__
+        case arm_gemm::GemmMethod::GEMM_NATIVE:
+        {
+            auto kernel = support::cpp14::make_unique<NEGEMMNativeWrapperKernel<float, float>>();
+            kernel->configure(a, b, d, alpha, beta);
+            auto function = support::cpp14::make_unique<NESimpleAssemblyFunction>();
+            function->configure(std::move(kernel));
+            _function = std::move(function);
+            return true;
+        }
+#endif /* __aarch64__ */
+        default:
+            return false;
+    }
+}
+
+template <typename TypeInput, typename TypeOutput>
+bool NEGEMMAssemblyDispatch<TypeInput, TypeOutput>::create_function(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint)
+{
+    ARM_COMPUTE_UNUSED(method);
+    ARM_COMPUTE_UNUSED(a);
+    ARM_COMPUTE_UNUSED(b);
+    ARM_COMPUTE_UNUSED(d);
+    ARM_COMPUTE_UNUSED(alpha);
+    ARM_COMPUTE_UNUSED(beta);
+    ARM_COMPUTE_UNUSED(pretranspose_hint);
+    return false;
+}
+
 template <typename TypeInput, typename TypeOutput>
 void NEGEMMAssemblyDispatch<TypeInput, TypeOutput>::configure(const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint)
 {
-    //TODO(antbar01) Check heuristics here to figure out if we should use an ACL IFunction
-    _arm_gemm.configure(a, b, d, alpha, beta, pretranspose_hint, _memory_group);
+    INEGEMMWrapperKernel::Params p           = INEGEMMWrapperKernel::extract_parameters(a, b, d);
+    const CPUInfo               &ci          = NEScheduler::get().cpu_info();
+    unsigned int                 num_threads = NEScheduler::get().num_threads();
+
+    arm_gemm::GemmArgs<TypeOutput> args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, alpha, beta, num_threads, pretranspose_hint);
+
+    //Try to create an ACL function:
+    if(!create_function(arm_gemm::get_gemm_method<TypeInput, TypeOutput>(args), a, b, d, alpha, beta, pretranspose_hint))
+    {
+        //Fallback onto arm_gemm function if ACL doesn't support this method.
+        _arm_gemm.configure(a, b, d, args, _memory_group);
+    }
 }
 
 template <typename TypeInput, typename TypeOutput>
@@ -75,10 +128,8 @@
 }
 
 #ifndef __aarch64__
-namespace arm_compute
-{
 template <>
-void NEGEMMAssemblyDispatch<uint8_t, uint32_t>::Fallback::configure(const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint, MemoryGroup &memory_group)
+void NEGEMMAssemblyDispatch<uint8_t, uint32_t>::configure(const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint)
 {
     // arm_gemm::gemm for 8bit only exists for aarch64
     ARM_COMPUTE_UNUSED(a);
@@ -87,11 +138,11 @@
     ARM_COMPUTE_UNUSED(alpha);
     ARM_COMPUTE_UNUSED(beta);
     ARM_COMPUTE_UNUSED(pretranspose_hint);
-    ARM_COMPUTE_UNUSED(memory_group);
+    ARM_COMPUTE_ERROR("Not supported for this architecture");
 }
 
 template <>
-void NEGEMMAssemblyDispatch<int8_t, int32_t>::Fallback::configure(const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint, MemoryGroup &memory_group)
+void NEGEMMAssemblyDispatch<int8_t, int32_t>::configure(const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint)
 {
     // arm_gemm::gemm for 8bit only exists for aarch64
     ARM_COMPUTE_UNUSED(a);
@@ -100,23 +151,37 @@
     ARM_COMPUTE_UNUSED(alpha);
     ARM_COMPUTE_UNUSED(beta);
     ARM_COMPUTE_UNUSED(pretranspose_hint);
-    ARM_COMPUTE_UNUSED(memory_group);
+    ARM_COMPUTE_ERROR("Not supported for this architecture");
 }
 
-} //namespace arm_compute
+template <>
+void NEGEMMAssemblyDispatch<uint8_t, uint32_t>::Fallback::configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs<uint32_t> &args, MemoryGroup &memory_group)
+{
+    // arm_gemm::gemm for 8bit only exists for aarch64
+    ARM_COMPUTE_UNUSED(a);
+    ARM_COMPUTE_UNUSED(b);
+    ARM_COMPUTE_UNUSED(d);
+    ARM_COMPUTE_UNUSED(args);
+    ARM_COMPUTE_UNUSED(memory_group);
+    ARM_COMPUTE_ERROR("Not supported for this architecture");
+}
+
+template <>
+void NEGEMMAssemblyDispatch<int8_t, int32_t>::Fallback::configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs<int32_t> &args, MemoryGroup &memory_group)
+{
+    // arm_gemm::gemm for 8bit only exists for aarch64
+    ARM_COMPUTE_UNUSED(a);
+    ARM_COMPUTE_UNUSED(b);
+    ARM_COMPUTE_UNUSED(d);
+    ARM_COMPUTE_UNUSED(args);
+    ARM_COMPUTE_UNUSED(memory_group);
+    ARM_COMPUTE_ERROR("Not supported for this architecture");
+}
 #endif // aarch64
 template <typename TypeInput, typename TypeOutput>
-void NEGEMMAssemblyDispatch<TypeInput, TypeOutput>::Fallback::configure(const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint, MemoryGroup &memory_group)
+void NEGEMMAssemblyDispatch<TypeInput, TypeOutput>::Fallback::configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs<TypeOutput> &args, MemoryGroup &memory_group)
 {
-    const CPUInfo &ci          = NEScheduler::get().cpu_info();
-    const int      M           = d->info()->tensor_shape().y();
-    const int      N           = d->info()->tensor_shape().x();
-    const int      K           = a->info()->tensor_shape().x();
-    const int      batches     = d->info()->tensor_shape().total_size_upper(2);
-    const int      multis      = b->info()->tensor_shape().z();
-    unsigned int   num_threads = NEScheduler::get().num_threads();
-
-    _gemm_kernel_asm = arm_gemm::gemm<TypeInput, TypeOutput>(ci, M, N, K, batches, multis, false, false, alpha, beta, num_threads, pretranspose_hint);
+    _gemm_kernel_asm = arm_gemm::gemm<TypeInput, TypeOutput>(args, nullptr);
     if(_gemm_kernel_asm == nullptr)
     {
         //configuration not supported: Leave function unconfigured:
@@ -139,11 +204,10 @@
     //if we disable this code below in brackets then ConvLayer deadlocks when threads > 1 and
     //the shapes are In=1x1x1024 Weights=1x1x1024x1001 Biases=1001 Out=1x1x1001
     {
-        const unsigned int window_size = _gemm_kernel_asm->get_window_size();
-        if(window_size < num_threads)
+        const int window_size = _gemm_kernel_asm->get_window_size();
+        if(window_size < args._maxthreads)
         {
-            num_threads = window_size;
-            _gemm_kernel_asm->set_nthreads(num_threads);
+            _gemm_kernel_asm->set_nthreads(window_size);
         }
     }
 
@@ -248,8 +312,6 @@
     NEScheduler::get().schedule(_optimised_kernel.get(), Window::DimX);
 }
 
-namespace arm_compute
-{
 template class NEGEMMAssemblyDispatch<float, float>;
 template class NEGEMMAssemblyDispatch<uint8_t, uint32_t>;
 template class NEGEMMAssemblyDispatch<int8_t, int32_t>;