Update GEMM assembly kernels

- Introduce Fp32 kernels with internal calculations in Bfloat16 when
fast_mode is enabled
- Improve kernel selection heuristics

Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com>
Change-Id: I68a9e7e862b6fd2721b46e0d7cc791091c4ab279
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5965
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp
index f3f2f33..abd2799 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp
@@ -29,13 +29,17 @@
 #include "kernels/a64_gemm_u8_4x4.hpp"
 #include "kernels/a64_gemm_u8_8x12.hpp"
 #include "kernels/a64_hybrid_u8qa_dot_4x16.hpp"
+#include "kernels/a64_hybrid_u8qa_mmla_4x16.hpp"
 #include "kernels/a64_hybrid_u8u32_dot_6x16.hpp"
+#include "kernels/a64_hybrid_u8u32_mmla_6x16.hpp"
 #include "kernels/a64_interleaved_u8u32_mmla_8x12.hpp"
 #include "kernels/a64_smallK_hybrid_u8u32_dot_6x4.hpp"
 #include "kernels/a64_smallK_hybrid_u8u32_dot_8x4.hpp"
 
-#include "kernels/sve_hybrid_u8u32_dot_6x4VL.hpp"
 #include "kernels/sve_hybrid_u8qa_dot_4x4VL.hpp"
+#include "kernels/sve_hybrid_u8qa_mmla_4x4VL.hpp"
+#include "kernels/sve_hybrid_u8u32_dot_6x4VL.hpp"
+#include "kernels/sve_hybrid_u8u32_mmla_6x4VL.hpp"
 #include "kernels/sve_interleaved_u8u32_dot_8x3VL.hpp"
 #include "kernels/sve_interleaved_u8u32_mmla_8x3VL.hpp"
 #include "kernels/sve_smallK_hybrid_u8u32_dot_8x1VL.hpp"
@@ -51,55 +55,77 @@
 static const GemmImplementation<uint8_t, uint8_t, Requantize32> gemm_quint8_methods[] =
 {
 #ifdef ARM_COMPUTE_ENABLE_SVE
-#ifdef ARM_COMPUTE_ENABLE_I8MM
-{
+GemmImplementation<uint8_t, uint8_t, Requantize32>::with_estimate(
+    GemmMethod::GEMM_HYBRID,
+    "sve_hybrid_u8qa_mmla_4x4VL",
+    [](const GemmArgs &args, const Requantize32 &qp) { return quant_hybrid_asymmetric(qp) && args._ci->has_sve2() && args._ci->has_svei8mm(); },
+    [](const GemmArgs &args, const Requantize32 &) { return GemmHybridIndirect<cls_sve_hybrid_u8qa_mmla_4x4VL, uint8_t, uint8_t, Requantize32>::estimate_cycles<uint8_t>(args); },
+    [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridIndirect<cls_sve_hybrid_u8qa_mmla_4x4VL, uint8_t, uint8_t, Requantize32>(args, qp); }
+),
+GemmImplementation<uint8_t, uint8_t, Requantize32>::with_estimate(
     GemmMethod::GEMM_INTERLEAVED,
     "sve_interleaved_u8u32_mmla_8x3VL",
     [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_svei8mm() && (args._Ksize>8); },
-    [](const GemmArgs &args, const Requantize32 &) { return args._ci->get_cpu_model() != CPUModel::KLEIN; },
+    [](const GemmArgs &args, const Requantize32 &) { return GemmInterleavedQuantized<cls_sve_interleaved_u8u32_mmla_8x3VL, uint8_t, uint8_t>::estimate_cycles<uint8_t>(args); },
     [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedQuantized<cls_sve_interleaved_u8u32_mmla_8x3VL, uint8_t, uint8_t>(args, qp); }
-},
-#endif
+),
+GemmImplementation<uint8_t, uint8_t, Requantize32>::with_estimate(
+    GemmMethod::GEMM_INTERLEAVED,
+    "sve_hybrid_u8u32_mmla_6x4VL",
+    [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_svei8mm(); },
+    [](const GemmArgs &args, const Requantize32 &) { return GemmHybridIndirect<cls_sve_hybrid_u8u32_mmla_6x4VL, uint8_t, uint8_t, Requantize32, true>::estimate_cycles<uint8_t>(args); },
+    [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridIndirect<cls_sve_hybrid_u8u32_mmla_6x4VL, uint8_t, uint8_t, Requantize32, true>(args, qp); }
+),
 {
     GemmMethod::GEMM_HYBRID_QUANTIZED,
     "sve_smallK_hybrid_u8u32_dot_8x1VL",
     [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_sve() && args._Ksize<=64 && !args._indirect_input; },
-    [](const GemmArgs &args, const Requantize32 &) { return args._ci->get_cpu_model() != CPUModel::KLEIN; },
+    nullptr,
     [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized<cls_sve_smallK_hybrid_u8u32_dot_8x1VL, uint8_t, uint8_t>(args, qp); }
 },
-#ifdef ARM_COMPUTE_ENABLE_SVE2 // Requantizing kernels include some SVE2 only instructions (SQRDMULH, SRSHL)
-{
+GemmImplementation<uint8_t, uint8_t, Requantize32>::with_estimate(
     GemmMethod::GEMM_HYBRID,
     "sve_hybrid_u8qa_dot_4x4VL",
-    [](const GemmArgs &args, const Requantize32 &qp) { return  args._ci->has_sve2() && quant_hybrid_asymmetric(qp); },
-    [](const GemmArgs &args, const Requantize32 &) { return args._ci->get_cpu_model() != CPUModel::KLEIN; },
+    [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_sve2() && quant_hybrid_asymmetric(qp); },
+    [](const GemmArgs &args, const Requantize32 &) { return GemmHybridIndirect<cls_sve_hybrid_u8qa_dot_4x4VL, uint8_t, uint8_t, Requantize32>::estimate_cycles<uint8_t>(args); },
     [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridIndirect<cls_sve_hybrid_u8qa_dot_4x4VL, uint8_t, uint8_t, Requantize32>(args, qp); }
-},
-#endif // ARM_COMPUTE_ENABLE_SVE2
-{
+),
+GemmImplementation<uint8_t, uint8_t, Requantize32>::with_estimate(
     GemmMethod::GEMM_HYBRID,
     "sve_hybrid_u8u32_dot_6x4VL",
-    [](const GemmArgs &args, const Requantize32 &) { return  args._ci->has_sve(); },
-    [](const GemmArgs &args, const Requantize32 &) { return args._ci->get_cpu_model() != CPUModel::KLEIN; },
+    [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_sve(); },
+    [](const GemmArgs &args, const Requantize32 &) { return GemmHybridIndirect<cls_sve_hybrid_u8u32_dot_6x4VL, uint8_t, uint8_t, Requantize32, true>::estimate_cycles<uint8_t>(args); },
     [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridIndirect<cls_sve_hybrid_u8u32_dot_6x4VL, uint8_t, uint8_t, Requantize32, true>(args, qp); }
-},
-{
+),
+GemmImplementation<uint8_t, uint8_t, Requantize32>::with_estimate(
     GemmMethod::GEMM_INTERLEAVED,
     "sve_interleaved_u8u32_dot_8x3VL",
-    [](const GemmArgs &args, const Requantize32 &) { return  args._ci->has_sve() && (args._Ksize>4); },
-    [](const GemmArgs &args, const Requantize32 &) { return  args._ci->get_cpu_model() != CPUModel::KLEIN; },
+    [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_sve() && (args._Ksize>4); },
+    [](const GemmArgs &args, const Requantize32 &) { return GemmInterleavedQuantized<cls_sve_interleaved_u8u32_dot_8x3VL, uint8_t, uint8_t>::estimate_cycles<uint8_t>(args); },
     [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedQuantized<cls_sve_interleaved_u8u32_dot_8x3VL, uint8_t, uint8_t>(args, qp); }
-},
-#endif
-#ifdef ARM_COMPUTE_ENABLE_I8MM
-{
+),
+#endif // ARM_COMPUTE_ENABLE_SVE
+GemmImplementation<uint8_t, uint8_t, Requantize32>::with_estimate(
+    GemmMethod::GEMM_HYBRID,
+    "a64_hybrid_u8qa_mmla_4x16",
+    [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_i8mm() && quant_hybrid_asymmetric(qp); },
+    [](const GemmArgs &args, const Requantize32 &) { return GemmHybridIndirect<cls_a64_hybrid_u8qa_mmla_4x16, uint8_t, uint8_t, Requantize32>::estimate_cycles<uint8_t>(args); },
+    [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridIndirect<cls_a64_hybrid_u8qa_mmla_4x16, uint8_t, uint8_t, Requantize32>(args, qp); }
+),
+GemmImplementation<uint8_t, uint8_t, Requantize32>::with_estimate(
     GemmMethod::GEMM_INTERLEAVED,
     "a64_interleaved_u8u32_mmla_8x12",
     [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_i8mm() && (args._Ksize>8); },
-    [](const GemmArgs &args, const Requantize32 &) { return args._ci->get_cpu_model() != CPUModel::KLEIN; },
+    [](const GemmArgs &args, const Requantize32 &) { return GemmInterleavedQuantized<cls_a64_interleaved_u8u32_mmla_8x12, uint8_t, uint8_t>::estimate_cycles<uint8_t>(args); },
     [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedQuantized<cls_a64_interleaved_u8u32_mmla_8x12, uint8_t, uint8_t>(args, qp); }
-},
-#endif
+),
+GemmImplementation<uint8_t, uint8_t, Requantize32>::with_estimate(
+    GemmMethod::GEMM_INTERLEAVED,
+    "a64_hybrid_u8u32_mmla_6x16",
+    [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_i8mm(); },
+    [](const GemmArgs &args, const Requantize32 &) { return GemmHybridIndirect<cls_a64_hybrid_u8u32_mmla_6x16, uint8_t, uint8_t, Requantize32, true>::estimate_cycles<uint8_t>(args); },
+    [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridIndirect<cls_a64_hybrid_u8u32_mmla_6x16, uint8_t, uint8_t, Requantize32, true>(args, qp); }
+),
 {
     GemmMethod::GEMM_HYBRID_QUANTIZED,
     "a64_smallK_hybrid_u8u32_dot_8x4",
@@ -125,35 +151,35 @@
     GemmMethod::GEMM_HYBRID,
     "a64_hybrid_u8qa_dot_4x16",
     [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_dotprod() && quant_hybrid_asymmetric(qp); },
-    [](const GemmArgs &args, const Requantize32 &) { return GemmHybridIndirect<cls_a64_hybrid_u8qa_dot_4x16, int8_t, int8_t, Requantize32>::estimate_cycles(args, cls_a64_hybrid_u8qa_dot_4x16::get_performance_parameters(args._ci)); },
+    [](const GemmArgs &args, const Requantize32 &) { return GemmHybridIndirect<cls_a64_hybrid_u8qa_dot_4x16, uint8_t, uint8_t, Requantize32>::estimate_cycles<uint8_t>(args); },
     [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridIndirect<cls_a64_hybrid_u8qa_dot_4x16, uint8_t, uint8_t, Requantize32>(args, qp); }
 ),
 GemmImplementation<uint8_t, uint8_t, Requantize32>::with_estimate(
     GemmMethod::GEMM_HYBRID,
     "a64_hybrid_u8u32_dot_6x16",
     [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_dotprod(); },
-    [](const GemmArgs &args, const Requantize32 &) { return GemmHybridIndirect<cls_a64_hybrid_u8u32_dot_6x16, int8_t, int8_t, Requantize32, true>::estimate_cycles(args, cls_a64_hybrid_u8u32_dot_6x16::get_performance_parameters(args._ci)); },
+    [](const GemmArgs &args, const Requantize32 &) { return GemmHybridIndirect<cls_a64_hybrid_u8u32_dot_6x16, uint8_t, uint8_t, Requantize32, true>::estimate_cycles<uint8_t>(args); },
     [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridIndirect<cls_a64_hybrid_u8u32_dot_6x16, uint8_t, uint8_t, Requantize32, true>(args, qp); }
 ),
 GemmImplementation<uint8_t, uint8_t, Requantize32>::with_estimate(
     GemmMethod::GEMM_INTERLEAVED,
     "a64_gemm_u8_8x12",
     [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_dotprod(); },
-    [](const GemmArgs &args, const Requantize32 &) { return GemmInterleavedQuantized<cls_a64_gemm_u8_8x12, int8_t, int8_t>::estimate_cycles(args, cls_a64_gemm_u8_8x12::get_performance_parameters(args._ci)); },
+    [](const GemmArgs &args, const Requantize32 &) { return GemmInterleavedQuantized<cls_a64_gemm_u8_8x12, uint8_t, uint8_t>::estimate_cycles<uint8_t>(args); },
     [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedQuantized<cls_a64_gemm_u8_8x12, uint8_t, uint8_t>(args, qp); }
 ),
-{
+GemmImplementation<uint8_t, uint8_t, Requantize32>::with_estimate(
     GemmMethod::GEMM_INTERLEAVED,
     "a64_gemm_u8_4x4",
     nullptr,
-    [](const GemmArgs &args, const Requantize32 &) { return !args._ci->has_dotprod(); },
+    [](const GemmArgs &args, const Requantize32 &) { return GemmInterleavedQuantized<cls_a64_gemm_u8_4x4, uint8_t, uint8_t>::estimate_cycles<uint8_t>(args); },
     [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedQuantized<cls_a64_gemm_u8_4x4, uint8_t, uint8_t>(args, qp); }
-},
+),
 {
     GemmMethod::QUANTIZE_WRAPPER,
     "quantized_wrapper",
     [](const GemmArgs &args, const Requantize32 &) { return !args._indirect_input; },
-    [](const GemmArgs &args, const Requantize32 &) { return !args._ci->has_dotprod(); },
+    [](const GemmArgs &, const Requantize32 &) { return false; },
     [](const GemmArgs &args, const Requantize32 &qp) { return new QuantizeWrapper<uint8_t, uint8_t, uint32_t>(args, qp); }
 },
 {