Validate mlgo gemm type selection and fall back to default heuristics

GEMM kernel type returned by mlgo heuristics in each of the CLGEMM and
CLGEMMLowpMatrixMultiplyCore could also be invalid. Fix this by falling
back to default heuristics, similar to how we deal with gemm configs for
now.

Resolves COMPMID-3847

Change-Id: Iae7c1dcd7def04969ad13a4c132873fda8c8a571
Signed-off-by: SiCong Li <sicong.li@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5044
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
diff --git a/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp
index c5618f2..ef160d1 100644
--- a/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp
+++ b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp
@@ -43,21 +43,31 @@
 {
 namespace auto_heuristics
 {
-CLGEMMKernelType auto_select_gemm_kernel(const CommonQuery &query, bool reshape_b_only_on_first_run)
+GEMMTypeResult select_mlgo_gemm_kernel(const CommonQuery &query, bool reshape_b_only_on_first_run)
 {
-    // Select between mlgo and default heuristics
-    auto mlgo_heuristics = CLScheduler::get().gemm_heuristics();
+    ARM_COMPUTE_UNUSED(reshape_b_only_on_first_run);
+    bool             valid = false;
+    CLGEMMKernelType gemm_type{};
+    const auto       mlgo_heuristics = CLScheduler::get().gemm_heuristics();
     if(mlgo_heuristics != nullptr)
     {
-        auto res = mlgo_heuristics->get()->query_gemm_type(mlgo::Query{ string_from_target(query.gpu_target), query.data_type, query.m, query.n, query.k, query.b });
-        if(res.first)
-        {
-            ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use gemm kernel from mlgo heuristics: %s.", to_string(res.second).c_str());
-            return res.second;
-        }
+        std::tie(valid, gemm_type) = mlgo_heuristics->get()->query_gemm_type(mlgo::Query{ string_from_target(query.gpu_target), query.data_type, query.m, query.n, query.k, query.b });
     }
-    std::unique_ptr<ICLGEMMKernelSelection> gemm_kernel = CLGEMMKernelSelectionFactory::create(query.gpu_target);
-    ARM_COMPUTE_ERROR_ON_NULLPTR(gemm_kernel.get());
+    if(valid)
+    {
+        ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics query returns gemm type: %s.", to_string(gemm_type).c_str());
+    }
+    else
+    {
+        ARM_COMPUTE_LOG_INFO_MSG_CORE("MLGOHeuristics query failed");
+    }
+    return GEMMTypeResult(valid, gemm_type);
+}
+
+GEMMTypeResult select_default_gemm_kernel(const CommonQuery &query, bool reshape_b_only_on_first_run)
+{
+    std::unique_ptr<ICLGEMMKernelSelection> default_heuristics = CLGEMMKernelSelectionFactory::create(query.gpu_target);
+    ARM_COMPUTE_ERROR_ON_NULLPTR(default_heuristics.get());
 
     CLGEMMKernelSelectionParams params;
     params.m               = query.m;
@@ -67,9 +77,8 @@
     params.is_rhs_constant = reshape_b_only_on_first_run;
     params.data_type       = query.data_type;
 
-    const auto kernel_type = gemm_kernel->select_kernel(params);
-    ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use gemm kernel from default heuristics: %s.", to_string(kernel_type).c_str());
-    return kernel_type;
+    const auto kernel_type = default_heuristics->select_kernel(params);
+    return GEMMTypeResult(true, kernel_type);
 }
 
 GEMMConfigResult select_default_gemm_config_reshaped_only_rhs(const CommonQuery &query)
@@ -88,7 +97,7 @@
     GEMMLHSMatrixInfo               lhs_info;
     GEMMRHSMatrixInfo               rhs_info;
     mlgo::GEMMConfigReshapedOnlyRHS config{};
-    auto                            mlgo_heuristics = CLScheduler::get().gemm_heuristics();
+    const auto                      mlgo_heuristics = CLScheduler::get().gemm_heuristics();
     if(mlgo_heuristics != nullptr)
     {
         std::tie(valid, config) = mlgo_heuristics->get()->query_gemm_config_reshaped_only_rhs(mlgo::Query{ string_from_target(query.gpu_target), query.data_type, query.m, query.n, query.k, query.b });
@@ -123,7 +132,7 @@
     GEMMLHSMatrixInfo        lhs_info;
     GEMMRHSMatrixInfo        rhs_info;
     mlgo::GEMMConfigReshaped config{};
-    auto                     mlgo_heuristics = CLScheduler::get().gemm_heuristics();
+    const auto               mlgo_heuristics = CLScheduler::get().gemm_heuristics();
     if(mlgo_heuristics != nullptr)
     {
         std::tie(valid, config) = mlgo_heuristics->get()->query_gemm_config_reshaped(mlgo::Query{ string_from_target(query.gpu_target), query.data_type, query.m, query.n, query.k, query.b });
@@ -158,7 +167,7 @@
     GEMMLHSMatrixInfo      lhs_info;
     GEMMRHSMatrixInfo      rhs_info;
     mlgo::GEMMConfigNative config{};
-    auto                   mlgo_heuristics = CLScheduler::get().gemm_heuristics();
+    const auto             mlgo_heuristics = CLScheduler::get().gemm_heuristics();
     if(mlgo_heuristics != nullptr)
     {
         std::tie(valid, config) = mlgo_heuristics->get()->query_gemm_config_native(mlgo::Query{ string_from_target(query.gpu_target), query.data_type, query.m, query.n, query.k, query.b });