COMPMID-1271: New system for GEMM heuristics

This patch implements a system for separating the "validity" from
"preferred" aspect of the current heuristics in gemm_*.cpp.

Now, each gemm_*.cpp defines a list of candidate implementations,
each of which supplies an is_valid() function (to check for
validity), an is_preferred() function (the "heuristic" part), and an
instantiate() function which actually produces the GemmCommon object
pointer.

The actual gemm() function is now templated and uses this list to
select an implementation.  This patch also implements a mechanism to
identify the preferred implementation, and override it via the
GemmConfig structure.

Change-Id: Id49ab7af8bf2e3e9fd951a9698883ade234d40e1
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/139120
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Tested-by: Jenkins <bsgcomp@arm.com>
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp
index 65f43f3..829ae32 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp
@@ -28,6 +28,7 @@
 #include "arm_gemm.hpp"
 
 #include "gemm_common.hpp"
+#include "gemm_implementation.hpp"
 #include "gemm_interleaved.hpp"
 
 #include "kernels/a64_hgemm_24x8.hpp"
@@ -36,37 +37,59 @@
 
 namespace arm_gemm {
 
-template<>
-UniqueGemmCommon<__fp16, __fp16> gemm(const CPUInfo &ci, const unsigned int M, const unsigned int N, const unsigned int K,
-                                      const unsigned int nbatches, const unsigned int nmulti,
-                                      const bool trA, const bool trB, const __fp16 alpha, const __fp16 beta,
-                                      const int maxthreads, const bool pretransposed_hint) {
 #ifdef __aarch64__
 
-// Only consider the native FP16 kernel if it will get built.
 #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) || defined(FP16_KERNELS)
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-    // If the compiler is configured to enable this feature always, then assume it is available at runtime too.
-    const bool use_fp16=true;
-#else
-    // Otherwise, detect at runtime via CPUInfo.
-    const bool use_fp16=ci.has_fp16();
-#endif
-
-    // If FP16 is supported, use it.
-    if (use_fp16) {
-        return UniqueGemmCommon<__fp16, __fp16>(new GemmInterleaved<hgemm_24x8, __fp16, __fp16>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+class GemmImpl_gemm_fp16_interleaved_fp16 : public GemmImplementation<__fp16, __fp16> {
+public:
+#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+    bool is_supported(const GemmArgs<__fp16> &args) override {
+        return args._ci->has_fp16();
     }
 #endif
 
-    // Fallback to using the blocked SGEMM kernel.
-    return UniqueGemmCommon<__fp16, __fp16>(new GemmInterleaved<sgemm_12x8, __fp16, __fp16>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
-#else
-    // For AArch32, only support the SGEMM route for now.
-    return UniqueGemmCommon<__fp16, __fp16>(new GemmInterleaved<sgemm_8x6, __fp16, __fp16>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+    UniqueGemmCommon<__fp16, __fp16> instantiate(const GemmArgs<__fp16> &args) override {
+        return UniqueGemmCommon<__fp16, __fp16>(new GemmInterleaved<hgemm_24x8, __fp16, __fp16>(args));
+    }
+
+    GemmImpl_gemm_fp16_interleaved_fp16() : GemmImplementation<__fp16, __fp16>(GemmMethod::GEMM_INTERLEAVED_FP16) { }
+};
 #endif
+
+#endif // __aarch64__
+
+class GemmImpl_gemm_fp16_interleaved : public GemmImplementation<__fp16, __fp16> {
+public:
+    UniqueGemmCommon<__fp16, __fp16> instantiate(const GemmArgs<__fp16> &args) override {
+#ifdef __aarch64__
+        return UniqueGemmCommon<__fp16, __fp16>(new GemmInterleaved<sgemm_12x8, __fp16, __fp16>(args));
+#elif defined(__arm__)
+        return UniqueGemmCommon<__fp16, __fp16>(new GemmInterleaved<sgemm_8x6, __fp16, __fp16>(args));
+#else
+# error Unknown Architecture
+#endif
+    }
+
+    GemmImpl_gemm_fp16_interleaved() : GemmImplementation<__fp16, __fp16>(GemmMethod::GEMM_INTERLEAVED) { }
+};
+
+static std::vector<GemmImplementation<__fp16, __fp16> *> gemm_fp16_methods = {
+#if defined(__aarch64__) && (defined(__ARM_FEATURE_VECTOR_ARITHMETIC) || defined(FP16_KERNELS))
+    new GemmImpl_gemm_fp16_interleaved_fp16(),
+#endif
+    new GemmImpl_gemm_fp16_interleaved()
+};
+
+template<>
+std::vector<GemmImplementation<__fp16, __fp16> *> &gemm_implementation_list<__fp16, __fp16>() {
+    return gemm_fp16_methods;
 }
 
+/* Explicitly instantiate the external functions for these types. */
+template UniqueGemmCommon<__fp16, __fp16> gemm<__fp16, __fp16>(GemmArgs<__fp16> &args, GemmConfig *cfg);
+template GemmMethod get_gemm_method<__fp16, __fp16>(GemmArgs<__fp16> &args);
+template bool method_is_compatible<__fp16, __fp16>(GemmMethod method, GemmArgs<__fp16> &args);
+
 } // namespace arm_gemm
 
 #endif // __ARM_FP16_ARGS