COMPMID-2571: Add support for FP16 in CLGEMMReshaped - part 1

Change-Id: I8adb8850cc5ade49ebc1dbf63401f03d5ecad708
Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Reviewed-on: https://review.mlplatform.org/c/1983
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/core/CL/cl_kernels/gemm.cl b/src/core/CL/cl_kernels/gemm.cl
index 8e628e8..c35d160 100644
--- a/src/core/CL/cl_kernels/gemm.cl
+++ b/src/core/CL/cl_kernels/gemm.cl
@@ -2041,79 +2041,37 @@
 #define VTYPE(TYPE, SIZE) VEC_DATA_TYPE(TYPE, SIZE)
 
 #if GPU_ARCH == GPU_ARCH_MIDGARD
-#define ARM_VFMA(SIZE, a, b, c) c += (a) * (b);
+#define ARM_VFMA(a, b, c) c += (a) * (b);
 #else // GPU_ARCH == GPU_ARCH_MIDGARD
-#define ARM_VFMA_1(a, b, c)     \
-    ({                          \
-        c = fma((a), (b), (c)); \
-    })
-#define ARM_VFMA_2(a, b, c)                   \
-    ({                                        \
-        (c).s0 = fma((a).s0, (b).s0, (c).s0); \
-        (c).s1 = fma((a).s1, (b).s1, (c).s1); \
-    })
-#define ARM_VFMA_3(a, b, c)                   \
-    ({                                        \
-        ARM_VFMA_2(a, b, c);                  \
-        (c).s2 = fma((a).s2, (b).s2, (c).s2); \
-    })
-#define ARM_VFMA_4(a, b, c)                   \
-    ({                                        \
-        ARM_VFMA_3(a, b, c);                  \
-        (c).s3 = fma((a).s3, (b).s3, (c).s3); \
-    })
-#define ARM_VFMA_8(a, b, c)                   \
-    ({                                        \
-        ARM_VFMA_4(a, b, c);                  \
-        (c).s4 = fma((a).s4, (b).s4, (c).s4); \
-        (c).s5 = fma((a).s5, (b).s5, (c).s5); \
-        (c).s6 = fma((a).s6, (b).s6, (c).s6); \
-        (c).s7 = fma((a).s7, (b).s7, (c).s7); \
-    })
-#define ARM_VFMA_16(a, b, c)                  \
-    ({                                        \
-        ARM_VFMA_8(a, b, c);                  \
-        (c).s8 = fma((a).s8, (b).s8, (c).s8); \
-        (c).s9 = fma((a).s9, (b).s9, (c).s9); \
-        (c).sA = fma((a).sA, (b).sA, (c).sA); \
-        (c).sB = fma((a).sB, (b).sB, (c).sB); \
-        (c).sC = fma((a).sC, (b).sC, (c).sC); \
-        (c).sD = fma((a).sD, (b).sD, (c).sD); \
-        (c).sE = fma((a).sE, (b).sE, (c).sE); \
-        (c).sF = fma((a).sF, (b).sF, (c).sF); \
-    })
-
-// Factory macro for the vector FMA
-#define ARM_VFMA(SIZE, a, b, c) ARM_VFMA_##SIZE((a), (b), (c))
-
+#define ARM_VFMA(a, b, c) c = fma((a), (b), (c));
 #endif // GPU_ARCH == GPU_ARCH_MIDGARD
 
-#define ARM_VVM_T_NT_1xN0x1(N0, TYPE, a, b, C)         \
-    ({                                                 \
-        ARM_VFMA(N0, (VTYPE(TYPE, N0))(a), b, (C##0)); \
+#define ARM_VVM_T_NT_1xN0x1(N0, TYPE, a, b, C)     \
+    ({                                             \
+        ARM_VFMA((VTYPE(TYPE, N0))(a), b, (C##0)); \
     })
-#define ARM_VVM_T_NT_2xN0x1(N0, TYPE, a, b, C)            \
-    ({                                                    \
-        ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s0), b, (C##0)); \
-        ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s1), b, (C##1)); \
+#define ARM_VVM_T_NT_2xN0x1(N0, TYPE, a, b, C)        \
+    ({                                                \
+        ARM_VFMA((VTYPE(TYPE, N0))(a.s0), b, (C##0)); \
+        ARM_VFMA((VTYPE(TYPE, N0))(a.s1), b, (C##1)); \
     })
-#define ARM_VVM_T_NT_3xN0x1(N0, TYPE, a, b, C)            \
-    ({                                                    \
-        ARM_VVM_T_NT_2xN0x1(N0, TYPE, a, b, C);           \
-        ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s2), b, (C##2)); \
+#define ARM_VVM_T_NT_3xN0x1(N0, TYPE, a, b, C)        \
+    ({                                                \
+        ARM_VVM_T_NT_2xN0x1(N0, TYPE, a, b, C);       \
+        ARM_VFMA((VTYPE(TYPE, N0))(a.s2), b, (C##2)); \
     })
-#define ARM_VVM_T_NT_4xN0x1(N0, TYPE, a, b, C)            \
-    ({                                                    \
-        ARM_VVM_T_NT_3xN0x1(N0, TYPE, a, b, C);           \
-        ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s3), b, (C##3)); \
+#define ARM_VVM_T_NT_4xN0x1(N0, TYPE, a, b, C)        \
+    ({                                                \
+        ARM_VVM_T_NT_3xN0x1(N0, TYPE, a, b, C);       \
+        ARM_VFMA((VTYPE(TYPE, N0))(a.s3), b, (C##3)); \
     })
-#define ARM_VVM_T_NT_8xN0x1(N0, TYPE, a, b, C)            \
-    ({                                                    \
-        ARM_VVM_T_NT_4xN0x1(N0, TYPE, a, b, C);           \
-        ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s4), b, (C##4)); \
-        ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s5), b, (C##5)); \
-        ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s6), b, (C##6)); \
-        ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s7), b, (C##7)); \
+#define ARM_VVM_T_NT_8xN0x1(N0, TYPE, a, b, C)        \
+    ({                                                \
+        ARM_VVM_T_NT_4xN0x1(N0, TYPE, a, b, C);       \
+        ARM_VFMA((VTYPE(TYPE, N0))(a.s4), b, (C##4)); \
+        ARM_VFMA((VTYPE(TYPE, N0))(a.s5), b, (C##5)); \
+        ARM_VFMA((VTYPE(TYPE, N0))(a.s6), b, (C##6)); \
+        ARM_VFMA((VTYPE(TYPE, N0))(a.s7), b, (C##7)); \
     })
 
 // Factory macro for the column-vector (transposed) by row-vector (not transposed) multiplication. K0 = 1
@@ -2172,7 +2130,8 @@
 // K0: 1, 2, 3, 4, 8, 16
 // This macro calls the vector-by-matrix macro K0 times
 // A, B and C are matrices
-#define ARM_MM_T_NT(M0, N0, K0, TYPE, A, B, C) CONCAT(ARM_MM_T_NT_M0xN0x, K0) \
+#define ARM_MM_T_NT(M0, N0, K0, TYPE, A, B, C) \
+    CONCAT(ARM_MM_T_NT_M0xN0x, K0)             \
     (M0, N0, TYPE, A, B, C)
 
 /** This OpenCL kernel computes the matrix multiplication between 2 matrices.
@@ -2272,11 +2231,9 @@
 #if defined(RHS_INTERLEAVE)
 #define RHS_OFFSET_X (N0)
 #define RHS_STEP_X ((N0) * (H0))
-#define RHS_STEP_LOOP (1)
 #else // defined(RHS_INTERLEAVE)
 #define RHS_OFFSET_X (RHS_BLOCK_SIZE)
 #define RHS_STEP_X (N0)
-#define RHS_STEP_LOOP (H0)
 #endif // defined(RHS_INTERLEAVE)
 
     const uint x = get_global_id(0);
@@ -2306,28 +2263,160 @@
     // Initialize the accumulators
     REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0)    c0=0,c1=0,c2=0,... c(M0-1)=0;
 
-    REPEAT_VAR_INIT_TO_CONST(K0, uint, zlhs, 0); //uint zlhs0=0,zlhs1=0,zlhs2=0,... zlhs7=0;
     REPEAT_VAR_INIT_TO_CONST(M0, uint, zero, 0);
 
+    __global DATA_TYPE *lhs = (__global DATA_TYPE *)(lhs_addr);
+    __global DATA_TYPE *rhs = (__global DATA_TYPE *)(rhs_addr);
+
     for(int i = 0; i < k; i += K0)
     {
-        // Supported cases (K0, M0):
-        // 1,2  - 2,2  - 3,2  - 4,2  - 5,2  - 6,2  - 7,2  - 8,2
-        // 1,3  - 2,3  - 3,3  - 4,3  - 5,3  - 6,3  - 7,3  - 8,3
-        // 1,4  - 2,4  - 3,4  - 4,4  - 5,4  - 6,4  - 7,4  - 8,4
-        // 1,8  - 2,8  - 3,8 -  4,8  - 5,8  - 6,8  - 7,8  - 8,8
-        // 1,16 - 2,16 - 3,16 - 4,16 - 5,16 - 6,16 - 7,16 - 8,16
-        // Load values from LHS matrix
-        LOAD_BLOCK(K0, M0, DATA_TYPE, a, lhs_addr, 0, LHS_STEP_X * sizeof(DATA_TYPE), zlhs);
+        VEC_DATA_TYPE(DATA_TYPE, M0)
+        a0 = VLOAD(M0)(0, lhs);
+        VEC_DATA_TYPE(DATA_TYPE, N0)
+        b0 = VLOAD(N0)(0, rhs);
 
-        // Load values from RHS matrix
-        LOAD_BLOCK(K0, N0, DATA_TYPE, b, rhs_addr, 0, RHS_STEP_X * sizeof(DATA_TYPE), zlhs);
+        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
 
-        // Perform the partial matrix multiplication
-        ARM_MM_T_NT(M0, N0, K0, DATA_TYPE, a, b, c);
+        lhs += LHS_STEP_X;
+        rhs += RHS_STEP_X;
 
-        lhs_addr += (K0 * LHS_STEP_X * LHS_STEP_LOOP) * sizeof(DATA_TYPE);
-        rhs_addr += (K0 * RHS_STEP_X * RHS_STEP_LOOP) * sizeof(DATA_TYPE);
+#if K0 > 1
+        a0 = VLOAD(M0)(0, lhs);
+        b0 = VLOAD(N0)(0, rhs);
+
+        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+        lhs += LHS_STEP_X;
+        rhs += RHS_STEP_X;
+#endif // K0 > 1
+
+#if K0 > 2
+        a0 = VLOAD(M0)(0, lhs);
+        b0 = VLOAD(N0)(0, rhs);
+
+        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+        lhs += LHS_STEP_X;
+        rhs += RHS_STEP_X;
+#endif // K0 > 2
+
+#if K0 > 3
+        a0 = VLOAD(M0)(0, lhs);
+        b0 = VLOAD(N0)(0, rhs);
+
+        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+        lhs += LHS_STEP_X;
+        rhs += RHS_STEP_X;
+#endif // K0 > 3
+
+#if K0 > 4
+        a0 = VLOAD(M0)(0, lhs);
+        b0 = VLOAD(N0)(0, rhs);
+
+        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+        lhs += LHS_STEP_X;
+        rhs += RHS_STEP_X;
+
+        a0 = VLOAD(M0)(0, lhs);
+        b0 = VLOAD(N0)(0, rhs);
+
+        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+        lhs += LHS_STEP_X;
+        rhs += RHS_STEP_X;
+
+        a0 = VLOAD(M0)(0, lhs);
+        b0 = VLOAD(N0)(0, rhs);
+
+        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+        lhs += LHS_STEP_X;
+        rhs += RHS_STEP_X;
+
+        a0 = VLOAD(M0)(0, lhs);
+        b0 = VLOAD(N0)(0, rhs);
+
+        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+        lhs += LHS_STEP_X;
+        rhs += RHS_STEP_X;
+#endif // K0 > 4
+
+#if K0 > 8
+        a0 = VLOAD(M0)(0, lhs);
+        b0 = VLOAD(N0)(0, rhs);
+
+        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+        lhs += LHS_STEP_X;
+        rhs += RHS_STEP_X;
+
+        a0 = VLOAD(M0)(0, lhs);
+        b0 = VLOAD(N0)(0, rhs);
+
+        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+        lhs += LHS_STEP_X;
+        rhs += RHS_STEP_X;
+
+        a0 = VLOAD(M0)(0, lhs);
+        b0 = VLOAD(N0)(0, rhs);
+
+        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+        lhs += LHS_STEP_X;
+        rhs += RHS_STEP_X;
+
+        a0 = VLOAD(M0)(0, lhs);
+        b0 = VLOAD(N0)(0, rhs);
+
+        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+        lhs += LHS_STEP_X;
+        rhs += RHS_STEP_X;
+
+        a0 = VLOAD(M0)(0, lhs);
+        b0 = VLOAD(N0)(0, rhs);
+
+        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+        lhs += LHS_STEP_X;
+        rhs += RHS_STEP_X;
+
+        a0 = VLOAD(M0)(0, lhs);
+        b0 = VLOAD(N0)(0, rhs);
+
+        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+        lhs += LHS_STEP_X;
+        rhs += RHS_STEP_X;
+
+        a0 = VLOAD(M0)(0, lhs);
+        b0 = VLOAD(N0)(0, rhs);
+
+        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+        lhs += LHS_STEP_X;
+        rhs += RHS_STEP_X;
+
+        a0 = VLOAD(M0)(0, lhs);
+        b0 = VLOAD(N0)(0, rhs);
+
+        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+        lhs += LHS_STEP_X;
+        rhs += RHS_STEP_X;
+#endif // K0 > 8
+
+#ifndef LHS_INTERLEAVE
+        lhs += (M0 * K0 * (V0 - 1));
+#endif // LHS_INTERLEAVE
+
+#ifndef RHS_INTERLEAVE
+        rhs += (N0 * K0 * (H0 - 1));
+#endif // RHS_INTERLEAVE
     }
 
     __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * dst_stride_y);
diff --git a/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp b/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp
index b791c1c..0c2942a 100644
--- a/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp
+++ b/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp
@@ -42,8 +42,7 @@
 
 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfigurationBifrost::configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type)
 {
-    ARM_COMPUTE_ERROR_ON(data_type != DataType::F32 && data_type != DataType::QASYMM8);
-    ARM_COMPUTE_UNUSED(data_type);
+    ARM_COMPUTE_ERROR_ON(data_type != DataType::F32 && data_type != DataType::F16 && data_type != DataType::QASYMM8);
 
     using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> (CLGEMMReshapedKernelConfigurationBifrost::*)(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
 
@@ -51,6 +50,7 @@
     static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G76 =
     {
         { DataType::F32, &CLGEMMReshapedKernelConfigurationBifrost::configure_G76_f32 },
+        { DataType::F16, &CLGEMMReshapedKernelConfigurationBifrost::configure_G76_f16 },
         { DataType::QASYMM8, &CLGEMMReshapedKernelConfigurationBifrost::configure_G76_u8 }
     };
 
@@ -58,6 +58,7 @@
     static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G7x =
     {
         { DataType::F32, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_f32 },
+        { DataType::F16, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_f16 },
         { DataType::QASYMM8, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_u8 }
     };
 
@@ -85,6 +86,21 @@
     }
 }
 
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+{
+    ARM_COMPUTE_UNUSED(k);
+    ARM_COMPUTE_UNUSED(b);
+
+    if(n <= 4)
+    {
+        return configure_lhs_rhs_info(m, n, 4, 2, 8, 8, 2, true, true, true, false);
+    }
+    else
+    {
+        return configure_lhs_rhs_info(m, n, 4, 8, 4, 4, 2, true, true, true, false);
+    }
+}
+
 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
 {
     ARM_COMPUTE_UNUSED(k);
@@ -129,6 +145,21 @@
     }
 }
 
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfigurationBifrost::configure_G76_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+{
+    ARM_COMPUTE_UNUSED(k);
+    ARM_COMPUTE_UNUSED(b);
+
+    if(n <= 4)
+    {
+        return configure_lhs_rhs_info(m, n, 4, 4, 4, 8, 2, true, true, true, false);
+    }
+    else
+    {
+        return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 8, true, true, true, false);
+    }
+}
+
 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfigurationBifrost::configure_G76_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
 {
     ARM_COMPUTE_UNUSED(k);
diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp
index 222a63d..f77ab02 100644
--- a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp
@@ -63,7 +63,7 @@
     ARM_COMPUTE_UNUSED(alpha);
     ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input0, input1, output);
     ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input0);
-    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::F32);
+    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::F16, DataType::F32);
     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1);
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(input0->num_dimensions() > 4, "The number of dimensions for the LHS matrix must be <= 4");
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->num_dimensions() > 3, "The number of dimensions for the RHS matrix must be <= 3");