New SME2 heuristics.

Change-Id: I69aa973e61df950060807a31230a1edd91add498
Signed-off-by: David Mansell <David.Mansell@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11514
Reviewed-by: Gunes Bayir <gunes.bayir@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp
index 5c08e61..0ddca04 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp
@@ -86,7 +86,7 @@
     "sme2_interleaved_nomerge_bf16fp32_mopa_1VLx4VL",
     [](const GemmArgs &args) { return args._ci->has_sme2(); },
     [](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>();
-                               return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); },
+                               return args._Nsize >= 8*VL || args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); },
     [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_bf16fp32_mopa_1VLx4VL, bfloat16, float>(args); }
 },
 {
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp
index 3b444ae..c7adf8e 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp
@@ -69,6 +69,14 @@
 },
 {
     GemmMethod::GEMM_INTERLEAVED,
+    "sme2_interleaved_nomerge_fp16fp32fp16_mopa_1VLx4VL",
+    [](const GemmArgs &args) { return args._ci->has_sme2(); },
+    [](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>();
+                               return args._Nsize >= 8*VL || args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); },
+    [](const GemmArgs &args) { return new GemmInterleaved<cls_sme2_interleaved_nomerge_fp16fp32fp16_mopa_1VLx4VL, __fp16, __fp16, Nothing, false, false, false, true>(args); }
+},
+{
+    GemmMethod::GEMM_INTERLEAVED,
     "sme2_interleaved_nomerge_fp16fp32fp16_mopa_4VLx1VL",
     [](const GemmArgs &args) { return args._ci->has_sme2(); },
     [](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>();
@@ -77,14 +85,6 @@
 },
 {
     GemmMethod::GEMM_INTERLEAVED,
-    "sme2_interleaved_nomerge_fp16fp32fp16_mopa_1VLx4VL",
-    [](const GemmArgs &args) { return args._ci->has_sme2(); },
-    [](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>();
-                               return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); },
-    [](const GemmArgs &args) { return new GemmInterleaved<cls_sme2_interleaved_nomerge_fp16fp32fp16_mopa_1VLx4VL, __fp16, __fp16, Nothing, false, false, false, true>(args); }
-},
-{
-    GemmMethod::GEMM_INTERLEAVED,
     "sme2_interleaved_nomerge_fp16fp32fp16_mopa_2VLx2VL",
     [](const GemmArgs &args) { return args._ci->has_sme2(); },
     nullptr,
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
index af0d38e..f223dea 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
@@ -141,7 +141,7 @@
     "sme2_interleaved_nomerge_bf16fp32_mopa_1VLx4VL",
     [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2() && !args._accumulate; },
     [](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>();
-                               return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); },
+                               return args._Nsize >= 8*VL || args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); },
     [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_bf16fp32_mopa_1VLx4VL, float, float>(args); }
 },
 #endif // ARM_COMPUTE_ENABLE_BF16
@@ -150,7 +150,7 @@
     "sme2_interleaved_nomerge_fp32_mopa_1VLx4VL",
     [](const GemmArgs &args) { return args._ci->has_sme2() && !args._accumulate; },
     [](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>();
-                               return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); },
+                               return args._Nsize >= 8*VL || args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); },
     [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_fp32_mopa_1VLx4VL, float, float>(args); }
 },
 #ifdef ARM_COMPUTE_ENABLE_BF16
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp
index 0dc0d55..fedda3a 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp
@@ -63,7 +63,7 @@
     "sme2_interleaved_nomerge_s8s32_mopa_1VLx4VL",
     [](const GemmArgs &args) { return args._ci->has_sme2(); },
     [](const GemmArgs &args) { const auto VL = sme::get_vector_length<int32_t>();
-                               return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); },
+                               return args._Nsize >= 8*VL || args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); },
     [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_s8s32_mopa_1VLx4VL, int8_t, int32_t>(args); }
 },
 {
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
index ae344f0..897ec9d 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
@@ -190,10 +190,19 @@
     auto p=prof.ScopedProfiler(PROFILE_KERNEL, (m_max - m_0) * (n_max - n_0) * kern_k);
 #endif
 
+    // Offset C pointer in a similar way to non-quantized case above.
+    Tri *offset_c_ptr;
+
+    if (c_ptr == nullptr) {
+        offset_c_ptr = nullptr;
+    } else {
+        offset_c_ptr = c_ptr + m_0 * ldc + n_0;
+    }
+
     strat.kernel(// A and B pointers are just the packed panels.
                  a_ptr, b_panel,
                  // Provide relevant part of output array and row stride.
-                 c_ptr + m_0 * ldc + n_0, ldc,
+                 offset_c_ptr, ldc,
                  // M, N, K sizes
                  m_max-m_0, n_max - n_0, kern_k,
                  // Bias, activation, accumulation.  Need to offset the bias as needed.
@@ -663,15 +672,27 @@
             return roundup(args._cfg->inner_block_size, strategy::k_unroll());
         }
 
-        // K blocking not supported if we are requantizing.
-        if (std::is_same<OutputStage, Requantize32>::value) {
+        // K blocking not supported if we are requantizing with the merging
+        // kernels.
+        if (std::is_same<OutputStage, Requantize32>::value && MergeStep) {
             return get_ktotal(args);
         }
 
+        const unsigned int L1_size = args._ci->get_L1_cache_size();
+
         // Special blocking for SME
         if (is_sme<strategy>::value) {
-            // Don't bother to block below this size threshold, experimentally determined to be 320 for FP32
-            unsigned int scaling_threshold = 1280 / sizeof(Toi);
+            // Target 512 bytes for 64kB L1, or 1024 bytes for 128kB L1.
+            unsigned int target_bytes_per_block = L1_size / 128;
+
+            // Default cache size in gemm-linux is 32kB though - so make
+            // sure minimum is 512
+            if (target_bytes_per_block < 512) {
+                target_bytes_per_block = 512;
+            }
+
+            // Don't bother to block below this size threshold (1.25X target size)
+            unsigned int scaling_threshold = ((target_bytes_per_block * 5) / 4) / sizeof(Toi);
 
             if (get_ktotal(args) <= scaling_threshold) {
                 return get_ktotal(args);
@@ -679,7 +700,7 @@
 
             // Once we are blocking, this (lower) threshold determines when we should use more blocks
             // NOTE: Could be that some factor-based solution would work better here.
-            unsigned int max_block_size = 1024 / sizeof(Toi);
+            unsigned int max_block_size = target_bytes_per_block / sizeof(Toi);
 
             unsigned int num_k_blocks = iceildiv(get_ktotal(args), max_block_size);
 
@@ -688,7 +709,6 @@
             return k_block;
         }
 
-        const unsigned int L1_size = args._ci->get_L1_cache_size();
         unsigned int k_block;
 
         // k_block: Find out how much of the larger array can be loaded into half the cache.
@@ -723,6 +743,17 @@
             return roundup(args._cfg->outer_block_size, strategy::out_width());
         }
 
+        // Special blocking for SME
+        if (is_sme<strategy>::value) {
+            // If total width is less than 4x kernel width, return the entire width.
+            if (args._Nsize < strategy::out_width()*4) {
+                return roundup(args._Nsize, strategy::out_width());
+            }
+
+            // Otherwise block to single kernel width.
+            return strategy::out_width();
+        }
+
         unsigned int x_block;
         const unsigned int L2_size = args._ci->get_L2_cache_size();
         const unsigned int k_block = get_k_block_size(args);
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp
index d1c4e49..321c972 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp
@@ -82,7 +82,7 @@
     "sme2_interleaved_nomerge_s8q_mopa_1VLx4VL",
     [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_sme2() && ((qp.per_channel_requant && (qp.per_channel_left_shifts == nullptr)) || (!qp.per_channel_requant && (qp.per_layer_left_shift == 0)));},
     [](const GemmArgs &args, const Requantize32 &) { const auto VL = sme::get_vector_length<int32_t>();
-                               return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); },
+                               return args._Nsize >= 8*VL || args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); },
     [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedPretransposedNoMergeQuantizedInline<cls_sme2_interleaved_nomerge_s8q_mopa_1VLx4VL, int8_t, int8_t>(args, qp); }
 },
 {
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp
index b85b1c4..93eecf9 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp
@@ -78,7 +78,7 @@
     "sme2_interleaved_nomerge_u8q_mopa_1VLx4VL",
     [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_sme2() && ((qp.per_channel_requant && (qp.per_channel_left_shifts == nullptr)) || (!qp.per_channel_requant && (qp.per_layer_left_shift == 0)));},
     [](const GemmArgs &args, const Requantize32 &) { const auto VL = sme::get_vector_length<uint32_t>();
-                               return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); },
+                               return args._Nsize >= 8*VL || args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); },
     [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedPretransposedNoMergeQuantizedInline<cls_sme2_interleaved_nomerge_u8q_mopa_1VLx4VL, uint8_t, uint8_t>(args, qp); }
 },
 {