COMPMID-1054 Update RSH's GEMM to add batch+multi support

Change-Id: Ib9d91b77f1d51976da4449fa1e6eeeffae307353
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/127876
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Pablo Tello <pablo.tello@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
diff --git a/src/core/NEON/kernels/arm_gemm/gemv_native_transposed.hpp b/src/core/NEON/kernels/arm_gemm/gemv_native_transposed.hpp
index 29c71f2..e5cc79e 100644
--- a/src/core/NEON/kernels/arm_gemm/gemv_native_transposed.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemv_native_transposed.hpp
@@ -28,9 +28,12 @@
 #include "arm_gemm.hpp"
 
 #include "mergeresults.hpp"
-#include "profiler.hpp"
 #include "transform.hpp"
 
+#ifdef CYCLE_PROFILING
+#include "profiler.hpp"
+#endif
+
 namespace arm_gemm
 {
 // Implementation of the GemmCommon abstract class.
@@ -48,6 +51,7 @@
 
     const unsigned int _Nsize;
     const unsigned int _Ksize;
+    const unsigned int _nmultis;
 
     const Tr _beta;
 
@@ -60,45 +64,61 @@
     GemvNativeTransposed(GemvNativeTransposed &) = delete;
     GemvNativeTransposed &operator=(GemvNativeTransposed &) = delete;
 
-    GemvNativeTransposed(const CPUInfo *ci, const unsigned int N, const unsigned int K, const Tr beta)
-        : _Nsize(N), _Ksize(K), _beta(beta), _ci(ci)
+    GemvNativeTransposed(const CPUInfo *ci, const unsigned int N, const unsigned int K, const unsigned int nmultis, const Tr beta)
+        : _Nsize(N), _Ksize(K), _nmultis(nmultis), _beta(beta), _ci(ci)
     {
         /* For now don't do any blocking. TODO: figure out if we should. */
         m_block = K;
         n_block = N;
     }
 
-    // Window is number of out_width blocks.
+    // Window is number of out_width blocks times number of multis.
     unsigned int get_window_size() const override
     {
-        return iceildiv(_Nsize, strategy::out_width);
+        return iceildiv(_Nsize, strategy::out_width) * _nmultis;
     }
 
     // Actually execute the GEMV.
     void execute(unsigned int start, unsigned int end, int) override
     {
+#ifdef CYCLE_PROFILING
         profiler prof;
+#endif
+
         strategy strat(_ci);
 
-        unsigned int N_start = start * strategy::out_width;
-        unsigned int N_end   = std::min(end * strategy::out_width, _Nsize);
+        const unsigned int window_per_multi = iceildiv(_Nsize, strategy::out_width);
+        const unsigned int multi_0          = start / window_per_multi;
+        const unsigned int multi_end        = end / window_per_multi;
+
+        const unsigned int n_0   = (start - (multi_0 * window_per_multi)) * strategy::out_width;
+        const unsigned int n_max = (end - (multi_end * window_per_multi)) * strategy::out_width;
 
         static_assert(std::is_same<To, Toi>::value, "gemv_transposed: Operand types must be the same.");
         static_assert(std::is_same<Tr, Tri>::value, "gemv_transposed: Result types must be the same.");
 
-        for(unsigned int m0 = 0; m0 < _Ksize; m0 += m_block)
+        for(unsigned int multi = multi_0; multi <= multi_end; multi++)
         {
-            unsigned int mmax = std::min(m0 + m_block, _Ksize);
+            const unsigned int n_start = (multi == multi_0) ? n_0 : 0;
+            const unsigned int n_end   = (multi == multi_end) ? n_max : _Nsize;
 
-            for(unsigned int n0 = N_start; n0 < N_end; n0 += n_block)
+            if(n_end <= n_start)
+                continue;
+
+            for(unsigned int m0 = 0; m0 < _Ksize; m0 += m_block)
             {
-                unsigned int nmax = std::min(n0 + n_block, N_end);
-
-                prof(PROFILE_KERNEL, ((mmax - m0) * (nmax - n0)), [&](void)
+                unsigned int mmax = std::min(m0 + m_block, _Ksize);
+                for(unsigned int n0 = n_start; n0 < n_end; n0 += n_block)
                 {
-                    strat.kernel(this->_Bptr + (m0 * this->_ldb) + n0, this->_Aptr + m0, this->_Cptr + n0,
+                    unsigned int nmax = std::min(n0 + n_block, n_end);
+#ifdef CYCLE_PROFILING
+                    auto p = prof.ScopedProfiler(PROFILE_KERNEL, (mmax - m0) * (nmax - n0));
+#endif
+                    strat.kernel(this->_Bptr + (multi * this->_B_multi_stride) + (m0 * this->_ldb) + n0,
+                                 this->_Aptr + (multi * this->_A_multi_stride) + m0,
+                                 this->_Cptr + (multi * this->_C_multi_stride) + n0,
                                  _beta, this->_ldb, (mmax - m0), (nmax - n0));
-                });
+                }
             }
         }
     }