COMPMID-1056 - Optimizing CLGEMMMatrixMultiplyKernel refactoring the inner loop

Results reported at:
https://confluence.arm.com/display/MLENG/GEMM+FP32+performance%3A+ACL+18.05

Change-Id: I3246c4f19c4d21a7d6a44e4593bc5caffc016f81
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/127838
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
diff --git a/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h b/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h
index 584266b..67c0467 100644
--- a/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h
+++ b/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h
@@ -27,11 +27,11 @@
 #include "arm_compute/runtime/CL/ICLSimpleFunction.h"
 
 #include "arm_compute/core/CL/kernels/CLGEMMMatrixAccumulateBiasesKernel.h"
-#include "arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyKernel.h"
 #include "arm_compute/core/CL/kernels/CLIm2ColKernel.h"
 #include "arm_compute/core/CL/kernels/CLTransposeKernel.h"
 #include "arm_compute/runtime/CL/CLMemoryGroup.h"
 #include "arm_compute/runtime/CL/CLTensor.h"
+#include "arm_compute/runtime/CL/functions/CLGEMM.h"
 #include "arm_compute/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.h"
 #include "arm_compute/runtime/CL/functions/CLGEMMLowpOutputStage.h"
 
@@ -113,12 +113,12 @@
 private:
     void configure_fc_fc(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output);
     void configure_conv_fc(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output);
-    void configure_mm(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output, bool is_interleaved_transposed = true);
+    void configure_mm(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output);
 
     CLMemoryGroup                                       _memory_group;
     CLIm2ColKernel                                      _im2col_kernel;
     CLFullyConnectedLayerReshapeWeights                 _reshape_weights_kernel;
-    CLGEMMMatrixMultiplyKernel                          _mm_kernel;
+    CLGEMM                                              _mm_gemm;
     CLGEMMLowpMatrixMultiplyCore                        _mm_gemmlowp;
     CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint _gemmlowp_output_stage;
     CLGEMMMatrixAccumulateBiasesKernel                  _accumulate_biases_kernel;
diff --git a/src/core/CL/cl_kernels/gemm.cl b/src/core/CL/cl_kernels/gemm.cl
index 4b1672c..381130e 100644
--- a/src/core/CL/cl_kernels/gemm.cl
+++ b/src/core/CL/cl_kernels/gemm.cl
@@ -342,9 +342,6 @@
     __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes);
     __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes);
 
-    // Compute end row address for matrix B
-    __global float *src_end_addr_b = src_addr_b + COLS_B;
-
     src_addr_a += offset_row_a;
     src_addr_b += offset_row_b;
 
@@ -366,35 +363,17 @@
     float c32 = 0.0f;
     float c33 = 0.0f;
 
-    for(; src_addr_b <= (src_end_addr_b - (int)(16 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += (16 * MULT_INTERLEAVE4X4_HEIGHT), src_addr_b += (16 * MULT_TRANSPOSE1XW_WIDTH))
+#define COLS_MTX_B (COLS_B / (4 * MULT_TRANSPOSE1XW_WIDTH))
+
+    int i = 0;
+    for(; i <= (int)(COLS_MTX_B - 4); i += 4)
     {
         // Load values from matrix A (interleaved) and matrix B (transposed)
         float4 a0 = vload4(0, src_addr_a);
         float4 b0 = vload4(0, src_addr_b);
 
-        c00 = fma(a0.s0, b0.s0, c00);
-        c01 = fma(a0.s0, b0.s1, c01);
-        c02 = fma(a0.s0, b0.s2, c02);
-        c03 = fma(a0.s0, b0.s3, c03);
-
-        c10 = fma(a0.s1, b0.s0, c10);
-        c11 = fma(a0.s1, b0.s1, c11);
-        c12 = fma(a0.s1, b0.s2, c12);
-        c13 = fma(a0.s1, b0.s3, c13);
-
-        c20 = fma(a0.s2, b0.s0, c20);
-        c21 = fma(a0.s2, b0.s1, c21);
-        c22 = fma(a0.s2, b0.s2, c22);
-        c23 = fma(a0.s2, b0.s3, c23);
-
-        c30 = fma(a0.s3, b0.s0, c30);
-        c31 = fma(a0.s3, b0.s1, c31);
-        c32 = fma(a0.s3, b0.s2, c32);
-        c33 = fma(a0.s3, b0.s3, c33);
-
-        // Load values from matrix A (interleaved) and matrix B (transposed)
-        a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
-        b0 = vload4(0, src_addr_b + 4 * MULT_TRANSPOSE1XW_WIDTH);
+        src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
+        src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
 
         c00 = fma(a0.s0, b0.s0, c00);
         c01 = fma(a0.s0, b0.s1, c01);
@@ -417,8 +396,11 @@
         c33 = fma(a0.s3, b0.s3, c33);
 
         // Load values from matrix A (interleaved) and matrix B (transposed)
-        a0 = vload4(0, src_addr_a + 8 * MULT_INTERLEAVE4X4_HEIGHT);
-        b0 = vload4(0, src_addr_b + 8 * MULT_TRANSPOSE1XW_WIDTH);
+        a0 = vload4(0, src_addr_a);
+        b0 = vload4(0, src_addr_b);
+
+        src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
+        src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
 
         c00 = fma(a0.s0, b0.s0, c00);
         c01 = fma(a0.s0, b0.s1, c01);
@@ -441,8 +423,38 @@
         c33 = fma(a0.s3, b0.s3, c33);
 
         // Load values from matrix A (interleaved) and matrix B (transposed)
-        a0 = vload4(0, src_addr_a + 12 * MULT_INTERLEAVE4X4_HEIGHT);
-        b0 = vload4(0, src_addr_b + 12 * MULT_TRANSPOSE1XW_WIDTH);
+        a0 = vload4(0, src_addr_a);
+        b0 = vload4(0, src_addr_b);
+
+        src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
+        src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
+
+        c00 = fma(a0.s0, b0.s0, c00);
+        c01 = fma(a0.s0, b0.s1, c01);
+        c02 = fma(a0.s0, b0.s2, c02);
+        c03 = fma(a0.s0, b0.s3, c03);
+
+        c10 = fma(a0.s1, b0.s0, c10);
+        c11 = fma(a0.s1, b0.s1, c11);
+        c12 = fma(a0.s1, b0.s2, c12);
+        c13 = fma(a0.s1, b0.s3, c13);
+
+        c20 = fma(a0.s2, b0.s0, c20);
+        c21 = fma(a0.s2, b0.s1, c21);
+        c22 = fma(a0.s2, b0.s2, c22);
+        c23 = fma(a0.s2, b0.s3, c23);
+
+        c30 = fma(a0.s3, b0.s0, c30);
+        c31 = fma(a0.s3, b0.s1, c31);
+        c32 = fma(a0.s3, b0.s2, c32);
+        c33 = fma(a0.s3, b0.s3, c33);
+
+        // Load values from matrix A (interleaved) and matrix B (transposed)
+        a0 = vload4(0, src_addr_a);
+        b0 = vload4(0, src_addr_b);
+
+        src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
+        src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
 
         c00 = fma(a0.s0, b0.s0, c00);
         c01 = fma(a0.s0, b0.s1, c01);
@@ -465,12 +477,15 @@
         c33 = fma(a0.s3, b0.s3, c33);
     }
 
-    for(; src_addr_b < src_end_addr_b; src_addr_a += (4 * MULT_INTERLEAVE4X4_HEIGHT), src_addr_b += (4 * MULT_TRANSPOSE1XW_WIDTH))
+    for(; i < (int)(COLS_MTX_B); ++i)
     {
         // Load values from matrix A (interleaved) and matrix B (transposed)
         float4 a0 = vload4(0, src_addr_a);
         float4 b0 = vload4(0, src_addr_b);
 
+        src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
+        src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
+
         c00 = fma(a0.s0, b0.s0, c00);
         c01 = fma(a0.s0, b0.s1, c01);
         c02 = fma(a0.s0, b0.s2, c02);
@@ -1130,9 +1145,6 @@
     src_addr.s1 += get_global_id(2) * src1_stride_z;
 #endif // defined(MATRIX_B_DEPTH)
 
-    // Address boundary for matrix A
-    int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(float));
-
     // Initialize accumulators
     float acc00 = 0.0f;
     float acc01 = 0.0f;
@@ -1161,72 +1173,162 @@
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
 
     // A and B src indices get incremented at the same time.
-    for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(float)); src_addr += (int2)(2 * sizeof(float), 2 * src1_stride_y))
+    int i = 0;
+    for(; i <= ((int)COLS_A - 4); i += 4)
     {
-        // Load values from matrix A
-        float2 a0 = vload2(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
+        // Load values from matrix A and matrix B
+        float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-        float2 a1 = vload2(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
+        float4 a1 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-        float2 a2 = vload2(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
+        float4 a2 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-        float2 a3 = vload2(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
+        float4 a3 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-        // Load values from matrix B
-        float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1 + 0 * src1_stride_y));
-        float4 b1 = vload4(0, (__global float *)(src1_ptr + src_addr.s1 + 1 * src1_stride_y));
+        float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
+        src_addr.s1 += src1_stride_y;
 
         // Multiply and accumulate
         acc00 = fma(a0.s0, b0.s0, acc00);
-        acc00 = fma(a0.s1, b1.s0, acc00);
         acc01 = fma(a0.s0, b0.s1, acc01);
-        acc01 = fma(a0.s1, b1.s1, acc01);
         acc02 = fma(a0.s0, b0.s2, acc02);
-        acc02 = fma(a0.s1, b1.s2, acc02);
-        acc03 = fma(a0.s1, b1.s3, acc03);
         acc03 = fma(a0.s0, b0.s3, acc03);
 
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+
         acc10 = fma(a1.s0, b0.s0, acc10);
         acc11 = fma(a1.s0, b0.s1, acc11);
         acc12 = fma(a1.s0, b0.s2, acc12);
         acc13 = fma(a1.s0, b0.s3, acc13);
 
-        acc10 = fma(a1.s1, b1.s0, acc10);
-        acc11 = fma(a1.s1, b1.s1, acc11);
-        acc12 = fma(a1.s1, b1.s2, acc12);
-        acc13 = fma(a1.s1, b1.s3, acc13);
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+
         acc20 = fma(a2.s0, b0.s0, acc20);
         acc21 = fma(a2.s0, b0.s1, acc21);
         acc22 = fma(a2.s0, b0.s2, acc22);
         acc23 = fma(a2.s0, b0.s3, acc23);
 
-        acc20 = fma(a2.s1, b1.s0, acc20);
-        acc21 = fma(a2.s1, b1.s1, acc21);
-        acc22 = fma(a2.s1, b1.s2, acc22);
-        acc23 = fma(a2.s1, b1.s3, acc23);
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+
         acc30 = fma(a3.s0, b0.s0, acc30);
         acc31 = fma(a3.s0, b0.s1, acc31);
         acc32 = fma(a3.s0, b0.s2, acc32);
         acc33 = fma(a3.s0, b0.s3, acc33);
-
-        acc30 = fma(a3.s1, b1.s0, acc30);
-        acc31 = fma(a3.s1, b1.s1, acc31);
-        acc32 = fma(a3.s1, b1.s2, acc32);
-        acc33 = fma(a3.s1, b1.s3, acc33);
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+
+        // Load values from matrix A and matrix B
+        b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
+        src_addr.s1 += src1_stride_y;
+
+        // Multiply and accumulate
+        acc00 = fma(a0.s1, b0.s0, acc00);
+        acc01 = fma(a0.s1, b0.s1, acc01);
+        acc02 = fma(a0.s1, b0.s2, acc02);
+        acc03 = fma(a0.s1, b0.s3, acc03);
+
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+
+        acc10 = fma(a1.s1, b0.s0, acc10);
+        acc11 = fma(a1.s1, b0.s1, acc11);
+        acc12 = fma(a1.s1, b0.s2, acc12);
+        acc13 = fma(a1.s1, b0.s3, acc13);
+
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+
+        acc20 = fma(a2.s1, b0.s0, acc20);
+        acc21 = fma(a2.s1, b0.s1, acc21);
+        acc22 = fma(a2.s1, b0.s2, acc22);
+        acc23 = fma(a2.s1, b0.s3, acc23);
+
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+
+        acc30 = fma(a3.s1, b0.s0, acc30);
+        acc31 = fma(a3.s1, b0.s1, acc31);
+        acc32 = fma(a3.s1, b0.s2, acc32);
+        acc33 = fma(a3.s1, b0.s3, acc33);
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+
+        // Load values from matrix A and matrix B
+        b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
+        src_addr.s1 += src1_stride_y;
+
+        // Multiply and accumulate
+        acc00 = fma(a0.s2, b0.s0, acc00);
+        acc01 = fma(a0.s2, b0.s1, acc01);
+        acc02 = fma(a0.s2, b0.s2, acc02);
+        acc03 = fma(a0.s2, b0.s3, acc03);
+
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+
+        acc10 = fma(a1.s2, b0.s0, acc10);
+        acc11 = fma(a1.s2, b0.s1, acc11);
+        acc12 = fma(a1.s2, b0.s2, acc12);
+        acc13 = fma(a1.s2, b0.s3, acc13);
+
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+
+        acc20 = fma(a2.s2, b0.s0, acc20);
+        acc21 = fma(a2.s2, b0.s1, acc21);
+        acc22 = fma(a2.s2, b0.s2, acc22);
+        acc23 = fma(a2.s2, b0.s3, acc23);
+
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+
+        acc30 = fma(a3.s2, b0.s0, acc30);
+        acc31 = fma(a3.s2, b0.s1, acc31);
+        acc32 = fma(a3.s2, b0.s2, acc32);
+        acc33 = fma(a3.s2, b0.s3, acc33);
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+
+        // Load values from matrix A and matrix B
+        b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
+        src_addr.s1 += src1_stride_y;
+
+        // Multiply and accumulate
+        acc00 = fma(a0.s3, b0.s0, acc00);
+        acc01 = fma(a0.s3, b0.s1, acc01);
+        acc02 = fma(a0.s3, b0.s2, acc02);
+        acc03 = fma(a0.s3, b0.s3, acc03);
+
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+
+        acc10 = fma(a1.s3, b0.s0, acc10);
+        acc11 = fma(a1.s3, b0.s1, acc11);
+        acc12 = fma(a1.s3, b0.s2, acc12);
+        acc13 = fma(a1.s3, b0.s3, acc13);
+
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+
+        acc20 = fma(a2.s3, b0.s0, acc20);
+        acc21 = fma(a2.s3, b0.s1, acc21);
+        acc22 = fma(a2.s3, b0.s2, acc22);
+        acc23 = fma(a2.s3, b0.s3, acc23);
+
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
+#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+
+        acc30 = fma(a3.s3, b0.s0, acc30);
+        acc31 = fma(a3.s3, b0.s1, acc31);
+        acc32 = fma(a3.s3, b0.s2, acc32);
+        acc33 = fma(a3.s3, b0.s3, acc33);
+#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+
+        src_addr.s0 += 4 * sizeof(float);
     }
 
-    for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(float), src1_stride_y))
+    for(; i < (int)COLS_A; ++i)
     {
         // Load values from matrix A
-        float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
+        float a0 = *((__global float *)(src0_ptr + src_addr.s0));
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
         float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
@@ -1238,6 +1340,7 @@
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
         // Load values from matrix B
         float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
+        src_addr.s1 += src1_stride_y;
 
         // Multiply and accumulate
         acc00 = fma(a0, b0.s0, acc00);
@@ -1262,6 +1365,8 @@
         acc32 = fma(a3, b0.s2, acc32);
         acc33 = fma(a3, b0.s3, acc33);
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+
+        src_addr.s0 += sizeof(float);
     }
 
     // Compute destination address
@@ -1375,9 +1480,6 @@
     src_addr.s1 += get_global_id(2) * src1_stride_z;
 #endif // defined(MATRIX_B_DEPTH)
 
-    // Address boundary for the matrix A
-    int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(float));
-
     // Initialize accumulators
     float acc00 = 0.0f;
     float acc01 = 0.0f;
@@ -1396,67 +1498,114 @@
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
 
     // A and B src indices get incremented at the same time.
-    for(; src_addr.s0 <= (end_row_vec_a - 4 * (int)sizeof(float)); src_addr += (int2)(4 * sizeof(float), 4 * src1_stride_y))
+    int i = 0;
+    for(; i <= ((int)COLS_A - 8); i += 8)
     {
         // Load values from matrix A
-        float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
+        float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0));
 
         // Load values from matrix B
-        float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1 + 0 * src1_stride_y));
-        float2 b1 = vload2(0, (__global float *)(src1_ptr + src_addr.s1 + 1 * src1_stride_y));
-        float2 b2 = vload2(0, (__global float *)(src1_ptr + src_addr.s1 + 2 * src1_stride_y));
-        float2 b3 = vload2(0, (__global float *)(src1_ptr + src_addr.s1 + 3 * src1_stride_y));
+        float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
+        src_addr.s1 += src1_stride_y;
+        float2 b1 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
+        src_addr.s1 += src1_stride_y;
+        float2 b2 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
+        src_addr.s1 += src1_stride_y;
+        float2 b3 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
+        src_addr.s1 += src1_stride_y;
+        float2 b4 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
+        src_addr.s1 += src1_stride_y;
+        float2 b5 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
+        src_addr.s1 += src1_stride_y;
+        float2 b6 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
+        src_addr.s1 += src1_stride_y;
+        float2 b7 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
+        src_addr.s1 += src1_stride_y;
 
         // Multiply and accumulate
         acc00 = fma(a0.s0, b0.s0, acc00);
         acc00 = fma(a0.s1, b1.s0, acc00);
         acc00 = fma(a0.s2, b2.s0, acc00);
         acc00 = fma(a0.s3, b3.s0, acc00);
+        acc00 = fma(a0.s4, b4.s0, acc00);
+        acc00 = fma(a0.s5, b5.s0, acc00);
+        acc00 = fma(a0.s6, b6.s0, acc00);
+        acc00 = fma(a0.s7, b7.s0, acc00);
 
         acc01 = fma(a0.s0, b0.s1, acc01);
         acc01 = fma(a0.s1, b1.s1, acc01);
         acc01 = fma(a0.s2, b2.s1, acc01);
         acc01 = fma(a0.s3, b3.s1, acc01);
+        acc01 = fma(a0.s4, b4.s1, acc01);
+        acc01 = fma(a0.s5, b5.s1, acc01);
+        acc01 = fma(a0.s6, b6.s1, acc01);
+        acc01 = fma(a0.s7, b7.s1, acc01);
 
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
-        a0    = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
+        a0    = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
         acc10 = fma(a0.s0, b0.s0, acc10);
         acc10 = fma(a0.s1, b1.s0, acc10);
         acc10 = fma(a0.s2, b2.s0, acc10);
         acc10 = fma(a0.s3, b3.s0, acc10);
+        acc10 = fma(a0.s4, b4.s0, acc10);
+        acc10 = fma(a0.s5, b5.s0, acc10);
+        acc10 = fma(a0.s6, b6.s0, acc10);
+        acc10 = fma(a0.s7, b7.s0, acc10);
 
         acc11 = fma(a0.s0, b0.s1, acc11);
         acc11 = fma(a0.s1, b1.s1, acc11);
         acc11 = fma(a0.s2, b2.s1, acc11);
         acc11 = fma(a0.s3, b3.s1, acc11);
+        acc11 = fma(a0.s4, b4.s1, acc11);
+        acc11 = fma(a0.s5, b5.s1, acc11);
+        acc11 = fma(a0.s6, b6.s1, acc11);
+        acc11 = fma(a0.s7, b7.s1, acc11);
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
-        a0    = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
+        a0    = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
         acc20 = fma(a0.s0, b0.s0, acc20);
         acc20 = fma(a0.s1, b1.s0, acc20);
         acc20 = fma(a0.s2, b2.s0, acc20);
         acc20 = fma(a0.s3, b3.s0, acc20);
+        acc20 = fma(a0.s4, b4.s0, acc20);
+        acc20 = fma(a0.s5, b5.s0, acc20);
+        acc20 = fma(a0.s6, b6.s0, acc20);
+        acc20 = fma(a0.s7, b7.s0, acc20);
 
         acc21 = fma(a0.s0, b0.s1, acc21);
         acc21 = fma(a0.s1, b1.s1, acc21);
         acc21 = fma(a0.s2, b2.s1, acc21);
         acc21 = fma(a0.s3, b3.s1, acc21);
+        acc21 = fma(a0.s4, b4.s1, acc21);
+        acc21 = fma(a0.s5, b5.s1, acc21);
+        acc21 = fma(a0.s6, b6.s1, acc21);
+        acc21 = fma(a0.s7, b7.s1, acc21);
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
-        a0    = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
+        a0    = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
         acc30 = fma(a0.s0, b0.s0, acc30);
         acc30 = fma(a0.s1, b1.s0, acc30);
         acc30 = fma(a0.s2, b2.s0, acc30);
         acc30 = fma(a0.s3, b3.s0, acc30);
+        acc30 = fma(a0.s4, b4.s0, acc30);
+        acc30 = fma(a0.s5, b5.s0, acc30);
+        acc30 = fma(a0.s6, b6.s0, acc30);
+        acc30 = fma(a0.s7, b7.s0, acc30);
 
         acc31 = fma(a0.s0, b0.s1, acc31);
         acc31 = fma(a0.s1, b1.s1, acc31);
         acc31 = fma(a0.s2, b2.s1, acc31);
         acc31 = fma(a0.s3, b3.s1, acc31);
+        acc31 = fma(a0.s4, b4.s1, acc31);
+        acc31 = fma(a0.s5, b5.s1, acc31);
+        acc31 = fma(a0.s6, b6.s1, acc31);
+        acc31 = fma(a0.s7, b7.s1, acc31);
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+
+        src_addr.s0 += sizeof(float) * 8;
     }
     // float size increment
-    for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(4, src1_stride_y))
+    for(; i < (int)COLS_A; ++i)
     {
         // Load values from matrix A
         float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
@@ -1471,6 +1620,7 @@
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
         // Load values from matrix B
         float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
+        src_addr.s1 += src1_stride_y;
 
         // Multiply and accumulate
         acc00 = fma(a0, b0.s0, acc00);
@@ -1487,6 +1637,8 @@
         acc30 = fma(a3, b0.s0, acc30);
         acc31 = fma(a3, b0.s1, acc31);
 #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
+
+        src_addr.s0 += sizeof(float);
     }
 
     // Compute destination address
diff --git a/src/runtime/CL/functions/CLFullyConnectedLayer.cpp b/src/runtime/CL/functions/CLFullyConnectedLayer.cpp
index 5dd1f00..9b3bf48 100644
--- a/src/runtime/CL/functions/CLFullyConnectedLayer.cpp
+++ b/src/runtime/CL/functions/CLFullyConnectedLayer.cpp
@@ -37,10 +37,8 @@
 
 namespace
 {
-Status validate_mm(const ITensorInfo &input, const ITensorInfo &weights, const ITensorInfo &output, bool is_interleaved_transposed)
+Status validate_mm(const ITensorInfo &input, const ITensorInfo &weights, const ITensorInfo &output)
 {
-    const GPUTarget gpu_target = CLScheduler::get().target();
-
     if(is_data_type_quantized_asymmetric(input.data_type()))
     {
         // Since we need negative offsets for computing convolution, we need to change QuantizationInfo()
@@ -55,7 +53,7 @@
     }
     else
     {
-        ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyKernel::validate(&input, &weights, &output, 1.f, is_interleaved_transposed, GEMMReshapeInfo(), gpu_target));
+        ARM_COMPUTE_RETURN_ON_ERROR(CLGEMM::validate(&input, &weights, nullptr, &output, 1.f, 0.0f, GEMMInfo(false, false, true /* Reshape weights only for the first run */)));
     }
 
     return Status{};
@@ -75,12 +73,12 @@
 }
 
 CLFullyConnectedLayer::CLFullyConnectedLayer(std::shared_ptr<IMemoryManager> memory_manager)
-    : _memory_group(memory_manager), _im2col_kernel(), _reshape_weights_kernel(), _mm_kernel(), _mm_gemmlowp(memory_manager), _gemmlowp_output_stage(), _accumulate_biases_kernel(), _im2col_output(),
-      _gemmlowp_output(), _reshape_weights_output(), _are_weights_reshaped(true), _is_fc_after_conv(true), _accumulate_biases(false), _is_quantized(false), _original_weights(nullptr)
+    : _memory_group(memory_manager), _im2col_kernel(), _reshape_weights_kernel(), _mm_gemm(memory_manager), _mm_gemmlowp(memory_manager), _gemmlowp_output_stage(), _accumulate_biases_kernel(),
+      _im2col_output(), _gemmlowp_output(), _reshape_weights_output(), _are_weights_reshaped(true), _is_fc_after_conv(true), _accumulate_biases(false), _is_quantized(false), _original_weights(nullptr)
 {
 }
 
-void CLFullyConnectedLayer::configure_mm(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output, bool is_interleaved_transposed)
+void CLFullyConnectedLayer::configure_mm(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output)
 {
     if(_is_quantized)
     {
@@ -102,8 +100,7 @@
     else
     {
         // Configure matrix multiply kernel
-        _mm_kernel.set_target(CLScheduler::get().target());
-        _mm_kernel.configure(input, weights, output, 1.f, is_interleaved_transposed);
+        _mm_gemm.configure(input, weights, nullptr, output, 1.f, 0.0f, GEMMInfo(false, false, true /* Reshape weights only for the first run */));
     }
 }
 
@@ -122,7 +119,7 @@
     _im2col_kernel.configure(input, &_im2col_output, Size2D(1, 1), PadStrideInfo(1, 1, 0, 0), false);
 
     // Configure matrix multiply kernel
-    configure_mm(&_im2col_output, weights, output, false);
+    configure_mm(&_im2col_output, weights, output);
 
     // Allocate the output tensor for im2col once all the configure methods have been called
     _im2col_output.allocator()->allocate();
@@ -133,7 +130,7 @@
     ARM_COMPUTE_ERROR_ON(input->info()->dimension(0) != weights->info()->dimension(1));
 
     // Configure matrix multiply kernel
-    configure_mm(input, weights, output, false);
+    configure_mm(input, weights, output);
 }
 
 void CLFullyConnectedLayer::configure(const ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output, bool transpose_weights, bool are_weights_reshaped)
@@ -301,7 +298,7 @@
         ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(0) != weights_to_use->dimension(1));
     }
     // Validate matrix multiply kernel
-    ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(*input_to_use, *weights_to_use, *tmp_output, false));
+    ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(*input_to_use, *weights_to_use, *tmp_output));
 
     // Validate output stage for asymmetric quantized types
     if(is_quantized)
@@ -341,7 +338,7 @@
     }
     else
     {
-        CLScheduler::get().enqueue(_mm_kernel, !_accumulate_biases);
+        _mm_gemm.run();
     }
 
     // Accumulate biases if provided