COMPMID-922 - CLGEMM FP16 optimizations - part2

This patch improves of ~30 % GEMM fp16 when the reshape is required
The results have been reported at the following confluence page:
https://confluence.arm.com/display/MLENG/GEMM+FP16+performance%3A+ACL+18.05

Change-Id: I8233095a7e9ab06f1f915782a25dd41653b49140
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/128254
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Tested-by: Jenkins <bsgcomp@arm.com>
diff --git a/src/core/CL/CLKernelLibrary.cpp b/src/core/CL/CLKernelLibrary.cpp
index 7e3eebc..f1be935 100644
--- a/src/core/CL/CLKernelLibrary.cpp
+++ b/src/core/CL/CLKernelLibrary.cpp
@@ -230,7 +230,8 @@
     { "gemm_mv", "gemv.cl" },
     { "gemm_mv_quantized", "gemv.cl" },
     { "gemm_mm_interleaved_transposed_f16", "gemm.cl" },
-    { "gemm_mm_interleaved_transposed_f32_midgard", "gemm.cl" },
+    { "gemm_mm_interleaved_transposed_f16_bifrost", "gemm.cl" },
+    { "gemm_mm_interleaved_transposed_f32", "gemm.cl" },
     { "gemm_mm_interleaved_transposed_f32_bifrost", "gemm.cl" },
     { "gemm_mm_interleaved_transposed_qs8", "gemm.cl" },
     { "gemm_mm_interleaved_transposed_qs16", "gemm.cl" },
diff --git a/src/core/CL/cl_kernels/gemm.cl b/src/core/CL/cl_kernels/gemm.cl
index 381130e..7215f58 100644
--- a/src/core/CL/cl_kernels/gemm.cl
+++ b/src/core/CL/cl_kernels/gemm.cl
@@ -184,12 +184,12 @@
  * @param[in]  dst_step_y                         dst_stride_y * number of elements along Y processed per workitem(in bytes)
  * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination matrix
  */
-__kernel void gemm_mm_interleaved_transposed_f32_midgard(IMAGE_DECLARATION(src0),
-                                                         IMAGE_DECLARATION(src1),
-                                                         IMAGE_DECLARATION(dst),
-                                                         uint src0_stride_z,
-                                                         uint src1_stride_z,
-                                                         uint dst_stride_z)
+__kernel void gemm_mm_interleaved_transposed_f32(IMAGE_DECLARATION(src0),
+                                                 IMAGE_DECLARATION(src1),
+                                                 IMAGE_DECLARATION(dst),
+                                                 uint src0_stride_z,
+                                                 uint src1_stride_z,
+                                                 uint dst_stride_z)
 {
     int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
     int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
@@ -670,6 +670,215 @@
     vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
     vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
 }
+
+/** This OpenCL kernel optimized for Bifrost architectures computes the matrix multiplication between matrix A (src0) and matrix B (src1)
+ *  Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
+ *
+ * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
+ * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
+ * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
+ * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
+ *       This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
+ *
+ * @param[in]  src0_ptr                           Pointer to the source matrix. Supported data types: F16
+ * @param[in]  src0_stride_x                      Stride of the source matrix in X dimension (in bytes)
+ * @param[in]  src0_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in]  src0_stride_y                      Stride of the source matrix in Y dimension (in bytes)
+ * @param[in]  src0_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in]  src0_offset_first_element_in_bytes The offset of the first element in the source matrix
+ * @param[in]  src1_ptr                           Pointer to the source matrix. Supported data types: same as @p src0_ptr
+ * @param[in]  src1_stride_x                      Stride of the source matrix in X dimension (in bytes)
+ * @param[in]  src1_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in]  src1_stride_y                      Stride of the source matrix in Y dimension (in bytes)
+ * @param[in]  src1_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in]  src1_offset_first_element_in_bytes The offset of the first element in the source matrix
+ * @param[out] dst_ptr                            Pointer to the destination matrix Supported data types: same as @p src0_ptr
+ * @param[in]  dst_stride_x                       Stride of the destination matrix in X dimension (in bytes)
+ * @param[in]  dst_step_x                         dst_stride_x * number of elements along X processed per workitem(in bytes)
+ * @param[in]  dst_stride_y                       Stride of the destination matrix in Y dimension (in bytes)
+ * @param[in]  dst_step_y                         dst_stride_y * number of elements along Y processed per workitem(in bytes)
+ * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination matrix
+ */
+__kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0),
+                                                         IMAGE_DECLARATION(src1),
+                                                         IMAGE_DECLARATION(dst),
+                                                         uint src0_stride_z,
+                                                         uint src1_stride_z,
+                                                         uint dst_stride_z)
+{
+    int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
+    int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
+    int z = get_global_id(2);
+
+    // Offset
+    const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
+    const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
+
+    // src_addr_a = address of matrix A
+    // src_addr_b = address of matrix B
+    int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
+    int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
+
+#if defined(MATRIX_B_DEPTH)
+    // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
+    src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
+#else  // defined(MATRIX_B_DEPTH)
+    src1_addr_in_bytes += z * src1_stride_z;
+#endif // defined(MATRIX_B_DEPTH)
+
+    __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
+    __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
+
+    // Compute end row address for matrix B
+    __global half *src_end_addr_b = src_addr_b + COLS_B;
+
+    src_addr_a += offset_row_a;
+    src_addr_b += offset_row_b;
+
+    // Reset accumulators
+    half8 c00 = 0.0f;
+    half8 c10 = 0.0f;
+    half8 c20 = 0.0f;
+    half8 c30 = 0.0f;
+
+#define COLS_MTX_B (COLS_B / (8 * MULT_TRANSPOSE1XW_WIDTH))
+
+    int i = 0;
+    for(; i <= (int)(COLS_MTX_B - 4); i += 4)
+    {
+#if MULT_INTERLEAVE4X4_HEIGHT == 1
+        // Load values from matrix A (interleaved) and matrix B (transposed)
+        half8 a0 = vload8(0, src_addr_a);
+        half8 b0 = vload8(0, src_addr_b);
+
+        src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT;
+        src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
+
+        c00 = fma((half8)a0.s0, b0, c00);
+        c10 = fma((half8)a0.s1, b0, c10);
+        c20 = fma((half8)a0.s2, b0, c20);
+        c30 = fma((half8)a0.s3, b0, c30);
+
+        // Load values from matrix B (transposed)
+        b0 = vload8(0, src_addr_b);
+
+        src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
+
+        c00 = fma((half8)a0.s4, b0, c00);
+        c10 = fma((half8)a0.s5, b0, c10);
+        c20 = fma((half8)a0.s6, b0, c20);
+        c30 = fma((half8)a0.s7, b0, c30);
+
+        // Load values from matrix A (interleaved) and matrix B (transposed)
+        a0 = vload8(0, src_addr_a);
+        b0 = vload8(0, src_addr_b);
+
+        src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT;
+        src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
+
+        c00 = fma((half8)a0.s0, b0, c00);
+        c10 = fma((half8)a0.s1, b0, c10);
+        c20 = fma((half8)a0.s2, b0, c20);
+        c30 = fma((half8)a0.s3, b0, c30);
+
+        // Load values from matrix B (transposed)
+        b0 = vload8(0, src_addr_b);
+
+        src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
+
+        c00 = fma((half8)a0.s4, b0, c00);
+        c10 = fma((half8)a0.s5, b0, c10);
+        c20 = fma((half8)a0.s6, b0, c20);
+        c30 = fma((half8)a0.s7, b0, c30);
+#else  // MULT_INTERLEAVE4X4_HEIGHT == 1
+        // Load values from matrix A (interleaved) and matrix B (transposed)
+        half4 a0 = vload4(0, src_addr_a);
+        half8 b0 = vload8(0, src_addr_b);
+
+        src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
+        src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
+
+        c00 = fma((half8)a0.s0, b0, c00);
+        c10 = fma((half8)a0.s1, b0, c10);
+        c20 = fma((half8)a0.s2, b0, c20);
+        c30 = fma((half8)a0.s3, b0, c30);
+
+        // Load values from matrix A (interleaved) and matrix B (transposed)
+        a0 = vload4(0, src_addr_a);
+        b0 = vload8(0, src_addr_b);
+
+        src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
+        src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
+
+        c00 = fma((half8)a0.s0, b0, c00);
+        c10 = fma((half8)a0.s1, b0, c10);
+        c20 = fma((half8)a0.s2, b0, c20);
+        c30 = fma((half8)a0.s3, b0, c30);
+
+        // Load values from matrix A (interleaved) and matrix B (transposed)
+        a0 = vload4(0, src_addr_a);
+        b0 = vload8(0, src_addr_b);
+
+        src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
+        src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
+
+        c00 = fma((half8)a0.s0, b0, c00);
+        c10 = fma((half8)a0.s1, b0, c10);
+        c20 = fma((half8)a0.s2, b0, c20);
+        c30 = fma((half8)a0.s3, b0, c30);
+
+        // Load values from matrix A (interleaved) and matrix B (transposed)
+        a0 = vload4(0, src_addr_a);
+        b0 = vload8(0, src_addr_b);
+
+        src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
+        src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
+
+        c00 = fma((half8)a0.s0, b0, c00);
+        c10 = fma((half8)a0.s1, b0, c10);
+        c20 = fma((half8)a0.s2, b0, c20);
+        c30 = fma((half8)a0.s3, b0, c30);
+#endif // MULT_INTERLEAVE4X4_HEIGHT == 1
+    }
+
+    for(; i < (int)(COLS_MTX_B); ++i)
+    {
+        // Load values from matrix A (interleaved) and matrix B (transposed)
+        half4 a0 = vload4(0, src_addr_a);
+        half8 b0 = vload8(0, src_addr_b);
+
+        src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
+        src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
+
+        c00 = fma((half8)a0.s0, b0, c00);
+        c10 = fma((half8)a0.s1, b0, c10);
+        c20 = fma((half8)a0.s2, b0, c20);
+        c30 = fma((half8)a0.s3, b0, c30);
+    }
+
+    // Compute destination address
+    Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
+
+#if defined(ALPHA)
+    // Multiply by the weight of matrix product
+    c00 = c00 * (half8)ALPHA;
+    c10 = c10 * (half8)ALPHA;
+    c20 = c20 * (half8)ALPHA;
+    c30 = c30 * (half8)ALPHA;
+#endif // defined(ALPHA)
+
+    // Compute dst address
+    __global uchar *dst_addr = offset(&dst, 0, 0);
+
+    // Add offset for batched GEMM
+    dst_addr += z * dst_stride_z;
+
+    // Store 4x8 block
+    vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
+    vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
+    vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
+    vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
+}
 #endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
 
 #if defined(FIXED_POINT_POSITION)
diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
index 2761247..674937e 100644
--- a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp
@@ -265,6 +265,8 @@
     // Do not slide matrix B if _slide_matrix_b = false
     build_opts.add_option_if(!_slide_matrix_b, "-DMATRIX_B_DEPTH=" + support::cpp11::to_string(input1->info()->dimension(2)));
 
+    const bool is_bifrost = get_arch_from_target(gpu_target) == GPUTarget::BIFROST;
+
     std::string kernel_name;
     if(is_interleaved_transposed)
     {
@@ -275,10 +277,9 @@
         build_opts.add_option("-DMULT_TRANSPOSE1XW_WIDTH=" + support::cpp11::to_string(mult_transpose1xW_width));
         build_opts.add_option("-DMULT_INTERLEAVE4X4_HEIGHT=" + support::cpp11::to_string(mult_interleave4x4_height));
 
-        if(data_type == DataType::F32)
+        if(is_data_type_float(data_type) && is_bifrost)
         {
-            GPUTarget arch_target = get_arch_from_target(gpu_target);
-            kernel_name           = "gemm_mm_interleaved_transposed_f32_" + string_from_target(arch_target);
+            kernel_name = "gemm_mm_interleaved_transposed_" + lower_string(string_from_data_type(data_type)) + "_bifrost";
         }
         else
         {
@@ -291,7 +292,7 @@
         build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(data_type));
 
         // Create kernels according to the architecture, data type and input size.
-        if(gpu_target_is_in(gpu_target, GPUTarget::G71, GPUTarget::G72, GPUTarget::G51, GPUTarget::G51BIG, GPUTarget::G51LIT, GPUTarget::TNOX) && is_data_type_float(data_type))
+        if(is_data_type_float(data_type) && is_bifrost)
         {
             kernel_name = "gemm_mm_floating_point";
 
diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp
index e735adb..1ee51a0 100644
--- a/src/runtime/CL/functions/CLGEMM.cpp
+++ b/src/runtime/CL/functions/CLGEMM.cpp
@@ -32,6 +32,7 @@
 #include "arm_compute/core/Helpers.h"
 #include "arm_compute/core/TensorInfo.h"
 #include "arm_compute/core/Types.h"
+#include "arm_compute/core/Utils.h"
 #include "arm_compute/core/Validate.h"
 #include "arm_compute/runtime/CL/CLScheduler.h"
 #include "arm_compute/runtime/ITensorAllocator.h"
@@ -47,7 +48,7 @@
     if(gpu_target_is_in(gpu_target, GPUTarget::G71, GPUTarget::G72, GPUTarget::G51, GPUTarget::G51BIG, GPUTarget::G51LIT, GPUTarget::TNOX))
     {
         // COMPMID-852
-        if(k > 256 && m > 4 && data_type == DataType::F32 && reshape_b_only_on_first_run)
+        if(k > 256 && m > 4 && is_data_type_float(data_type) && reshape_b_only_on_first_run)
         {
             const float scale = k < 1024 ? 2.0f : 2.5f;
             flag              = (scale * n) > ((1.66f * n) + 38.4f);