Fix leftover cols in CpuGemmLowpMatrixBReductionKernel

CpuGemmLowpMatrixBReductionKernel::run_internal randomly segfaults
because it reads out of bounds with vloadq. This doesn't trigger with
the unit tests because the read isn't out of bounds for the process, but
it can be seen clearly by running the following in debug mode

./examples/neon_gemm_qasymm8 1 1 1

The vloadq at src/cpu/kernels/CpuGemmLowpMatrixReductionKernel.cpp:353
accesses a quadword even though the input is a single byte.

relates to: ONCPUML-1444 MLINFSW-439 COMPMID-6844

Change-Id: I2ae5260c9f38d6d8149a6bcd5dc146b911209784
Signed-off-by: Jonathan Deakin <jonathan.deakin@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/10966
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Jakub Sujak <jakub.sujak@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/cpu/kernels/CpuGemmLowpMatrixReductionKernel.cpp b/src/cpu/kernels/CpuGemmLowpMatrixReductionKernel.cpp
index 9bd1eae..9a099bd 100644
--- a/src/cpu/kernels/CpuGemmLowpMatrixReductionKernel.cpp
+++ b/src/cpu/kernels/CpuGemmLowpMatrixReductionKernel.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017-2021 Arm Limited.
+ * Copyright (c) 2017-2021,2024 Arm Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -295,6 +295,7 @@
             }
 
             // Note: Since the input is unsigned char, we can safely use unsigned int for the accumulation
+            // 4 x u/int32x4_t = 16 column accumulators
             typename wrapper::traits::neon_bitvector<TAcc, wrapper::traits::BitWidth::W128>::type sum_col[4] = {
                 wrapper::vdup_n(static_cast<TAcc>(0), wrapper::traits::vector_128_tag{}),
                 wrapper::vdup_n(static_cast<TAcc>(0), wrapper::traits::vector_128_tag{}),
@@ -308,61 +309,91 @@
             asm volatile("PLD [%0, #128*4]" ::"r"(matrix_b + in_b_stride));
 #endif /* __arm__ */
 
-            int i = 0;
-            // This for loop performs 4 accumulations
-            for (; i <= (_k - 4); i += 4)
+            // If we have less than 16 columns left, we can't use the main unrolled loop
+            if ((width_matrix_b - id.x()) >= 16)
             {
-                const auto b0_u8 = wrapper::vloadq(matrix_b + 0 * in_b_stride);
-                const auto b1_u8 = wrapper::vloadq(matrix_b + 1 * in_b_stride);
-                const auto b2_u8 = wrapper::vloadq(matrix_b + 2 * in_b_stride);
-                const auto b3_u8 = wrapper::vloadq(matrix_b + 3 * in_b_stride);
+                // Row index
+                int i = 0;
+                // 4 x u/int32x4_t = 16 columns unrolled across 4 rows
+                for (; i <= (_k - 4); i += 4)
+                {
+                    // Load 4 rows of 16 columns of 8bit elements
+                    // (|                   |        )
+                    // (|                   |        )
+                    // (|                   |        )
+                    // (|                   |        )
+                    const auto b0_u8 = wrapper::vloadq(matrix_b + 0 * in_b_stride);
+                    const auto b1_u8 = wrapper::vloadq(matrix_b + 1 * in_b_stride);
+                    const auto b2_u8 = wrapper::vloadq(matrix_b + 2 * in_b_stride);
+                    const auto b3_u8 = wrapper::vloadq(matrix_b + 3 * in_b_stride);
 
 #if __arm__
-                asm volatile("PLD [%0, #128*1]" ::"r"(matrix_b + 1 * in_b_stride));
-                asm volatile("PLD [%0, #128*1]" ::"r"(matrix_b + 2 * in_b_stride));
-                asm volatile("PLD [%0, #128*1]" ::"r"(matrix_b + 3 * in_b_stride));
-                asm volatile("PLD [%0, #128*1]" ::"r"(matrix_b + 4 * in_b_stride));
+                    asm volatile("PLD [%0, #128*1]" ::"r"(matrix_b + 1 * in_b_stride));
+                    asm volatile("PLD [%0, #128*1]" ::"r"(matrix_b + 2 * in_b_stride));
+                    asm volatile("PLD [%0, #128*1]" ::"r"(matrix_b + 3 * in_b_stride));
+                    asm volatile("PLD [%0, #128*1]" ::"r"(matrix_b + 4 * in_b_stride));
 #endif /* __arm__ */
 
-                // Partial accumulation in 16bit
-                typename wrapper::traits::neon_bitvector<TIAcc, wrapper::traits::BitWidth::W128>::type tmp_sum[2] = {
-                    wrapper::vdup_n(static_cast<TIAcc>(0), wrapper::traits::vector_128_tag{}),
-                    wrapper::vdup_n(static_cast<TIAcc>(0), wrapper::traits::vector_128_tag{})};
+                    // Partial accumulation to 16bit (4 rows => 2 rows)
+                    // (|         |         |        )
+                    // (|         |         |        )
+                    typename wrapper::traits::neon_bitvector<TIAcc, wrapper::traits::BitWidth::W128>::type tmp_sum[2] =
+                        {wrapper::vdup_n(static_cast<TIAcc>(0), wrapper::traits::vector_128_tag{}),
+                         wrapper::vdup_n(static_cast<TIAcc>(0), wrapper::traits::vector_128_tag{})};
 
-                tmp_sum[0] = wrapper::vaddw(tmp_sum[0], wrapper::vgetlow(b1_u8));
-                tmp_sum[0] = wrapper::vaddw(tmp_sum[0], wrapper::vgetlow(b0_u8));
-                tmp_sum[0] = wrapper::vaddw(tmp_sum[0], wrapper::vgetlow(b2_u8));
-                tmp_sum[0] = wrapper::vaddw(tmp_sum[0], wrapper::vgetlow(b3_u8));
-                tmp_sum[1] = wrapper::vaddw(tmp_sum[1], wrapper::vgethigh(b0_u8));
-                tmp_sum[1] = wrapper::vaddw(tmp_sum[1], wrapper::vgethigh(b1_u8));
-                tmp_sum[1] = wrapper::vaddw(tmp_sum[1], wrapper::vgethigh(b2_u8));
-                tmp_sum[1] = wrapper::vaddw(tmp_sum[1], wrapper::vgethigh(b3_u8));
+                    tmp_sum[0] = wrapper::vaddw(tmp_sum[0], wrapper::vgetlow(b1_u8));
+                    tmp_sum[0] = wrapper::vaddw(tmp_sum[0], wrapper::vgetlow(b0_u8));
+                    tmp_sum[0] = wrapper::vaddw(tmp_sum[0], wrapper::vgetlow(b2_u8));
+                    tmp_sum[0] = wrapper::vaddw(tmp_sum[0], wrapper::vgetlow(b3_u8));
+                    tmp_sum[1] = wrapper::vaddw(tmp_sum[1], wrapper::vgethigh(b0_u8));
+                    tmp_sum[1] = wrapper::vaddw(tmp_sum[1], wrapper::vgethigh(b1_u8));
+                    tmp_sum[1] = wrapper::vaddw(tmp_sum[1], wrapper::vgethigh(b2_u8));
+                    tmp_sum[1] = wrapper::vaddw(tmp_sum[1], wrapper::vgethigh(b3_u8));
 
-                // Accumulate to 32bit
-                sum_col[0] = wrapper::vaddw(sum_col[0], wrapper::vgetlow(tmp_sum[0]));
-                sum_col[1] = wrapper::vaddw(sum_col[1], wrapper::vgethigh(tmp_sum[0]));
-                sum_col[2] = wrapper::vaddw(sum_col[2], wrapper::vgetlow(tmp_sum[1]));
-                sum_col[3] = wrapper::vaddw(sum_col[3], wrapper::vgethigh(tmp_sum[1]));
+                    // Accumulate to 32bit (2 rows => 1 row)
+                    // (|    |    |    |    |        )
+                    sum_col[0] = wrapper::vaddw(sum_col[0], wrapper::vgetlow(tmp_sum[0]));
+                    sum_col[1] = wrapper::vaddw(sum_col[1], wrapper::vgethigh(tmp_sum[0]));
+                    sum_col[2] = wrapper::vaddw(sum_col[2], wrapper::vgetlow(tmp_sum[1]));
+                    sum_col[3] = wrapper::vaddw(sum_col[3], wrapper::vgethigh(tmp_sum[1]));
 
-                matrix_b += 4 * in_b_stride;
+                    matrix_b += 4 * in_b_stride;
+                }
+
+                // This for loop accumulates the rows left over from the 4x unrolling above
+                for (; i < _k; ++i)
+                {
+                    const auto b0_b8 = wrapper::vloadq(matrix_b + 0 * in_b_stride);
+
+                    // Convert 8bit => 16bit
+                    const typename wrapper::traits::neon_bitvector<TIAcc, wrapper::traits::BitWidth::W128>::type
+                        b0_b16[2]{wrapper::vmovl(wrapper::vgetlow(b0_b8)), wrapper::vmovl(wrapper::vgethigh(b0_b8))};
+
+                    // Accumulate to 32bit
+                    sum_col[0] = wrapper::vaddw(sum_col[0], wrapper::vgetlow(b0_b16[0]));
+                    sum_col[1] = wrapper::vaddw(sum_col[1], wrapper::vgethigh(b0_b16[0]));
+                    sum_col[2] = wrapper::vaddw(sum_col[2], wrapper::vgetlow(b0_b16[1]));
+                    sum_col[3] = wrapper::vaddw(sum_col[3], wrapper::vgethigh(b0_b16[1]));
+
+                    matrix_b += in_b_stride;
+                }
             }
-
-            // This for loop perfoms the leftover accumulations
-            for (; i < _k; ++i)
+            else
             {
-                const auto b0_b8 = wrapper::vloadq(matrix_b + 0 * in_b_stride);
-
-                // Convert S8 to S16
-                const typename wrapper::traits::neon_bitvector<TIAcc, wrapper::traits::BitWidth::W128>::type b0_b16[2]{
-                    wrapper::vmovl(wrapper::vgetlow(b0_b8)), wrapper::vmovl(wrapper::vgethigh(b0_b8))};
-
-                // Accumulate to 32bit
-                sum_col[0] = wrapper::vaddw(sum_col[0], wrapper::vgetlow(b0_b16[0]));
-                sum_col[1] = wrapper::vaddw(sum_col[1], wrapper::vgethigh(b0_b16[0]));
-                sum_col[2] = wrapper::vaddw(sum_col[2], wrapper::vgetlow(b0_b16[1]));
-                sum_col[3] = wrapper::vaddw(sum_col[3], wrapper::vgethigh(b0_b16[1]));
-
-                matrix_b += in_b_stride;
+                // Accumulate left over columns to sum_cols
+                for (int i = 0; i < _k; ++i) // row loop
+                {
+                    auto left_over_cols = width_matrix_b - id.x();
+                    auto l              = left_over_cols;
+                    for (auto k = 0; k < 4 && l; ++k)
+                    {
+                        for (auto j = 0; j < 4 && l; ++j, --l)
+                        {
+                            sum_col[k][j] += matrix_b[left_over_cols - l];
+                        }
+                    }
+                    matrix_b += in_b_stride;
+                }
             }
 
             // Multiply by scalar if necessary
@@ -375,7 +406,7 @@
             }
 
             auto vector_sum_col = reinterpret_cast<int32_t *>(out.ptr());
-            if (id.x() + 16 < width_matrix_b)
+            if ((width_matrix_b - id.x()) >= 16)
             {
                 wrapper::vstore(vector_sum_col + 0, wrapper::vreinterpret(sum_col[0]));
                 wrapper::vstore(vector_sum_col + 4, wrapper::vreinterpret(sum_col[1]));