Integrate MLGO into CLGEMM for reshaped kernel

Resolves COMPMID-3845

Signed-off-by: SiCong Li <sicong.li@arm.com>
Change-Id: I878ea6dc076177095816a75f9bc951326fd095b3
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5031
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp
index dcb9cb2..a0aaabf 100644
--- a/src/runtime/CL/functions/CLGEMM.cpp
+++ b/src/runtime/CL/functions/CLGEMM.cpp
@@ -37,9 +37,6 @@
 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
 #include "arm_compute/runtime/CL/CLScheduler.h"
 #include "arm_compute/runtime/ITensorAllocator.h"
-#include "src/core/CL/ICLGEMMKernelConfiguration.h"
-#include "src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfiguration.h"
-#include "src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfiguration.h"
 #include "src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.h"
 #include "src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.h"
 #include "src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.h"
@@ -102,6 +99,7 @@
 
 namespace
 {
+// Validate lhs_info and rhs_info for reshaped only rhs kernel
 inline bool validate_lhs_rhs_info_reshaped_only_rhs(const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c,
                                                     const ITensorInfo *output, GEMMKernelInfo gemm_kernel_info)
 {
@@ -129,6 +127,7 @@
     return true;
 }
 
+//Automatically select between mlgo (prioritized) and default heuristics for reshaped only rhs kernel configs
 inline std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> auto_select_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery query, GEMMKernelInfo kernel_info, const ITensorInfo *a,
                                                                                                  const ITensorInfo *b,
                                                                                                  const ITensorInfo *c, const ITensorInfo *output)
@@ -147,6 +146,55 @@
     return { config.lhs_info, config.rhs_info };
 }
 
+// Validate lhs_info and rhs_info for reshaped kernel
+inline bool validate_lhs_rhs_info_reshaped(const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c,
+                                           const ITensorInfo *output, GEMMKernelInfo gemm_kernel_info, bool reinterpret_input_as_3d)
+{
+    // Validate GEMMLHSMatrixInfo and GEMMRHSMatrixInfo for reshaped kernel
+    TensorInfo tmp_a_info{};
+    TensorInfo tmp_b_info{};
+
+    // Validate reshape LHS kernel
+    auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_lhs_reshaped_shape(*a, lhs_info, reinterpret_input_as_3d)));
+    if(!bool(CLGEMMReshapeLHSMatrixKernel::validate(a, &tmp_a_info, lhs_info, reinterpret_input_as_3d)))
+    {
+        return false;
+    }
+
+    // Validate reshape RHS kernel
+    auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
+    if(!bool(CLGEMMReshapeRHSMatrixKernel::validate(b, &tmp_b_info, rhs_info)))
+    {
+        return false;
+    }
+    // Validate mm kernel
+    gemm_kernel_info.lhs_info = lhs_info;
+    gemm_kernel_info.rhs_info = rhs_info;
+    if(!bool(CLGEMMMatrixMultiplyReshapedKernel::validate(&tmp_a_info, &tmp_b_info, c, output, 1.f, 0.f, lhs_info, rhs_info, gemm_kernel_info)))
+    {
+        return false;
+    }
+    return true;
+}
+
+//Automatically select between mlgo (prioritized) and default heuristics for reshaped kernel configs
+inline std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> auto_select_gemm_config_reshaped(auto_heuristics::CommonQuery query, GEMMKernelInfo kernel_info, const ITensorInfo *a, const ITensorInfo *b,
+                                                                                        const ITensorInfo *c, const ITensorInfo *output, bool reinterpret_input_as_3d)
+{
+    auto config = auto_heuristics::select_mlgo_gemm_config_reshaped(query);
+    if(config)
+    {
+        if(validate_lhs_rhs_info_reshaped(config.lhs_info, config.rhs_info, a, b, c, output, kernel_info, reinterpret_input_as_3d))
+        {
+            ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use reshaped config from mlgo heuristics: LHS info: %s ; RHS info: %s ", to_string(config.lhs_info).c_str(), to_string(config.rhs_info).c_str());
+            return { config.lhs_info, config.rhs_info };
+        }
+    }
+    config = select_default_gemm_config_reshaped(query);
+    ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use reshaped config from default heuristics: LHS info: %s ; RHS info: %s ", to_string(config.lhs_info).c_str(), to_string(config.rhs_info).c_str());
+    return { config.lhs_info, config.rhs_info };
+}
+
 } // namespace
 
 CLGEMM::CLGEMM(std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager)
@@ -311,10 +359,8 @@
     GEMMRHSMatrixInfo rhs_info{};
 
     // Pick up the GEMM configuration
-    std::unique_ptr<ICLGEMMKernelConfiguration> gemm_config = CLGEMMReshapedKernelConfigurationFactory::create(gpu_target);
-    ARM_COMPUTE_ERROR_ON_NULLPTR(gemm_config.get());
-    // Configure lhs_info and rhs_info
-    std::tie(lhs_info, rhs_info) = gemm_config->configure(m, n, k, batch_size, data_type);
+    std::tie(lhs_info, rhs_info) = auto_select_gemm_config_reshaped(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size }, kernel_info, a->info(), b->info(),
+                                                                    c == nullptr ? nullptr : c->info(), output->info(), gemm_info.reinterpret_input_as_3d());
 
     _reshape_lhs_kernel->configure(compile_context, a, &_tmp_a, lhs_info, gemm_info.reinterpret_input_as_3d());
 
@@ -518,11 +564,10 @@
     GEMMRHSMatrixInfo rhs_info;
 
     // Pick up the GEMM configuration
-    std::unique_ptr<ICLGEMMKernelConfiguration> gemm_config = CLGEMMReshapedKernelConfigurationFactory::create(gpu_target);
-    ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(gemm_config.get());
-
-    // Configure lhs_info and rhs_info
-    std::tie(lhs_info, rhs_info) = gemm_config->configure(m, n, k, batch_size, data_type);
+    // NOTE: No need to validate mlgo configurations as they automatically fall back to default heuristics if validation fails
+    const auto gemm_config = select_default_gemm_config_reshaped(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size });
+    lhs_info               = gemm_config.lhs_info;
+    rhs_info               = gemm_config.rhs_info;
 
     auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_lhs_reshaped_shape(*a, lhs_info, gemm_info.reinterpret_input_as_3d())));
     ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeLHSMatrixKernel::validate(a, &tmp_a_info, lhs_info, gemm_info.reinterpret_input_as_3d()));
@@ -567,8 +612,6 @@
     GEMMRHSMatrixInfo rhs_info;
 
     // Pick up the GEMM configuration
-    // Note there is no need to validate the configuration from mlgo heuristics as it is already validated in configure() and will fall back
-    // to default heuristics should it fail
     // NOTE: No need to validate mlgo configurations as they automatically fall back to default heuristics if validation fails
     const auto gemm_config = select_default_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size });
     lhs_info               = gemm_config.lhs_info;
diff --git a/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp
index 8f0d5f4..5790e07 100644
--- a/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp
+++ b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.cpp
@@ -29,6 +29,7 @@
 #include "arm_compute/runtime/CL/ICLGEMMKernelSelection.h"
 #include "src/core/CL/ICLGEMMKernelConfiguration.h"
 #include "src/core/CL/gemm/CLGEMMHelpers.cpp"
+#include "src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfiguration.h"
 #include "src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfiguration.h"
 #include "src/runtime/CL/gemm/CLGEMMKernelSelection.h"
 #include "src/runtime/CL/mlgo/MLGOHeuristics.h"
@@ -103,6 +104,42 @@
                                                           config.export_cl_image);
     return GEMMConfigResult{ valid, lhs_info, rhs_info };
 }
+
+GEMMConfigResult select_default_gemm_config_reshaped(const CommonQuery &query)
+{
+    GEMMLHSMatrixInfo                           lhs_info;
+    GEMMRHSMatrixInfo                           rhs_info;
+    std::unique_ptr<ICLGEMMKernelConfiguration> gemm_config = CLGEMMReshapedKernelConfigurationFactory::create(query.gpu_target);
+    ARM_COMPUTE_ERROR_ON_NULLPTR(gemm_config.get());
+    std::tie(lhs_info, rhs_info) = gemm_config->configure(query.m, query.n, query.k, query.b, query.data_type);
+    return GEMMConfigResult{ true, lhs_info, rhs_info };
+}
+
+GEMMConfigResult select_mlgo_gemm_config_reshaped(const CommonQuery &query)
+{
+    bool                     valid = false;
+    GEMMLHSMatrixInfo        lhs_info;
+    GEMMRHSMatrixInfo        rhs_info;
+    mlgo::GEMMConfigReshaped config{};
+    auto                     mlgo_heuristics = CLScheduler::get().gemm_heuristics();
+    if(mlgo_heuristics != nullptr)
+    {
+        std::tie(valid, config) = mlgo_heuristics->get()->query_gemm_config_reshaped(mlgo::Query{ string_from_target(query.gpu_target), query.data_type, query.m, query.n, query.k, query.b });
+    }
+    if(valid)
+    {
+        ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics query returns gemm config: %s.", to_string(config).c_str());
+    }
+    else
+    {
+        ARM_COMPUTE_LOG_INFO_MSG_CORE("MLGOHeuristics query failed");
+    }
+    std::tie(lhs_info, rhs_info) = configure_lhs_rhs_info(query.m, query.n, config.m0, config.n0, config.k0, config.v0, config.h0, config.interleave_lhs, config.interleave_rhs, !config.transpose_rhs,
+                                                          config.transpose_rhs,
+                                                          config.export_cl_image);
+    return GEMMConfigResult{ valid, lhs_info, rhs_info };
+}
 } // namespace auto_heuristics
+
 } // namespace cl_gemm
 } // namespace arm_compute
\ No newline at end of file
diff --git a/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h
index e4fa1a6..7cb9cab 100644
--- a/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h
+++ b/src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h
@@ -82,6 +82,19 @@
  * @return GEMMConfigResult
  */
 GEMMConfigResult select_default_gemm_config_reshaped_only_rhs(const CommonQuery &query);
+
+/** Select gemm config based on mlgo heuristics
+ * @param query Query
+ * @return GEMMConfigResult
+ */
+GEMMConfigResult select_mlgo_gemm_config_reshaped(const CommonQuery &query);
+
+/** Select gemm config based on default heuristics
+ * @param query Query
+ * @return GEMMConfigResult
+ */
+GEMMConfigResult select_default_gemm_config_reshaped(const CommonQuery &query);
+
 } // namespace auto_heuristics
 } // namespace cl_gemm
 } // namespace arm_compute