blob: 815c254c69d563509976c87d8de149cb61c3bdf0 [file] [log] [blame]
Georgios Pinitas856f66e2021-04-22 21:13:21 +01001/*
SiCong Li13bab712023-01-13 15:29:39 +00002 * Copyright (c) 2017-2023 Arm Limited.
Georgios Pinitas856f66e2021-04-22 21:13:21 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Georgios Pinitas7891a732021-08-20 21:39:25 +010024#include "src/gpu/cl/operators/ClGemm.h"
Georgios Pinitas856f66e2021-04-22 21:13:21 +010025
26#include "arm_compute/core/CL/CLKernelLibrary.h"
27#include "arm_compute/core/CL/ICLTensor.h"
28#include "arm_compute/core/Error.h"
29#include "arm_compute/core/GPUTarget.h"
30#include "arm_compute/core/Helpers.h"
31#include "arm_compute/core/KernelDescriptors.h"
32#include "arm_compute/core/Log.h"
33#include "arm_compute/core/TensorInfo.h"
34#include "arm_compute/core/Types.h"
35#include "arm_compute/core/Utils.h"
Georgios Pinitas856f66e2021-04-22 21:13:21 +010036#include "arm_compute/core/utils/misc/ShapeCalculator.h"
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010037#include "arm_compute/core/Validate.h"
Georgios Pinitas856f66e2021-04-22 21:13:21 +010038#include "arm_compute/runtime/CL/CLScheduler.h"
39#include "arm_compute/runtime/ITensorAllocator.h"
Georgios Pinitas2b147ee2021-07-08 18:14:45 +010040
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010041#include "src/common/utils/Log.h"
Georgios Pinitas856f66e2021-04-22 21:13:21 +010042#include "src/core/helpers/AutoConfiguration.h"
43#include "src/core/helpers/MemoryHelpers.h"
44#include "src/core/utils/helpers/float_ops.h"
Georgios Pinitas7891a732021-08-20 21:39:25 +010045#include "src/gpu/cl/IClKernel.h"
46#include "src/gpu/cl/utils/ClAuxTensorHandler.h"
Georgios Pinitas856f66e2021-04-22 21:13:21 +010047#include "src/runtime/CL/gemm/CLGEMMKernelSelection.h"
48#include "src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h"
Georgios Pinitas856f66e2021-04-22 21:13:21 +010049#include "support/Cast.h"
50#include "utils/TypePrinter.h"
51
52namespace arm_compute
53{
54namespace opencl
55{
56using namespace arm_compute::misc::shape_calculator;
57using namespace arm_compute::cl_gemm;
58using namespace arm_compute::experimental;
59using namespace arm_compute::utils::cast;
60using namespace arm_compute::opencl::kernels;
61
62namespace
63{
64inline bool validate_gemm_kernel(CLGEMMKernelType kernel_type)
65{
SiCongLi579ca842021-10-18 09:38:33 +010066 return kernel_type == CLGEMMKernelType::NATIVE ? false : true;
Georgios Pinitas856f66e2021-04-22 21:13:21 +010067}
68//Automatically select between mlgo (prioritized) and default heuristics for gemm kernel type
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010069inline CLGEMMKernelType
70auto_select_gemm_kernel(auto_heuristics::CommonQuery query, bool reshape_b_only_on_first_run, bool constant_weights)
Georgios Pinitas856f66e2021-04-22 21:13:21 +010071{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010072 if (!constant_weights)
Giorgio Arena4403ed32021-05-17 13:03:50 +010073 {
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +010074 return CLGEMMKernelType::NATIVE;
Giorgio Arena4403ed32021-05-17 13:03:50 +010075 }
76
Georgios Pinitas856f66e2021-04-22 21:13:21 +010077 auto gemm_kernel = auto_heuristics::select_mlgo_gemm_kernel(query, reshape_b_only_on_first_run);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010078 if (bool(gemm_kernel))
Georgios Pinitas856f66e2021-04-22 21:13:21 +010079 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010080 if (validate_gemm_kernel(gemm_kernel.gemm_type))
Georgios Pinitas856f66e2021-04-22 21:13:21 +010081 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010082 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use gemm kernel from mlgo heuristics: %s.",
83 to_string(gemm_kernel.gemm_type).c_str());
Georgios Pinitas856f66e2021-04-22 21:13:21 +010084 return gemm_kernel.gemm_type;
85 }
86 }
87 gemm_kernel = auto_heuristics::select_default_gemm_kernel(query, reshape_b_only_on_first_run);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010088 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use gemm kernel from default heuristics: %s.",
89 to_string(gemm_kernel.gemm_type).c_str());
Georgios Pinitas856f66e2021-04-22 21:13:21 +010090 return gemm_kernel.gemm_type;
91}
92// Validate lhs_info and rhs_info for reshaped only rhs kernel
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +010093inline bool validate_lhs_rhs_info_reshaped_only_rhs(const GEMMLHSMatrixInfo &lhs_info,
94 const GEMMRHSMatrixInfo &rhs_info,
95 const ITensorInfo *a,
96 const ITensorInfo *b,
97 const ITensorInfo *c,
98 const ITensorInfo *output,
99 GEMMKernelInfo gemm_kernel_info)
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100100{
101 // Validate GEMMLHSMatrixInfo and GEMMRHSMatrixInfo for reshaped only rhs kernel
102 TensorInfo tmp_b_info{};
103 // Validate reshape RHS kernel
104 auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100105 if (!bool(ClGemmReshapeRhsMatrixKernel::validate(b, &tmp_b_info, rhs_info)))
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100106 {
107 return false;
108 }
109 // Validate mm kernel
110 gemm_kernel_info.lhs_info = lhs_info;
111 gemm_kernel_info.rhs_info = rhs_info;
112 gemm_kernel_info.has_pad_y = false;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100113 if (!bool(ClGemmMatrixMultiplyReshapedOnlyRhsKernel::validate(a, &tmp_b_info, c, output, 1.f, 0.f, lhs_info,
114 rhs_info, gemm_kernel_info)))
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100115 {
116 return false;
117 }
118 gemm_kernel_info.has_pad_y = true;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100119 if (!bool(ClGemmMatrixMultiplyReshapedOnlyRhsKernel::validate(a, &tmp_b_info, c, output, 1.f, 0.f, lhs_info,
120 rhs_info, gemm_kernel_info)))
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100121 {
122 return false;
123 }
124 return true;
125}
126
127//Automatically select between mlgo (prioritized) and default heuristics for reshaped only rhs kernel configs
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100128inline std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
129auto_select_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery query,
130 GEMMKernelInfo kernel_info,
131 const ITensorInfo *a,
132 const ITensorInfo *b,
133 const ITensorInfo *c,
134 const ITensorInfo *output)
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100135{
136 auto config = auto_heuristics::select_mlgo_gemm_config_reshaped_only_rhs(query);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100137 if (config)
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100138 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100139 if (validate_lhs_rhs_info_reshaped_only_rhs(config.lhs_info, config.rhs_info, a, b, c, output, kernel_info))
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100140 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100141 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE(
142 "Use reshaped_only_rhs config from mlgo heuristics: LHS info: %s ; RHS info: %s ",
143 to_string(config.lhs_info).c_str(), to_string(config.rhs_info).c_str());
144 return {config.lhs_info, config.rhs_info};
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100145 }
146 }
147 config = auto_heuristics::select_default_gemm_config_reshaped_only_rhs(query);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100148 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE(
149 "Use reshaped_only_rhs config from default heuristics: LHS info: %s ; RHS info: %s ",
150 to_string(config.lhs_info).c_str(), to_string(config.rhs_info).c_str());
151 return {config.lhs_info, config.rhs_info};
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100152}
153
154// Validate lhs_info and rhs_info for reshaped kernel
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100155inline bool validate_lhs_rhs_info_reshaped(const GEMMLHSMatrixInfo &lhs_info,
156 const GEMMRHSMatrixInfo &rhs_info,
157 const ITensorInfo *a,
158 const ITensorInfo *b,
159 const ITensorInfo *c,
160 const ITensorInfo *output,
161 GEMMKernelInfo gemm_kernel_info,
162 bool reinterpret_input_as_3d)
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100163{
164 // Validate GEMMLHSMatrixInfo and GEMMRHSMatrixInfo for reshaped kernel
165 TensorInfo tmp_a_info{};
166 TensorInfo tmp_b_info{};
167
168 // Validate reshape LHS kernel
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100169 auto_init_if_empty(tmp_a_info,
170 a->clone()->set_tensor_shape(compute_lhs_reshaped_shape(*a, lhs_info, reinterpret_input_as_3d)));
171 if (!bool(ClGemmReshapeLhsMatrixKernel::validate(a, &tmp_a_info, lhs_info, reinterpret_input_as_3d)))
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100172 {
173 return false;
174 }
175
176 // Validate reshape RHS kernel
177 auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100178 if (!bool(ClGemmReshapeRhsMatrixKernel::validate(b, &tmp_b_info, rhs_info)))
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100179 {
180 return false;
181 }
182 // Validate mm kernel
183 gemm_kernel_info.lhs_info = lhs_info;
184 gemm_kernel_info.rhs_info = rhs_info;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100185 if (!bool(ClGemmMatrixMultiplyReshapedKernel::validate(&tmp_a_info, &tmp_b_info, c, output, 1.f, 0.f, lhs_info,
186 rhs_info, gemm_kernel_info)))
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100187 {
188 return false;
189 }
190 return true;
191}
192
193//Automatically select between mlgo (prioritized) and default heuristics for reshaped kernel configs
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100194inline std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo>
195auto_select_gemm_config_reshaped(auto_heuristics::CommonQuery query,
196 GEMMKernelInfo kernel_info,
197 const ITensorInfo *a,
198 const ITensorInfo *b,
199 const ITensorInfo *c,
200 const ITensorInfo *output,
201 bool reinterpret_input_as_3d)
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100202{
203 auto config = auto_heuristics::select_mlgo_gemm_config_reshaped(query);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100204 if (config)
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100205 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100206 if (validate_lhs_rhs_info_reshaped(config.lhs_info, config.rhs_info, a, b, c, output, kernel_info,
207 reinterpret_input_as_3d))
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100208 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100209 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE(
210 "Use reshaped config from mlgo heuristics: LHS info: %s ; RHS info: %s ",
211 to_string(config.lhs_info).c_str(), to_string(config.rhs_info).c_str());
212 return {config.lhs_info, config.rhs_info};
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100213 }
214 }
215 config = auto_heuristics::select_default_gemm_config_reshaped(query);
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100216 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE(
217 "Use reshaped config from default heuristics: LHS info: %s ; RHS info: %s ", to_string(config.lhs_info).c_str(),
218 to_string(config.rhs_info).c_str());
219 return {config.lhs_info, config.rhs_info};
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100220}
221} // namespace
222
223ClGemm::ClGemm()
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100224 : _reshape_lhs_kernel(std::make_unique<ClGemmReshapeLhsMatrixKernel>()),
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100225 _reshape_rhs_kernel(std::make_unique<ClGemmReshapeRhsMatrixKernel>()),
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100226 _mm_native_kernel(std::make_unique<ClGemmMatrixMultiplyNativeKernel>()),
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100227 _mm_reshaped_kernel(std::make_unique<ClGemmMatrixMultiplyReshapedKernel>()),
228 _mm_reshaped_only_rhs_kernel(std::make_unique<ClGemmMatrixMultiplyReshapedOnlyRhsKernel>()),
Gunes Bayir4bfc70e2021-12-10 16:17:56 +0000229 _mm_reshaped_only_rhs_mmul_kernel(std::make_unique<ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel>()),
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100230 _tmp_a(),
231 _tmp_b(),
232 _reshape_b_only_on_first_run(false),
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100233 _gemm_kernel_type(CLGEMMKernelType::NATIVE),
Manuel Bottinid87aded2021-07-16 10:23:31 +0100234 _is_prepared(false),
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100235 _aux_mem(AuxTensorIdx::Count)
236{
237}
238
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100239void ClGemm::configure_native(const CLCompileContext &compile_context,
240 ITensorInfo *a,
241 ITensorInfo *b,
242 ITensorInfo *c,
243 ITensorInfo *output,
244 float alpha,
245 float beta,
246 const GEMMInfo &gemm_info)
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100247{
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100248 DataType data_type = a->data_type();
249 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100250 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
251 const unsigned int n = b->dimension(0);
252 const unsigned int k = a->dimension(0);
253 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
254 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
255 const GPUTarget gpu_target = CLScheduler::get().target();
256 bool broadcast_bias = gemm_info.broadcast_bias();
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100257
258 GEMMKernelInfo kernel_info;
259 kernel_info.m = m;
260 kernel_info.n = n;
261 kernel_info.k = k;
262 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
263 kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
264 kernel_info.broadcast_bias = broadcast_bias;
265 kernel_info.activation_info = gemm_info.activation_info();
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100266
267 // Set the target for the kernels
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100268 _mm_native_kernel->set_target(gpu_target);
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100269
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100270 auto config = auto_heuristics::select_mlgo_gemm_config_reshaped_only_rhs(
271 auto_heuristics::CommonQuery{gpu_target, data_type, m, n, k, batch_size});
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100272
273 // Configure and tune matrix multiply kernel
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100274 _mm_native_kernel->configure(compile_context, a, b, c, output, alpha, beta, config.lhs_info, config.rhs_info,
275 kernel_info);
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100276}
277
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100278void ClGemm::configure_reshaped(const CLCompileContext &compile_context,
279 ITensorInfo *a,
280 ITensorInfo *b,
281 ITensorInfo *c,
282 ITensorInfo *output,
283 float alpha,
284 float beta,
285 const GEMMInfo &gemm_info)
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100286{
287 DataType data_type = a->data_type();
288 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100289 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
290 const unsigned int n = b->dimension(0);
291 const unsigned int k = a->dimension(0);
292 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
293 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
294 const GPUTarget gpu_target = CLScheduler::get().target();
295 bool broadcast_bias = gemm_info.broadcast_bias();
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100296
297 GEMMKernelInfo kernel_info;
298 kernel_info.m = m;
299 kernel_info.n = n;
300 kernel_info.k = k;
301 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
302 kernel_info.reinterpret_input_as_3d = false;
303 kernel_info.broadcast_bias = broadcast_bias;
304 kernel_info.activation_info = gemm_info.activation_info();
305
306 // Set the target for the kernels
307 _reshape_lhs_kernel->set_target(gpu_target);
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100308 _mm_reshaped_kernel->set_target(gpu_target);
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100309
310 GEMMLHSMatrixInfo lhs_info{};
311 GEMMRHSMatrixInfo rhs_info{};
312
313 // Pick up the GEMM configuration
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100314 std::tie(lhs_info, rhs_info) =
315 auto_select_gemm_config_reshaped(auto_heuristics::CommonQuery{gpu_target, data_type, m, n, k, batch_size},
316 kernel_info, a, b, c, output, gemm_info.reinterpret_input_as_3d());
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100317
318 _reshape_lhs_kernel->configure(compile_context, a, &_tmp_a, lhs_info, gemm_info.reinterpret_input_as_3d());
319 _reshape_rhs_kernel->configure(compile_context, b, &_tmp_b, rhs_info);
320
321 // Configure and tune matrix multiply kernel
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100322 _mm_reshaped_kernel->configure(compile_context, &_tmp_a, &_tmp_b, c, output, alpha, beta, lhs_info, rhs_info,
323 kernel_info);
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100324
325 // Request memory for LHS and RHS reshape matrix
326 _aux_mem[LhsReshape] = MemoryInfo(offset_int_vec(LhsReshape), MemoryLifetime::Temporary, _tmp_a.total_size());
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100327 _aux_mem[RhsReshape] = MemoryInfo(
328 offset_int_vec(RhsReshape),
329 _reshape_b_only_on_first_run ? MemoryLifetime::Persistent : MemoryLifetime::Temporary, _tmp_b.total_size());
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100330}
331
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100332void ClGemm::configure_reshaped_only_rhs(const CLCompileContext &compile_context,
333 ITensorInfo *a,
334 ITensorInfo *b,
335 ITensorInfo *c,
336 ITensorInfo *output,
337 float alpha,
338 float beta,
339 const GEMMInfo &gemm_info)
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100340{
341 DataType data_type = a->data_type();
342 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100343 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
344 const unsigned int n = b->dimension(0);
345 const unsigned int k = a->dimension(0);
346 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
347 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
348 const GPUTarget gpu_target = CLScheduler::get().target();
349 bool broadcast_bias = gemm_info.broadcast_bias();
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100350
351 GEMMKernelInfo kernel_info;
352 kernel_info.m = m;
353 kernel_info.n = n;
354 kernel_info.k = k;
355 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
356 kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
357 kernel_info.broadcast_bias = broadcast_bias;
358 kernel_info.activation_info = gemm_info.activation_info();
359
360 // Set the target for the kernels
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100361 _mm_reshaped_only_rhs_kernel->set_target(gpu_target);
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100362
363 GEMMLHSMatrixInfo lhs_info{};
364 GEMMRHSMatrixInfo rhs_info{};
365
366 // Pick up the GEMM configuration
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100367 std::tie(lhs_info, rhs_info) = auto_select_gemm_config_reshaped_only_rhs(
368 auto_heuristics::CommonQuery{gpu_target, data_type, m, n, k, batch_size}, kernel_info, a, b, c, output);
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100369
370 // Transpose matrix
371 _reshape_rhs_kernel->configure(compile_context, b, &_tmp_b, rhs_info);
372
373 // Configure two variants of CLGEMMMatrixMultiplyReshapedOnlyRHSKernel (has_pad_y = false/true)
374 // During the prepare stage we check the padding requirement for the lhs and dst tensors. If they do not have
375 // pad y, we dispatch CLGEMMMatrixMultiplyReshapedOnlyRHSKernel with has_pad_y = false
376
377 // Configure matrix multiply kernel with no y padding support
378 kernel_info.has_pad_y = false;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100379 _mm_reshaped_only_rhs_kernel->configure(compile_context, a, &_tmp_b, c, output, alpha, beta, lhs_info, rhs_info,
380 kernel_info);
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100381
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100382 // Request memory for RHS reshape matrix
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100383 _aux_mem[RhsReshape] = MemoryInfo(
384 offset_int_vec(RhsReshape),
385 _reshape_b_only_on_first_run ? MemoryLifetime::Persistent : MemoryLifetime::Temporary, _tmp_b.total_size());
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100386}
387
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100388void ClGemm::configure_reshaped_only_rhs_mmul(const CLCompileContext &compile_context,
389 ITensorInfo *a,
390 ITensorInfo *b,
391 ITensorInfo *c,
392 ITensorInfo *output,
393 float alpha,
394 float beta,
395 const GEMMInfo &gemm_info)
Gunes Bayir4bfc70e2021-12-10 16:17:56 +0000396{
397 DataType data_type = a->data_type();
398 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100399 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
400 const unsigned int n = b->dimension(0);
401 const unsigned int k = a->dimension(0);
402 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
403 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
404 const GPUTarget gpu_target = CLScheduler::get().target();
405 bool broadcast_bias = gemm_info.broadcast_bias();
Gunes Bayir4bfc70e2021-12-10 16:17:56 +0000406
407 GEMMKernelInfo kernel_info;
408 kernel_info.m = m;
409 kernel_info.n = n;
410 kernel_info.k = k;
411 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
412 kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
413 kernel_info.broadcast_bias = broadcast_bias;
414 kernel_info.activation_info = gemm_info.activation_info();
Gunes Bayir4bfc70e2021-12-10 16:17:56 +0000415
416 // Set the target for the kernels
417 _mm_reshaped_only_rhs_mmul_kernel->set_target(gpu_target);
418
419 GEMMLHSMatrixInfo lhs_info{};
420 GEMMRHSMatrixInfo rhs_info{};
421
422 // Pick up the GEMM configuration
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100423 auto gemm_config = select_default_gemm_config_reshaped_only_rhs(
424 auto_heuristics::CommonQuery{gpu_target, data_type, m, n, k, batch_size});
425 lhs_info = gemm_config.lhs_info;
426 rhs_info = gemm_config.rhs_info;
Gunes Bayir4bfc70e2021-12-10 16:17:56 +0000427 // Force H0 to 4 in order to use the MMUL extension
428 rhs_info.h0 = 4;
429
430 // Reshape Rhs matrix
431 _reshape_rhs_kernel->configure(compile_context, b, &_tmp_b, rhs_info);
432
433 // Configure matrix multiply kernel with no y padding support
434 kernel_info.has_pad_y = false;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100435 _mm_reshaped_only_rhs_mmul_kernel->configure(compile_context, a, &_tmp_b, c, output, alpha, beta, lhs_info,
436 rhs_info, kernel_info);
Gunes Bayir4bfc70e2021-12-10 16:17:56 +0000437
438 // Request memory for RHS reshape matrix
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100439 _aux_mem[RhsReshape] = MemoryInfo(
440 offset_int_vec(RhsReshape),
441 _reshape_b_only_on_first_run ? MemoryLifetime::Persistent : MemoryLifetime::Temporary, _tmp_b.total_size());
Gunes Bayir4bfc70e2021-12-10 16:17:56 +0000442}
443
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100444Status ClGemm::validate_native(const ITensorInfo *a,
445 const ITensorInfo *b,
446 const ITensorInfo *c,
447 const ITensorInfo *output,
448 float alpha,
449 float beta,
450 const GEMMInfo &gemm_info)
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100451{
452 ARM_COMPUTE_UNUSED(alpha);
453 ARM_COMPUTE_UNUSED(output);
454
455 // Get the GPU target
456 const GPUTarget gpu_target = CLScheduler::get().target();
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100457 DataType data_type = a->data_type();
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100458 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100459 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
460 const unsigned int n = b->dimension(0);
461 const unsigned int k = a->dimension(0);
462 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
463 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
464 const bool broadcast_bias = gemm_info.broadcast_bias();
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100465
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100466 GEMMKernelInfo kernel_info;
467 kernel_info.m = m;
468 kernel_info.n = n;
469 kernel_info.k = k;
470 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
471 kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
472 kernel_info.broadcast_bias = broadcast_bias;
473 kernel_info.activation_info = gemm_info.activation_info();
474
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100475 auto config = auto_heuristics::select_mlgo_gemm_config_reshaped_only_rhs(
476 auto_heuristics::CommonQuery{gpu_target, data_type, m, n, k, batch_size});
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100477
478 // Validate matrix multiply
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100479 ARM_COMPUTE_RETURN_ON_ERROR(ClGemmMatrixMultiplyNativeKernel::validate(
480 a, b, c, output, alpha, beta, config.lhs_info, config.rhs_info, kernel_info));
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100481
482 return Status{};
483}
484
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100485Status ClGemm::validate_reshaped(const ITensorInfo *a,
486 const ITensorInfo *b,
487 const ITensorInfo *c,
488 const ITensorInfo *output,
489 float alpha,
490 float beta,
491 const GEMMInfo &gemm_info)
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100492{
493 ARM_COMPUTE_UNUSED(alpha);
494 ARM_COMPUTE_UNUSED(output);
495
496 TensorInfo tmp_a_info{};
497 TensorInfo tmp_b_info{};
498
499 // Get the GPU target
500 const GPUTarget gpu_target = CLScheduler::get().target();
501 DataType data_type = a->data_type();
502 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100503 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
504 const unsigned int n = b->dimension(0);
505 const unsigned int k = a->dimension(0);
506 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
507 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
508 const bool broadcast_bias = gemm_info.broadcast_bias();
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100509
510 GEMMKernelInfo kernel_info;
511 kernel_info.m = m;
512 kernel_info.n = n;
513 kernel_info.k = k;
514 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
515 kernel_info.reinterpret_input_as_3d = false;
516 kernel_info.broadcast_bias = broadcast_bias;
517 kernel_info.activation_info = gemm_info.activation_info();
518
519 GEMMLHSMatrixInfo lhs_info;
520 GEMMRHSMatrixInfo rhs_info;
521
522 // Pick up the GEMM configuration
523 // NOTE: No need to validate mlgo configurations as they automatically fall back to default heuristics if validation fails
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100524 const auto gemm_config =
525 select_default_gemm_config_reshaped(auto_heuristics::CommonQuery{gpu_target, data_type, m, n, k, batch_size});
526 lhs_info = gemm_config.lhs_info;
527 rhs_info = gemm_config.rhs_info;
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100528
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100529 auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(
530 compute_lhs_reshaped_shape(*a, lhs_info, gemm_info.reinterpret_input_as_3d())));
531 ARM_COMPUTE_RETURN_ON_ERROR(
532 ClGemmReshapeLhsMatrixKernel::validate(a, &tmp_a_info, lhs_info, gemm_info.reinterpret_input_as_3d()));
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100533
534 auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
535 ARM_COMPUTE_RETURN_ON_ERROR(ClGemmReshapeRhsMatrixKernel::validate(b, &tmp_b_info, rhs_info));
536
537 // Validate matrix multiply
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100538 ARM_COMPUTE_RETURN_ON_ERROR(ClGemmMatrixMultiplyReshapedKernel::validate(&tmp_a_info, &tmp_b_info, c, output, alpha,
539 beta, lhs_info, rhs_info, kernel_info));
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100540
541 return Status{};
542}
543
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100544Status ClGemm::validate_reshaped_only_rhs(const ITensorInfo *a,
545 const ITensorInfo *b,
546 const ITensorInfo *c,
547 const ITensorInfo *output,
548 float alpha,
549 float beta,
550 const GEMMInfo &gemm_info)
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100551{
552 ARM_COMPUTE_UNUSED(alpha);
553 ARM_COMPUTE_UNUSED(output);
554
555 TensorInfo tmp_b_info{};
556
557 // Get the GPU target
558 const GPUTarget gpu_target = CLScheduler::get().target();
559 const DataType data_type = a->data_type();
560 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100561 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
562 const unsigned int n = b->dimension(0);
563 const unsigned int k = a->dimension(0);
564 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
565 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
566 const bool broadcast_bias = gemm_info.broadcast_bias();
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100567
568 GEMMKernelInfo kernel_info;
569 kernel_info.m = m;
570 kernel_info.n = n;
571 kernel_info.k = k;
572 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
573 kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
574 kernel_info.broadcast_bias = broadcast_bias;
575 kernel_info.activation_info = gemm_info.activation_info();
576
577 GEMMLHSMatrixInfo lhs_info;
578 GEMMRHSMatrixInfo rhs_info;
579
580 // Pick up the GEMM configuration
581 // NOTE: No need to validate mlgo configurations as they automatically fall back to default heuristics if validation fails
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100582 const auto gemm_config = select_default_gemm_config_reshaped_only_rhs(
583 auto_heuristics::CommonQuery{gpu_target, data_type, m, n, k, batch_size});
584 lhs_info = gemm_config.lhs_info;
585 rhs_info = gemm_config.rhs_info;
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100586
587 auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
588 ARM_COMPUTE_RETURN_ON_ERROR(ClGemmReshapeRhsMatrixKernel::validate(b, &tmp_b_info, rhs_info));
589
590 // Validate matrix multiply
591 kernel_info.has_pad_y = false;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100592 ARM_COMPUTE_RETURN_ON_ERROR(ClGemmMatrixMultiplyReshapedOnlyRhsKernel::validate(
593 a, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info));
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100594
Ramy Elgammal451c3092022-02-01 23:01:27 +0000595 kernel_info.has_pad_y = true;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100596 ARM_COMPUTE_RETURN_ON_ERROR(ClGemmMatrixMultiplyReshapedOnlyRhsKernel::validate(
597 a, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info));
Ramy Elgammal451c3092022-02-01 23:01:27 +0000598
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100599 return Status{};
600}
601
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100602Status ClGemm::validate_reshaped_only_rhs_mmul(const ITensorInfo *a,
603 const ITensorInfo *b,
604 const ITensorInfo *c,
605 const ITensorInfo *output,
606 float alpha,
607 float beta,
608 const GEMMInfo &gemm_info)
Gunes Bayir4bfc70e2021-12-10 16:17:56 +0000609{
610 ARM_COMPUTE_UNUSED(alpha);
611 ARM_COMPUTE_UNUSED(output);
612 TensorInfo tmp_b_info{};
613
614 // Get the GPU target
615 const GPUTarget gpu_target = CLScheduler::get().target();
616 const DataType data_type = a->data_type();
617 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100618 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
619 const unsigned int n = b->dimension(0);
620 const unsigned int k = a->dimension(0);
621 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
622 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
623 const bool broadcast_bias = gemm_info.broadcast_bias();
Gunes Bayir4bfc70e2021-12-10 16:17:56 +0000624
625 GEMMKernelInfo kernel_info;
626 kernel_info.m = m;
627 kernel_info.n = n;
628 kernel_info.k = k;
629 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
630 kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
631 kernel_info.broadcast_bias = broadcast_bias;
632 kernel_info.activation_info = gemm_info.activation_info();
Gunes Bayir4bfc70e2021-12-10 16:17:56 +0000633
634 GEMMLHSMatrixInfo lhs_info;
635 GEMMRHSMatrixInfo rhs_info;
636
637 // Pick up the GEMM configuration
638 // NOTE: No need to validate mlgo configurations as they automatically fall back to default heuristics if validation fails
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100639 const auto gemm_config = select_default_gemm_config_reshaped_only_rhs(
640 auto_heuristics::CommonQuery{gpu_target, data_type, m, n, k, batch_size});
641 lhs_info = gemm_config.lhs_info;
642 rhs_info = gemm_config.rhs_info;
Gunes Bayir4bfc70e2021-12-10 16:17:56 +0000643 // Force H0 to 4 in order to use the MMUL extension
644 rhs_info.h0 = 4;
645
646 auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
647 ARM_COMPUTE_RETURN_ON_ERROR(ClGemmReshapeRhsMatrixKernel::validate(b, &tmp_b_info, rhs_info));
648
649 // Validate matrix multiply
650 kernel_info.has_pad_y = false;
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100651 ARM_COMPUTE_RETURN_ON_ERROR(ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel::validate(
652 a, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info));
Gunes Bayir4bfc70e2021-12-10 16:17:56 +0000653
654 return Status{};
655}
656
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100657void ClGemm::configure(const CLCompileContext &compile_context,
658 ITensorInfo *a,
659 ITensorInfo *b,
660 ITensorInfo *c,
661 ITensorInfo *output,
662 float alpha,
663 float beta,
664 const GEMMInfo &gemm_info)
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100665{
666 ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, output);
667
668 // Perform validation step
669 ARM_COMPUTE_ERROR_THROW_ON(validate(a, b, c, output, alpha, beta, gemm_info));
ramelg012e53f172021-09-22 10:48:25 +0100670 ARM_COMPUTE_LOG_PARAMS(a, b, c, output, alpha, beta, gemm_info);
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100671
672 // Check if we need to reshape the matrix B only on the first run
673 _reshape_b_only_on_first_run = gemm_info.reshape_b_only_on_first_run();
Georgios Pinitasf5d51f32021-08-17 16:09:10 +0100674 _is_prepared = gemm_info.retain_internal_weights();
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100675
676 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100677 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
678 const unsigned int n = b->dimension(0);
679 const unsigned int k = a->dimension(0);
680 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100681
682 // Select GEMMType
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100683 _gemm_kernel_type = auto_select_gemm_kernel(
684 auto_heuristics::CommonQuery{CLScheduler::get().target(), a->data_type(), m, n, k, batch_size},
685 _reshape_b_only_on_first_run, b->are_values_constant());
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100686
687 const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr);
688
689 ITensorInfo *c_to_use = fuse_add_c ? c : nullptr;
690
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100691 switch (_gemm_kernel_type)
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100692 {
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100693 case CLGEMMKernelType::NATIVE:
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100694 {
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100695 configure_native(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100696 break;
697 }
698 case CLGEMMKernelType::RESHAPED:
699 {
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100700 configure_reshaped(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100701 break;
702 }
703 case CLGEMMKernelType::RESHAPED_ONLY_RHS:
704 {
705 configure_reshaped_only_rhs(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
706 break;
707 }
Gunes Bayir4bfc70e2021-12-10 16:17:56 +0000708 case CLGEMMKernelType::RESHAPED_ONLY_RHS_MMUL:
709 {
710 configure_reshaped_only_rhs_mmul(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
711 break;
712 }
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100713 default:
714 {
715 ARM_COMPUTE_ERROR("GEMMType not supported");
716 }
717 }
718}
719
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100720Status ClGemm::validate(const ITensorInfo *a,
721 const ITensorInfo *b,
722 const ITensorInfo *c,
723 const ITensorInfo *output,
724 float alpha,
725 float beta,
726 const GEMMInfo &gemm_info)
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100727{
728 // Get the GPU target
729 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100730 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
731 const unsigned int n = b->dimension(0);
732 const unsigned int k = a->dimension(0);
733 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100734
SiCong Li13bab712023-01-13 15:29:39 +0000735 // Check data type early because the auto_select_gemm_kernel has assertions on supported data types
736 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::F32, DataType::F16);
737
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100738 // Select GEMMType
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100739 CLGEMMKernelType gemm_kernel_type = auto_select_gemm_kernel(
740 auto_heuristics::CommonQuery{
741 CLScheduler::get().target(),
742 a->data_type(),
743 m,
744 n,
745 k,
746 batch_size,
747 },
748 gemm_info.reshape_b_only_on_first_run(), b->are_values_constant());
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100749
750 const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr);
751
752 const ITensorInfo *c_to_use = fuse_add_c ? c : nullptr;
753
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100754 switch (gemm_kernel_type)
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100755 {
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100756 case CLGEMMKernelType::NATIVE:
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100757 {
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100758 ARM_COMPUTE_RETURN_ON_ERROR(validate_native(a, b, c_to_use, output, alpha, beta, gemm_info));
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100759 break;
760 }
761 case CLGEMMKernelType::RESHAPED:
762 {
763 ARM_COMPUTE_RETURN_ON_ERROR(validate_reshaped(a, b, c_to_use, output, alpha, beta, gemm_info));
764 break;
765 }
766 case CLGEMMKernelType::RESHAPED_ONLY_RHS:
767 {
768 ARM_COMPUTE_RETURN_ON_ERROR(validate_reshaped_only_rhs(a, b, c_to_use, output, alpha, beta, gemm_info));
769 break;
770 }
Gunes Bayir4bfc70e2021-12-10 16:17:56 +0000771 case CLGEMMKernelType::RESHAPED_ONLY_RHS_MMUL:
772 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100773 ARM_COMPUTE_RETURN_ON_ERROR(
774 validate_reshaped_only_rhs_mmul(a, b, c_to_use, output, alpha, beta, gemm_info));
Gunes Bayir4bfc70e2021-12-10 16:17:56 +0000775 break;
776 }
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100777 default:
778 {
779 ARM_COMPUTE_RETURN_ERROR_MSG("GEMMType not supported");
780 }
781 }
782
783 return Status{};
784}
785
786void ClGemm::run(ITensorPack &tensors)
787{
SiCongLiafa19722021-10-24 19:12:33 +0100788 const ITensor *lhs = tensors.get_const_tensor(ACL_SRC_0);
789 const ITensor *rhs = tensors.get_const_tensor(ACL_SRC_1);
790 ITensor *dst = tensors.get_tensor(ACL_DST);
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100791
792 ARM_COMPUTE_ERROR_ON_NULLPTR(lhs, dst);
793
794 CLAuxTensorHandler lhs_reshaped(offset_int_vec(LhsReshape), _tmp_a, tensors, true);
795 CLAuxTensorHandler rhs_reshaped(offset_int_vec(RhsReshape), _tmp_b, tensors, true);
796
797 // Prepare the consts if needed
798 prepare(tensors);
799
800 // Run matrix multiply kernel
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100801 switch (_gemm_kernel_type)
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100802 {
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100803 case CLGEMMKernelType::NATIVE:
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100804 {
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100805 CLScheduler::get().enqueue_op(*_mm_native_kernel, tensors, true);
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100806 break;
807 }
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100808 case CLGEMMKernelType::RESHAPED:
809 {
810 // Run interleave kernel
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100811 ITensorPack reshape_lhs_pack{{ACL_SRC, lhs}, {ACL_DST, lhs_reshaped.get()}};
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100812 CLScheduler::get().enqueue_op(*_reshape_lhs_kernel, reshape_lhs_pack, false);
813
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100814 if (!_reshape_b_only_on_first_run)
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100815 {
816 // Run transpose kernel
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100817 ITensorPack reshape_rhs_pack{{ACL_SRC, rhs}, {ACL_DST, rhs_reshaped.get()}};
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100818 CLScheduler::get().enqueue_op(*_reshape_rhs_kernel, reshape_rhs_pack, false);
819 }
SiCongLi579ca842021-10-18 09:38:33 +0100820 // Copy original tensor pack and overwrite lhs and rhs with reshaped counterparts
821 ITensorPack gemm_reshaped_pack(tensors);
822 gemm_reshaped_pack.add_const_tensor(ACL_SRC_0, lhs_reshaped.get());
823 gemm_reshaped_pack.add_const_tensor(ACL_SRC_1, rhs_reshaped.get());
Manuel Bottinid87aded2021-07-16 10:23:31 +0100824
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100825 if (_gemm_kernel_type == CLGEMMKernelType::RESHAPED)
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100826 {
827 CLScheduler::get().enqueue_op(*_mm_reshaped_kernel, gemm_reshaped_pack, true);
828 }
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100829 break;
830 }
831 case CLGEMMKernelType::RESHAPED_ONLY_RHS:
832 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100833 if (!_reshape_b_only_on_first_run)
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100834 {
835 // Run transpose kernel
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100836 ITensorPack reshape_rhs_pack{{ACL_SRC, rhs}, {ACL_DST, rhs_reshaped.get()}};
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100837 CLScheduler::get().enqueue_op(*_reshape_rhs_kernel, reshape_rhs_pack, false);
838 }
839 // In case of RESHAPED_ONLY_RHS, we need to check the padding requirement
840 // Check if the lhs or dst tensors have padding
841 const unsigned int cross_plane_pad_lhs = lhs->info()->padding().top + lhs->info()->padding().bottom;
842 const unsigned int cross_plane_pad_dst = dst->info()->padding().top + dst->info()->padding().bottom;
843 bool has_pad_y = (cross_plane_pad_lhs != 0) || (cross_plane_pad_dst != 0);
844
SiCongLiafa19722021-10-24 19:12:33 +0100845 // Copy original tensor pack and overwrite rhs with reshaped counterpart
846 ITensorPack gemm_reshaped_onlyrhs_pack(tensors);
847 gemm_reshaped_onlyrhs_pack.add_const_tensor(ACL_SRC_1, rhs_reshaped.get());
848
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100849 if (has_pad_y)
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100850 {
ramelg019cca5922021-11-11 10:05:00 +0000851 ARM_COMPUTE_ERROR_ON(has_pad_y);
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100852 }
853 else
854 {
855 CLScheduler::get().enqueue_op(*_mm_reshaped_only_rhs_kernel, gemm_reshaped_onlyrhs_pack, true);
856 }
857 break;
858 }
Gunes Bayir4bfc70e2021-12-10 16:17:56 +0000859 case CLGEMMKernelType::RESHAPED_ONLY_RHS_MMUL:
860 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100861 if (!_reshape_b_only_on_first_run)
Gunes Bayir4bfc70e2021-12-10 16:17:56 +0000862 {
863 // Run transpose kernel
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100864 ITensorPack reshape_rhs_pack{{ACL_SRC, rhs}, {ACL_DST, rhs_reshaped.get()}};
Gunes Bayir4bfc70e2021-12-10 16:17:56 +0000865 CLScheduler::get().enqueue_op(*_reshape_rhs_kernel, reshape_rhs_pack, false);
866 }
867 // In case of RESHAPED_ONLY_RHS, we need to check the padding requirement
868 // Check if the lhs or dst tensors have padding
869 const unsigned int cross_plane_pad_lhs = lhs->info()->padding().top + lhs->info()->padding().bottom;
870 const unsigned int cross_plane_pad_dst = dst->info()->padding().top + dst->info()->padding().bottom;
871 bool has_pad_y = (cross_plane_pad_lhs != 0) || (cross_plane_pad_dst != 0);
872
873 // Copy original tensor pack and overwrite rhs with reshaped counterpart
874 ITensorPack gemm_reshaped_onlyrhs_pack(tensors);
875 gemm_reshaped_onlyrhs_pack.add_const_tensor(ACL_SRC_1, rhs_reshaped.get());
876
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100877 if (has_pad_y)
Gunes Bayir4bfc70e2021-12-10 16:17:56 +0000878 {
879 ARM_COMPUTE_ERROR_ON(has_pad_y);
880 }
881 else
882 {
883 CLScheduler::get().enqueue_op(*_mm_reshaped_only_rhs_mmul_kernel, gemm_reshaped_onlyrhs_pack, true);
884 }
885 break;
886 }
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100887 default:
888 {
889 ARM_COMPUTE_ERROR("GEMMType not supported");
890 }
891 }
892}
893
894void ClGemm::prepare(ITensorPack &constants)
895{
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100896 if (!_is_prepared)
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100897 {
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100898 const ITensor *src1 = constants.get_const_tensor(ACL_SRC_1);
899 ICLTensor *rhs_aux =
900 utils::cast::polymorphic_downcast<ICLTensor *>(constants.get_tensor(offset_int_vec(RhsReshape)));
Georgios Pinitas2b147ee2021-07-08 18:14:45 +0100901
Manuel Bottinid87aded2021-07-16 10:23:31 +0100902 // If memory for RHS is persistent and src1 is provided re-transform else assume that RHS is transformed
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100903 if ((_aux_mem[AuxTensorIdx::RhsReshape].lifetime == MemoryLifetime::Persistent) &&
904 (src1 != nullptr && rhs_aux != nullptr) && rhs_aux)
Manuel Bottinid87aded2021-07-16 10:23:31 +0100905 {
906 ARM_COMPUTE_LOG_INFO_WITH_FUNCNAME_ACL("Transforming RHS Matrix!");
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100907
Manuel Bottinid87aded2021-07-16 10:23:31 +0100908 CLAuxTensorHandler rhs_reshaped(_tmp_b, *rhs_aux);
909 ARM_COMPUTE_ERROR_ON(rhs_reshaped.get()->cl_buffer().get() == nullptr);
910
Felix Thomasmathibalanafd38f02023-09-27 17:46:17 +0100911 ITensorPack reshape_rhs_pack{{ACL_SRC, src1}, {ACL_DST, rhs_reshaped.get()}};
Manuel Bottinid87aded2021-07-16 10:23:31 +0100912 CLScheduler::get().enqueue_op(*_reshape_rhs_kernel, reshape_rhs_pack, true);
913 }
914 _is_prepared = true;
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100915 }
916}
917
918experimental::MemoryRequirements ClGemm::workspace() const
919{
920 return _aux_mem;
921}
922} // namespace opencl
923} // namespace arm_compute