COMPMID-3170: Remove padding in NEGEMMLowpMatrixMultiplyKernel

Change-Id: Ie95442c6c6a145c1a45937b03cbd433bf08e36ab
Signed-off-by: morgolock <pablo.tello@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4094
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/core/NEON/kernels/NEGEMMLowpMatrixMultiplyKernel.cpp b/src/core/NEON/kernels/NEGEMMLowpMatrixMultiplyKernel.cpp
index c5d7f10..f3ba290 100644
--- a/src/core/NEON/kernels/NEGEMMLowpMatrixMultiplyKernel.cpp
+++ b/src/core/NEON/kernels/NEGEMMLowpMatrixMultiplyKernel.cpp
@@ -23,7 +23,6 @@
  */
 #include "arm_compute/core/NEON/kernels/NEGEMMLowpMatrixMultiplyKernel.h"
 
-#include "arm_compute/core/AccessWindowStatic.h"
 #include "arm_compute/core/Error.h"
 #include "arm_compute/core/Helpers.h"
 #include "arm_compute/core/ITensor.h"
@@ -32,11 +31,7 @@
 #include "arm_compute/core/Utils.h"
 #include "arm_compute/core/Validate.h"
 #include "arm_compute/core/Window.h"
-
 #include <arm_neon.h>
-#include <cstddef>
-#include <cstdint>
-#include <tuple>
 
 using namespace arm_compute;
 
@@ -44,7 +39,7 @@
 {
 namespace
 {
-void inline vector_matrix_multiply_u8(Iterator &ina, Iterator &inb, Iterator &out, int width_a, int width_b, size_t stride_b, const Window &window)
+void inline vector_matrix_multiply_u8(Iterator &ina, Iterator &inb, Iterator &out, int width_a, int width_b, int width_out, size_t stride_b, const Window &window)
 {
     execute_window_loop(window, [&](const Coordinates & id)
     {
@@ -253,15 +248,29 @@
         }
 
         auto vec_out = reinterpret_cast<int32_t *>(out.ptr());
-        vst1q_s32(vec_out + 0, vreinterpretq_s32_u32(c0.val[0]));
-        vst1q_s32(vec_out + 4, vreinterpretq_s32_u32(c0.val[1]));
-        vst1q_s32(vec_out + 8, vreinterpretq_s32_u32(c0.val[2]));
-        vst1q_s32(vec_out + 12, vreinterpretq_s32_u32(c0.val[3]));
+        if(id.x() < (width_out - 16))
+        {
+            vst1q_s32(vec_out + 0, vreinterpretq_s32_u32(c0.val[0]));
+            vst1q_s32(vec_out + 4, vreinterpretq_s32_u32(c0.val[1]));
+            vst1q_s32(vec_out + 8, vreinterpretq_s32_u32(c0.val[2]));
+            vst1q_s32(vec_out + 12, vreinterpretq_s32_u32(c0.val[3]));
+        }
+        else
+        {
+            auto left_over = width_out - id.x();
+            for(auto k = 0; k < 4 && left_over; ++k)
+            {
+                for(auto j = 0; j < 4 && left_over; ++j, --left_over)
+                {
+                    *(vec_out + k * 4 + j) = c0.val[k][j];
+                }
+            }
+        }
     },
     ina, inb, out);
 }
 
-void inline vector_matrix_multiply_s8(Iterator &ina, Iterator &inb, Iterator &out, int width_a, int width_b, size_t stride_b, const Window &window)
+void inline vector_matrix_multiply_s8(Iterator &ina, Iterator &inb, Iterator &out, int width_a, int width_b, int width_out, size_t stride_b, const Window &window)
 {
     execute_window_loop(window, [&](const Coordinates & id)
     {
@@ -469,17 +478,34 @@
         }
 
         auto vec_out = reinterpret_cast<int32_t *>(out.ptr());
-        vst1q_s32(vec_out + 0, c0.val[0]);
-        vst1q_s32(vec_out + 4, c0.val[1]);
-        vst1q_s32(vec_out + 8, c0.val[2]);
-        vst1q_s32(vec_out + 12, c0.val[3]);
+        if(id.x() < (width_out - 16))
+        {
+            vst1q_s32(vec_out + 0, c0.val[0]);
+            vst1q_s32(vec_out + 4, c0.val[1]);
+            vst1q_s32(vec_out + 8, c0.val[2]);
+            vst1q_s32(vec_out + 12, c0.val[3]);
+        }
+        else
+        {
+            auto left_over = width_out - id.x();
+            for(auto k = 0; k < 4 && left_over; ++k)
+            {
+                for(auto j = 0; j < 4 && left_over; ++j, --left_over)
+                {
+                    *(vec_out + k * 4 + j) = c0.val[k][j];
+                }
+            }
+        }
     },
     ina, inb, out);
 }
 
-void inline matrix_multiply_u8(Iterator &ina, Iterator &inb, Iterator &out, int width_b, size_t out_stride, const Window &window)
+void inline matrix_multiply_u8(Iterator &ina, Iterator &inb, Iterator &out, int width_b, const TensorInfo &out_info, const Window &window)
 {
-    execute_window_loop(window, [&](const Coordinates &)
+    const auto   width_out  = static_cast<int>(out_info.dimension(0));
+    const auto   height_out = static_cast<int>(out_info.dimension(1));
+    const size_t out_stride = out_info.strides_in_bytes()[1] / out_info.element_size();
+    execute_window_loop(window, [&](const Coordinates & id)
     {
         const uint8_t *mtx_a0 = ina.ptr();
         const uint8_t *mtx_b0 = inb.ptr();
@@ -574,32 +600,93 @@
         }
 
         auto mtx_out = reinterpret_cast<int32_t *>(out.ptr());
-        vst1q_s32(mtx_out + 0 * out_stride + 0, vreinterpretq_s32_u32(c0.val[0]));
-        vst1q_s32(mtx_out + 0 * out_stride + 4, vreinterpretq_s32_u32(c0.val[1]));
-        vst1q_s32(mtx_out + 0 * out_stride + 8, vreinterpretq_s32_u32(c0.val[2]));
-        vst1q_s32(mtx_out + 0 * out_stride + 12, vreinterpretq_s32_u32(c0.val[3]));
-        vst1q_s32(mtx_out + 1 * out_stride + 0, vreinterpretq_s32_u32(c1.val[0]));
-        vst1q_s32(mtx_out + 1 * out_stride + 4, vreinterpretq_s32_u32(c1.val[1]));
-        vst1q_s32(mtx_out + 1 * out_stride + 8, vreinterpretq_s32_u32(c1.val[2]));
-        vst1q_s32(mtx_out + 1 * out_stride + 12, vreinterpretq_s32_u32(c1.val[3]));
-        vst1q_s32(mtx_out + 2 * out_stride + 0, vreinterpretq_s32_u32(c2.val[0]));
-        vst1q_s32(mtx_out + 2 * out_stride + 4, vreinterpretq_s32_u32(c2.val[1]));
-        vst1q_s32(mtx_out + 2 * out_stride + 8, vreinterpretq_s32_u32(c2.val[2]));
-        vst1q_s32(mtx_out + 2 * out_stride + 12, vreinterpretq_s32_u32(c2.val[3]));
-        vst1q_s32(mtx_out + 3 * out_stride + 0, vreinterpretq_s32_u32(c3.val[0]));
-        vst1q_s32(mtx_out + 3 * out_stride + 4, vreinterpretq_s32_u32(c3.val[1]));
-        vst1q_s32(mtx_out + 3 * out_stride + 8, vreinterpretq_s32_u32(c3.val[2]));
-        vst1q_s32(mtx_out + 3 * out_stride + 12, vreinterpretq_s32_u32(c3.val[3]));
+
+        if(id.y() < height_out && id.x() < (width_out - 16))
+        {
+            vst1q_s32(mtx_out + 0 * out_stride + 0, vreinterpretq_s32_u32(c0.val[0]));
+            vst1q_s32(mtx_out + 0 * out_stride + 4, vreinterpretq_s32_u32(c0.val[1]));
+            vst1q_s32(mtx_out + 0 * out_stride + 8, vreinterpretq_s32_u32(c0.val[2]));
+            vst1q_s32(mtx_out + 0 * out_stride + 12, vreinterpretq_s32_u32(c0.val[3]));
+            if(id.y() + 1 < height_out)
+            {
+                vst1q_s32(mtx_out + 1 * out_stride + 0, vreinterpretq_s32_u32(c1.val[0]));
+                vst1q_s32(mtx_out + 1 * out_stride + 4, vreinterpretq_s32_u32(c1.val[1]));
+                vst1q_s32(mtx_out + 1 * out_stride + 8, vreinterpretq_s32_u32(c1.val[2]));
+                vst1q_s32(mtx_out + 1 * out_stride + 12, vreinterpretq_s32_u32(c1.val[3]));
+                if(id.y() + 2 < height_out)
+                {
+                    vst1q_s32(mtx_out + 2 * out_stride + 0, vreinterpretq_s32_u32(c2.val[0]));
+                    vst1q_s32(mtx_out + 2 * out_stride + 4, vreinterpretq_s32_u32(c2.val[1]));
+                    vst1q_s32(mtx_out + 2 * out_stride + 8, vreinterpretq_s32_u32(c2.val[2]));
+                    vst1q_s32(mtx_out + 2 * out_stride + 12, vreinterpretq_s32_u32(c2.val[3]));
+                    if(id.y() + 3 < height_out)
+                    {
+                        vst1q_s32(mtx_out + 3 * out_stride + 0, vreinterpretq_s32_u32(c3.val[0]));
+                        vst1q_s32(mtx_out + 3 * out_stride + 4, vreinterpretq_s32_u32(c3.val[1]));
+                        vst1q_s32(mtx_out + 3 * out_stride + 8, vreinterpretq_s32_u32(c3.val[2]));
+                        vst1q_s32(mtx_out + 3 * out_stride + 12, vreinterpretq_s32_u32(c3.val[3]));
+                    }
+                }
+            }
+        }
+        else
+        {
+            const auto left_over_value = width_out - id.x();
+            auto       left_over       = left_over_value;
+            for(auto k = 0; k < 4 && left_over; ++k)
+            {
+                for(auto j = 0; j < 4 && left_over; ++j, --left_over)
+                {
+                    *(mtx_out + k * 4 + j) = c0.val[k][j];
+                }
+            }
+            if(id.y() + 1 < height_out)
+            {
+                left_over = left_over_value;
+                for(auto k = 0; k < 4 && left_over; ++k)
+                {
+                    for(auto j = 0; j < 4 && left_over; ++j, --left_over)
+                    {
+                        *(mtx_out + out_stride + k * 4 + j) = c1.val[k][j];
+                    }
+                }
+                if(id.y() + 2 < height_out)
+                {
+                    left_over = left_over_value;
+                    for(auto k = 0; k < 4 && left_over; ++k)
+                    {
+                        for(auto j = 0; j < 4 && left_over; ++j, --left_over)
+                        {
+                            *(mtx_out + out_stride * 2 + k * 4 + j) = c2.val[k][j];
+                        }
+                    }
+                    if(id.y() + 3 < height_out)
+                    {
+                        left_over = left_over_value;
+                        for(auto k = 0; k < 4 && left_over; ++k)
+                        {
+                            for(auto j = 0; j < 4 && left_over; ++j, --left_over)
+                            {
+                                *(mtx_out + out_stride * 3 + k * 4 + j) = c3.val[k][j];
+                            }
+                        }
+                    }
+                }
+            }
+        }
     },
     ina, inb, out);
 }
 
-void inline matrix_multiply_s8(Iterator &ina, Iterator &inb, Iterator &out, int width_b, size_t out_stride, const Window &window)
+void inline matrix_multiply_s8(Iterator &ina, Iterator &inb, Iterator &out, int width_b, const TensorInfo &out_info, const Window &window)
 {
+    const auto   width_out  = static_cast<int>(out_info.dimension(0));
+    const auto   height_out = static_cast<int>(out_info.dimension(1));
+    const size_t out_stride = out_info.strides_in_bytes()[1] / out_info.element_size();
     // The implementation assumes that the matrix A and Matrix B have been reshaped respectively with NEGEMMInterleave4x4 and NEGEMMTranspose1xW
     // The reshaping of the matrices helps to have a cache friendly implementation and helps to avoid the data re-arrangements needed for computing 16x4 elements per iteration
     // All the values needed for computing a single 4x4 block will be read from consecutive memory positions
-    execute_window_loop(window, [&](const Coordinates &)
+    execute_window_loop(window, [&](const Coordinates & id)
     {
         auto *mtx_a0 = reinterpret_cast<const int8_t *>(ina.ptr());
         auto *mtx_b0 = reinterpret_cast<const int8_t *>(inb.ptr());
@@ -692,32 +779,86 @@
             c3.val[2] = vmlal_lane_s16(c3.val[2], b00_s16.val[2], a00_s16, 3);
             c3.val[3] = vmlal_lane_s16(c3.val[3], b00_s16.val[3], a00_s16, 3);
         }
-
         auto mtx_out = reinterpret_cast<int32_t *>(out.ptr());
-        vst1q_s32(mtx_out + 0 * out_stride + 0, c0.val[0]);
-        vst1q_s32(mtx_out + 0 * out_stride + 4, c0.val[1]);
-        vst1q_s32(mtx_out + 0 * out_stride + 8, c0.val[2]);
-        vst1q_s32(mtx_out + 0 * out_stride + 12, c0.val[3]);
-        vst1q_s32(mtx_out + 1 * out_stride + 0, c1.val[0]);
-        vst1q_s32(mtx_out + 1 * out_stride + 4, c1.val[1]);
-        vst1q_s32(mtx_out + 1 * out_stride + 8, c1.val[2]);
-        vst1q_s32(mtx_out + 1 * out_stride + 12, c1.val[3]);
-        vst1q_s32(mtx_out + 2 * out_stride + 0, c2.val[0]);
-        vst1q_s32(mtx_out + 2 * out_stride + 4, c2.val[1]);
-        vst1q_s32(mtx_out + 2 * out_stride + 8, c2.val[2]);
-        vst1q_s32(mtx_out + 2 * out_stride + 12, c2.val[3]);
-        vst1q_s32(mtx_out + 3 * out_stride + 0, c3.val[0]);
-        vst1q_s32(mtx_out + 3 * out_stride + 4, c3.val[1]);
-        vst1q_s32(mtx_out + 3 * out_stride + 8, c3.val[2]);
-        vst1q_s32(mtx_out + 3 * out_stride + 12, c3.val[3]);
+        if(id.y() < height_out && id.x() < (width_out - 16))
+        {
+            vst1q_s32(mtx_out + 0 * out_stride + 0, c0.val[0]);
+            vst1q_s32(mtx_out + 0 * out_stride + 4, c0.val[1]);
+            vst1q_s32(mtx_out + 0 * out_stride + 8, c0.val[2]);
+            vst1q_s32(mtx_out + 0 * out_stride + 12, c0.val[3]);
+            if(id.y() + 1 < height_out)
+            {
+                vst1q_s32(mtx_out + 1 * out_stride + 0, c1.val[0]);
+                vst1q_s32(mtx_out + 1 * out_stride + 4, c1.val[1]);
+                vst1q_s32(mtx_out + 1 * out_stride + 8, c1.val[2]);
+                vst1q_s32(mtx_out + 1 * out_stride + 12, c1.val[3]);
+                if(id.y() + 2 < height_out)
+                {
+                    vst1q_s32(mtx_out + 2 * out_stride + 0, c2.val[0]);
+                    vst1q_s32(mtx_out + 2 * out_stride + 4, c2.val[1]);
+                    vst1q_s32(mtx_out + 2 * out_stride + 8, c2.val[2]);
+                    vst1q_s32(mtx_out + 2 * out_stride + 12, c2.val[3]);
+                    if(id.y() + 3 < height_out)
+                    {
+                        vst1q_s32(mtx_out + 3 * out_stride + 0, c3.val[0]);
+                        vst1q_s32(mtx_out + 3 * out_stride + 4, c3.val[1]);
+                        vst1q_s32(mtx_out + 3 * out_stride + 8, c3.val[2]);
+                        vst1q_s32(mtx_out + 3 * out_stride + 12, c3.val[3]);
+                    }
+                }
+            }
+        }
+        else if(id.y() < height_out)
+        {
+            const auto left_over_value = width_out - id.x();
+            auto       left_over       = left_over_value;
+            for(auto k = 0; k < 4 && left_over; ++k)
+            {
+                for(auto j = 0; j < 4 && left_over; ++j, --left_over)
+                {
+                    *(mtx_out + k * 4 + j) = c0.val[k][j];
+                }
+            }
+            if(id.y() + 1 < height_out)
+            {
+                left_over = left_over_value;
+                for(auto k = 0; k < 4 && left_over; ++k)
+                {
+                    for(auto j = 0; j < 4 && left_over; ++j, --left_over)
+                    {
+                        *(mtx_out + out_stride + k * 4 + j) = c1.val[k][j];
+                    }
+                }
+                if(id.y() + 2 < height_out)
+                {
+                    left_over = left_over_value;
+                    for(auto k = 0; k < 4 && left_over; ++k)
+                    {
+                        for(auto j = 0; j < 4 && left_over; ++j, --left_over)
+                        {
+                            *(mtx_out + out_stride * 2 + k * 4 + j) = c2.val[k][j];
+                        }
+                    }
+                    if(id.y() + 3 < height_out)
+                    {
+                        left_over = left_over_value;
+                        for(auto k = 0; k < 4 && left_over; ++k)
+                        {
+                            for(auto j = 0; j < 4 && left_over; ++j, --left_over)
+                            {
+                                *(mtx_out + out_stride * 3 + k * 4 + j) = c3.val[k][j];
+                            }
+                        }
+                    }
+                }
+            }
+        }
+
     },
     ina, inb, out);
 }
 } // namespace
 
-class Coordinates;
-} // namespace arm_compute
-
 namespace
 {
 Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output)
@@ -748,50 +889,6 @@
 
     return Status{};
 }
-
-std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input0, ITensorInfo *input1, ITensorInfo *output)
-{
-    constexpr unsigned int num_elems_processed_per_iteration_x = 16;
-    constexpr unsigned int num_elems_processed_per_iteration_y = 4;
-
-    Window win;
-    bool   window_changed = false;
-
-    // Check if the output tensor is a vector. If so,the kernel runs the vector-matrix multiplication
-    if((output->dimension(1) == 1))
-    {
-        // Configure kernel window
-        win = calculate_max_window(*output, Steps(num_elems_processed_per_iteration_x));
-
-        // We cannot read out-of-bound elements from matrix A as we use the left-over for loop
-        AccessWindowStatic     in0_access(input0, 0, 0, input0->tensor_shape().x(), 1);
-        AccessWindowHorizontal in1_access(input1, 0, num_elems_processed_per_iteration_x);
-        AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration_x);
-
-        window_changed = update_window_and_padding(win, in0_access, in1_access, output_access);
-
-        Coordinates coord;
-        coord.set_num_dimensions(output->num_dimensions());
-        output_access.set_valid_region(win, ValidRegion(coord, output->tensor_shape()));
-    }
-    else
-    {
-        win = calculate_max_window(*output, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
-
-        unsigned int num_k_iterations = ceil_to_multiple(input1->dimension(0), num_elems_processed_per_iteration_x) / 16;
-        // For each iteration of "k" we increment the input pointer by 4, and we load 8 elements a the time:
-        AccessWindowStatic    in0_access(input0, 0, 0, (num_k_iterations - 1) * 4 + 8, input0->dimension(1));
-        AccessWindowStatic    in1_access(input1, 0, 0, ceil_to_multiple(input1->dimension(0), num_elems_processed_per_iteration_x), input1->dimension(1));
-        AccessWindowRectangle output_access(output, 0, 0, num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y);
-
-        window_changed = update_window_and_padding(win, in0_access, in1_access, output_access);
-
-        output_access.set_valid_region(win, ValidRegion(Coordinates(), output->tensor_shape()));
-    }
-
-    Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
-    return std::make_pair(err, win);
-}
 } // namespace
 
 NEGEMMLowpMatrixMultiplyKernel::NEGEMMLowpMatrixMultiplyKernel()
@@ -812,16 +909,33 @@
     _output         = output;
     _slide_matrix_b = in1_shape[2] != 1;
 
-    // Configure kernel window
-    auto win_config = validate_and_configure_window(input0->info(), input1->info(), output->info());
-    ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
-    INEKernel::configure(win_config.second);
+    constexpr unsigned int num_elems_processed_per_iteration_x = 16;
+    constexpr unsigned int num_elems_processed_per_iteration_y = 4;
+
+    Window win;
+
+    // Check if the output tensor is a vector. If so,the kernel runs the vector-matrix multiplication
+    if((output->info()->dimension(1) == 1))
+    {
+        // Configure kernel window
+        win = calculate_max_window(*output->info(), Steps(num_elems_processed_per_iteration_x));
+
+        Coordinates coord;
+        coord.set_num_dimensions(output->info()->num_dimensions());
+        output->info()->set_valid_region(ValidRegion(coord, output->info()->tensor_shape()));
+    }
+    else
+    {
+        win = calculate_max_window(*output->info(), Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
+        output->info()->set_valid_region(ValidRegion(Coordinates(), output->info()->tensor_shape()));
+    }
+
+    INEKernel::configure(win);
 }
 
 Status NEGEMMLowpMatrixMultiplyKernel::validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output)
 {
     ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input0, input1, output));
-    ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input0->clone().get(), input1->clone().get(), output->clone().get()).first);
 
     return Status{};
 }
@@ -837,6 +951,7 @@
     {
         const auto width_matrix_a = static_cast<int>(_input0->info()->dimension(0));
         const auto width_matrix_b = static_cast<int>(_input1->info()->dimension(0));
+        const auto width_out      = static_cast<int>(_output->info()->dimension(0));
         const auto in_b_stride    = static_cast<int>(_input1->info()->strides_in_bytes()[1] / data_size_from_type(_input1->info()->data_type()));
 
         // The implementation computes 16 elements per iteration
@@ -872,13 +987,13 @@
             case DataType::S8:
             case DataType::QASYMM8_SIGNED:
             {
-                vector_matrix_multiply_s8(ina, inb, out, width_matrix_a, width_matrix_b, in_b_stride, window);
+                vector_matrix_multiply_s8(ina, inb, out, width_matrix_a, width_matrix_b, width_out, in_b_stride, window);
                 break;
             }
             case DataType::U8:
             case DataType::QASYMM8:
             {
-                vector_matrix_multiply_u8(ina, inb, out, width_matrix_a, width_matrix_b, in_b_stride, window);
+                vector_matrix_multiply_u8(ina, inb, out, width_matrix_a, width_matrix_b, width_out, in_b_stride, window);
                 break;
             }
             default:
@@ -891,7 +1006,7 @@
     else
     {
         const size_t in_b_stride = _input1->info()->strides_in_bytes()[1];
-        const size_t out_stride  = _output->info()->strides_in_bytes()[1] / _output->info()->element_size();
+        const int    width_b     = _input1->info()->dimension(0);
 
         // Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the output matrix
         Window win_a(window);
@@ -914,19 +1029,18 @@
         Iterator inb(_input1, win_b);
         Iterator out(_output, window);
 
-        const int width_b = _input1->info()->dimension(0);
         switch(_input0->info()->data_type())
         {
             case DataType::S8:
             case DataType::QASYMM8_SIGNED:
             {
-                matrix_multiply_s8(ina, inb, out, width_b, out_stride, window);
+                matrix_multiply_s8(ina, inb, out, width_b, *_output->info(), window);
                 break;
             }
             case DataType::U8:
             case DataType::QASYMM8:
             {
-                matrix_multiply_u8(ina, inb, out, width_b, out_stride, window);
+                matrix_multiply_u8(ina, inb, out, width_b, *_output->info(), window);
                 break;
             }
             default:
@@ -937,3 +1051,6 @@
         }
     }
 }
+} // namespace arm_compute
+
+