Integrate MLGO into CLGEMMLowpMatrixMultiplyCore for native kernel

Resolves COMPMID-3846

Signed-off-by: SiCong Li <sicong.li@arm.com>
Change-Id: Iad66f6dd7fa5b13ebace9f95fbc2fc4d677cf6a9
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5032
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Reviewed-by: Pablo Marquez Tello <pablo.tello@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
diff --git a/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp b/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
index 6c4d9ef..2c9bb3c 100644
--- a/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
+++ b/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
@@ -34,8 +34,6 @@
 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
 #include "arm_compute/core/utils/quantization/AsymmHelpers.h"
 #include "arm_compute/runtime/CL/CLScheduler.h"
-#include "src/core/CL/gemm/native/CLGEMMNativeKernelConfiguration.h"
-#include "src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfiguration.h"
 #include "src/core/CL/kernels/CLDepthConvertLayerKernel.h"
 #include "src/core/CL/kernels/CLGEMMLowpMatrixMultiplyNativeKernel.h"
 #include "src/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel.h"
@@ -44,7 +42,6 @@
 #include "src/core/CL/kernels/CLGEMMLowpReductionKernel.h"
 #include "src/core/CL/kernels/CLGEMMReshapeRHSMatrixKernel.h"
 #include "src/core/helpers/AutoConfiguration.h"
-#include "src/runtime/CL/gemm/CLGEMMKernelSelection.h"
 #include "src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h"
 #include "utils/TypePrinter.h"
 
@@ -55,6 +52,43 @@
 
 namespace
 {
+// Validate lhs_info and rhs_info for native kernel
+inline bool validate_lhs_rhs_info_native(const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const ITensorInfo *a, const ITensorInfo *b, const GEMMReshapeInfo &reshape_info)
+{
+    // Validate GEMMLHSMatrixInfo and GEMMRHSMatrixInfo for reshaped only rhs kernel
+    TensorInfo mm_result_s32_info{};
+    // Output tensor auto initialization if not yet initialized
+    auto_init_if_empty(mm_result_s32_info, a->clone()->set_tensor_shape(compute_mm_shape(*a, *b, false, reshape_info)).set_data_type(DataType::S32));
+    // Validate mm kernel
+    // NOTE: Ignore all other parameters (eg. output stage etc.) and only validate lhs and rhs info
+    // NOTE: This assumes:
+    //  1. lhs and rhs info's validity does not depend on these other parameters and vice versa(in CLGEMMLowpMatrixMultiplyNativeKernel.cpp validate_arguments).
+    //  2. lhs and rhs info does not cause window and padding issues through side effects (in CLGEMMLowpMatrixMultiplyNativeKernel.cpp validate_and_configure_window).
+    if(!bool(CLGEMMLowpMatrixMultiplyNativeKernel::validate(a, b, &mm_result_s32_info, lhs_info, rhs_info, reshape_info)))
+    {
+        return false;
+    }
+    return true;
+}
+
+// Automatically select between mlgo (prioritized) and default heuristics for native kernel configs
+std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> auto_select_gemm_config_native(auto_heuristics::CommonQuery query, const ITensorInfo *a, const ITensorInfo *b, const GEMMReshapeInfo &reshape_info)
+{
+    auto config = auto_heuristics::select_mlgo_gemm_config_native(query);
+    if(config)
+    {
+        if(validate_lhs_rhs_info_native(config.lhs_info, config.rhs_info, a, b, reshape_info))
+        {
+            ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use native config from mlgo heuristics: LHS info: %s ; RHS info: %s ", to_string(config.lhs_info).c_str(), to_string(config.rhs_info).c_str());
+            return { config.lhs_info, config.rhs_info };
+        }
+    }
+    config = auto_heuristics::select_default_gemm_config_native(query);
+    ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use native config from default heuristics: LHS info: %s ; RHS info: %s ", to_string(config.lhs_info).c_str(), to_string(config.rhs_info).c_str());
+    return { config.lhs_info, config.rhs_info };
+}
+
+// Validate lhs_info and rhs_info for reshaped only rhs kernel
 inline bool validate_lhs_rhs_info_reshaped_only_rhs(const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *output,
                                                     unsigned int m, unsigned int n, unsigned int k, bool reinterpret_input_as_3d, int depth_output_gemm3d)
 {
@@ -89,6 +123,7 @@
     return true;
 }
 
+// Automatically select between mlgo (prioritized) and default heuristics for reshaped only rhs kernel configs
 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> auto_select_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery query, bool reinterpret_input_as_3d, int depth_output_gemm3d,
                                                                                           const ITensorInfo *a,
                                                                                           const ITensorInfo *b, const ITensorInfo *output)
@@ -102,7 +137,7 @@
             return { config.lhs_info, config.rhs_info };
         }
     }
-    config = select_default_gemm_config_reshaped_only_rhs(query);
+    config = auto_heuristics::select_default_gemm_config_reshaped_only_rhs(query);
     ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use reshaped_only_rhs config from default heuristics: LHS info: %s ; RHS info: %s ", to_string(config.lhs_info).c_str(), to_string(config.rhs_info).c_str());
     return { config.lhs_info, config.rhs_info };
 }
@@ -195,6 +230,8 @@
     const unsigned int batch_size              = reinterpret_input_as_3d ? a->info()->dimension(3) : a->info()->dimension(2);
     const int          depth_output_gemm3d     = gemm_info.depth_output_gemm3d();
 
+    const auto reshape_info = GEMMReshapeInfo(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d);
+
     // Check if we need to reshape the matrix A and matrix B
     _is_gemm_reshaped = is_gemm_reshaped(auto_select_gemm_kernel(auto_heuristics::CommonQuery{ gpu_target, a->info()->data_type(), m, n, k, batch_size }, _reshape_b_only_on_first_run));
 
@@ -298,10 +335,12 @@
             else
             {
                 // Pick up the GEMM configuration
-                std::tie(lhs_info, rhs_info) = CLGEMMNativeKernelConfigurationFactory::create(gpu_target)->configure(m, n, k, batch_size, DataType::QASYMM8);
+                // It doesn't matter whether Datatype is DataType::QASYMM8 or DataType::QASYMM8_SIGNED, since it only affect the shape configuration
+                std::tie(lhs_info, rhs_info) = auto_select_gemm_config_native(auto_heuristics::CommonQuery{ gpu_target, DataType::QASYMM8, m, n, k, batch_size },
+                                                                              _matrix_a->info(), _convert_to_qasymm8 ? _qasymm8_weights.info() : matrix_b->info(), reshape_info);
 
                 // Configure matrix multiply kernel
-                _mm_native_kernel->configure(compile_context, _matrix_a, matrix_b, &_mm_result_s32, lhs_info, rhs_info, GEMMReshapeInfo(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d));
+                _mm_native_kernel->configure(compile_context, _matrix_a, matrix_b, &_mm_result_s32, lhs_info, rhs_info, reshape_info);
 
                 _offset_contribution_output_stage_kernel->configure(compile_context, &_mm_result_s32, _a_offset == 0 ? nullptr : &_vector_sum_col, _b_offset == 0 ? nullptr : &_vector_sum_row, c, output,
                                                                     a->info()->dimension(0),
@@ -331,10 +370,12 @@
         else
         {
             // Pick up the GEMM configuration
-            std::tie(lhs_info, rhs_info) = CLGEMMNativeKernelConfigurationFactory::create(gpu_target)->configure(m, n, k, batch_size, DataType::QASYMM8);
+            // It doesn't matter whether Datatype is DataType::QASYMM8 or DataType::QASYMM8_SIGNED, since it only affect the shape configuration
+            std::tie(lhs_info, rhs_info) = auto_select_gemm_config_native(auto_heuristics::CommonQuery{ gpu_target, DataType::QASYMM8, m, n, k, batch_size },
+                                                                          a->info(), _convert_to_qasymm8 ? _qasymm8_weights.info() : b->info(), reshape_info);
 
             // Configure matrix multiply kernel
-            _mm_native_kernel->configure(compile_context, _matrix_a, matrix_b, output, lhs_info, rhs_info, GEMMReshapeInfo(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d));
+            _mm_native_kernel->configure(compile_context, _matrix_a, matrix_b, output, lhs_info, rhs_info, reshape_info);
         }
 
         // Configure offset contribution kernel
@@ -490,7 +531,11 @@
                 auto_init_if_empty(mm_result_s32_info, a->clone()->set_tensor_shape(compute_mm_shape(*matrix_a_info, *matrix_b_info, false, reshape_info)).set_data_type(DataType::S32));
 
                 // Pick up the GEMM configuration
-                std::tie(lhs_info, rhs_info) = CLGEMMNativeKernelConfigurationFactory::create(gpu_target)->configure(m, n, k, batch_size, DataType::QASYMM8);
+                // NOTE: No need to validate mlgo configurations as they automatically fall back to default heuristics if validation fails
+                // It doesn't matter whether Datatype is DataType::QASYMM8 or DataType::QASYMM8_SIGNED, since it only affect the shape configuration
+                const auto res = select_default_gemm_config_native(auto_heuristics::CommonQuery{ gpu_target, DataType::QASYMM8, m, n, k, batch_size });
+                lhs_info       = res.lhs_info;
+                rhs_info       = res.rhs_info;
 
                 // Validate matrix multiply
                 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixMultiplyNativeKernel::validate(matrix_a_info, matrix_b_info, &mm_result_s32_info, lhs_info, rhs_info, reshape_info));
@@ -518,7 +563,10 @@
         else
         {
             // Pick up the GEMM configuration
-            std::tie(lhs_info, rhs_info) = CLGEMMNativeKernelConfigurationFactory::create(gpu_target)->configure(m, n, k, batch_size, DataType::QASYMM8);
+            // It doesn't matter whether Datatype is DataType::QASYMM8 or DataType::QASYMM8_SIGNED, since it only affect the shape configuration
+            const auto res = select_default_gemm_config_native(auto_heuristics::CommonQuery{ gpu_target, DataType::QASYMM8, m, n, k, batch_size });
+            lhs_info       = res.lhs_info;
+            rhs_info       = res.rhs_info;
 
             // Validate matrix multiply
             ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixMultiplyNativeKernel::validate(matrix_a_info, matrix_b_info, output, lhs_info, rhs_info, reshape_info));