COMPMID-2572: Update the heuristic in CLGEMM for FP16

Change-Id: Ia7f248d72bf2690a8a4dae3f2e2afc983109bd6e
Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Reviewed-on: https://review.mlplatform.org/c/2035
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Giorgio Arena <giorgio.arena@arm.com>
diff --git a/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp b/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp
index 9f3fc3a..5955bac 100644
--- a/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp
+++ b/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp
@@ -42,9 +42,6 @@
 
 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type)
 {
-    ARM_COMPUTE_ERROR_ON(data_type != DataType::F32 && data_type != DataType::QASYMM8);
-    ARM_COMPUTE_UNUSED(data_type);
-
     using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> (CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::*)(unsigned int m, unsigned int n, unsigned int k,
                                              unsigned int b);
 
@@ -52,6 +49,7 @@
     static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G51 =
     {
         { DataType::F32, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G51_f32 },
+        { DataType::F16, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G51_f16 },
         { DataType::QASYMM8, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G51_u8 }
     };
 
@@ -59,6 +57,7 @@
     static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G76 =
     {
         { DataType::F32, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G76_f32 },
+        { DataType::F16, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G76_f16 },
         { DataType::QASYMM8, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G76_u8 }
     };
 
@@ -66,17 +65,39 @@
     static std::map<DataType, ConfigurationFunctionExecutorPtr> gemm_configs_G7x =
     {
         { DataType::F32, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_f32 },
+        { DataType::F16, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_f16 },
         { DataType::QASYMM8, &CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_u8 }
     };
 
     switch(_target)
     {
         case GPUTarget::G76:
-            return (this->*gemm_configs_G76[data_type])(m, n, k, b);
+            if (gemm_configs_G76.find(data_type) != gemm_configs_G76.end())
+            {
+                return (this->*gemm_configs_G76[data_type])(m, n, k, b);
+            }
+            else
+            {
+                ARM_COMPUTE_ERROR("Not supported data type");
+            }
         case GPUTarget::G51:
-            return (this->*gemm_configs_G51[data_type])(m, n, k, b);
+            if (gemm_configs_G51.find(data_type) != gemm_configs_G51.end())
+            {
+                return (this->*gemm_configs_G51[data_type])(m, n, k, b);
+            }
+            else
+            {
+                ARM_COMPUTE_ERROR("Not supported data type");
+            }
         default:
-            return (this->*gemm_configs_G7x[data_type])(m, n, k, b);
+            if (gemm_configs_G7x.find(data_type) != gemm_configs_G7x.end())
+            {
+                return (this->*gemm_configs_G7x[data_type])(m, n, k, b);
+            }
+            else
+            {
+                ARM_COMPUTE_ERROR("Not supported data type");
+            }
     }
 }
 
@@ -89,12 +110,12 @@
     {
         if(n > 2048)
         {
-            const unsigned int h0 = std::max(n / 4, static_cast<unsigned int>(1));
+            const unsigned int h0 = std::max(n / 4, 1U);
             return configure_lhs_rhs_info(m, n, 1, 4, 4, 1, h0, false, true, false, true);
         }
         else
         {
-            const unsigned int h0 = std::max(n / 2, static_cast<unsigned int>(1));
+            const unsigned int h0 = std::max(n / 2, 1U);
             return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, h0, false, true, false, true);
         }
     }
@@ -111,7 +132,7 @@
 
     if(m == 1)
     {
-        const unsigned int h0 = std::max(n / 2, static_cast<unsigned int>(1));
+        const unsigned int h0 = std::max(n / 2, 1U);
         return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, h0, false, true, false, true);
     }
     else
@@ -128,7 +149,7 @@
     if(m == 1)
     {
         const unsigned int n0 = n < 1280? 2 : 4;
-        const unsigned int h0 = std::max(n / n0, static_cast<unsigned int>(1));
+        const unsigned int h0 = std::max(n / n0, 1U);
         return configure_lhs_rhs_info(m, n, 1, n0, 4, 1, h0, false, true, false, true);
     }
     else
@@ -137,6 +158,63 @@
     }
 }
 
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+{
+    ARM_COMPUTE_UNUSED(k);
+    ARM_COMPUTE_UNUSED(b);
+
+    if(m == 1)
+    {
+        if(n > 2048)
+        {
+            const unsigned int h0 = std::max(n / 4, 1U);
+            return configure_lhs_rhs_info(m, n, 1, 4, 4, 1, h0, false, true, false, true);
+        }
+        else
+        {
+            const unsigned int h0 = std::max(n / 2, 1U);
+            return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, h0, false, true, false, true);
+        }
+    }
+    else
+    {
+        return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 4, false, true, false, true);
+    }
+}
+
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G76_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+{
+    ARM_COMPUTE_UNUSED(k);
+    ARM_COMPUTE_UNUSED(b);
+
+    if(m == 1)
+    {
+        const unsigned int h0 = std::max(n / 2, 1U);
+        return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, h0, false, true, false, true);
+    }
+    else
+    {
+        return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, true, false, true);
+    }
+}
+
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G51_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
+{
+    ARM_COMPUTE_UNUSED(k);
+    ARM_COMPUTE_UNUSED(b);
+
+    if(m == 1)
+    {
+        const unsigned int n0 = n < 1280? 2 : 4;
+        const unsigned int h0 = std::max(n / n0, 1U);
+        return configure_lhs_rhs_info(m, n, 1, n0, 8, 1, h0, false, true, false, true);
+    }
+    else
+    {
+        return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, true, false, true);
+    }
+}
+
 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G7x_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
 {
     ARM_COMPUTE_UNUSED(k);
@@ -146,12 +224,12 @@
     {
         if(m == 1)
         {
-            const unsigned int h0 = std::max(n / 2, static_cast<unsigned int>(1));
+            const unsigned int h0 = std::max(n / 2, 1U);
             return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0, false, true, false, true);
         }
         else
         {
-            const unsigned int h0 = std::max(n / 4, static_cast<unsigned int>(1));
+            const unsigned int h0 = std::max(n / 4, 1U);
             return configure_lhs_rhs_info(m, n, 4, 4, 16, 1, h0, false, true, false, true);
         }
     }
@@ -159,12 +237,12 @@
     {
         if(m == 1)
         {
-            const unsigned int h0 = std::max(n / 2, static_cast<unsigned int>(1));
+            const unsigned int h0 = std::max(n / 2, 1U);
             return configure_lhs_rhs_info(m, n, 1, 2, 4, 1, h0, false, true, false, true);
         }
         else
         {
-            const unsigned int h0 = std::max(n / 4, static_cast<unsigned int>(1));
+            const unsigned int h0 = std::max(n / 4, 1U);
             return configure_lhs_rhs_info(m, n, 2, 2, 16, 1, h0, false, true, false, true);
         }
     }
@@ -177,7 +255,7 @@
 
     if(m == 1)
     {
-        const unsigned int h0 = std::max(n / 2, static_cast<unsigned int>(1));
+        const unsigned int h0 = std::max(n / 2, 1U);
         return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0, false, true, false, true);
     }
     else
@@ -193,12 +271,12 @@
 
     if(m == 1)
     {
-        const unsigned int h0 = std::max(n / 2, static_cast<unsigned int>(1));
+        const unsigned int h0 = std::max(n / 2, 1U);
         return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, h0, false, true, false, true);
     }
     else
     {
-        const unsigned int h0 = std::max(n / 2, static_cast<unsigned int>(1));
+        const unsigned int h0 = std::max(n / 2, 1U);
         return configure_lhs_rhs_info(m, n, 4, 2, 16, 1, h0, false, true, false, true);
     }
 }
diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp
index 3d5e148..24b2301 100644
--- a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.cpp
@@ -57,7 +57,7 @@
 {
     ARM_COMPUTE_UNUSED(alpha);
     ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input0, input1, output);
-    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::F32);
+    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::F16, DataType::F32);
     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1);
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(input0->num_dimensions() > 4, "The number of dimensions for the LHS matrix must be <= 4");
     ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->num_dimensions() > 3, "The number of dimensions for the RHS matrix must be <= 3");
diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp
index 2a027d8..8d46014 100644
--- a/src/runtime/CL/functions/CLGEMM.cpp
+++ b/src/runtime/CL/functions/CLGEMM.cpp
@@ -109,7 +109,7 @@
         {
             if((m == 1) || (!reshape_b_only_on_first_run))
             {
-                gemm_type = GEMMType::NATIVE;
+                gemm_type = GEMMType::RESHAPED_ONLY_RHS;
             }
             else
             {