COMPMID-1405: Create our own gemm_native kernel / function.

Change-Id: Ie0a80bd6b4eb5632cac63ccf54bcb07d4309da19
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/140305
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Giorgio Arena <giorgio.arena@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
diff --git a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
index e796a6a..f4710fa 100644
--- a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
+++ b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
@@ -23,21 +23,74 @@
  */
 #include "arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h"
 
+#include "arm_compute/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.h"
 #include "arm_compute/runtime/NEON/NEScheduler.h"
+#include "arm_compute/runtime/NEON/functions/NESimpleAssemblyFunction.h"
 
-using namespace arm_compute;
-
+namespace arm_compute
+{
 template <typename TypeInput, typename TypeOutput>
 NEGEMMAssemblyDispatch<TypeInput, TypeOutput>::NEGEMMAssemblyDispatch(std::shared_ptr<IMemoryManager> memory_manager)
     : _function(nullptr), _arm_gemm(), _memory_group(std::move(memory_manager))
 {
 }
 
+template <>
+bool NEGEMMAssemblyDispatch<float, float>::create_function(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint)
+{
+    ARM_COMPUTE_UNUSED(method);
+    ARM_COMPUTE_UNUSED(a);
+    ARM_COMPUTE_UNUSED(b);
+    ARM_COMPUTE_UNUSED(d);
+    ARM_COMPUTE_UNUSED(alpha);
+    ARM_COMPUTE_UNUSED(beta);
+    ARM_COMPUTE_UNUSED(pretranspose_hint);
+    switch(method)
+    {
+#ifdef __aarch64__
+        case arm_gemm::GemmMethod::GEMM_NATIVE:
+        {
+            auto kernel = support::cpp14::make_unique<NEGEMMNativeWrapperKernel<float, float>>();
+            kernel->configure(a, b, d, alpha, beta);
+            auto function = support::cpp14::make_unique<NESimpleAssemblyFunction>();
+            function->configure(std::move(kernel));
+            _function = std::move(function);
+            return true;
+        }
+#endif /* __aarch64__ */
+        default:
+            return false;
+    }
+}
+
+template <typename TypeInput, typename TypeOutput>
+bool NEGEMMAssemblyDispatch<TypeInput, TypeOutput>::create_function(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint)
+{
+    ARM_COMPUTE_UNUSED(method);
+    ARM_COMPUTE_UNUSED(a);
+    ARM_COMPUTE_UNUSED(b);
+    ARM_COMPUTE_UNUSED(d);
+    ARM_COMPUTE_UNUSED(alpha);
+    ARM_COMPUTE_UNUSED(beta);
+    ARM_COMPUTE_UNUSED(pretranspose_hint);
+    return false;
+}
+
 template <typename TypeInput, typename TypeOutput>
 void NEGEMMAssemblyDispatch<TypeInput, TypeOutput>::configure(const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint)
 {
-    //TODO(antbar01) Check heuristics here to figure out if we should use an ACL IFunction
-    _arm_gemm.configure(a, b, d, alpha, beta, pretranspose_hint, _memory_group);
+    INEGEMMWrapperKernel::Params p           = INEGEMMWrapperKernel::extract_parameters(a, b, d);
+    const CPUInfo               &ci          = NEScheduler::get().cpu_info();
+    unsigned int                 num_threads = NEScheduler::get().num_threads();
+
+    arm_gemm::GemmArgs<TypeOutput> args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, alpha, beta, num_threads, pretranspose_hint);
+
+    //Try to create an ACL function:
+    if(!create_function(arm_gemm::get_gemm_method<TypeInput, TypeOutput>(args), a, b, d, alpha, beta, pretranspose_hint))
+    {
+        //Fallback onto arm_gemm function if ACL doesn't support this method.
+        _arm_gemm.configure(a, b, d, args, _memory_group);
+    }
 }
 
 template <typename TypeInput, typename TypeOutput>
@@ -75,10 +128,8 @@
 }
 
 #ifndef __aarch64__
-namespace arm_compute
-{
 template <>
-void NEGEMMAssemblyDispatch<uint8_t, uint32_t>::Fallback::configure(const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint, MemoryGroup &memory_group)
+void NEGEMMAssemblyDispatch<uint8_t, uint32_t>::configure(const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint)
 {
     // arm_gemm::gemm for 8bit only exists for aarch64
     ARM_COMPUTE_UNUSED(a);
@@ -87,11 +138,11 @@
     ARM_COMPUTE_UNUSED(alpha);
     ARM_COMPUTE_UNUSED(beta);
     ARM_COMPUTE_UNUSED(pretranspose_hint);
-    ARM_COMPUTE_UNUSED(memory_group);
+    ARM_COMPUTE_ERROR("Not supported for this architecture");
 }
 
 template <>
-void NEGEMMAssemblyDispatch<int8_t, int32_t>::Fallback::configure(const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint, MemoryGroup &memory_group)
+void NEGEMMAssemblyDispatch<int8_t, int32_t>::configure(const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint)
 {
     // arm_gemm::gemm for 8bit only exists for aarch64
     ARM_COMPUTE_UNUSED(a);
@@ -100,23 +151,37 @@
     ARM_COMPUTE_UNUSED(alpha);
     ARM_COMPUTE_UNUSED(beta);
     ARM_COMPUTE_UNUSED(pretranspose_hint);
-    ARM_COMPUTE_UNUSED(memory_group);
+    ARM_COMPUTE_ERROR("Not supported for this architecture");
 }
 
-} //namespace arm_compute
+template <>
+void NEGEMMAssemblyDispatch<uint8_t, uint32_t>::Fallback::configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs<uint32_t> &args, MemoryGroup &memory_group)
+{
+    // arm_gemm::gemm for 8bit only exists for aarch64
+    ARM_COMPUTE_UNUSED(a);
+    ARM_COMPUTE_UNUSED(b);
+    ARM_COMPUTE_UNUSED(d);
+    ARM_COMPUTE_UNUSED(args);
+    ARM_COMPUTE_UNUSED(memory_group);
+    ARM_COMPUTE_ERROR("Not supported for this architecture");
+}
+
+template <>
+void NEGEMMAssemblyDispatch<int8_t, int32_t>::Fallback::configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs<int32_t> &args, MemoryGroup &memory_group)
+{
+    // arm_gemm::gemm for 8bit only exists for aarch64
+    ARM_COMPUTE_UNUSED(a);
+    ARM_COMPUTE_UNUSED(b);
+    ARM_COMPUTE_UNUSED(d);
+    ARM_COMPUTE_UNUSED(args);
+    ARM_COMPUTE_UNUSED(memory_group);
+    ARM_COMPUTE_ERROR("Not supported for this architecture");
+}
 #endif // aarch64
 template <typename TypeInput, typename TypeOutput>
-void NEGEMMAssemblyDispatch<TypeInput, TypeOutput>::Fallback::configure(const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint, MemoryGroup &memory_group)
+void NEGEMMAssemblyDispatch<TypeInput, TypeOutput>::Fallback::configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs<TypeOutput> &args, MemoryGroup &memory_group)
 {
-    const CPUInfo &ci          = NEScheduler::get().cpu_info();
-    const int      M           = d->info()->tensor_shape().y();
-    const int      N           = d->info()->tensor_shape().x();
-    const int      K           = a->info()->tensor_shape().x();
-    const int      batches     = d->info()->tensor_shape().total_size_upper(2);
-    const int      multis      = b->info()->tensor_shape().z();
-    unsigned int   num_threads = NEScheduler::get().num_threads();
-
-    _gemm_kernel_asm = arm_gemm::gemm<TypeInput, TypeOutput>(ci, M, N, K, batches, multis, false, false, alpha, beta, num_threads, pretranspose_hint);
+    _gemm_kernel_asm = arm_gemm::gemm<TypeInput, TypeOutput>(args, nullptr);
     if(_gemm_kernel_asm == nullptr)
     {
         //configuration not supported: Leave function unconfigured:
@@ -139,11 +204,10 @@
     //if we disable this code below in brackets then ConvLayer deadlocks when threads > 1 and
     //the shapes are In=1x1x1024 Weights=1x1x1024x1001 Biases=1001 Out=1x1x1001
     {
-        const unsigned int window_size = _gemm_kernel_asm->get_window_size();
-        if(window_size < num_threads)
+        const int window_size = _gemm_kernel_asm->get_window_size();
+        if(window_size < args._maxthreads)
         {
-            num_threads = window_size;
-            _gemm_kernel_asm->set_nthreads(num_threads);
+            _gemm_kernel_asm->set_nthreads(window_size);
         }
     }
 
@@ -248,8 +312,6 @@
     NEScheduler::get().schedule(_optimised_kernel.get(), Window::DimX);
 }
 
-namespace arm_compute
-{
 template class NEGEMMAssemblyDispatch<float, float>;
 template class NEGEMMAssemblyDispatch<uint8_t, uint32_t>;
 template class NEGEMMAssemblyDispatch<int8_t, int32_t>;
diff --git a/src/runtime/NEON/functions/NESimpleAssemblyFunction.cpp b/src/runtime/NEON/functions/NESimpleAssemblyFunction.cpp
new file mode 100644
index 0000000..a4b0dff
--- /dev/null
+++ b/src/runtime/NEON/functions/NESimpleAssemblyFunction.cpp
@@ -0,0 +1,46 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "arm_compute/runtime/NEON/functions/NESimpleAssemblyFunction.h"
+
+#include "arm_compute/core/Validate.h"
+#include "arm_compute/runtime/NEON/NEScheduler.h"
+
+using namespace arm_compute;
+
+NESimpleAssemblyFunction::NESimpleAssemblyFunction() // NOLINT
+    : _kernel()
+{
+}
+
+void NESimpleAssemblyFunction::run()
+{
+    NEScheduler::get().schedule(_kernel.get(), Window::DimX);
+}
+
+void NESimpleAssemblyFunction::configure(std::unique_ptr<INEGEMMWrapperKernel> kernel)
+{
+    ARM_COMPUTE_ERROR_ON_NULLPTR(kernel.get());
+    _kernel = std::move(kernel);
+    ARM_COMPUTE_ERROR_ON_WINDOW_DIMENSIONS_GTE(_kernel->window(), 1);
+}