COMPMID-3776: Indirect GEMM

Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com>
Change-Id: I51a1b0f098bc3a8c408c50c92221e4df3061e12c
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4343
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Sang-Hoon Park <sang-hoon.park@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp
index 04cac60..05c5116 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp
@@ -25,68 +25,151 @@
 
 #include "arm_gemm.hpp"
 
-#include "kernels/a64_hybrid_s8s32_dot_16x4.hpp"
-#include "kernels/a64_smallK_hybrid_s8s32_dot_4x6.hpp"
-#include "kernels/a64_smallK_hybrid_s8s32_dot_4x8.hpp"
-#include "kernels/sve_hybrid_s8s32_dot_4VLx4.hpp"
-#include "kernels/sve_smallK_hybrid_s8s32_dot_1VLx8.hpp"
+#include "kernels/a64_gemm_s16_8x12.hpp"
+#include "kernels/a64_gemm_s8_4x4.hpp"
+#include "kernels/a64_gemm_s8_8x12.hpp"
+#include "kernels/a64_hybrid_s8qa_dot_4x16.hpp"
+#include "kernels/a64_hybrid_s8qs_dot_6x16.hpp"
+#include "kernels/a64_hybrid_s8s32_dot_6x16.hpp"
+#include "kernels/a64_interleaved_s8s32_mmla_8x12.hpp"
+#include "kernels/a64_smallK_hybrid_s8s32_dot_6x4.hpp"
+#include "kernels/a64_smallK_hybrid_s8s32_dot_8x4.hpp"
 
+#include "kernels/sve_hybrid_s8s32_dot_6x4VL.hpp"
+#include "kernels/sve_hybrid_s8qa_dot_4x4VL.hpp"
+#include "kernels/sve_hybrid_s8qs_dot_6x4VL.hpp"
+#include "kernels/sve_interleaved_s8s32_dot_8x3VL.hpp"
+#include "kernels/sve_interleaved_s8s32_mmla_8x3VL.hpp"
+#include "kernels/sve_smallK_hybrid_s8s32_dot_8x1VL.hpp"
+
+#include "gemm_hybrid_indirect.hpp"
 #include "gemm_hybrid_quantized.hpp"
+#include "gemm_hybrid_quantized_inline.hpp"
+#include "gemm_interleaved.hpp"
 #include "quantize_wrapper.hpp"
+#include "utils.hpp"
 
 namespace arm_gemm {
 
 static const GemmImplementation<int8_t, int8_t, Requantize32> gemm_qint8_methods[] =
 {
 #ifdef __ARM_FEATURE_SVE
+#ifdef MMLA_INT8
 {
-    GemmMethod::GEMM_HYBRID_QUANTIZED,
-    "smallK_hybrid_s8s32_dot_1VLx8",
-    [](const GemmArgs &args, const Requantize32 &) { return args._Ksize<=64; },
+    GemmMethod::GEMM_INTERLEAVED,
+    "sve_interleaved_s8s32_mmla_8x3VL",
+    [](const GemmArgs &args, const Requantize32 &) { return (args._Ksize>8); },
     nullptr,
-    [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized<smallK_hybrid_s8s32_dot_1VLx8, int8_t, int8_t>(args, qp); }
-},
-{
-    GemmMethod::GEMM_HYBRID_QUANTIZED,
-    "hybrid_s8s32_dot_4VLx4",
-    [](const GemmArgs &args, const Requantize32 &) { return args._Ksize>=16; },
-    [](const GemmArgs &args, const Requantize32 &) { return ((args._Ksize <= 128) && (args._Nsize <= 128)) || ((args._nmulti > 1) && ((args._Msize / args._maxthreads) < 8)); },
-    [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized<hybrid_s8s32_dot_4VLx4, int8_t, int8_t>(args, qp); }
+    [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedQuantized<cls_sve_interleaved_s8s32_mmla_8x3VL, int8_t, int8_t>(args, qp); }
 },
 #endif
 {
     GemmMethod::GEMM_HYBRID_QUANTIZED,
-    "smallK_hybrid_s8s32_dot_4x8",
-    [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize<=32); },
+    "sve_smallK_hybrid_s8s32_dot_8x1VL",
+    [](const GemmArgs &args, const Requantize32 &) { return args._Ksize<=64 && !args._indirect_input; },
     nullptr,
-    [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized<smallK_hybrid_s8s32_dot_4x8, int8_t, int8_t>(args, qp); }
+    [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized<cls_sve_smallK_hybrid_s8s32_dot_8x1VL, int8_t, int8_t>(args, qp); }
+},
+#ifdef SVE2
+{
+    GemmMethod::GEMM_HYBRID,
+    "sve_hybrid_s8qs_dot_6x4VL",
+    [](const GemmArgs &args, const Requantize32 &qp) { return quant_hybrid_symmetric(qp); },
+    nullptr,
+    [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridIndirect<cls_sve_hybrid_s8qs_dot_6x4VL, int8_t, int8_t, Requantize32>(args, qp); }
+},
+{
+    GemmMethod::GEMM_HYBRID,
+    "sve_hybrid_s8qa_dot_4x4VL",
+    [](const GemmArgs &args, const Requantize32 &qp) { return quant_hybrid_asymmetric(qp); },
+    nullptr,
+    [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridIndirect<cls_sve_hybrid_s8qa_dot_4x4VL, int8_t, int8_t, Requantize32>(args, qp); }
+},
+#endif
+{
+    GemmMethod::GEMM_HYBRID,
+    "sve_hybrid_s8s32_dot_6x4VL",
+    nullptr,
+    nullptr,
+    [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridIndirect<cls_sve_hybrid_s8s32_dot_6x4VL, int8_t, int8_t, Requantize32, true>(args, qp); }
+},
+{
+    GemmMethod::GEMM_INTERLEAVED,
+    "sve_interleaved_s8s32_dot_8x3VL",
+    [](const GemmArgs &args, const Requantize32 &) { return (args._Ksize>4); },
+    nullptr,
+    [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedQuantized<cls_sve_interleaved_s8s32_dot_8x3VL, int8_t, int8_t>(args, qp); }
+},
+#endif // SVE
+#ifdef MMLA_INT8
+{
+    GemmMethod::GEMM_INTERLEAVED,
+    "a64_interleaved_s8s32_mmla_8x12",
+    [](const GemmArgs &args, const Requantize32 &) { return (args._Ksize>8); },
+    nullptr,
+    [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedQuantized<cls_a64_interleaved_s8s32_mmla_8x12, int8_t, int8_t>(args, qp); }
+},
+#endif
+{
+    GemmMethod::GEMM_HYBRID_QUANTIZED,
+    "a64_smallK_hybrid_s8s32_dot_8x4",
+    [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize<=32) && !args._indirect_input; },
+    nullptr,
+    [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized<cls_a64_smallK_hybrid_s8s32_dot_8x4, int8_t, int8_t>(args, qp); }
 },
 {
     GemmMethod::GEMM_HYBRID_QUANTIZED,
-    "smallK_hybrid_s8s32_dot_4x6",
-    [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize>32) && (args._Ksize<=64); },
+    "a64_smallK_hybrid_s8s32_dot_6x4",
+    [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize>32) && (args._Ksize<=64) && !args._indirect_input; },
     nullptr,
-    [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized<smallK_hybrid_s8s32_dot_4x6, int8_t, int8_t>(args, qp); }
+    [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized<cls_a64_smallK_hybrid_s8s32_dot_6x4, int8_t, int8_t>(args, qp); }
 },
 {
-    GemmMethod::GEMM_HYBRID_QUANTIZED,
-    "hybrid_s8s32_dot_16x4",
-    [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_dotprod() && args._Ksize>=16; },
-    [](const GemmArgs &args, const Requantize32 &) { return args._Nsize<=256 && args._Ksize>128; },
-    [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized<hybrid_s8s32_dot_16x4, int8_t, int8_t>(args, qp); }
-},
-/** QUANTIZE_WRAPPER_2D enables 2D parallelisation hint for IScheduler in NEGEMMAssemblyDispatch */
-{
-    GemmMethod::QUANTIZE_WRAPPER_2D,
-    "quantized_wrapper_2d",
+    GemmMethod::GEMM_INTERLEAVED,
+    "a64_gemm_s16_8x12",
     nullptr,
-    [](const GemmArgs &args, const Requantize32 &) { return (args._maxthreads >= 8) && (args._Msize >= 8) && (args._Nsize >= 8);},
-    [](const GemmArgs &args, const Requantize32 &qp) { return new QuantizeWrapper<int8_t, int8_t, int32_t>(args, qp); }
+    [](const GemmArgs &args, const Requantize32 &) { return args._ci->get_cpu_model() == CPUModel::A53; },
+    [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedQuantized<cls_a64_gemm_s16_8x12, int8_t, int8_t>(args, qp); }
+},
+{
+    GemmMethod::GEMM_HYBRID,
+    "a64_hybrid_s8qs_dot_6x16",
+    [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_dotprod() && quant_hybrid_symmetric(qp); },
+    nullptr,
+    [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridIndirect<cls_a64_hybrid_s8qs_dot_6x16, int8_t, int8_t, Requantize32>(args, qp); }
+},
+{
+    GemmMethod::GEMM_HYBRID,
+    "a64_hybrid_s8qa_dot_4x16",
+    [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_dotprod() && quant_hybrid_asymmetric(qp); },
+    nullptr,
+    [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridIndirect<cls_a64_hybrid_s8qa_dot_4x16, int8_t, int8_t, Requantize32>(args, qp); }
+},
+{
+    GemmMethod::GEMM_HYBRID,
+    "a64_hybrid_s8s32_dot_6x16",
+    [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_dotprod(); },
+    nullptr,
+    [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridIndirect<cls_a64_hybrid_s8s32_dot_6x16, int8_t, int8_t, Requantize32, true>(args, qp); }
+},
+{
+    GemmMethod::GEMM_INTERLEAVED,
+    "a64_gemm_s8_8x12",
+    [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_dotprod(); },
+    nullptr,
+    [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedQuantized<cls_a64_gemm_s8_8x12, int8_t, int8_t>(args, qp); }
+},
+{
+    GemmMethod::GEMM_INTERLEAVED,
+    "a64_gemm_s8_4x4",
+    nullptr,
+    nullptr,
+    [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedQuantized<cls_a64_gemm_s8_4x4, int8_t, int8_t>(args, qp); }
 },
 {
     GemmMethod::QUANTIZE_WRAPPER,
     "quantized_wrapper",
-    nullptr,
+    [](const GemmArgs &args, const Requantize32 &) { return !args._indirect_input; },
     nullptr,
     [](const GemmArgs &args, const Requantize32 &qp) { return new QuantizeWrapper<int8_t, int8_t, int32_t>(args, qp); }
 },