Add in place summation to CPU GEMM kernels

Instead of dispatching the sum postop for GEMM kernels to a
separate kernel + add, that requires an extra destination sized
allocation, plus 3 extra load/stores per element,
just do it in the GEMM kernel.

Resolves: ONCPUML-1442

Signed-off-by: Radu Salavat <radu.salavat@arm.com>
Co-authored-by: Milos Puzovic <milos.puzovic@arm.com>
Change-Id: I7a1f2da3300875fa1ac88b705a34390969518077
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11298
Reviewed-by: Gunes Bayir <gunes.bayir@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
index e85dd59..290fe87 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
@@ -293,14 +293,14 @@
 {
     GemmMethod::GEMM_HYBRID,
     "a64_smallK_hybrid_fp32_mla_8x4",
-    [](const GemmArgs &args) { return args._Ksize <= 8 && (args._Nsize % 4)==0 && !args._indirect_input; },
+    [](const GemmArgs &args) { return args._Ksize <= 8 && (args._Nsize % 4)==0 && !args._indirect_input && !args._accumulate; },
     nullptr,
     [](const GemmArgs &args) { return new GemmHybrid<cls_a64_smallK_hybrid_fp32_mla_8x4, float, float>(args); }
 },
 {
     GemmMethod::GEMM_HYBRID,
     "a64_smallK_hybrid_fp32_mla_6x4",
-    [](const GemmArgs &args) { return (args._Ksize > 8 && args._Ksize <= 16) && (args._Nsize % 4)==0 && !args._indirect_input; },
+    [](const GemmArgs &args) { return (args._Ksize > 8 && args._Ksize <= 16) && (args._Nsize % 4)==0 && !args._indirect_input && !args._accumulate; },
     nullptr,
     [](const GemmArgs &args) { return new GemmHybrid<cls_a64_smallK_hybrid_fp32_mla_6x4, float, float>(args); }
 },
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp
index 89c2d5a..0cc4d4f 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp
@@ -530,7 +530,7 @@
                                  (m_end - m_start), (nmax - n0), kern_k, b_panel, this->_ldb, out_arg,
                                  (this->_bias && first_pass) ? this->_bias + (multi * this->_bias_multi_stride) + n0 : nullptr,
                                  last_pass ? _args._act : Activation(),
-                                 !first_pass,
+                                 !first_pass || _args._accumulate,
                                  // Quantization parameters
                                  _os, _col_bias+(multi * _args._Nsize), n0);
                 } else if (_convolver) {
@@ -563,7 +563,7 @@
                                  (m_end - m_start), (nmax - n0), kern_k, b_panel, this->_ldb, out_arg,
                                  (this->_bias && first_pass) ? this->_bias + (multi * this->_bias_multi_stride) + n0 : nullptr,
                                  last_pass ? _args._act : Activation(),
-                                 !first_pass,
+                                 !first_pass || _args._accumulate,
                                  // Quantization parameters
                                  _os, _col_bias+(multi * _args._Nsize), n0);
                 } else {
@@ -579,7 +579,7 @@
                                  (m_end - m_start), (nmax - n0), kern_k, b_panel, this->_ldb, out_arg,
                                  (this->_bias && first_pass) ? this->_bias + (multi * this->_bias_multi_stride) + n0 : nullptr,
                                  last_pass ? _args._act : Activation(),
-                                 !first_pass,
+                                 !first_pass || _args._accumulate,
                                  // Quantization parameters
                                  _os, _col_bias+(multi * _args._Nsize), n0);
                 }
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp
index fd20e53..0dc0d55 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2020, 2022-2023 Arm Limited.
+ * Copyright (c) 2017-2020, 2022-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -128,14 +128,14 @@
 {
     GemmMethod::GEMM_HYBRID,
     "a64_smallK_hybrid_s8s32_dot_8x4",
-    [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize<=32) && !args._indirect_input; },
+    [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize<=32) && !args._indirect_input && !args._accumulate; },
     [](const GemmArgs &args) { return !(args._ci->has_svei8mm() || args._ci->has_i8mm()); },
     [](const GemmArgs &args) { return new GemmHybrid<cls_a64_smallK_hybrid_s8s32_dot_8x4, int8_t, int32_t>(args); }
 },
 {
     GemmMethod::GEMM_HYBRID,
     "a64_smallK_hybrid_s8s32_dot_6x4",
-    [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize>32) && (args._Ksize<=64) && !args._indirect_input; },
+    [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize>32) && (args._Ksize<=64) && !args._indirect_input && !args._accumulate; },
     [](const GemmArgs &args) { return !(args._ci->has_svei8mm() || args._ci->has_i8mm()); },
     [](const GemmArgs &args) { return new GemmHybrid<cls_a64_smallK_hybrid_s8s32_dot_6x4, int8_t, int32_t>(args); }
 },
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
index 4f732f7..d8b4645 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
@@ -350,6 +350,7 @@
     const bool _thread_columns;
 
     const Activation _act;
+    const bool _accumulate;
 
     const int _maxthreads;
     int _nthreads;
@@ -680,7 +681,7 @@
                       _Ksections(args._Ksections), _Ktotal(get_ktotal(args)),
                       _rounded_Ksize(roundup(_Ksize, strategy::k_unroll())),
                       _nbatches(args._nbatches), _nmulti(args._nmulti), _thread_columns(is_thread_columns(args)),
-                      _act(args._act), _maxthreads(args._maxthreads), _nthreads(args._maxthreads),
+                      _act(args._act), _accumulate(args._accumulate), _maxthreads(args._maxthreads), _nthreads(args._maxthreads),
                       _k_block(get_k_block_size(args)), _x_block(get_x_block_size(args)), _Mround(roundup(args._Msize, strategy::out_height())),
                       _os(os) { }
 
@@ -690,7 +691,7 @@
                       _Ksections(args._Ksections), _Ktotal(get_ktotal(args)),
                       _rounded_Ksize(roundup(_Ksize, strategy::k_unroll())),
                       _nbatches(args._nbatches), _nmulti(args._nmulti), _thread_columns(is_thread_columns(args)),
-                      _act(args._act), _maxthreads(args._maxthreads), _nthreads(args._maxthreads),
+                      _act(args._act), _accumulate(args._accumulate),  _maxthreads(args._maxthreads), _nthreads(args._maxthreads),
                       _k_block(get_k_block_size(args)), _x_block(get_x_block_size(args)), _Mround(roundup(args._Msize, strategy::out_height())),
                       _os() { }
 
@@ -823,7 +824,7 @@
                             // Only do bias on the first pass
                             ((first_pass && this->_bias) ? this->_bias + (multi * this->_bias_multi_stride) : nullptr),
                             // Only do activation on the last pass, and accumulation on any non-first pass.
-                            (last_pass ? _act : Activation()), !first_pass,
+                            (last_pass ? _act : Activation()), (!first_pass || _accumulate),
                             // Pass in quantization parameters for requantizing kernels (others will ignore)
                             _os, col_bias + (multi * _Nsize),
                             // Accumulation buffer
@@ -971,7 +972,7 @@
                             // Only do bias on the first pass
                             ((first_pass && this->_bias) ? this->_bias + (current.multi() * this->_bias_multi_stride) : nullptr),
                             // Only do activation on the last pass, and accumulation on any non-first pass.
-                            (last_pass ? _act : Activation()), !first_pass,
+                            (last_pass ? _act : Activation()), (!first_pass || _accumulate),
                             // Pass in quantization parameters for requantizing kernels (others will ignore)
                             _os, col_bias + (current.multi() * _Nsize),
                             // Accumulation buffer
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp
index af5cfbb..dfacb68 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2020, 2022-2023 Arm Limited.
+ * Copyright (c) 2017-2020, 2022-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -94,14 +94,14 @@
 {
     GemmMethod::GEMM_HYBRID,
     "a64_smallK_hybrid_u8u32_dot_8x4",
-    [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize<=32) && !args._indirect_input; },
+    [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize<=32) && !args._indirect_input && !args._accumulate; },
     [](const GemmArgs &args) { return !(args._ci->has_svei8mm() || args._ci->has_i8mm()); },
     [](const GemmArgs &args) { return new GemmHybrid<cls_a64_smallK_hybrid_u8u32_dot_8x4, uint8_t, uint32_t>(args); }
 },
 {
     GemmMethod::GEMM_HYBRID,
     "a64_smallK_hybrid_u8u32_dot_6x4",
-    [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize>32) && (args._Ksize<=64) && !args._indirect_input; },
+    [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize>32) && (args._Ksize<=64) && !args._indirect_input && !args._accumulate; },
     [](const GemmArgs &args) { return !(args._ci->has_svei8mm() || args._ci->has_i8mm()); },
     [](const GemmArgs &args) { return new GemmHybrid<cls_a64_smallK_hybrid_u8u32_dot_6x4, uint8_t, uint32_t>(args); }
 },
diff --git a/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp b/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp
index 92c884c..dbada36 100644
--- a/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp
@@ -180,7 +180,7 @@
                                  this->_Cptr + (multi * this->_C_multi_stride) + n,
                                  (nmax - n), (kmax-k0),
                                  this->_bias ? this->_bias + (multi * this->_bias_multi_stride) + n : nullptr,
-                                 _args._act, (k0 != 0),
+                                 _args._act, (k0 != 0) || _args._accumulate,
                                  _os, col_bias, n + (_args._Nsize * multi));
                 }
             }
diff --git a/src/cpu/kernels/assembly/arm_gemm.hpp b/src/cpu/kernels/assembly/arm_gemm.hpp
index 9a913c5..5d7cf79 100644
--- a/src/cpu/kernels/assembly/arm_gemm.hpp
+++ b/src/cpu/kernels/assembly/arm_gemm.hpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2018-2022 Arm Limited.
+ * Copyright (c) 2018-2022, 2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -21,6 +21,10 @@
  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  * SOFTWARE.
  */
+
+#ifndef ACL_SRC_CPU_KERNELS_ASSEMBLY_ARM_GEMM_HPP
+#define ACL_SRC_CPU_KERNELS_ASSEMBLY_ARM_GEMM_HPP
+
 #pragma once
 
 #include "arm_gemm_local.hpp"
@@ -151,6 +155,7 @@
     int               _maxthreads;
     bool              _fixed_format;
     bool              _fast_mode;
+    bool              _accumulate;
     const GemmConfig *_cfg;
 
     GemmArgs(const CPUInfo    *ci,
@@ -165,6 +170,7 @@
              const int         maxthreads,
              bool              fixed_format = false,
              bool              fast_mode    = false,
+             bool              accumulate   = false,
              const GemmConfig *cfg          = nullptr)
         : _ci(ci),
           _Msize(M),
@@ -178,6 +184,7 @@
           _maxthreads(maxthreads),
           _fixed_format(fixed_format),
           _fast_mode(fast_mode),
+          _accumulate(accumulate),
           _cfg(cfg)
     {
     }
@@ -278,3 +285,5 @@
 bool has_opt_gemm(WeightFormat &weight_format, const GemmArgs &args, const OutputStage & = {});
 
 } // namespace arm_gemm
+
+#endif // ACL_SRC_CPU_KERNELS_ASSEMBLY_ARM_GEMM_HPP
diff --git a/src/cpu/operators/CpuGemm.cpp b/src/cpu/operators/CpuGemm.cpp
index e035de0..905e86c 100644
--- a/src/cpu/operators/CpuGemm.cpp
+++ b/src/cpu/operators/CpuGemm.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021-2023 Arm Limited.
+ * Copyright (c) 2021-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -53,6 +53,7 @@
     asm_info.fast_mode               = info.fast_math();
     asm_info.fixed_format            = info.fixed_format();
     asm_info.weight_format           = info.weight_format();
+    asm_info.accumulate              = info.accumulate();
     asm_info.transpose_b =
         info.pretranspose_B(); // The "pretranspose_B" flag here is not the same as the pretranspose_B_array method. The flag here signals to pretranspose_B_array method if we want to perform additional transpose on B before the pretranspose_B_array method
 
@@ -219,6 +220,16 @@
                          const GEMMInfo    &gemm_info)
 {
     ARM_COMPUTE_UNUSED(alpha);
+    // When using accumulation(in place summation), for now, the only supported values for alpha and beta are 1 respectively 0.
+    // Do the appropriate checks before proceeding.
+    if (gemm_info.accumulate())
+    {
+        ARM_COMPUTE_RETURN_ERROR_ON_MSG(alpha != 1, "Accumulation is not supported when alpha is different from 1");
+        ARM_COMPUTE_RETURN_ERROR_ON_MSG(
+            (beta != 0 && c != nullptr),
+            "Accumulation is not supported when beta is different from 0 with a non-null bias matrix c");
+    }
+
     const bool is_c_bias    = beta == 1 && c != nullptr;
     const bool run_addition = c != nullptr && beta != 0 && beta != 1;
     // Check if we should use the pretransposed_b or original b
diff --git a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp
index b25505a..94e86c6 100644
--- a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp
+++ b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2021-2023 Arm Limited.
+ * Copyright (c) 2021-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -65,6 +65,7 @@
     asm_info.activation_info         = info.activation_info();
     asm_info.output_stage            = info.gemmlowp_output_stage();
     asm_info.fast_mode               = info.fast_math();
+    asm_info.accumulate              = info.accumulate();
 
     return asm_info;
 }
@@ -343,6 +344,13 @@
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_a_reshaped(), "Matrix A already reshaped is not supported");
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_b_reshaped(), "Matrix B already reshaped is not supported");
 
+    // When using accumulation(in place summation), for now, the only supported DataType for output is S32.
+    if (gemm_info.accumulate())
+    {
+        ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.gemmlowp_output_stage().type != GEMMLowpOutputStageType::NONE,
+                                        "Accumulation is not supported for output QASYMM8/QASYMM8_SIGNED");
+    }
+
     GEMMInfo           info          = gemm_info;
     const ITensorInfo *matrix_a_info = a;
     const ITensorInfo *matrix_b_info = b;
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
index efe2a7a..01a74a5 100644
--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
@@ -775,7 +775,7 @@
     arm_gemm::GemmConfig cfg;
     cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format);
     arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads,
-                            info.fixed_format, info.fast_mode, &cfg);
+                            info.fixed_format, info.fast_mode, info.accumulate, &cfg);
 
     // Create arm_gemm fallback
     auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput>>();
@@ -800,7 +800,7 @@
     arm_gemm::GemmConfig cfg;
     cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format);
     arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads,
-                            info.fixed_format, info.fast_mode, &cfg);
+                            info.fixed_format, info.fast_mode, info.accumulate, &cfg);
 
     // Create arm_gemm fallback
     auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput, arm_gemm::Requantize32>>();
@@ -855,8 +855,7 @@
     cfg.weight_format                           = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format);
     arm_gemm::WeightFormat arm_gemm_expected_wf = assembly_utils::map_to_arm_gemm_weight_format(expected_weight_format);
     arm_gemm::GemmArgs     args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, act, num_threads,
-                                info.fixed_format, info.fast_mode, &cfg);
-
+                                info.fixed_format, info.fast_mode, info.accumulate, &cfg);
     // TODO: Incorporate info.transpose_b COMPMID-6595
     switch (a->data_type())
     {
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h
index 671a222..44c5c18 100644
--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h
+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2018-2023 Arm Limited.
+ * Copyright (c) 2018-2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -57,6 +57,7 @@
     bool                      fixed_format{false};
     arm_compute::WeightFormat weight_format{arm_compute::WeightFormat::UNSPECIFIED};
     bool                      reshape_b_only_on_first_run{true};
+    bool                      accumulate{false};
     /** Whether we want to perform an additional transpose of b before passing it to gemm or pretranspose_B_array
      * @note This transpose b operation is also considered a form of "reshape" or "transform", so should be counted for
      *       by the reshape_b_only_on_first_run flag