blob: 8db6dabe580db97c2d82c50e9a5f322af1b5aa71 [file] [log] [blame]
Georgios Pinitas856f66e2021-04-22 21:13:21 +01001/*
SiCong Li13bab712023-01-13 15:29:39 +00002 * Copyright (c) 2017-2023 Arm Limited.
Georgios Pinitas856f66e2021-04-22 21:13:21 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Georgios Pinitas7891a732021-08-20 21:39:25 +010024#include "src/gpu/cl/operators/ClGemm.h"
Georgios Pinitas856f66e2021-04-22 21:13:21 +010025
26#include "arm_compute/core/CL/CLKernelLibrary.h"
27#include "arm_compute/core/CL/ICLTensor.h"
28#include "arm_compute/core/Error.h"
29#include "arm_compute/core/GPUTarget.h"
30#include "arm_compute/core/Helpers.h"
31#include "arm_compute/core/KernelDescriptors.h"
32#include "arm_compute/core/Log.h"
33#include "arm_compute/core/TensorInfo.h"
34#include "arm_compute/core/Types.h"
35#include "arm_compute/core/Utils.h"
36#include "arm_compute/core/Validate.h"
37#include "arm_compute/core/utils/misc/ShapeCalculator.h"
38#include "arm_compute/runtime/CL/CLScheduler.h"
39#include "arm_compute/runtime/ITensorAllocator.h"
Georgios Pinitas2b147ee2021-07-08 18:14:45 +010040
SiCongLi579ca842021-10-18 09:38:33 +010041#include "arm_compute/core/experimental/IPostOp.h"
Georgios Pinitas856f66e2021-04-22 21:13:21 +010042#include "src/core/helpers/AutoConfiguration.h"
43#include "src/core/helpers/MemoryHelpers.h"
44#include "src/core/utils/helpers/float_ops.h"
Georgios Pinitas7891a732021-08-20 21:39:25 +010045#include "src/gpu/cl/IClKernel.h"
46#include "src/gpu/cl/utils/ClAuxTensorHandler.h"
Georgios Pinitas856f66e2021-04-22 21:13:21 +010047#include "src/runtime/CL/gemm/CLGEMMKernelSelection.h"
48#include "src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h"
Georgios Pinitas856f66e2021-04-22 21:13:21 +010049
ramelg012e53f172021-09-22 10:48:25 +010050#include "src/common/utils/Log.h"
Georgios Pinitas856f66e2021-04-22 21:13:21 +010051#include "support/Cast.h"
52#include "utils/TypePrinter.h"
53
54namespace arm_compute
55{
56namespace opencl
57{
58using namespace arm_compute::misc::shape_calculator;
59using namespace arm_compute::cl_gemm;
60using namespace arm_compute::experimental;
61using namespace arm_compute::utils::cast;
62using namespace arm_compute::opencl::kernels;
63
64namespace
65{
66inline bool validate_gemm_kernel(CLGEMMKernelType kernel_type)
67{
SiCongLi579ca842021-10-18 09:38:33 +010068 return kernel_type == CLGEMMKernelType::NATIVE ? false : true;
Georgios Pinitas856f66e2021-04-22 21:13:21 +010069}
70//Automatically select between mlgo (prioritized) and default heuristics for gemm kernel type
Giorgio Arena4403ed32021-05-17 13:03:50 +010071inline CLGEMMKernelType auto_select_gemm_kernel(auto_heuristics::CommonQuery query, bool reshape_b_only_on_first_run, bool constant_weights)
Georgios Pinitas856f66e2021-04-22 21:13:21 +010072{
Giorgio Arena4403ed32021-05-17 13:03:50 +010073 if(!constant_weights)
74 {
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +010075 return CLGEMMKernelType::NATIVE;
Giorgio Arena4403ed32021-05-17 13:03:50 +010076 }
77
Georgios Pinitas856f66e2021-04-22 21:13:21 +010078 auto gemm_kernel = auto_heuristics::select_mlgo_gemm_kernel(query, reshape_b_only_on_first_run);
79 if(bool(gemm_kernel))
80 {
81 if(validate_gemm_kernel(gemm_kernel.gemm_type))
82 {
83 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use gemm kernel from mlgo heuristics: %s.", to_string(gemm_kernel.gemm_type).c_str());
84 return gemm_kernel.gemm_type;
85 }
86 }
87 gemm_kernel = auto_heuristics::select_default_gemm_kernel(query, reshape_b_only_on_first_run);
88 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use gemm kernel from default heuristics: %s.", to_string(gemm_kernel.gemm_type).c_str());
89 return gemm_kernel.gemm_type;
90}
91// Validate lhs_info and rhs_info for reshaped only rhs kernel
92inline bool validate_lhs_rhs_info_reshaped_only_rhs(const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c,
93 const ITensorInfo *output, GEMMKernelInfo gemm_kernel_info)
94{
95 // Validate GEMMLHSMatrixInfo and GEMMRHSMatrixInfo for reshaped only rhs kernel
96 TensorInfo tmp_b_info{};
97 // Validate reshape RHS kernel
98 auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
99 if(!bool(ClGemmReshapeRhsMatrixKernel::validate(b, &tmp_b_info, rhs_info)))
100 {
101 return false;
102 }
103 // Validate mm kernel
104 gemm_kernel_info.lhs_info = lhs_info;
105 gemm_kernel_info.rhs_info = rhs_info;
106 gemm_kernel_info.has_pad_y = false;
107 if(!bool(ClGemmMatrixMultiplyReshapedOnlyRhsKernel::validate(a, &tmp_b_info, c, output, 1.f, 0.f, lhs_info, rhs_info, gemm_kernel_info)))
108 {
109 return false;
110 }
111 gemm_kernel_info.has_pad_y = true;
112 if(!bool(ClGemmMatrixMultiplyReshapedOnlyRhsKernel::validate(a, &tmp_b_info, c, output, 1.f, 0.f, lhs_info, rhs_info, gemm_kernel_info)))
113 {
114 return false;
115 }
116 return true;
117}
118
119//Automatically select between mlgo (prioritized) and default heuristics for reshaped only rhs kernel configs
120inline std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> auto_select_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery query, GEMMKernelInfo kernel_info, const ITensorInfo *a,
121 const ITensorInfo *b,
122 const ITensorInfo *c, const ITensorInfo *output)
123{
124 auto config = auto_heuristics::select_mlgo_gemm_config_reshaped_only_rhs(query);
125 if(config)
126 {
127 if(validate_lhs_rhs_info_reshaped_only_rhs(config.lhs_info, config.rhs_info, a, b, c, output, kernel_info))
128 {
129 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use reshaped_only_rhs config from mlgo heuristics: LHS info: %s ; RHS info: %s ", to_string(config.lhs_info).c_str(), to_string(config.rhs_info).c_str());
130 return { config.lhs_info, config.rhs_info };
131 }
132 }
133 config = auto_heuristics::select_default_gemm_config_reshaped_only_rhs(query);
134 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use reshaped_only_rhs config from default heuristics: LHS info: %s ; RHS info: %s ", to_string(config.lhs_info).c_str(), to_string(config.rhs_info).c_str());
135 return { config.lhs_info, config.rhs_info };
136}
137
138// Validate lhs_info and rhs_info for reshaped kernel
139inline bool validate_lhs_rhs_info_reshaped(const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c,
140 const ITensorInfo *output, GEMMKernelInfo gemm_kernel_info, bool reinterpret_input_as_3d)
141{
142 // Validate GEMMLHSMatrixInfo and GEMMRHSMatrixInfo for reshaped kernel
143 TensorInfo tmp_a_info{};
144 TensorInfo tmp_b_info{};
145
146 // Validate reshape LHS kernel
147 auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_lhs_reshaped_shape(*a, lhs_info, reinterpret_input_as_3d)));
148 if(!bool(ClGemmReshapeLhsMatrixKernel::validate(a, &tmp_a_info, lhs_info, reinterpret_input_as_3d)))
149 {
150 return false;
151 }
152
153 // Validate reshape RHS kernel
154 auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
155 if(!bool(ClGemmReshapeRhsMatrixKernel::validate(b, &tmp_b_info, rhs_info)))
156 {
157 return false;
158 }
159 // Validate mm kernel
160 gemm_kernel_info.lhs_info = lhs_info;
161 gemm_kernel_info.rhs_info = rhs_info;
162 if(!bool(ClGemmMatrixMultiplyReshapedKernel::validate(&tmp_a_info, &tmp_b_info, c, output, 1.f, 0.f, lhs_info, rhs_info, gemm_kernel_info)))
163 {
164 return false;
165 }
166 return true;
167}
168
169//Automatically select between mlgo (prioritized) and default heuristics for reshaped kernel configs
170inline std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> auto_select_gemm_config_reshaped(auto_heuristics::CommonQuery query, GEMMKernelInfo kernel_info, const ITensorInfo *a, const ITensorInfo *b,
171 const ITensorInfo *c, const ITensorInfo *output, bool reinterpret_input_as_3d)
172{
173 auto config = auto_heuristics::select_mlgo_gemm_config_reshaped(query);
174 if(config)
175 {
176 if(validate_lhs_rhs_info_reshaped(config.lhs_info, config.rhs_info, a, b, c, output, kernel_info, reinterpret_input_as_3d))
177 {
178 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use reshaped config from mlgo heuristics: LHS info: %s ; RHS info: %s ", to_string(config.lhs_info).c_str(), to_string(config.rhs_info).c_str());
179 return { config.lhs_info, config.rhs_info };
180 }
181 }
182 config = auto_heuristics::select_default_gemm_config_reshaped(query);
183 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use reshaped config from default heuristics: LHS info: %s ; RHS info: %s ", to_string(config.lhs_info).c_str(), to_string(config.rhs_info).c_str());
184 return { config.lhs_info, config.rhs_info };
185}
186} // namespace
187
188ClGemm::ClGemm()
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100189 : _reshape_lhs_kernel(std::make_unique<ClGemmReshapeLhsMatrixKernel>()),
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100190 _reshape_rhs_kernel(std::make_unique<ClGemmReshapeRhsMatrixKernel>()),
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100191 _mm_native_kernel(std::make_unique<ClGemmMatrixMultiplyNativeKernel>()),
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100192 _mm_reshaped_kernel(std::make_unique<ClGemmMatrixMultiplyReshapedKernel>()),
193 _mm_reshaped_only_rhs_kernel(std::make_unique<ClGemmMatrixMultiplyReshapedOnlyRhsKernel>()),
Gunes Bayir4bfc70e2021-12-10 16:17:56 +0000194 _mm_reshaped_only_rhs_mmul_kernel(std::make_unique<ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel>()),
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100195 _tmp_a(),
196 _tmp_b(),
197 _reshape_b_only_on_first_run(false),
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100198 _gemm_kernel_type(CLGEMMKernelType::NATIVE),
Manuel Bottinid87aded2021-07-16 10:23:31 +0100199 _is_prepared(false),
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100200 _aux_mem(AuxTensorIdx::Count)
201{
202}
203
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100204void ClGemm::configure_native(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta,
205 const GEMMInfo &gemm_info)
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100206{
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100207 DataType data_type = a->data_type();
208 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
209 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
210 const unsigned int n = b->dimension(0);
211 const unsigned int k = a->dimension(0);
212 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
213 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
214 const GPUTarget gpu_target = CLScheduler::get().target();
215 bool broadcast_bias = gemm_info.broadcast_bias();
216
217 GEMMKernelInfo kernel_info;
218 kernel_info.m = m;
219 kernel_info.n = n;
220 kernel_info.k = k;
221 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
222 kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
223 kernel_info.broadcast_bias = broadcast_bias;
224 kernel_info.activation_info = gemm_info.activation_info();
SiCongLiafa19722021-10-24 19:12:33 +0100225 kernel_info.post_ops = gemm_info.post_ops();
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100226
227 // Set the target for the kernels
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100228 _mm_native_kernel->set_target(gpu_target);
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100229
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100230 auto config = auto_heuristics::select_mlgo_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size });
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100231
232 // Configure and tune matrix multiply kernel
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100233 _mm_native_kernel->configure(compile_context, a, b, c, output, alpha, beta, config.lhs_info, config.rhs_info, kernel_info);
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100234}
235
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100236void ClGemm::configure_reshaped(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta,
237 const GEMMInfo &gemm_info)
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100238{
239 DataType data_type = a->data_type();
240 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
241 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
242 const unsigned int n = b->dimension(0);
243 const unsigned int k = a->dimension(0);
244 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
245 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
246 const GPUTarget gpu_target = CLScheduler::get().target();
247 bool broadcast_bias = gemm_info.broadcast_bias();
248
249 GEMMKernelInfo kernel_info;
250 kernel_info.m = m;
251 kernel_info.n = n;
252 kernel_info.k = k;
253 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
254 kernel_info.reinterpret_input_as_3d = false;
255 kernel_info.broadcast_bias = broadcast_bias;
256 kernel_info.activation_info = gemm_info.activation_info();
SiCongLi579ca842021-10-18 09:38:33 +0100257 kernel_info.post_ops = gemm_info.post_ops();
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100258
259 // Set the target for the kernels
260 _reshape_lhs_kernel->set_target(gpu_target);
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100261 _mm_reshaped_kernel->set_target(gpu_target);
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100262
263 GEMMLHSMatrixInfo lhs_info{};
264 GEMMRHSMatrixInfo rhs_info{};
265
266 // Pick up the GEMM configuration
267 std::tie(lhs_info, rhs_info) = auto_select_gemm_config_reshaped(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size }, kernel_info, a, b,
268 c, output, gemm_info.reinterpret_input_as_3d());
269
270 _reshape_lhs_kernel->configure(compile_context, a, &_tmp_a, lhs_info, gemm_info.reinterpret_input_as_3d());
271 _reshape_rhs_kernel->configure(compile_context, b, &_tmp_b, rhs_info);
272
273 // Configure and tune matrix multiply kernel
274 _mm_reshaped_kernel->configure(compile_context, &_tmp_a, &_tmp_b, c, output, alpha, beta, lhs_info, rhs_info, kernel_info);
275
276 // Request memory for LHS and RHS reshape matrix
277 _aux_mem[LhsReshape] = MemoryInfo(offset_int_vec(LhsReshape), MemoryLifetime::Temporary, _tmp_a.total_size());
278 _aux_mem[RhsReshape] = MemoryInfo(offset_int_vec(RhsReshape), _reshape_b_only_on_first_run ? MemoryLifetime::Persistent : MemoryLifetime::Temporary, _tmp_b.total_size());
279}
280
281void ClGemm::configure_reshaped_only_rhs(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta,
282 const GEMMInfo &gemm_info)
283{
284 DataType data_type = a->data_type();
285 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
286 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
287 const unsigned int n = b->dimension(0);
288 const unsigned int k = a->dimension(0);
289 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
290 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
291 const GPUTarget gpu_target = CLScheduler::get().target();
292 bool broadcast_bias = gemm_info.broadcast_bias();
293
294 GEMMKernelInfo kernel_info;
295 kernel_info.m = m;
296 kernel_info.n = n;
297 kernel_info.k = k;
298 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
299 kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
300 kernel_info.broadcast_bias = broadcast_bias;
301 kernel_info.activation_info = gemm_info.activation_info();
SiCongLiafa19722021-10-24 19:12:33 +0100302 kernel_info.post_ops = gemm_info.post_ops();
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100303
304 // Set the target for the kernels
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100305 _mm_reshaped_only_rhs_kernel->set_target(gpu_target);
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100306
307 GEMMLHSMatrixInfo lhs_info{};
308 GEMMRHSMatrixInfo rhs_info{};
309
310 // Pick up the GEMM configuration
311 std::tie(lhs_info, rhs_info) = auto_select_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size }, kernel_info, a, b, c, output);
312
313 // Transpose matrix
314 _reshape_rhs_kernel->configure(compile_context, b, &_tmp_b, rhs_info);
315
316 // Configure two variants of CLGEMMMatrixMultiplyReshapedOnlyRHSKernel (has_pad_y = false/true)
317 // During the prepare stage we check the padding requirement for the lhs and dst tensors. If they do not have
318 // pad y, we dispatch CLGEMMMatrixMultiplyReshapedOnlyRHSKernel with has_pad_y = false
319
320 // Configure matrix multiply kernel with no y padding support
321 kernel_info.has_pad_y = false;
322 _mm_reshaped_only_rhs_kernel->configure(compile_context, a, &_tmp_b, c, output, alpha, beta, lhs_info, rhs_info, kernel_info);
323
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100324 // Request memory for RHS reshape matrix
325 _aux_mem[RhsReshape] = MemoryInfo(offset_int_vec(RhsReshape), _reshape_b_only_on_first_run ? MemoryLifetime::Persistent : MemoryLifetime::Temporary, _tmp_b.total_size());
326}
327
Gunes Bayir4bfc70e2021-12-10 16:17:56 +0000328void ClGemm::configure_reshaped_only_rhs_mmul(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta,
329 const GEMMInfo &gemm_info)
330{
331 DataType data_type = a->data_type();
332 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
333 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
334 const unsigned int n = b->dimension(0);
335 const unsigned int k = a->dimension(0);
336 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
337 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
338 const GPUTarget gpu_target = CLScheduler::get().target();
339 bool broadcast_bias = gemm_info.broadcast_bias();
340
341 GEMMKernelInfo kernel_info;
342 kernel_info.m = m;
343 kernel_info.n = n;
344 kernel_info.k = k;
345 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
346 kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
347 kernel_info.broadcast_bias = broadcast_bias;
348 kernel_info.activation_info = gemm_info.activation_info();
349 kernel_info.post_ops = gemm_info.post_ops();
350
351 // Set the target for the kernels
352 _mm_reshaped_only_rhs_mmul_kernel->set_target(gpu_target);
353
354 GEMMLHSMatrixInfo lhs_info{};
355 GEMMRHSMatrixInfo rhs_info{};
356
357 // Pick up the GEMM configuration
358 auto gemm_config = select_default_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size });
359 lhs_info = gemm_config.lhs_info;
360 rhs_info = gemm_config.rhs_info;
361 // Force H0 to 4 in order to use the MMUL extension
362 rhs_info.h0 = 4;
363
364 // Reshape Rhs matrix
365 _reshape_rhs_kernel->configure(compile_context, b, &_tmp_b, rhs_info);
366
367 // Configure matrix multiply kernel with no y padding support
368 kernel_info.has_pad_y = false;
369 _mm_reshaped_only_rhs_mmul_kernel->configure(compile_context, a, &_tmp_b, c, output, alpha, beta, lhs_info, rhs_info, kernel_info);
370
371 // Request memory for RHS reshape matrix
372 _aux_mem[RhsReshape] = MemoryInfo(offset_int_vec(RhsReshape), _reshape_b_only_on_first_run ? MemoryLifetime::Persistent : MemoryLifetime::Temporary, _tmp_b.total_size());
373}
374
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100375Status ClGemm::validate_native(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100376{
377 ARM_COMPUTE_UNUSED(alpha);
378 ARM_COMPUTE_UNUSED(output);
379
380 // Get the GPU target
381 const GPUTarget gpu_target = CLScheduler::get().target();
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100382 DataType data_type = a->data_type();
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100383 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
384 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
385 const unsigned int n = b->dimension(0);
386 const unsigned int k = a->dimension(0);
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100387 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100388 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100389 const bool broadcast_bias = gemm_info.broadcast_bias();
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100390
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100391 GEMMKernelInfo kernel_info;
392 kernel_info.m = m;
393 kernel_info.n = n;
394 kernel_info.k = k;
395 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
396 kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
397 kernel_info.broadcast_bias = broadcast_bias;
398 kernel_info.activation_info = gemm_info.activation_info();
SiCongLiafa19722021-10-24 19:12:33 +0100399 kernel_info.post_ops = gemm_info.post_ops();
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100400
401 auto config = auto_heuristics::select_mlgo_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size });
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100402
403 // Validate matrix multiply
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100404 ARM_COMPUTE_RETURN_ON_ERROR(ClGemmMatrixMultiplyNativeKernel::validate(a, b, c, output, alpha, beta, config.lhs_info, config.rhs_info, kernel_info));
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100405
406 return Status{};
407}
408
409Status ClGemm::validate_reshaped(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
410{
411 ARM_COMPUTE_UNUSED(alpha);
412 ARM_COMPUTE_UNUSED(output);
413
414 TensorInfo tmp_a_info{};
415 TensorInfo tmp_b_info{};
416
417 // Get the GPU target
418 const GPUTarget gpu_target = CLScheduler::get().target();
419 DataType data_type = a->data_type();
420 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
421 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
422 const unsigned int n = b->dimension(0);
423 const unsigned int k = a->dimension(0);
424 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
425 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
426 const bool broadcast_bias = gemm_info.broadcast_bias();
427
428 GEMMKernelInfo kernel_info;
429 kernel_info.m = m;
430 kernel_info.n = n;
431 kernel_info.k = k;
432 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
433 kernel_info.reinterpret_input_as_3d = false;
434 kernel_info.broadcast_bias = broadcast_bias;
435 kernel_info.activation_info = gemm_info.activation_info();
SiCongLi579ca842021-10-18 09:38:33 +0100436 kernel_info.post_ops = gemm_info.post_ops();
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100437
438 GEMMLHSMatrixInfo lhs_info;
439 GEMMRHSMatrixInfo rhs_info;
440
441 // Pick up the GEMM configuration
442 // NOTE: No need to validate mlgo configurations as they automatically fall back to default heuristics if validation fails
443 const auto gemm_config = select_default_gemm_config_reshaped(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size });
444 lhs_info = gemm_config.lhs_info;
445 rhs_info = gemm_config.rhs_info;
446
447 auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_lhs_reshaped_shape(*a, lhs_info, gemm_info.reinterpret_input_as_3d())));
448 ARM_COMPUTE_RETURN_ON_ERROR(ClGemmReshapeLhsMatrixKernel::validate(a, &tmp_a_info, lhs_info, gemm_info.reinterpret_input_as_3d()));
449
450 auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
451 ARM_COMPUTE_RETURN_ON_ERROR(ClGemmReshapeRhsMatrixKernel::validate(b, &tmp_b_info, rhs_info));
452
453 // Validate matrix multiply
454 ARM_COMPUTE_RETURN_ON_ERROR(ClGemmMatrixMultiplyReshapedKernel::validate(&tmp_a_info, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info));
455
456 return Status{};
457}
458
459Status ClGemm::validate_reshaped_only_rhs(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
460{
461 ARM_COMPUTE_UNUSED(alpha);
462 ARM_COMPUTE_UNUSED(output);
463
464 TensorInfo tmp_b_info{};
465
466 // Get the GPU target
467 const GPUTarget gpu_target = CLScheduler::get().target();
468 const DataType data_type = a->data_type();
469 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
470 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
471 const unsigned int n = b->dimension(0);
472 const unsigned int k = a->dimension(0);
473 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
474 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
475 const bool broadcast_bias = gemm_info.broadcast_bias();
476
477 GEMMKernelInfo kernel_info;
478 kernel_info.m = m;
479 kernel_info.n = n;
480 kernel_info.k = k;
481 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
482 kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
483 kernel_info.broadcast_bias = broadcast_bias;
484 kernel_info.activation_info = gemm_info.activation_info();
SiCongLiafa19722021-10-24 19:12:33 +0100485 kernel_info.post_ops = gemm_info.post_ops();
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100486
487 GEMMLHSMatrixInfo lhs_info;
488 GEMMRHSMatrixInfo rhs_info;
489
490 // Pick up the GEMM configuration
491 // NOTE: No need to validate mlgo configurations as they automatically fall back to default heuristics if validation fails
492 const auto gemm_config = select_default_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size });
493 lhs_info = gemm_config.lhs_info;
494 rhs_info = gemm_config.rhs_info;
495
496 auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
497 ARM_COMPUTE_RETURN_ON_ERROR(ClGemmReshapeRhsMatrixKernel::validate(b, &tmp_b_info, rhs_info));
498
499 // Validate matrix multiply
500 kernel_info.has_pad_y = false;
501 ARM_COMPUTE_RETURN_ON_ERROR(ClGemmMatrixMultiplyReshapedOnlyRhsKernel::validate(a, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info));
502
Ramy Elgammal451c3092022-02-01 23:01:27 +0000503 kernel_info.has_pad_y = true;
504 ARM_COMPUTE_RETURN_ON_ERROR(ClGemmMatrixMultiplyReshapedOnlyRhsKernel::validate(a, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info));
505
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100506 return Status{};
507}
508
Gunes Bayir4bfc70e2021-12-10 16:17:56 +0000509Status ClGemm::validate_reshaped_only_rhs_mmul(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
510{
511 ARM_COMPUTE_UNUSED(alpha);
512 ARM_COMPUTE_UNUSED(output);
513 TensorInfo tmp_b_info{};
514
515 // Get the GPU target
516 const GPUTarget gpu_target = CLScheduler::get().target();
517 const DataType data_type = a->data_type();
518 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
519 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
520 const unsigned int n = b->dimension(0);
521 const unsigned int k = a->dimension(0);
522 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
523 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
524 const bool broadcast_bias = gemm_info.broadcast_bias();
525
526 GEMMKernelInfo kernel_info;
527 kernel_info.m = m;
528 kernel_info.n = n;
529 kernel_info.k = k;
530 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
531 kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
532 kernel_info.broadcast_bias = broadcast_bias;
533 kernel_info.activation_info = gemm_info.activation_info();
534 kernel_info.post_ops = gemm_info.post_ops();
535
536 GEMMLHSMatrixInfo lhs_info;
537 GEMMRHSMatrixInfo rhs_info;
538
539 // Pick up the GEMM configuration
540 // NOTE: No need to validate mlgo configurations as they automatically fall back to default heuristics if validation fails
541 const auto gemm_config = select_default_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size });
542 lhs_info = gemm_config.lhs_info;
543 rhs_info = gemm_config.rhs_info;
544 // Force H0 to 4 in order to use the MMUL extension
545 rhs_info.h0 = 4;
546
547 auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
548 ARM_COMPUTE_RETURN_ON_ERROR(ClGemmReshapeRhsMatrixKernel::validate(b, &tmp_b_info, rhs_info));
549
550 // Validate matrix multiply
551 kernel_info.has_pad_y = false;
552 ARM_COMPUTE_RETURN_ON_ERROR(ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel::validate(a, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info));
553
554 return Status{};
555}
556
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100557void ClGemm::configure(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
558{
559 ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, output);
560
561 // Perform validation step
562 ARM_COMPUTE_ERROR_THROW_ON(validate(a, b, c, output, alpha, beta, gemm_info));
ramelg012e53f172021-09-22 10:48:25 +0100563 ARM_COMPUTE_LOG_PARAMS(a, b, c, output, alpha, beta, gemm_info);
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100564
565 // Check if we need to reshape the matrix B only on the first run
566 _reshape_b_only_on_first_run = gemm_info.reshape_b_only_on_first_run();
Georgios Pinitasf5d51f32021-08-17 16:09:10 +0100567 _is_prepared = gemm_info.retain_internal_weights();
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100568
569 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
570 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
571 const unsigned int n = b->dimension(0);
572 const unsigned int k = a->dimension(0);
573 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
574
575 // Select GEMMType
Giorgio Arena4403ed32021-05-17 13:03:50 +0100576 _gemm_kernel_type = auto_select_gemm_kernel(auto_heuristics::CommonQuery{ CLScheduler::get().target(), a->data_type(), m, n, k, batch_size }, _reshape_b_only_on_first_run,
Giorgio Arena63e0beb2021-09-24 14:04:27 +0100577 b->are_values_constant());
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100578
579 const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr);
580
581 ITensorInfo *c_to_use = fuse_add_c ? c : nullptr;
582
583 switch(_gemm_kernel_type)
584 {
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100585 case CLGEMMKernelType::NATIVE:
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100586 {
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100587 configure_native(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100588 break;
589 }
590 case CLGEMMKernelType::RESHAPED:
591 {
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100592 configure_reshaped(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100593 break;
594 }
595 case CLGEMMKernelType::RESHAPED_ONLY_RHS:
596 {
597 configure_reshaped_only_rhs(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
598 break;
599 }
Gunes Bayir4bfc70e2021-12-10 16:17:56 +0000600 case CLGEMMKernelType::RESHAPED_ONLY_RHS_MMUL:
601 {
602 configure_reshaped_only_rhs_mmul(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
603 break;
604 }
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100605 default:
606 {
607 ARM_COMPUTE_ERROR("GEMMType not supported");
608 }
609 }
610}
611
612Status ClGemm::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
613{
614 // Get the GPU target
615 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
616 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
617 const unsigned int n = b->dimension(0);
618 const unsigned int k = a->dimension(0);
619 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
620
SiCong Li13bab712023-01-13 15:29:39 +0000621 // Check data type early because the auto_select_gemm_kernel has assertions on supported data types
622 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::F32, DataType::F16);
623
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100624 // Select GEMMType
625 CLGEMMKernelType gemm_kernel_type = auto_select_gemm_kernel(auto_heuristics::CommonQuery
626 {
627 CLScheduler::get().target(), a->data_type(), m, n, k, batch_size,
628 },
Giorgio Arena63e0beb2021-09-24 14:04:27 +0100629 gemm_info.reshape_b_only_on_first_run(), b->are_values_constant());
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100630
631 const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr);
632
633 const ITensorInfo *c_to_use = fuse_add_c ? c : nullptr;
634
635 switch(gemm_kernel_type)
636 {
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100637 case CLGEMMKernelType::NATIVE:
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100638 {
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100639 ARM_COMPUTE_RETURN_ON_ERROR(validate_native(a, b, c_to_use, output, alpha, beta, gemm_info));
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100640 break;
641 }
642 case CLGEMMKernelType::RESHAPED:
643 {
644 ARM_COMPUTE_RETURN_ON_ERROR(validate_reshaped(a, b, c_to_use, output, alpha, beta, gemm_info));
645 break;
646 }
647 case CLGEMMKernelType::RESHAPED_ONLY_RHS:
648 {
649 ARM_COMPUTE_RETURN_ON_ERROR(validate_reshaped_only_rhs(a, b, c_to_use, output, alpha, beta, gemm_info));
650 break;
651 }
Gunes Bayir4bfc70e2021-12-10 16:17:56 +0000652 case CLGEMMKernelType::RESHAPED_ONLY_RHS_MMUL:
653 {
654 ARM_COMPUTE_RETURN_ON_ERROR(validate_reshaped_only_rhs_mmul(a, b, c_to_use, output, alpha, beta, gemm_info));
655 break;
656 }
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100657 default:
658 {
659 ARM_COMPUTE_RETURN_ERROR_MSG("GEMMType not supported");
660 }
661 }
662
663 return Status{};
664}
665
666void ClGemm::run(ITensorPack &tensors)
667{
SiCongLiafa19722021-10-24 19:12:33 +0100668 const ITensor *lhs = tensors.get_const_tensor(ACL_SRC_0);
669 const ITensor *rhs = tensors.get_const_tensor(ACL_SRC_1);
670 ITensor *dst = tensors.get_tensor(ACL_DST);
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100671
672 ARM_COMPUTE_ERROR_ON_NULLPTR(lhs, dst);
673
674 CLAuxTensorHandler lhs_reshaped(offset_int_vec(LhsReshape), _tmp_a, tensors, true);
675 CLAuxTensorHandler rhs_reshaped(offset_int_vec(RhsReshape), _tmp_b, tensors, true);
676
677 // Prepare the consts if needed
678 prepare(tensors);
679
680 // Run matrix multiply kernel
681 switch(_gemm_kernel_type)
682 {
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100683 case CLGEMMKernelType::NATIVE:
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100684 {
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100685 CLScheduler::get().enqueue_op(*_mm_native_kernel, tensors, true);
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100686 break;
687 }
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100688 case CLGEMMKernelType::RESHAPED:
689 {
690 // Run interleave kernel
691 ITensorPack reshape_lhs_pack{ { ACL_SRC, lhs }, { ACL_DST, lhs_reshaped.get() } };
692 CLScheduler::get().enqueue_op(*_reshape_lhs_kernel, reshape_lhs_pack, false);
693
694 if(!_reshape_b_only_on_first_run)
695 {
696 // Run transpose kernel
697 ITensorPack reshape_rhs_pack{ { ACL_SRC, rhs }, { ACL_DST, rhs_reshaped.get() } };
698 CLScheduler::get().enqueue_op(*_reshape_rhs_kernel, reshape_rhs_pack, false);
699 }
SiCongLi579ca842021-10-18 09:38:33 +0100700 // Copy original tensor pack and overwrite lhs and rhs with reshaped counterparts
701 ITensorPack gemm_reshaped_pack(tensors);
702 gemm_reshaped_pack.add_const_tensor(ACL_SRC_0, lhs_reshaped.get());
703 gemm_reshaped_pack.add_const_tensor(ACL_SRC_1, rhs_reshaped.get());
Manuel Bottinid87aded2021-07-16 10:23:31 +0100704
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100705 if(_gemm_kernel_type == CLGEMMKernelType::RESHAPED)
706 {
707 CLScheduler::get().enqueue_op(*_mm_reshaped_kernel, gemm_reshaped_pack, true);
708 }
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100709 break;
710 }
711 case CLGEMMKernelType::RESHAPED_ONLY_RHS:
712 {
713 if(!_reshape_b_only_on_first_run)
714 {
715 // Run transpose kernel
716 ITensorPack reshape_rhs_pack{ { ACL_SRC, rhs }, { ACL_DST, rhs_reshaped.get() } };
717 CLScheduler::get().enqueue_op(*_reshape_rhs_kernel, reshape_rhs_pack, false);
718 }
719 // In case of RESHAPED_ONLY_RHS, we need to check the padding requirement
720 // Check if the lhs or dst tensors have padding
721 const unsigned int cross_plane_pad_lhs = lhs->info()->padding().top + lhs->info()->padding().bottom;
722 const unsigned int cross_plane_pad_dst = dst->info()->padding().top + dst->info()->padding().bottom;
723 bool has_pad_y = (cross_plane_pad_lhs != 0) || (cross_plane_pad_dst != 0);
724
SiCongLiafa19722021-10-24 19:12:33 +0100725 // Copy original tensor pack and overwrite rhs with reshaped counterpart
726 ITensorPack gemm_reshaped_onlyrhs_pack(tensors);
727 gemm_reshaped_onlyrhs_pack.add_const_tensor(ACL_SRC_1, rhs_reshaped.get());
728
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100729 if(has_pad_y)
730 {
ramelg019cca5922021-11-11 10:05:00 +0000731 ARM_COMPUTE_ERROR_ON(has_pad_y);
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100732 }
733 else
734 {
735 CLScheduler::get().enqueue_op(*_mm_reshaped_only_rhs_kernel, gemm_reshaped_onlyrhs_pack, true);
736 }
737 break;
738 }
Gunes Bayir4bfc70e2021-12-10 16:17:56 +0000739 case CLGEMMKernelType::RESHAPED_ONLY_RHS_MMUL:
740 {
741 if(!_reshape_b_only_on_first_run)
742 {
743 // Run transpose kernel
744 ITensorPack reshape_rhs_pack{ { ACL_SRC, rhs }, { ACL_DST, rhs_reshaped.get() } };
745 CLScheduler::get().enqueue_op(*_reshape_rhs_kernel, reshape_rhs_pack, false);
746 }
747 // In case of RESHAPED_ONLY_RHS, we need to check the padding requirement
748 // Check if the lhs or dst tensors have padding
749 const unsigned int cross_plane_pad_lhs = lhs->info()->padding().top + lhs->info()->padding().bottom;
750 const unsigned int cross_plane_pad_dst = dst->info()->padding().top + dst->info()->padding().bottom;
751 bool has_pad_y = (cross_plane_pad_lhs != 0) || (cross_plane_pad_dst != 0);
752
753 // Copy original tensor pack and overwrite rhs with reshaped counterpart
754 ITensorPack gemm_reshaped_onlyrhs_pack(tensors);
755 gemm_reshaped_onlyrhs_pack.add_const_tensor(ACL_SRC_1, rhs_reshaped.get());
756
757 if(has_pad_y)
758 {
759 ARM_COMPUTE_ERROR_ON(has_pad_y);
760 }
761 else
762 {
763 CLScheduler::get().enqueue_op(*_mm_reshaped_only_rhs_mmul_kernel, gemm_reshaped_onlyrhs_pack, true);
764 }
765 break;
766 }
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100767 default:
768 {
769 ARM_COMPUTE_ERROR("GEMMType not supported");
770 }
771 }
772}
773
774void ClGemm::prepare(ITensorPack &constants)
775{
Manuel Bottinid87aded2021-07-16 10:23:31 +0100776 if(!_is_prepared)
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100777 {
Manuel Bottinid87aded2021-07-16 10:23:31 +0100778 const ITensor *src1 = constants.get_const_tensor(ACL_SRC_1);
779 ICLTensor *rhs_aux = utils::cast::polymorphic_downcast<ICLTensor *>(constants.get_tensor(offset_int_vec(RhsReshape)));
Georgios Pinitas2b147ee2021-07-08 18:14:45 +0100780
Manuel Bottinid87aded2021-07-16 10:23:31 +0100781 // If memory for RHS is persistent and src1 is provided re-transform else assume that RHS is transformed
782 if((_aux_mem[AuxTensorIdx::RhsReshape].lifetime == MemoryLifetime::Persistent) && (src1 != nullptr && rhs_aux != nullptr) && rhs_aux)
783 {
784 ARM_COMPUTE_LOG_INFO_WITH_FUNCNAME_ACL("Transforming RHS Matrix!");
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100785
Manuel Bottinid87aded2021-07-16 10:23:31 +0100786 CLAuxTensorHandler rhs_reshaped(_tmp_b, *rhs_aux);
787 ARM_COMPUTE_ERROR_ON(rhs_reshaped.get()->cl_buffer().get() == nullptr);
788
789 ITensorPack reshape_rhs_pack{ { ACL_SRC, src1 }, { ACL_DST, rhs_reshaped.get() } };
790 CLScheduler::get().enqueue_op(*_reshape_rhs_kernel, reshape_rhs_pack, true);
791 }
792 _is_prepared = true;
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100793 }
794}
795
796experimental::MemoryRequirements ClGemm::workspace() const
797{
798 return _aux_mem;
799}
800} // namespace opencl
801} // namespace arm_compute