Fix GEMMLowp/Batched MatMul mismatches on CPU

- Fixes Column Offset matrix is not being iterated through in y dimension

Resolves : COMPMID-5795

Signed-off-by: Mohammed Suhail Munshi <MohammedSuhail.Munshi@arm.com>
Change-Id: I0190474be404b4f0e171855739cfd0a48cbed5bc
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9020
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Gunes Bayir <gunes.bayir@arm.com>
Reviewed-by: SiCong Li <sicong.li@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/cpu/kernels/CpuGemmLowpOffsetContributionOutputStageKernel.cpp b/src/cpu/kernels/CpuGemmLowpOffsetContributionOutputStageKernel.cpp
index 89aa364..c69af55 100644
--- a/src/cpu/kernels/CpuGemmLowpOffsetContributionOutputStageKernel.cpp
+++ b/src/cpu/kernels/CpuGemmLowpOffsetContributionOutputStageKernel.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2019-2021 Arm Limited.
+ * Copyright (c) 2019-2021, 2023 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -475,9 +475,18 @@
 template <typename T>
 void run_offset_contribution_output_stage(const Window &window,
                                           const ITensor *mm_result, const ITensor *vector_sum_col, const ITensor *vector_sum_row, const ITensor *bias, ITensor *output,
-                                          int32_t a_offset, int32_t b_offset, int32_t k_offset, bool slide_vector_sum_col,
+                                          int32_t a_offset, int32_t b_offset, int32_t k_offset, bool is_vector_sum_col_batched,
                                           GEMMLowpOutputStageInfo output_stage, bool is_gemm3d, bool is_bounded_relu, bool is_fixed_point)
 {
+    //  Semantics of XYZW Explained for each tensor
+    //
+    //  | Tensor            |    XYZW when is_gemm3d == false       |    XYZW when is_gemm3d == true                    |
+    // -------------------------------------------------------------------------------------------------------------------
+    //  | mm_result         |  x -> width,  y -> height, z -> batch |  x -> width, y -> height, z -> depth, w -> batch  |
+    //  | collapsed window  |  x -> width,  y -> height, z -> batch |  x -> width, y -> height, z -> depth * batch      |
+    //  | vector_sum_row    |  x -> height, y -> batch              |  x -> height * depth, y -> batch                  |
+    //  | Vector_sum_col    |  x -> width,  y -> batch              |  x -> width, y -> batch                           |
+
     using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>;
     using Typer        = VectorTyper<T>;
 
@@ -517,8 +526,8 @@
 
         const size_t sum_row_stride_y = vector_sum_row->info()->strides_in_bytes().y();
 
-        // Offset in case vector_sum_col is batched
-        const int vector_sum_col_batch_offset = slide_vector_sum_col ? vector_sum_col->info()->strides_in_bytes().z() : 0;
+        // Offset in case vector_sum_col is batched in y dimension
+        const int vector_sum_col_stride_batch = is_vector_sum_col_batched ? vector_sum_col->info()->strides_in_bytes().y() : 0;
 
         if(bias != nullptr)
         {
@@ -526,7 +535,7 @@
             execute_window_loop(collapsed_window, [&](const Coordinates & id)
             {
                 const int  batch_id           = id.z() / depth_input;
-                const auto vector_sum_col_ptr = reinterpret_cast<const int32_t *>(vector_sum_col_it.ptr() + batch_id * vector_sum_col_batch_offset);
+                const auto vector_sum_col_ptr = reinterpret_cast<const int32_t *>(vector_sum_col_it.ptr() + batch_id * vector_sum_col_stride_batch);
                 const auto vector_sum_row_ptr = reinterpret_cast<const int32_t *>(vector_sum_row_it.ptr() + batch_id * sum_row_stride_y)
                                                 + id.y() + (id.z() % depth_input) * height_input;
                 run_offset_contribution_output_stage_window<Typer>(vector_sum_col_ptr, vector_sum_row_ptr, reinterpret_cast<const int32_t *>(bias_it.ptr()),
@@ -544,7 +553,7 @@
             execute_window_loop(collapsed_window, [&](const Coordinates & id)
             {
                 const int  batch_id           = id.z() / depth_input;
-                const auto vector_sum_col_ptr = reinterpret_cast<const int32_t *>(vector_sum_col_it.ptr() + batch_id * vector_sum_col_batch_offset);
+                const auto vector_sum_col_ptr = reinterpret_cast<const int32_t *>(vector_sum_col_it.ptr() + batch_id * vector_sum_col_stride_batch);
                 const auto vector_sum_row_ptr = reinterpret_cast<const int32_t *>(vector_sum_row_it.ptr() + batch_id * sum_row_stride_y)
                                                 + id.y() + (id.z() % depth_input) * height_input;
                 run_offset_contribution_output_stage_window<Typer>(vector_sum_col_ptr, vector_sum_row_ptr, nullptr, mm_result_it, out_it,
@@ -603,8 +612,8 @@
 
         Iterator vector_sum_col_it = get_vector_sum_col_it(collapsed_window, vector_sum_col);
 
-        // Offset in case vector_sum_col is batched
-        const int vector_sum_col_batch_offset = slide_vector_sum_col ? vector_sum_col->info()->strides_in_bytes().z() : 0;
+        // Offset in case vector_sum_col is batched in y dimension
+        const int vector_sum_col_stride_batch = is_vector_sum_col_batched ? vector_sum_col->info()->strides_in_bytes().y() : 0;
 
         if(bias != nullptr)
         {
@@ -612,7 +621,7 @@
             execute_window_loop(collapsed_window, [&](const Coordinates & id)
             {
                 const int  batch_id           = id.z() / depth_input;
-                const auto vector_sum_col_ptr = reinterpret_cast<const int32_t *>(vector_sum_col_it.ptr() + batch_id * vector_sum_col_batch_offset);
+                const auto vector_sum_col_ptr = reinterpret_cast<const int32_t *>(vector_sum_col_it.ptr() + batch_id * vector_sum_col_stride_batch);
                 run_offset_contribution_output_stage_window<Typer>(vector_sum_col_ptr, nullptr, reinterpret_cast<const int32_t *>(bias_it.ptr()), mm_result_it,
                                                                    out_it,
                                                                    result_offset_s32, result_shift_s32,
@@ -627,7 +636,7 @@
             execute_window_loop(collapsed_window, [&](const Coordinates & id)
             {
                 const int  batch_id           = id.z() / depth_input;
-                const auto vector_sum_col_ptr = reinterpret_cast<const int32_t *>(vector_sum_col_it.ptr() + batch_id * vector_sum_col_batch_offset);
+                const auto vector_sum_col_ptr = reinterpret_cast<const int32_t *>(vector_sum_col_it.ptr() + batch_id * vector_sum_col_stride_batch);
                 run_offset_contribution_output_stage_window<Typer>(vector_sum_col_ptr, nullptr, nullptr, mm_result_it, out_it,
                                                                    result_offset_s32, result_shift_s32,
                                                                    min_vec, max_vec, a_offset, b_offset, k_offset,
@@ -670,7 +679,7 @@
 
 void run_offset_contribution_output_stage_symm(const Window &window,
                                                const ITensor *mm_result, const ITensor *vector_sum_col, const ITensor *vector_sum_row, const ITensor *bias, ITensor *output,
-                                               int32_t a_offset, int32_t b_offset, int32_t k_offset, bool slide_vector_sum_col,
+                                               int32_t a_offset, int32_t b_offset, int32_t k_offset, bool is_vector_sum_col_batched,
                                                GEMMLowpOutputStageInfo output_stage, bool is_gemm3d, bool is_bounded_relu, bool is_fixed_point)
 {
     ARM_COMPUTE_UNUSED(vector_sum_row, b_offset, k_offset);
@@ -705,8 +714,8 @@
 
         Iterator vector_sum_col_it = get_vector_sum_col_it(collapsed_window, vector_sum_col);
 
-        // Offset in case vector_sum_col is batched
-        const int vector_sum_col_batch_offset = slide_vector_sum_col ? vector_sum_col->info()->strides_in_bytes().z() : 0;
+        // Offset in case vector_sum_col is batched in y dimension
+        const int vector_sum_col_stride_batch = is_vector_sum_col_batched ? vector_sum_col->info()->strides_in_bytes().y() : 0;
 
         if(bias != nullptr)
         {
@@ -714,7 +723,7 @@
             execute_window_loop(collapsed_window, [&](const Coordinates & id)
             {
                 const int  batch_id           = id.z() / depth_input;
-                const auto vector_sum_col_ptr = reinterpret_cast<const int32_t *>(vector_sum_col_it.ptr() + batch_id * vector_sum_col_batch_offset);
+                const auto vector_sum_col_ptr = reinterpret_cast<const int32_t *>(vector_sum_col_it.ptr() + batch_id * vector_sum_col_stride_batch);
                 run_offset_contribution_output_stage_window_symm(vector_sum_col_ptr, reinterpret_cast<const int32_t *>(bias_it.ptr()), mm_result_it, out_it,
                                                                  result_multipliers, result_shifts,
                                                                  result_offset_s32, min_s8, max_s8,
@@ -728,7 +737,7 @@
             execute_window_loop(collapsed_window, [&](const Coordinates & id)
             {
                 const int  batch_id           = id.z() / depth_input;
-                const auto vector_sum_col_ptr = reinterpret_cast<const int32_t *>(vector_sum_col_it.ptr() + batch_id * vector_sum_col_batch_offset);
+                const auto vector_sum_col_ptr = reinterpret_cast<const int32_t *>(vector_sum_col_it.ptr() + batch_id * vector_sum_col_stride_batch);
                 run_offset_contribution_output_stage_window_symm(vector_sum_col_ptr, nullptr, mm_result_it, out_it,
                                                                  result_multipliers, result_shifts,
                                                                  result_offset_s32, min_s8, max_s8,
@@ -792,6 +801,7 @@
     {
         ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(vector_sum_col, 1, DataType::S32);
         ARM_COMPUTE_RETURN_ERROR_ON(vector_sum_col->dimension(0) != mm_result->dimension(0));
+        ARM_COMPUTE_RETURN_ERROR_ON(vector_sum_col->num_dimensions() > 2);
     }
 
     // If b_offset == 0, vector_sum_row can be a nullptr
@@ -827,6 +837,9 @@
                                                 "vector_sum_col tensor must have the same number of batches of vector_sum_row_shape or the number of batches must be set to 1");
             }
         }
+
+        // Check Tensor Rank of vector_sum_row
+        ARM_COMPUTE_RETURN_ERROR_ON(vector_sum_row->num_dimensions() > 2);
     }
 
     if(output->total_size() != 0)
@@ -860,7 +873,7 @@
         // Check if vector_sum_col_shape should be slidden or not
         // Don't slide vector_sum_col_shape along the y dimension if vector_sum_col_shape has just 1 dimension and vector_sum_row_shape more than 1
         // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
-        _slide_vector_sum_col = vector_sum_col->tensor_shape().num_dimensions() > 1;
+        _is_vector_sum_col_batched = vector_sum_col->tensor_shape().num_dimensions() > 1;
     }
 
     // Output auto inizialitation if not yet initialized
@@ -919,19 +932,19 @@
 
     if(is_symm)
     {
-        run_offset_contribution_output_stage_symm(window, mm_result, vector_sum_col, vector_sum_row, bias, dst, _a_offset, _b_offset, _k_offset, _slide_vector_sum_col, _output_stage,
+        run_offset_contribution_output_stage_symm(window, mm_result, vector_sum_col, vector_sum_row, bias, dst, _a_offset, _b_offset, _k_offset, _is_vector_sum_col_batched, _output_stage,
                                                   reinterpret_as_3d, is_bounded_relu, is_fixed_point);
     }
     else
     {
         if(is_signed)
         {
-            run_offset_contribution_output_stage<int8_t>(window, mm_result, vector_sum_col, vector_sum_row, bias, dst, _a_offset, _b_offset, _k_offset, _slide_vector_sum_col, _output_stage,
+            run_offset_contribution_output_stage<int8_t>(window, mm_result, vector_sum_col, vector_sum_row, bias, dst, _a_offset, _b_offset, _k_offset, _is_vector_sum_col_batched, _output_stage,
                                                          reinterpret_as_3d, is_bounded_relu, is_fixed_point);
         }
         else
         {
-            run_offset_contribution_output_stage<uint8_t>(window, mm_result, vector_sum_col, vector_sum_row, bias, dst, _a_offset, _b_offset, _k_offset, _slide_vector_sum_col, _output_stage,
+            run_offset_contribution_output_stage<uint8_t>(window, mm_result, vector_sum_col, vector_sum_row, bias, dst, _a_offset, _b_offset, _k_offset, _is_vector_sum_col_batched, _output_stage,
                                                           reinterpret_as_3d, is_bounded_relu, is_fixed_point);
         }
     }