COMPMID-3776: Indirect GEMM

Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com>
Change-Id: I51a1b0f098bc3a8c408c50c92221e4df3061e12c
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4343
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Sang-Hoon Park <sang-hoon.park@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp
index 0125f9c..7342fda 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp
@@ -25,13 +25,25 @@
 
 #include "arm_gemm.hpp"
 
-#include "kernels/a64_hybrid_u8u32_dot_16x4.hpp"
-#include "kernels/a64_smallK_hybrid_u8u32_dot_4x6.hpp"
-#include "kernels/a64_smallK_hybrid_u8u32_dot_4x8.hpp"
-#include "kernels/sve_hybrid_u8u32_dot_4VLx4.hpp"
-#include "kernels/sve_smallK_hybrid_u8u32_dot_1VLx8.hpp"
+#include "kernels/a64_gemm_u16_8x12.hpp"
+#include "kernels/a64_gemm_u8_4x4.hpp"
+#include "kernels/a64_gemm_u8_8x12.hpp"
+#include "kernels/a64_hybrid_u8qa_dot_4x16.hpp"
+#include "kernels/a64_hybrid_u8u32_dot_6x16.hpp"
+#include "kernels/a64_interleaved_u8u32_mmla_8x12.hpp"
+#include "kernels/a64_smallK_hybrid_u8u32_dot_6x4.hpp"
+#include "kernels/a64_smallK_hybrid_u8u32_dot_8x4.hpp"
 
+#include "kernels/sve_hybrid_u8u32_dot_6x4VL.hpp"
+#include "kernels/sve_hybrid_u8qa_dot_4x4VL.hpp"
+#include "kernels/sve_interleaved_u8u32_dot_8x3VL.hpp"
+#include "kernels/sve_interleaved_u8u32_mmla_8x3VL.hpp"
+#include "kernels/sve_smallK_hybrid_u8u32_dot_8x1VL.hpp"
+
+#include "gemm_hybrid_indirect.hpp"
 #include "gemm_hybrid_quantized.hpp"
+#include "gemm_hybrid_quantized_inline.hpp"
+#include "gemm_interleaved.hpp"
 #include "quantize_wrapper.hpp"
 
 namespace arm_gemm {
@@ -39,54 +51,108 @@
 static const GemmImplementation<uint8_t, uint8_t, Requantize32> gemm_quint8_methods[] =
 {
 #ifdef __ARM_FEATURE_SVE
+#ifdef MMLA_INT8
 {
-    GemmMethod::GEMM_HYBRID_QUANTIZED,
-    "smallK_hybrid_u8u32_dot_1VLx8",
-    [](const GemmArgs &args, const Requantize32 &) { return args._Ksize<=64; },
+    GemmMethod::GEMM_INTERLEAVED,
+    "sve_interleaved_u8u32_mmla_8x3VL",
+    [](const GemmArgs &args, const Requantize32 &) { return (args._Ksize>8); },
     nullptr,
-    [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized<smallK_hybrid_u8u32_dot_1VLx8, uint8_t, uint8_t>(args, qp); }
-},
-{
-    GemmMethod::GEMM_HYBRID_QUANTIZED,
-    "hybrid_u8u32_dot_4VLx4",
-    [](const GemmArgs &args, const Requantize32 &) { return args._Ksize>=16; },
-    [](const GemmArgs &args, const Requantize32 &) { return ((args._Ksize <= 128) && (args._Nsize <= 128)) || ((args._nmulti > 1) && ((args._Msize / args._maxthreads) < 8)); },
-    [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized<hybrid_u8u32_dot_4VLx4, uint8_t, uint8_t>(args, qp); }
+    [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedQuantized<cls_sve_interleaved_u8u32_mmla_8x3VL, uint8_t, uint8_t>(args, qp); }
 },
 #endif
 {
     GemmMethod::GEMM_HYBRID_QUANTIZED,
-    "smallK_hybrid_u8u32_dot_4x8",
-    [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize<=32); },
+    "sve_smallK_hybrid_u8u32_dot_8x1VL",
+    [](const GemmArgs &args, const Requantize32 &) { return args._Ksize<=64 && !args._indirect_input; },
     nullptr,
-    [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized<smallK_hybrid_u8u32_dot_4x8, uint8_t, uint8_t>(args, qp); }
+    [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized<cls_sve_smallK_hybrid_u8u32_dot_8x1VL, uint8_t, uint8_t>(args, qp); }
+},
+#ifdef SVE2 // Requantizing kernels include some SVE2 only instructions (SQRDMULH, SRSHL)
+{
+    GemmMethod::GEMM_HYBRID, 
+    "sve_hybrid_u8qa_dot_4x4VL",
+    [](const GemmArgs &args, const Requantize32 &qp) { return quant_hybrid_asymmetric(qp); },
+    nullptr,
+    [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridIndirect<cls_sve_hybrid_u8qa_dot_4x4VL, uint8_t, uint8_t, Requantize32>(args, qp); }
+},
+#endif
+{
+    GemmMethod::GEMM_HYBRID, 
+    "sve_hybrid_u8u32_dot_6x4VL",
+    nullptr,
+    nullptr,
+    [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridIndirect<cls_sve_hybrid_u8u32_dot_6x4VL, uint8_t, uint8_t, Requantize32, true>(args, qp); }
+},
+{
+    GemmMethod::GEMM_INTERLEAVED,
+    "sve_interleaved_u8u32_dot_8x3VL",
+    [](const GemmArgs &args, const Requantize32 &) { return (args._Ksize>4); },
+    nullptr,
+    [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedQuantized<cls_sve_interleaved_u8u32_dot_8x3VL, uint8_t, uint8_t>(args, qp); }
+},
+#endif
+#ifdef MMLA_INT8
+{
+    GemmMethod::GEMM_INTERLEAVED,
+    "a64_interleaved_u8u32_mmla_8x12",
+    [](const GemmArgs &args, const Requantize32 &) { return (args._Ksize>8); },
+    nullptr,
+    [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedQuantized<cls_a64_interleaved_u8u32_mmla_8x12, uint8_t, uint8_t>(args, qp); }
+},
+#endif
+{
+    GemmMethod::GEMM_HYBRID_QUANTIZED,
+    "a64_smallK_hybrid_u8u32_dot_8x4",
+    [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize<=32) && !args._indirect_input; },
+    nullptr,
+    [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized<cls_a64_smallK_hybrid_u8u32_dot_8x4, uint8_t, uint8_t>(args, qp); }
 },
 {
     GemmMethod::GEMM_HYBRID_QUANTIZED,
-    "smallK_hybrid_u8u32_dot_4x6",
-    [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize>32) && (args._Ksize<=64); },
+    "a64_smallK_hybrid_u8u32_dot_6x4",
+    [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize>32) && (args._Ksize<=64) && !args._indirect_input; },
     nullptr,
-    [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized<smallK_hybrid_u8u32_dot_4x6, uint8_t, uint8_t>(args, qp); }
+    [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized<cls_a64_smallK_hybrid_u8u32_dot_6x4, uint8_t, uint8_t>(args, qp); }
 },
 {
-    GemmMethod::GEMM_HYBRID_QUANTIZED,
-    "hybrid_u8u32_dot_16x4",
-    [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_dotprod() && args._Ksize>=16; },
-    [](const GemmArgs &args, const Requantize32 &) { return ((args._Nsize<=256) && (args._Ksize>128)) || (args._maxthreads >= 8); },
-    [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized<hybrid_u8u32_dot_16x4, uint8_t, uint8_t>(args, qp); }
-},
-/** QUANTIZE_WRAPPER_2D enables 2D parallelisation hint for IScheduler in NEGEMMAssemblyDispatch */
-{
-    GemmMethod::QUANTIZE_WRAPPER_2D,
-    "quantized_wrapper_2d",
+    GemmMethod::GEMM_INTERLEAVED,
+    "a64_gemm_u16_8x12",
     nullptr,
-    [](const GemmArgs &args, const Requantize32 &) { return (args._maxthreads >= 8) && (args._Msize >= 8) && (args._Nsize >= 8);},
-    [](const GemmArgs &args, const Requantize32 &qp) { return new QuantizeWrapper<uint8_t, uint8_t, uint32_t>(args, qp); }
+    [](const GemmArgs &args, const Requantize32 &) { return args._ci->get_cpu_model() == CPUModel::A53; },
+    [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedQuantized<cls_a64_gemm_u16_8x12, uint8_t, uint8_t>(args, qp); },
+},
+{
+    GemmMethod::GEMM_HYBRID,
+    "a64_hybrid_u8qa_dot_4x16",
+    [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_dotprod() && quant_hybrid_asymmetric(qp); },
+    [](const GemmArgs &args, const Requantize32 &) { return args._Nsize<=256 && args._Ksize>128; },
+    [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridIndirect<cls_a64_hybrid_u8qa_dot_4x16, uint8_t, uint8_t, Requantize32>(args, qp); }
+},
+{
+    GemmMethod::GEMM_HYBRID,
+    "a64_hybrid_u8u32_dot_6x16",
+    [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_dotprod(); },
+    [](const GemmArgs &args, const Requantize32 &) { return args._Nsize<=256 && args._Ksize>128; },
+    [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridIndirect<cls_a64_hybrid_u8u32_dot_6x16, uint8_t, uint8_t, Requantize32, true>(args, qp); }
+},
+{
+    GemmMethod::GEMM_INTERLEAVED,
+    "a64_gemm_u8_8x12",
+    [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_dotprod(); },
+    nullptr,
+    [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedQuantized<cls_a64_gemm_u8_8x12, uint8_t, uint8_t>(args, qp); }
+},
+{
+    GemmMethod::GEMM_INTERLEAVED,
+    "a64_gemm_u8_4x4",
+    nullptr,
+    nullptr,
+    [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedQuantized<cls_a64_gemm_u8_4x4, uint8_t, uint8_t>(args, qp); }
 },
 {
     GemmMethod::QUANTIZE_WRAPPER,
     "quantized_wrapper",
-    nullptr,
+    [](const GemmArgs &args, const Requantize32 &) { return !args._indirect_input; },
     nullptr,
     [](const GemmArgs &args, const Requantize32 &qp) { return new QuantizeWrapper<uint8_t, uint8_t, uint32_t>(args, qp); }
 },