COMPMID-1271: New system for GEMM heuristics

This patch implements a system for separating the "validity" from
"preferred" aspect of the current heuristics in gemm_*.cpp.

Now, each gemm_*.cpp defines a list of candidate implementations,
each of which supplies an is_valid() function (to check for
validity), an is_preferred() function (the "heuristic" part), and an
instantiate() function which actually produces the GemmCommon object
pointer.

The actual gemm() function is now templated and uses this list to
select an implementation.  This patch also implements a mechanism to
identify the preferred implementation, and override it via the
GemmConfig structure.

Change-Id: Id49ab7af8bf2e3e9fd951a9698883ade234d40e1
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/139120
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Tested-by: Jenkins <bsgcomp@arm.com>
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
index 2fd040e..6e47adb 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
@@ -23,6 +23,7 @@
  */
 #include "arm_gemm.hpp"
 #include "gemm_common.hpp"
+#include "gemm_implementation.hpp"
 #include "gemm_interleaved.hpp"
 #include "gemm_native.hpp"
 #include "gemv_batched.hpp"
@@ -37,47 +38,92 @@
 
 namespace arm_gemm {
 
-template<>
-UniqueGemmCommon<float, float> gemm<float, float>(const CPUInfo &ci, const unsigned int M, const unsigned int N, const unsigned int K,
-                                                  const unsigned int nbatches, const unsigned int nmulti,
-                                                  const bool trA, const bool trB, const float alpha, const float beta,
-                                                  const int maxthreads, const bool pretransposed_hint) {
-    /* Handle "batched GEMV" */
-    if (M==1 && nbatches>1) {
-        return UniqueGemmCommon<float, float> (new GemvBatched<float, float>(ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
-    }
-
 #ifdef __aarch64__
-    /* Cases in priority order */
-    /* GemvPretransposed: requires M=1, alpha=1, and transposed hint set.  nbatches must be 1 or we would have returned above so don't test. */
-    if (M==1 && alpha==1.0f && pretransposed_hint) {
-        return UniqueGemmCommon<float, float> (new GemvPretransposed<sgemv_pretransposed, float, float>(&ci, N, K, nmulti, trB, beta));
+// SGEMM implementations for AArch64
+
+// Pretransposed GEMV
+class GemmImpl_sgemm_gemv_pretransposed : public GemmImplementation<float, float> {
+public:
+    bool is_supported(const GemmArgs<float> &args) override {
+        return (args._Msize==1 && args._alpha==1.0f && args._pretransposed_hint && args._nbatches==1);
     }
 
-    /* GemvNativeTransposed: requires M=1, no trA or trB, doesn't handle alpha */
-    if (M==1 && alpha==1.0f && !trA && !trB) {
-        return UniqueGemmCommon<float, float> (new GemvNativeTransposed<sgemv_trans, float, float>(&ci, N, K, nmulti, beta));
+    UniqueGemmCommon<float, float> instantiate(const GemmArgs<float> &args) override {
+        return UniqueGemmCommon<float, float> (new GemvPretransposed<sgemv_pretransposed, float, float>(args._ci, args._Nsize, args._Ksize, args._nmulti, args._trB, args._beta));
     }
 
-    /* Native GEMM: requires K at least 4, N a multiple of 16, doesn't
-     * handle alpha or transpose.  Use for small N/K, or if the blocked GEMM
-     * won't thread properly.  */
-    if ((K >= 4) && ((N % 16) == 0) && alpha==1.0f && !trA && !trB &&
-        ((K <= 128 && N <= 128) || (nmulti > 1 && (M/maxthreads) < 8))) {
-        return UniqueGemmCommon<float, float> (new GemmNative<sgemm_native_16x4, float, float>(&ci, M, N, K, nbatches, nmulti, beta));
+    GemmImpl_sgemm_gemv_pretransposed() : GemmImplementation<float, float>(GemmMethod::GEMV_PRETRANSPOSED) { }
+};
+
+// Native GEMV
+class GemmImpl_sgemm_gemv_native_transposed : public GemmImplementation<float, float> {
+public:
+    bool is_supported(const GemmArgs<float> &args) override {
+        return (args._Msize==1 && args._alpha==1.0f && !args._trA && !args._trB && args._nbatches==1);
     }
 
-    /* Blocked GEMM, handles all cases. */
-    return UniqueGemmCommon<float, float> (new GemmInterleaved<sgemm_12x8, float, float>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+    UniqueGemmCommon<float, float> instantiate(const GemmArgs<float> &args) override {
+        return UniqueGemmCommon<float, float> (new GemvNativeTransposed<sgemv_trans, float, float>(args._ci, args._Nsize, args._Ksize, args._nmulti, args._beta));
+    }
+
+    GemmImpl_sgemm_gemv_native_transposed() : GemmImplementation<float, float>(GemmMethod::GEMV_NATIVE_TRANSPOSED) { }
+};
+
+// Native GEMM
+class GemmImpl_sgemm_gemm_native : public GemmImplementation<float, float> {
+public:
+    bool is_supported(const GemmArgs<float> &args) override {
+        return (args._Ksize>4 && (args._Nsize % 16)==0 && args._alpha==1.0f && !args._trA && !args._trB);
+    }
+
+    bool is_recommended(const GemmArgs<float> &args) override {
+        return ((args._Ksize <= 128) && (args._Nsize <= 128)) || ((args._nmulti > 1) && ((args._Msize / args._maxthreads) < 8));
+    }
+
+    UniqueGemmCommon<float, float> instantiate(const GemmArgs<float> &args) override {
+        return UniqueGemmCommon<float, float> (new GemmNative<sgemm_native_16x4, float, float>(args._ci, args._Msize, args._Nsize, args._Ksize, args._nbatches, args._nmulti, args._beta));
+    }
+
+    GemmImpl_sgemm_gemm_native() : GemmImplementation<float, float>(GemmMethod::GEMM_NATIVE) { }
+};
+#endif // __aarch64__
+
+// Interleaved GEMM
+class GemmImpl_sgemm_gemm_interleaved : public GemmImplementation<float, float> {
+public:
+    UniqueGemmCommon<float, float> instantiate(const GemmArgs<float> &args) override {
+#ifdef __aarch64__
+        return UniqueGemmCommon<float, float> (new GemmInterleaved<sgemm_12x8, float, float>(args));
+#elif defined(__arm__)
+        return UniqueGemmCommon<float, float> (new GemmInterleaved<sgemm_8x6, float, float>(args));
 #else
-    return UniqueGemmCommon<float, float> (new GemmInterleaved<sgemm_8x6, float, float>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+# error Unknown Architecture.
 #endif
+    }
+
+    GemmImpl_sgemm_gemm_interleaved() : GemmImplementation<float, float>(GemmMethod::GEMM_INTERLEAVED) { }
+};
+
+/* List of implementations (order matters) */
+static std::vector<GemmImplementation<float, float> *> SGemmMethods = {
+    new GemmImpl_gemv_batched<float, float>(),
+#ifdef __aarch64__
+    new GemmImpl_sgemm_gemv_pretransposed(),
+    new GemmImpl_sgemm_gemv_native_transposed(),
+    new GemmImpl_sgemm_gemm_native(),
+#endif
+    new GemmImpl_sgemm_gemm_interleaved()
+};
+
+/* Templated function to return this list. */
+template<>
+std::vector<GemmImplementation<float, float> *> &gemm_implementation_list<float, float>() {
+    return SGemmMethods;
 }
 
-// Instantiate static class variables.
-#ifdef __aarch64__
-const int sgemm_native_16x4::out_width;
-const int sgemm_native_16x4::out_height;
-#endif
+/* Explicitly instantiate the external functions for these types. */
+template UniqueGemmCommon<float, float> gemm<float, float>(GemmArgs<float> &args, GemmConfig *cfg);
+template GemmMethod get_gemm_method<float, float>(GemmArgs<float> &args);
+template bool method_is_compatible<float, float>(GemmMethod method, GemmArgs<float> &args);
 
 } // namespace arm_gemm