COMPMID-3930: Update CLGEMM heuristic for fp16. Mali-G76

- Since the GEMM kernel can now work without padding, the heuristic
  requires to be fine-tuned to exploit this feature
- The heuristic affects Mali-G76 FP16 only

Change-Id: Ia430627f02131ad956ce2219b80c83c8e7cabaf2
Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4284
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Reviewed-by: SiCong Li <sicong.li@arm.com>
diff --git a/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp b/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp
index 0a0fc5d..3105db6 100644
--- a/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp
+++ b/src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfigurationBifrost.cpp
@@ -151,15 +151,13 @@
     // Get lhs_info/rhs_info in case of OpenCL buffer
     if(m == 1)
     {
-        if((n / 4) >= 2048)
+        if(n <= 204.0)
         {
-            const unsigned int h0 = std::max(n / 4, 1U);
-            std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 1, 4, 8, 1, h0, false, true, false, true);
+            return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 16, false, true, false, true, false);
         }
         else
         {
-            const unsigned int h0 = std::max(n / 2, 1U);
-            std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 1, 2, 8, 1, h0, false, true, false, true);
+            return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, 32, false, true, false, true, false);
         }
     }
     else
@@ -247,7 +245,6 @@
 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> CLGEMMReshapedOnlyRHSKernelConfigurationBifrost::configure_G76_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
 {
     ARM_COMPUTE_UNUSED(k);
-    ARM_COMPUTE_UNUSED(b);
 
     if(m == 1)
     {
@@ -255,7 +252,65 @@
     }
     else
     {
-        return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, false, true, false, true);
+        const float r_mn     = static_cast<float>(m) / static_cast<float>(n);
+        const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
+
+        if(workload <= 7449.60f)
+        {
+            if(workload <= 691.60f)
+            {
+                return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 8, false, false, false, false, false);
+            }
+            else
+            {
+                if(workload <= 4155.20f)
+                {
+                    return configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false);
+                }
+                else
+                {
+                    return configure_lhs_rhs_info(m, n, 5, 8, 2, 1, 32, false, false, false, false, false);
+                }
+            }
+        }
+        else
+        {
+            if(workload <= 16300.80f)
+            {
+                if(r_mn <= 44.56f)
+                {
+                    GEMMLHSMatrixInfo lhs_info_buf;
+                    GEMMRHSMatrixInfo rhs_info_buf;
+                    GEMMLHSMatrixInfo lhs_info_img;
+                    GEMMRHSMatrixInfo rhs_info_img;
+
+                    std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, true, false, false, true);
+                    std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false);
+
+                    return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
+                                               std::make_pair(lhs_info_buf, rhs_info_buf),
+                                               n, k, b, DataType::F16);
+                }
+                else
+                {
+                    return configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false);
+                }
+            }
+            else
+            {
+                GEMMLHSMatrixInfo lhs_info_buf;
+                GEMMRHSMatrixInfo rhs_info_buf;
+                GEMMLHSMatrixInfo lhs_info_img;
+                GEMMRHSMatrixInfo rhs_info_img;
+
+                std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 2, false, true, false, false, true);
+                std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 5, 2, 8, 1, 16, false, false, false, false, false);
+
+                return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
+                                           std::make_pair(lhs_info_buf, rhs_info_buf),
+                                           n, k, b, DataType::F16);
+            }
+        }
     }
 }