COMPMID-1867: Add NEON/SVE GEMM Hybrid kernels.

Change-Id: Ib40a9921e7f9a6a8be6c38872d6b3a0f24ed0cd3
Reviewed-on: https://review.mlplatform.org/515
Reviewed-by: Anthony Barbier <Anthony.barbier@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp
index 6734e3c..bf80784 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2019 ARM Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -22,56 +22,53 @@
  * SOFTWARE.
  */
 
-#include "gemv_batched.hpp"
+#include <arm_gemm.hpp>
+
+#include <functional>
 
 namespace arm_gemm {
 
 template<typename Top, typename Tret>
-class GemmImplementation {
-public:
-    /* Is this implementation compatible with the args as provided? */
-    virtual bool is_supported(const GemmArgs<Tret> &args)   { return true; }
-    /* Is this implementation "recommended" for these args (heuristic)? */
-    virtual bool is_recommended(const GemmArgs<Tret> &args) { return true; }
-    /* Instantiate this method please. */
-    virtual UniqueGemmCommon<Top, Tret> instantiate(const GemmArgs<Tret> &args) = 0;
-
-    /* Indicate the "GemmMethod" for use as a selector */
-    const GemmMethod method;
-
-    virtual ~GemmImplementation() { }
-
-    GemmImplementation(GemmMethod method) : method(method) { }
-};
-
-/* "gemv_batched" implementation is type-agnostic, so template it here. */
-template<typename Top, typename Tret>
-class GemmImpl_gemv_batched : public GemmImplementation<Top, Tret> {
-public:
-    bool is_supported(const GemmArgs<Tret> &args) override {
-        return (args._Msize==1 && args._nbatches > 1);
-    }
-
-    UniqueGemmCommon<Top, Tret> instantiate(const GemmArgs<Tret> &args) override {
-        return UniqueGemmCommon<Top, Tret> (new GemvBatched<Top, Tret>(args));
-    }
-
-    GemmImpl_gemv_batched() : GemmImplementation<Top, Tret>(GemmMethod::GEMV_BATCHED) { }
+struct GemmImplementation {
+    const GemmMethod                                               method;
+    const char *                                                   name;
+    std::function<bool(const GemmArgs<Tret> &)>                    is_supported;
+    std::function<bool(const GemmArgs<Tret> &)>                    is_recommended;
+    std::function<GemmCommon<Top, Tret> *(const GemmArgs<Tret> &)> instantiate;
 };
 
 /* "Master" function implemented for each valid combination of types.
  * Returns a list of GEMM implementation descriptors for processing by the
- * other functions.  */
+ * other functions, terminated by an implementation with
+ * method==GemmMethod::DEFAULT.  */
 template<typename Top, typename Tret>
-std::vector<GemmImplementation<Top, Tret> *> &gemm_implementation_list();
+const GemmImplementation<Top, Tret> *gemm_implementation_list();
 
+/*
+ * Select a GEMM implementation for the given arguments.
+ *
+ * The logic here returns the first method on the list which supports the
+ * requested problem parameters, matches the provided filters (method and/or
+ * name string match) and recommends itself.
+ *
+ * If there is no such method, it will return the first method which
+ * supports the requested parameters and passes the filters, regardless of
+ * recommendation.
+ *
+ * If no method supports the requested parameters and passes the filters,
+ * this function returns false and doesn't touch the provided pointer
+ * reference.
+ */
 template<typename Top, typename Tret>
-GemmImplementation<Top, Tret> *find_implementation(GemmArgs<Tret> &args, GemmConfig *cfg) {
+bool find_implementation(const GemmArgs<Tret> &args, const GemmImplementation<Top, Tret> * &impl) {
     auto gemms = gemm_implementation_list<Top, Tret>();
+    const GemmConfig *cfg = args._cfg;
 
-    for(auto &&i : gemms) {
+    const GemmImplementation<Top, Tret> *saved_impl = nullptr;
+
+    for (auto i = gemms; i->method != GemmMethod::DEFAULT; i++) {
         /* Skip if this implementation doesn't support these args. */
-        if (!i->is_supported(args)) {
+        if (i->is_supported != nullptr && !i->is_supported(args)) {
             continue;
         }
 
@@ -80,52 +77,92 @@
             continue;
         }
 
-        /* If no specific method is requested, check that this method recommends itself. */
-        if ((!cfg || cfg->method == GemmMethod::DEFAULT) && !i->is_recommended(args)) {
+        /* Skip if a filter is to be applied and it doesn't match. */
+        if (cfg && cfg->filter != "" && !strstr(i->name, cfg->filter.c_str())) {
             continue;
         }
 
-        return i;
+        /* At this point, if we don't have a saved implementation, save this
+         * one.  This is so that we always return something if a filter
+         * matches, even if it doesn't recommend itself.
+         */
+        if (saved_impl == nullptr) {
+            saved_impl=i;
+        }
+
+        /* Check that this method recommends itself. */
+        if (i->is_recommended != nullptr && !i->is_recommended(args)) {
+            continue;
+        }
+
+        impl=i;
+
+        return true;
     }
 
-    return nullptr;
-}
-
-template<typename Top, typename Tret>
-UniqueGemmCommon<Top, Tret> gemm(GemmArgs<Tret> &args, GemmConfig *cfg) {
-    auto impl = find_implementation<Top, Tret>(args, cfg);
-
-    if (impl) {
-        return impl->instantiate(args);
-    }
-
-    return UniqueGemmCommon<Top, Tret>(nullptr);
-}
-
-template<typename Top, typename Tret>
-GemmMethod get_gemm_method(GemmArgs<Tret> &args) {
-    auto impl = find_implementation<Top, Tret>(args, nullptr);
-
-    if (impl) {
-        return impl->method;
-    }
-
-    /* This shouldn't happen - there should always be at least one valid implementation. */
-    return GemmMethod::DEFAULT;
-}
-
-template<typename Top, typename Tret>
-bool method_is_compatible(GemmMethod method, GemmArgs<Tret> &args) {
-    /* Determine if the method is valid by attempting to obtain an implementation specifying this method. */
-    GemmConfig cfg(method);
-
-    auto impl = find_implementation<Top, Tret>(args, &cfg);
-
-    if (impl) {
+    /* We didn't find an option matching the filters that recommended
+     * itself.  But if we found something earlier that matched the filters
+     * but wasn't recommended, return it here.  */
+    if (saved_impl != nullptr) {
+        impl = saved_impl;
         return true;
     }
 
     return false;
 }
 
-} // namespace arm_gemm
+template<typename Top, typename Tret>
+std::vector<std::string> get_compatible_kernels(const GemmArgs<Tret> &args) {
+    std::vector<std::string> res;
+
+    auto gemms = gemm_implementation_list<Top, Tret>();
+
+    for (auto i = gemms; i->method != GemmMethod::DEFAULT; i++) {
+        /* Check that this implementation supports the presented problem. */
+        if (i->is_supported != nullptr && !i->is_supported(args)) {
+            continue;
+        }
+
+        res.push_back(i->name);
+    }
+
+    return res;
+}
+
+template<typename Top, typename Tret>
+UniqueGemmCommon<Top, Tret> gemm(const GemmArgs<Tret> &args) {
+    const GemmImplementation<Top, Tret> *impl;
+
+    if (find_implementation<Top, Tret>(args, impl)) {
+        return UniqueGemmCommon<Top, Tret>(impl->instantiate(args));
+    }
+
+    return UniqueGemmCommon<Top, Tret>(nullptr);
+}
+
+template<typename Top, typename Tret>
+KernelDescription get_gemm_method(const GemmArgs<Tret> &args) {
+    const GemmImplementation<Top, Tret> *impl;
+
+    if (find_implementation<Top, Tret>(args, impl)) {
+        return KernelDescription(impl->method, impl->name);
+    }
+
+    /* This shouldn't happen - there should always be at least one valid implementation. */
+    return KernelDescription();
+}
+
+template<typename Top, typename Tret>
+bool method_is_compatible(GemmMethod method, const GemmArgs<Tret> &args) {
+    /* Determine if the method is valid by attempting to obtain an implementation specifying this method. */
+    GemmConfig       cfg(method);
+    GemmArgs<Tret>   myargs = args;
+
+    myargs._cfg = &cfg;
+
+    const GemmImplementation<Top, Tret> *impl;
+
+    return find_implementation<Top, Tret>(myargs, impl);
+}
+
+} // namespace arm_gemm
\ No newline at end of file