blob: 5790e077d4f897ff7737eaeab7b2514f161d14c7 [file] [log] [blame]
SiCong Libd8b1e22021-02-04 13:07:09 +00001/*
2 * Copyright (c) 2021 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h"
25
26#include "arm_compute/core/Log.h"
27#include "arm_compute/core/Validate.h"
28#include "arm_compute/runtime/CL/CLScheduler.h"
29#include "arm_compute/runtime/CL/ICLGEMMKernelSelection.h"
30#include "src/core/CL/ICLGEMMKernelConfiguration.h"
31#include "src/core/CL/gemm/CLGEMMHelpers.cpp"
SiCong Li8c23ba12021-02-08 14:19:23 +000032#include "src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfiguration.h"
SiCong Libd8b1e22021-02-04 13:07:09 +000033#include "src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfiguration.h"
34#include "src/runtime/CL/gemm/CLGEMMKernelSelection.h"
35#include "src/runtime/CL/mlgo/MLGOHeuristics.h"
36#include "src/runtime/CL/mlgo/Utils.h"
37#include "utils/TypePrinter.h"
38
39namespace arm_compute
40{
41namespace cl_gemm
42{
43namespace auto_heuristics
44{
45CLGEMMKernelType auto_select_gemm_kernel(const CommonQuery &query, bool reshape_b_only_on_first_run)
46{
47 // Select between mlgo and default heuristics
48 auto mlgo_heuristics = CLScheduler::get().gemm_heuristics();
49 if(mlgo_heuristics != nullptr)
50 {
51 auto res = mlgo_heuristics->get()->query_gemm_type(mlgo::Query{ string_from_target(query.gpu_target), query.data_type, query.m, query.n, query.k, query.b });
52 if(res.first)
53 {
54 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use gemm kernel from mlgo heuristics: %s.", to_string(res.second).c_str());
55 return res.second;
56 }
57 }
58 std::unique_ptr<ICLGEMMKernelSelection> gemm_kernel = CLGEMMKernelSelectionFactory::create(query.gpu_target);
59 ARM_COMPUTE_ERROR_ON_NULLPTR(gemm_kernel.get());
60
61 CLGEMMKernelSelectionParams params;
62 params.m = query.m;
63 params.n = query.n;
64 params.k = query.k;
65 params.b = query.b;
66 params.is_rhs_constant = reshape_b_only_on_first_run;
67 params.data_type = query.data_type;
68
69 const auto kernel_type = gemm_kernel->select_kernel(params);
70 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use gemm kernel from default heuristics: %s.", to_string(kernel_type).c_str());
71 return kernel_type;
72}
73
74GEMMConfigResult select_default_gemm_config_reshaped_only_rhs(const CommonQuery &query)
75{
76 GEMMLHSMatrixInfo lhs_info;
77 GEMMRHSMatrixInfo rhs_info;
78 std::unique_ptr<ICLGEMMKernelConfiguration> gemm_config = CLGEMMReshapedOnlyRHSKernelConfigurationFactory::create(query.gpu_target);
79 ARM_COMPUTE_ERROR_ON_NULLPTR(gemm_config.get());
80 std::tie(lhs_info, rhs_info) = gemm_config->configure(query.m, query.n, query.k, query.b, query.data_type);
81 return GEMMConfigResult{ true, lhs_info, rhs_info };
82}
83
84GEMMConfigResult select_mlgo_gemm_config_reshaped_only_rhs(const CommonQuery &query)
85{
86 bool valid = false;
87 GEMMLHSMatrixInfo lhs_info;
88 GEMMRHSMatrixInfo rhs_info;
89 mlgo::GEMMConfigReshapedOnlyRHS config{};
90 auto mlgo_heuristics = CLScheduler::get().gemm_heuristics();
91 if(mlgo_heuristics != nullptr)
92 {
93 std::tie(valid, config) = mlgo_heuristics->get()->query_gemm_config_reshaped_only_rhs(mlgo::Query{ string_from_target(query.gpu_target), query.data_type, query.m, query.n, query.k, query.b });
94 }
95 if(valid)
96 {
97 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics query returns gemm config: %s.", to_string(config).c_str());
98 }
99 else
100 {
101 ARM_COMPUTE_LOG_INFO_MSG_CORE("MLGOHeuristics query failed");
102 }
103 std::tie(lhs_info, rhs_info) = configure_lhs_rhs_info(query.m, query.n, config.m0, config.n0, config.k0, 1, config.h0, false, config.interleave_rhs, !config.transpose_rhs, config.transpose_rhs,
104 config.export_cl_image);
105 return GEMMConfigResult{ valid, lhs_info, rhs_info };
106}
SiCong Li8c23ba12021-02-08 14:19:23 +0000107
108GEMMConfigResult select_default_gemm_config_reshaped(const CommonQuery &query)
109{
110 GEMMLHSMatrixInfo lhs_info;
111 GEMMRHSMatrixInfo rhs_info;
112 std::unique_ptr<ICLGEMMKernelConfiguration> gemm_config = CLGEMMReshapedKernelConfigurationFactory::create(query.gpu_target);
113 ARM_COMPUTE_ERROR_ON_NULLPTR(gemm_config.get());
114 std::tie(lhs_info, rhs_info) = gemm_config->configure(query.m, query.n, query.k, query.b, query.data_type);
115 return GEMMConfigResult{ true, lhs_info, rhs_info };
116}
117
118GEMMConfigResult select_mlgo_gemm_config_reshaped(const CommonQuery &query)
119{
120 bool valid = false;
121 GEMMLHSMatrixInfo lhs_info;
122 GEMMRHSMatrixInfo rhs_info;
123 mlgo::GEMMConfigReshaped config{};
124 auto mlgo_heuristics = CLScheduler::get().gemm_heuristics();
125 if(mlgo_heuristics != nullptr)
126 {
127 std::tie(valid, config) = mlgo_heuristics->get()->query_gemm_config_reshaped(mlgo::Query{ string_from_target(query.gpu_target), query.data_type, query.m, query.n, query.k, query.b });
128 }
129 if(valid)
130 {
131 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("MLGOHeuristics query returns gemm config: %s.", to_string(config).c_str());
132 }
133 else
134 {
135 ARM_COMPUTE_LOG_INFO_MSG_CORE("MLGOHeuristics query failed");
136 }
137 std::tie(lhs_info, rhs_info) = configure_lhs_rhs_info(query.m, query.n, config.m0, config.n0, config.k0, config.v0, config.h0, config.interleave_lhs, config.interleave_rhs, !config.transpose_rhs,
138 config.transpose_rhs,
139 config.export_cl_image);
140 return GEMMConfigResult{ valid, lhs_info, rhs_info };
141}
SiCong Libd8b1e22021-02-04 13:07:09 +0000142} // namespace auto_heuristics
SiCong Li8c23ba12021-02-08 14:19:23 +0000143
SiCong Libd8b1e22021-02-04 13:07:09 +0000144} // namespace cl_gemm
145} // namespace arm_compute