COMPMID-1054 Update RSH's GEMM to add batch+multi support

Change-Id: Ib9d91b77f1d51976da4449fa1e6eeeffae307353
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/127876
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Pablo Tello <pablo.tello@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
diff --git a/arm_compute/runtime/NEON/AssemblyHelper.h b/arm_compute/runtime/NEON/AssemblyHelper.h
index 40f2858..2b4f35f 100644
--- a/arm_compute/runtime/NEON/AssemblyHelper.h
+++ b/arm_compute/runtime/NEON/AssemblyHelper.h
@@ -82,24 +82,19 @@
         const int ldb = _b->info()->strides_in_bytes().y() / sizeof(TypeInput);
         const int ldd = _d->info()->strides_in_bytes().y() / sizeof(TypeOutput);
 
-        // Configure kernel window
-        Window     window  = calculate_max_window(*_d->info());
-        const auto in1_ptr = reinterpret_cast<const TypeInput *>(_b->buffer());
+        const int batch_stride_a = _a->info()->strides_in_bytes().z() / sizeof(TypeInput);
+        const int batch_stride_d = _d->info()->strides_in_bytes().z() / sizeof(TypeOutput);
 
-        // Only iterate over batches
-        Window win(window);
-        win.set(0, Window::Dimension(0, 1, 1));
-        win.set(1, Window::Dimension(0, 1, 1));
-        Iterator in0(_a, window);
-        Iterator out(_d, window);
-        execute_window_loop(win, [&](const Coordinates &)
-        {
-            const auto in0_ptr = reinterpret_cast<const TypeInput *>(in0.ptr());
-            auto       out_ptr = reinterpret_cast<TypeOutput *>(out.ptr());
-            _gemm_kernel_asm->set_arrays(in0_ptr, lda, in1_ptr, ldb, out_ptr, ldd);
-            NEScheduler::get().schedule(_optimised_kernel.get(), Window::DimX);
-        },
-        in0, out);
+        const int multi_stride_a = _a->info()->strides_in_bytes()[3] / sizeof(TypeInput);
+        const int multi_stride_b = _b->info()->strides_in_bytes().z() / sizeof(TypeInput);
+        const int multi_stride_d = _d->info()->strides_in_bytes()[3] / sizeof(TypeOutput);
+
+        const auto in0_ptr = reinterpret_cast<const TypeInput *>(_a->buffer());
+        const auto in1_ptr = reinterpret_cast<const TypeInput *>(_b->buffer());
+        auto       out_ptr = reinterpret_cast<TypeOutput *>(_d->buffer());
+
+        _gemm_kernel_asm->set_arrays(in0_ptr, lda, batch_stride_a, multi_stride_a, in1_ptr, ldb, multi_stride_b, out_ptr, ldd, batch_stride_d, multi_stride_d);
+        NEScheduler::get().schedule(_optimised_kernel.get(), Window::DimX);
     }
 };
 
@@ -146,10 +141,13 @@
     const int      M           = d->info()->tensor_shape().y();
     const int      N           = d->info()->tensor_shape().x();
     const int      K           = a->info()->tensor_shape().x();
+    const int      batches     = a->info()->tensor_shape().total_size_upper(2);
+    const int      multis      = b->info()->tensor_shape().z();
     unsigned int   num_threads = NEScheduler::get().num_threads();
+
     // unique_ptr to a Gemm object
     std::unique_ptr<typename T::AssemblyGemm>
-    asm_gemm(arm_gemm::gemm<typename T::TypeOperator, typename T::TypeResult>(ci, M, N, K, false, false, alpha, beta, num_threads, false));
+    asm_gemm(arm_gemm::gemm<typename T::TypeOperator, typename T::TypeResult>(ci, M, N, K, batches, multis, false, false, alpha, beta, num_threads, false));
     // arm_compute wrapper for the Gemm object (see above)
     std::unique_ptr<NEGEMMAssemblyWrapper<typename T::AssemblyGemm>>
                                                                   acl_gemm_wrapper = support::cpp14::make_unique<NEGEMMAssemblyWrapper<typename T::AssemblyGemm>>();