COMPMID-2571: Add support for FP16 in CLGEMMReshaped - part 1

Change-Id: I8adb8850cc5ade49ebc1dbf63401f03d5ecad708
Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Reviewed-on: https://review.mlplatform.org/c/1983
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
diff --git a/arm_compute/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.h b/arm_compute/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.h
index a0aae19..3ce2776 100644
--- a/arm_compute/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.h
+++ b/arm_compute/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.h
@@ -54,6 +54,8 @@
 private:
     std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G7x_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
     std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
+    std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G7x_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
+    std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G76_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
     std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G7x_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
     std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> configure_G76_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
 };
diff --git a/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.h b/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.h
index 2a76f44..e6469f0 100644
--- a/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.h
+++ b/arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.h
@@ -51,7 +51,7 @@
     CLGEMMMatrixMultiplyReshapedKernel &operator=(CLGEMMMatrixMultiplyReshapedKernel &&) = default;
     /** Initialise the kernel's input and output.
      *
-     * @param[in]  input0    Input tensor containing the LHS reshaped matrix. Data type supported: F32. The number of dimensions for the LHS matrix must be less or equal than 4
+     * @param[in]  input0    Input tensor containing the LHS reshaped matrix. Data type supported: F16/F32. The number of dimensions for the LHS matrix must be less or equal than 4
      * @param[in]  input1    Input tensor containing the RHS reshaped matrix. Data type supported: same as @p input0. The number of dimensions for the RHS matrix must be less or equal than 3
      * @param[in]  input2    Input tensor containing the bias matrix. Data type supported: same as @p input0.
      * @param[out] output    Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0
@@ -74,7 +74,7 @@
                    const GEMMKernelInfo    &gemm_info);
     /** Static function to check if given info will lead to a valid configuration of @ref CLGEMMMatrixMultiplyReshapedKernel
      *
-     * @param[in] input0    Input tensor containing the LHS reshaped matrix. Data type supported: F32. The number of dimensions for the LHS matrix must be less or equal than 4
+     * @param[in] input0    Input tensor containing the LHS reshaped matrix. Data type supported: F16/F32. The number of dimensions for the LHS matrix must be less or equal than 4
      * @param[in] input1    Input tensor containing the RHS reshaped matrix. Data type supported: same as @p input0. The number of dimensions for the RHS matrix must be less or equal than 3
      * @param[in] input2    Input tensor info containing the bias matrix. Data type supported: same as @p input0.
      * @param[in] output    Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0
diff --git a/src/core/CL/cl_kernels/gemm.cl b/src/core/CL/cl_kernels/gemm.cl
index 8e628e8..c35d160 100644
--- a/src/core/CL/cl_kernels/gemm.cl
+++ b/src/core/CL/cl_kernels/gemm.cl
@@ -2041,79 +2041,37 @@
 #define VTYPE(TYPE, SIZE) VEC_DATA_TYPE(TYPE, SIZE)
 
 #if GPU_ARCH == GPU_ARCH_MIDGARD
-#define ARM_VFMA(SIZE, a, b, c) c += (a) * (b);
+#define ARM_VFMA(a, b, c) c += (a) * (b);
 #else // GPU_ARCH == GPU_ARCH_MIDGARD
-#define ARM_VFMA_1(a, b, c)     \
-    ({                          \
-        c = fma((a), (b), (c)); \
-    })
-#define ARM_VFMA_2(a, b, c)                   \
-    ({                                        \
-        (c).s0 = fma((a).s0, (b).s0, (c).s0); \
-        (c).s1 = fma((a).s1, (b).s1, (c).s1); \
-    })
-#define ARM_VFMA_3(a, b, c)                   \
-    ({                                        \
-        ARM_VFMA_2(a, b, c);                  \
-        (c).s2 = fma((a).s2, (b).s2, (c).s2); \
-    })
-#define ARM_VFMA_4(a, b, c)                   \
-    ({                                        \
-        ARM_VFMA_3(a, b, c);                  \
-        (c).s3 = fma((a).s3, (b).s3, (c).s3); \
-    })
-#define ARM_VFMA_8(a, b, c)                   \
-    ({                                        \
-        ARM_VFMA_4(a, b, c);                  \
-        (c).s4 = fma((a).s4, (b).s4, (c).s4); \
-        (c).s5 = fma((a).s5, (b).s5, (c).s5); \
-        (c).s6 = fma((a).s6, (b).s6, (c).s6); \
-        (c).s7 = fma((a).s7, (b).s7, (c).s7); \
-    })
-#define ARM_VFMA_16(a, b, c)                  \
-    ({                                        \
-        ARM_VFMA_8(a, b, c);                  \
-        (c).s8 = fma((a).s8, (b).s8, (c).s8); \
-        (c).s9 = fma((a).s9, (b).s9, (c).s9); \
-        (c).sA = fma((a).sA, (b).sA, (c).sA); \
-        (c).sB = fma((a).sB, (b).sB, (c).sB); \
-        (c).sC = fma((a).sC, (b).sC, (c).sC); \
-        (c).sD = fma((a).sD, (b).sD, (c).sD); \
-        (c).sE = fma((a).sE, (b).sE, (c).sE); \
-        (c).sF = fma((a).sF, (b).sF, (c).sF); \
-    })
-
-// Factory macro for the vector FMA
-#define ARM_VFMA(SIZE, a, b, c) ARM_VFMA_##SIZE((a), (b), (c))
-
+#define ARM_VFMA(a, b, c) c = fma((a), (b), (c));
 #endif // GPU_ARCH == GPU_ARCH_MIDGARD
 
-#define ARM_VVM_T_NT_1xN0x1(N0, TYPE, a, b, C)         \
-    ({                                                 \
-        ARM_VFMA(N0, (VTYPE(TYPE, N0))(a), b, (C##0)); \
+#define ARM_VVM_T_NT_1xN0x1(N0, TYPE, a, b, C)     \
+    ({                                             \
+        ARM_VFMA((VTYPE(TYPE, N0))(a), b, (C##0)); \
     })
-#define ARM_VVM_T_NT_2xN0x1(N0, TYPE, a, b, C)            \
-    ({                                                    \
-        ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s0), b, (C##0)); \
-        ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s1), b, (C##1)); \
+#define ARM_VVM_T_NT_2xN0x1(N0, TYPE, a, b, C)        \
+    ({                                                \
+        ARM_VFMA((VTYPE(TYPE, N0))(a.s0), b, (C##0)); \
+        ARM_VFMA((VTYPE(TYPE, N0))(a.s1), b, (C##1)); \
     })
-#define ARM_VVM_T_NT_3xN0x1(N0, TYPE, a, b, C)            \
-    ({                                                    \
-        ARM_VVM_T_NT_2xN0x1(N0, TYPE, a, b, C);           \
-        ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s2), b, (C##2)); \
+#define ARM_VVM_T_NT_3xN0x1(N0, TYPE, a, b, C)        \
+    ({                                                \
+        ARM_VVM_T_NT_2xN0x1(N0, TYPE, a, b, C);       \
+        ARM_VFMA((VTYPE(TYPE, N0))(a.s2), b, (C##2)); \
     })
-#define ARM_VVM_T_NT_4xN0x1(N0, TYPE, a, b, C)            \
-    ({                                                    \
-        ARM_VVM_T_NT_3xN0x1(N0, TYPE, a, b, C);           \
-        ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s3), b, (C##3)); \
+#define ARM_VVM_T_NT_4xN0x1(N0, TYPE, a, b, C)        \
+    ({                                                \
+        ARM_VVM_T_NT_3xN0x1(N0, TYPE, a, b, C);       \
+        ARM_VFMA((VTYPE(TYPE, N0))(a.s3), b, (C##3)); \
     })
-#define ARM_VVM_T_NT_8xN0x1(N0, TYPE, a, b, C)            \
-    ({                                                    \
-        ARM_VVM_T_NT_4xN0x1(N0, TYPE, a, b, C);           \
-        ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s4), b, (C##4)); \
-        ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s5), b, (C##5)); \
-        ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s6), b, (C##6)); \
-        ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s7), b, (C##7)); \
+#define ARM_VVM_T_NT_8xN0x1(N0, TYPE, a, b, C)        \
+    ({                                                \
+        ARM_VVM_T_NT_4xN0x1(N0, TYPE, a, b, C);       \
+        ARM_VFMA((VTYPE(TYPE, N0))(a.s4), b, (C##4)); \
+        ARM_VFMA((VTYPE(TYPE, N0))(a.s5), b, (C##5)); \
+        ARM_VFMA((VTYPE(TYPE, N0))(a.s6), b, (C##6)); \
+        ARM_VFMA((VTYPE(TYPE, N0))(a.s7), b, (C##7)); \
     })
 
 // Factory macro for the column-vector (transposed) by row-vector (not transposed) multiplication. K0 = 1
@@ -2172,7 +2130,8 @@
 // K0: 1, 2, 3, 4, 8, 16
 // This macro calls the vector-by-matrix macro K0 times
 // A, B and C are matrices
-#define ARM_MM_T_NT(M0, N0, K0, TYPE, A, B, C) CONCAT(ARM_MM_T_NT_M0xN0x, K0) \
+#define ARM_MM_T_NT(M0, N0, K0, TYPE, A, B, C) \
+    CONCAT(ARM_MM_T_NT_M0xN0x, K0)             \
     (M0, N0, TYPE, A, B, C)
 
 /** This OpenCL kernel computes the matrix multiplication between 2 matrices.
@@ -2272,11 +2231,9 @@
 #if defined(RHS_INTERLEAVE)
 #define RHS_OFFSET_X (N0)
 #define RHS_STEP_X ((N0) * (H0))
-#define RHS_STEP_LOOP (1)
 #else // defined(RHS_INTERLEAVE)
 #define RHS_OFFSET_X (RHS_BLOCK_SIZE)
 #define RHS_STEP_X (N0)
-#define RHS_STEP_LOOP (H0)
 #endif // defined(RHS_INTERLEAVE)
 
     const uint x = get_global_id(0);
@@ -2306,28 +2263,160 @@
     // Initialize the accumulators
     REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0)    c0=0,c1=0,c2=0,... c(M0-1)=0;
 
-    REPEAT_VAR_INIT_TO_CONST(K0, uint, zlhs, 0); //uint zlhs0=0,zlhs1=0,zlhs2=0,... zlhs7=0;
     REPEAT_VAR_INIT_TO_CONST(M0, uint, zero, 0);
 
+    __global DATA_TYPE *lhs = (__global DATA_TYPE *)(lhs_addr);
+    __global DATA_TYPE *rhs = (__global DATA_TYPE *)(rhs_addr);
+
     for(int i = 0; i < k; i += K0)
     {
-        // Supported cases (K0, M0):
-        // 1,2  - 2,2  - 3,2  - 4,2  - 5,2  - 6,2  - 7,2  - 8,2
-        // 1,3  - 2,3  - 3,3  - 4,3  - 5,3  - 6,3  - 7,3  - 8,3
-        // 1,4  - 2,4  - 3,4  - 4,4  - 5,4  - 6,4  - 7,4  - 8,4
-        // 1,8  - 2,8  - 3,8 -  4,8  - 5,8  - 6,8  - 7,8  - 8,8
-        // 1,16 - 2,16 - 3,16 - 4,16 - 5,16 - 6,16 - 7,16 - 8,16
-        // Load values from LHS matrix
-        LOAD_BLOCK(K0, M0, DATA_TYPE, a, lhs_addr, 0, LHS_STEP_X * sizeof(DATA_TYPE), zlhs);
+        VEC_DATA_TYPE(DATA_TYPE, M0)
+        a0 = VLOAD(M0)(0, lhs);
+        VEC_DATA_TYPE(DATA_TYPE, N0)
+        b0 = VLOAD(N0)(0, rhs);
 
-        // Load values from RHS matrix
-        LOAD_BLOCK(K0, N0, DATA_TYPE, b, rhs_addr, 0, RHS_STEP_X * sizeof(DATA_TYPE), zlhs);
+        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
 
-        // Perform the partial matrix multiplication
-        ARM_MM_T_NT(M0, N0, K0, DATA_TYPE, a, b, c);
+        lhs += LHS_STEP_X;
+        rhs += RHS_STEP_X;
 
-        lhs_addr += (K0 * LHS_STEP_X * LHS_STEP_LOOP) * sizeof(DATA_TYPE);
-        rhs_addr += (K0 * RHS_STEP_X * RHS_STEP_LOOP) * sizeof(DATA_TYPE);
+#if K0 > 1
+        a0 = VLOAD(M0)(0, lhs);
+        b0 = VLOAD(N0)(0, rhs);
+
+        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+        lhs += LHS_STEP_X;
+        rhs += RHS_STEP_X;
+#endif // K0 > 1
+
+#if K0 > 2
+        a0 = VLOAD(M0)(0, lhs);
+        b0 = VLOAD(N0)(0, rhs);
+
+        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+        lhs += LHS_STEP_X;
+        rhs += RHS_STEP_X;
+#endif // K0 > 2
+
+#if K0 > 3
+        a0 = VLOAD(M0)(0, lhs);
+        b0 = VLOAD(N0)(0, rhs);
+
+        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+        lhs += LHS_STEP_X;
+        rhs += RHS_STEP_X;
+#endif // K0 > 3
+
+#if K0 > 4
+        a0 = VLOAD(M0)(0, lhs);
+        b0 = VLOAD(N0)(0, rhs);
+
+        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+        lhs += LHS_STEP_X;
+        rhs += RHS_STEP_X;
+
+        a0 = VLOAD(M0)(0, lhs);
+        b0 = VLOAD(N0)(0, rhs);
+
+        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+        lhs += LHS_STEP_X;
+        rhs += RHS_STEP_X;
+
+        a0 = VLOAD(M0)(0, lhs);
+        b0 = VLOAD(N0)(0, rhs);
+
+        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+        lhs += LHS_STEP_X;
+        rhs += RHS_STEP_X;
+
+        a0 = VLOAD(M0)(0, lhs);
+        b0 = VLOAD(N0)(0, rhs);
+
+        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+        lhs += LHS_STEP_X;
+        rhs += RHS_STEP_X;
+#endif // K0 > 4
+
+#if K0 > 8
+        a0 = VLOAD(M0)(0, lhs);
+        b0 = VLOAD(N0)(0, rhs);
+
+        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+        lhs += LHS_STEP_X;
+        rhs += RHS_STEP_X;
+
+        a0 = VLOAD(M0)(0, lhs);
+        b0 = VLOAD(N0)(0, rhs);
+
+        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+        lhs += LHS_STEP_X;
+        rhs += RHS_STEP_X;
+
+        a0 = VLOAD(M0)(0, lhs);
+        b0 = VLOAD(N0)(0, rhs);
+
+        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+        lhs += LHS_STEP_X;
+        rhs += RHS_STEP_X;
+
+        a0 = VLOAD(M0)(0, lhs);
+        b0 = VLOAD(N0)(0, rhs);
+
+        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+        lhs += LHS_STEP_X;
+        rhs += RHS_STEP_X;
+
+        a0 = VLOAD(M0)(0, lhs);
+        b0 = VLOAD(N0)(0, rhs);
+
+        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+        lhs += LHS_STEP_X;
+        rhs += RHS_STEP_X;
+
+        a0 = VLOAD(M0)(0, lhs);
+        b0 = VLOAD(N0)(0, rhs);
+
+        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+        lhs += LHS_STEP_X;
+        rhs += RHS_STEP_X;
+
+        a0 = VLOAD(M0)(0, lhs);
+        b0 = VLOAD(N0)(0, rhs);
+
+        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+        lhs += LHS_STEP_X;
+        rhs += RHS_STEP_X;
+
+        a0 = VLOAD(M0)(0, lhs);
+        b0 = VLOAD(N0)(0, rhs);
+
+        ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
+
+        lhs += LHS_STEP_X;
+        rhs += RHS_STEP_X;
+#endif // K0 > 8
+
+#ifndef LHS_INTERLEAVE
+        lhs += (M0 * K0 * (V0 - 1));
+#endif // LHS_INTERLEAVE
+
+#ifndef RHS_INTERLEAVE
+        rhs += (N0 * K0 * (H0 - 1));
+#endif // RHS_INTERLEAVE
     }
 
     __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * dst_stride_y);
diff --git a/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp b/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp
index b791c1c..0c2942a 100644
--- a/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp
+++ b/src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfigurationBifrost.cpp
@@ -42,8 +42,7 @@
 
 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfigurationBifrost::configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type)
 {
-    ARM_COMPUTE_ERROR_ON(data_type != DataType::F32 && data_type != DataType::QASYMM8);
-    ARM_COMPUTE_UNUSED(data_type);
+    ARM_COMPUTE_ERROR_ON(data_type != DataType::F32 && data_type != DataType::F16 && data_type != DataType::QASYMM8);
 
     using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> (CLGEMMReshapedKernelConfigurationBifrost::*)(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
 
@@ -51,6 +50,7 @@
     static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G76 =
     {
         { DataType::F32, &CLGEMMReshapedKernelConfigurationBifrost::configure_G76_f32 },
+        { DataType::F16, &CLGEMMReshapedKernelConfigurationBifrost::configure_G76_f16 },
         { DataType::QASYMM8, &CLGEMMReshapedKernelConfigurationBifrost::configure_G76_u8 }
     };
 
@@ -58,6 +58,7 @@
     static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G7x =
     {
         { DataType::F32, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_f32 },
+        { DataType::F16, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_f16 },
         { DataType::QASYMM8, &CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_u8 }
     };
 
@@ -85,6 +86,21 @@
     }
 }
 
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+{
+    ARM_COMPUTE_UNUSED(k);
+    ARM_COMPUTE_UNUSED(b);
+
+    if(n <= 4)
+    {
+        return configure_lhs_rhs_info(m, n, 4, 2, 8, 8, 2, true, true, true, false);
+    }
+    else
+    {
+        return configure_lhs_rhs_info(m, n, 4, 8, 4, 4, 2, true, true, true, false);
+    }
+}
+
 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfigurationBifrost::configure_G7x_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
 {
     ARM_COMPUTE_UNUSED(k);
@@ -129,6 +145,21 @@
     }
 }
 
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfigurationBifrost::configure_G76_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+{
+    ARM_COMPUTE_UNUSED(k);
+    ARM_COMPUTE_UNUSED(b);
+
+    if(n <= 4)
+    {
+        return configure_lhs_rhs_info(m, n, 4, 4, 4, 8, 2, true, true, true, false);
+    }
+    else
+    {
+        return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 8, true, true, true, false);
+    }
+}
+
 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedKernelConfigurationBifrost::configure_G76_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
 {
     ARM_COMPUTE_UNUSED(k);
diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp
index 222a63d..f77ab02 100644
--- a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp
@@ -63,7 +63,7 @@
     ARM_COMPUTE_UNUSED(alpha);
     ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input0, input1, output);
     ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(input0);
-    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::F32);
+    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::F16, DataType::F32);
     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1);
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(input0->num_dimensions() > 4, "The number of dimensions for the LHS matrix must be <= 4");
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->num_dimensions() > 3, "The number of dimensions for the RHS matrix must be <= 3");
diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp
index e78395f..762b001 100644
--- a/src/runtime/CL/functions/CLGEMM.cpp
+++ b/src/runtime/CL/functions/CLGEMM.cpp
@@ -65,37 +65,53 @@
 {
     GEMMType gemm_type = GEMMType::RESHAPED_V1;
 
-    if(gpu_target_is_in(gpu_target, GPUTarget::G52, GPUTarget::G52LIT, GPUTarget::G71, GPUTarget::G72, GPUTarget::G76))
+    if(gpu_target_is_in(gpu_target, GPUTarget::G51, GPUTarget::G51BIG, GPUTarget::G51LIT,
+                        GPUTarget::G52, GPUTarget::G52LIT, GPUTarget::G71, GPUTarget::G72,
+                        GPUTarget::G76, GPUTarget::G77))
     {
-        if((m > 1) && (n < 16))
+        if(data_type == DataType::F32)
         {
-            gemm_type = GEMMType::RESHAPED_V1;
-        }
-        else if((m == 1) && (data_type == DataType::F32))
-        {
-            gemm_type = GEMMType::RESHAPED_ONLY_RHS;
-        }
-        else
-        {
-            // COMPMID-852
-            if((k > 256) && (m > 4) && is_data_type_float(data_type) && reshape_b_only_on_first_run)
+            if((m > 1) && (n < 16))
             {
-                constexpr float alpha = 3.2f;
-                constexpr float fact0 = 1.51f;
-                constexpr float fact1 = 1.66f;
-                constexpr float ops   = 12.0f;
-                const float     scale = k > 1024 ? 1.07f : 1.0f;
-                gemm_type             = (alpha + ((n * fact0) / ops) < ((fact1 * n * scale) / ops)) ? GEMMType::RESHAPED_V1 : GEMMType::NATIVE;
+                gemm_type = GEMMType::RESHAPED_V1;
+            }
+            else if(m == 1)
+            {
+                gemm_type = GEMMType::RESHAPED_ONLY_RHS;
             }
             else
             {
+                // COMPMID-852
+                if((k > 256) && (m > 4) && reshape_b_only_on_first_run)
+                {
+                    constexpr float alpha = 3.2f;
+                    constexpr float fact0 = 1.51f;
+                    constexpr float fact1 = 1.66f;
+                    constexpr float ops   = 12.0f;
+                    const float     scale = k > 1024 ? 1.07f : 1.0f;
+                    gemm_type             = (alpha + ((n * fact0) / ops) < ((fact1 * n * scale) / ops)) ? GEMMType::RESHAPED_V1 : GEMMType::NATIVE;
+                }
+                else
+                {
+                    gemm_type = GEMMType::NATIVE;
+                }
+            }
+
+            const auto workload = static_cast<float>((m * n) / 20.0f);
+
+            gemm_type = ((workload > 1600.0f) && (gemm_type == GEMMType::RESHAPED_V1) && (data_type == DataType::F32)) ? GEMMType::RESHAPED_V2 : gemm_type;
+        }
+        else
+        {
+            if((m == 1) || (!reshape_b_only_on_first_run))
+            {
                 gemm_type = GEMMType::NATIVE;
             }
+            else
+            {
+                gemm_type = GEMMType::RESHAPED_V2;
+            }
         }
-
-        const auto workload = static_cast<float>((m * n) / 20.0f);
-
-        gemm_type = ((workload > 1600.0f) && (gemm_type == GEMMType::RESHAPED_V1) && (data_type == DataType::F32)) ? GEMMType::RESHAPED_V2 : gemm_type;
     }
     else
     {
diff --git a/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp b/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp
index ba218f7..99f5ffe 100644
--- a/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp
+++ b/tests/validation/CL/GEMMMatrixMultiplyReshaped.cpp
@@ -71,6 +71,9 @@
 RelativeTolerance<float> rel_tolerance_f32(0.001f);
 constexpr float          abs_tolerance_f32(0.0001f);
 
+RelativeTolerance<float> rel_tolerance_f16(0.001f);
+constexpr float          abs_tolerance_f16(0.01f);
+
 /** Alpha values to test - Precommit */
 const auto a_values = framework::dataset::make("alpha", {1.0f, -0.75f} );
 
@@ -103,7 +106,7 @@
 });
 
 /** M0 values to test - Precommit */
-const auto m0_values_precommit = framework::dataset::make("M0", {4, 8});
+const auto m0_values_precommit = framework::dataset::make("M0", { 4 });
 
 /** N0 values to test - Precommit */
 const auto n0_values_precommit = framework::dataset::make("N0", { 4 });
@@ -143,94 +146,12 @@
 
 /** LHS transposed values */
 const auto lhs_transpose_values = framework::dataset::make("lhs_transpose", { false, true } );
-
-/** Configuration test */
-void validate_configuration(unsigned int m_value, unsigned int n_value, unsigned int k_value, unsigned int b_value, unsigned int m0_value, unsigned int n0_value, unsigned int k0_value, unsigned int v0_value, unsigned int h0_value, bool i_value_lhs, bool i_value_rhs, bool broadcast_bias, bool lhs_transpose, DataType data_type, const ActivationLayerInfo &act_info)
-{
-    const unsigned int M = m_value;
-    const unsigned int N = n_value;
-    const unsigned int K = k_value;
-
-    GEMMLHSMatrixInfo lhs_info;
-    lhs_info.m0         = m0_value;
-    lhs_info.k0         = k0_value;
-    lhs_info.v0         = v0_value;
-    lhs_info.interleave = i_value_lhs;
-    lhs_info.transpose  = lhs_transpose;
-
-    GEMMRHSMatrixInfo rhs_info;
-    rhs_info.n0         = n0_value;
-    rhs_info.k0         = k0_value;
-    rhs_info.h0         = h0_value;
-    rhs_info.interleave = i_value_rhs;
-    rhs_info.transpose  = !lhs_transpose;
-
-    GEMMKernelInfo kernel_info;
-    kernel_info.m                       = M;
-    kernel_info.n                       = N;
-    kernel_info.k                       = K;
-    kernel_info.depth_output_gemm3d     = 0;
-    kernel_info.reinterpret_input_as_3d = false;
-    kernel_info.broadcast_bias          = broadcast_bias;
-    kernel_info.activation_info         = act_info;
-
-    const TensorShape lhs_shape(K, M, b_value);
-    const TensorShape lhs_shape_reshaped = compute_lhs_reshaped_shape(TensorInfo(lhs_shape, 1, data_type),
-                                                                      lhs_info,
-                                                                      false);
-
-    const TensorShape rhs_shape(N, K, b_value);
-    const TensorShape rhs_shape_reshaped = compute_rhs_reshaped_shape(TensorInfo(rhs_shape, 1, data_type),
-                                                                      rhs_info);
-
-    const TensorShape dst_shape = compute_mm_shape(TensorInfo(lhs_shape_reshaped, 1, data_type),
-                                                   TensorInfo(rhs_shape_reshaped, 1, data_type),
-                                                   kernel_info);
-
-    const TensorShape bias_shape(N,
-                                 broadcast_bias? 1 : M,
-                                 broadcast_bias? 1 : b_value);
-
-    // Create tensors
-    CLTensor lhs_reshaped = create_tensor<CLTensor>(lhs_shape_reshaped, data_type);
-    CLTensor rhs_reshaped = create_tensor<CLTensor>(rhs_shape_reshaped, data_type);
-    CLTensor bias         = create_tensor<CLTensor>(bias_shape, data_type);
-    CLTensor dst          = create_tensor<CLTensor>(dst_shape, data_type);
-
-    ARM_COMPUTE_EXPECT(lhs_reshaped.info()->is_resizable(), framework::LogLevel::ERRORS);
-    ARM_COMPUTE_EXPECT(rhs_reshaped.info()->is_resizable(), framework::LogLevel::ERRORS);
-    ARM_COMPUTE_EXPECT(bias.info()->is_resizable(), framework::LogLevel::ERRORS);
-    ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS);
-
-    // Create and configure function
-    CLGEMMMatrixMultiplyReshaped gemm;
-    gemm.configure(&lhs_reshaped, &rhs_reshaped, &bias, &dst, 1.0f, 1.0f, lhs_info, rhs_info, kernel_info);
-}
 } // namespace
 
 TEST_SUITE(CL)
 TEST_SUITE(GEMMMatrixMultiplyReshaped)
 TEST_SUITE(Float)
 TEST_SUITE(FP32)
-DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
-                                                                   m_values,
-                                                                   n_values),
-                                                                   k_values),
-                                                                   framework::dataset::make("batch_size", 1)),
-                                                                   m0_values_precommit),
-                                                                   n0_values_precommit),
-                                                                   k0_values_precommit),
-                                                                   v0_values_precommit),
-                                                                   h0_values_precommit),
-                                                                   i_values_lhs),
-                                                                   i_values_rhs),
-                                                                   broadcast_bias_values),
-                                                                   lhs_transpose_values),
-                                                                   act_values),
-m_value, n_value, k_value, b_value, m0_value, n0_value, k0_value, v0_value, h0_value, i_value_lhs, i_value_rhs, broadcast_bias, lhs_transpose, act_value)
-{
-    validate_configuration(m_value, n_value, k_value, b_value, m0_value, n0_value, k0_value, v0_value, h0_value, i_value_lhs, i_value_rhs, broadcast_bias, lhs_transpose, DataType::F32, act_value);
-}
 
 FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture<float>, framework::DatasetMode::ALL,
                 combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
@@ -328,6 +249,105 @@
     validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32);
 }
 TEST_SUITE_END() // FP32
+
+TEST_SUITE(FP16)
+
+FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedFixture<half>, framework::DatasetMode::ALL,
+                combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
+                                                                   m_values,
+                                                                   n_values),
+                                                                   k_values),
+                                                                   b_values),
+                                                                   m0_values_precommit),
+                                                                   n0_values_precommit),
+                                                                   k0_values_precommit),
+                                                                   v0_values_precommit),
+                                                                   h0_values_precommit),
+                                                                   i_values_lhs),
+                                                                   i_values_rhs),
+                                                                   framework::dataset::make("DataType", DataType::F16)),
+                                                                   a_values),
+                                                                   beta_values),
+                                                                   broadcast_bias_values),
+                                                                   lhs_transpose_values),
+                                                                   act_values))
+{
+    // Validate output
+    validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
+}
+
+FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMMatrixMultiplyReshapedFixture<half>, framework::DatasetMode::NIGHTLY,
+                combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
+                                                                   m_values,
+                                                                   n_values),
+                                                                   k_values),
+                                                                   b_values),
+                                                                   m0_values_nightly),
+                                                                   n0_values_nightly),
+                                                                   k0_values_nightly),
+                                                                   v0_values_nightly),
+                                                                   h0_values_nightly),
+                                                                   i_values_lhs),
+                                                                   i_values_rhs),
+                                                                   framework::dataset::make("DataType", DataType::F16)),
+                                                                   a_values),
+                                                                   beta_values),
+                                                                   broadcast_bias_values),
+                                                                   lhs_transpose_values),
+                                                                   act_values))
+{
+    // Validate output
+    validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
+}
+
+FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMMatrixMultiplyReshaped3DFixture<half>, framework::DatasetMode::ALL,
+                combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
+                                                                   m_w_values,
+                                                                   m_h_values),
+                                                                   n_values),
+                                                                   k_values),
+                                                                   b_values),
+                                                                   m0_values_precommit),
+                                                                   n0_values_precommit),
+                                                                   k0_values_precommit),
+                                                                   v0_values_precommit),
+                                                                   h0_values_precommit),
+                                                                   i_values_lhs),
+                                                                   i_values_rhs),
+                                                                   framework::dataset::make("DataType", DataType::F16)),
+                                                                   a_values),
+                                                                   beta_values),
+                                                                   lhs_transpose_values),
+                                                                   act_values))
+{
+    // Validate output
+    validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
+}
+
+FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMMatrixMultiplyReshaped3DFixture<half>, framework::DatasetMode::NIGHTLY,
+                combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
+                                                                   m_w_values,
+                                                                   m_h_values),
+                                                                   n_values),
+                                                                   k_values),
+                                                                   b_values),
+                                                                   m0_values_nightly),
+                                                                   n0_values_nightly),
+                                                                   k0_values_nightly),
+                                                                   v0_values_nightly),
+                                                                   h0_values_nightly),
+                                                                   i_values_lhs),
+                                                                   i_values_rhs),
+                                                                   framework::dataset::make("DataType", DataType::F16)),
+                                                                   a_values),
+                                                                   beta_values),
+                                                                   lhs_transpose_values),
+                                                                   act_values))
+{
+    // Validate output
+    validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16);
+}
+TEST_SUITE_END() // FP16
 TEST_SUITE_END() // Float
 TEST_SUITE_END() // GEMMMatrixMultiplyReshaped
 TEST_SUITE_END() // CL