blob: 50ecb214e3b0971838b994d633dfc144e84e5217 [file] [log] [blame]
Georgios Pinitas856f66e2021-04-22 21:13:21 +01001/*
2 * Copyright (c) 2017-2021 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Georgios Pinitas7891a732021-08-20 21:39:25 +010024#include "src/gpu/cl/operators/ClGemm.h"
Georgios Pinitas856f66e2021-04-22 21:13:21 +010025
26#include "arm_compute/core/CL/CLKernelLibrary.h"
27#include "arm_compute/core/CL/ICLTensor.h"
28#include "arm_compute/core/Error.h"
29#include "arm_compute/core/GPUTarget.h"
30#include "arm_compute/core/Helpers.h"
31#include "arm_compute/core/KernelDescriptors.h"
32#include "arm_compute/core/Log.h"
33#include "arm_compute/core/TensorInfo.h"
34#include "arm_compute/core/Types.h"
35#include "arm_compute/core/Utils.h"
36#include "arm_compute/core/Validate.h"
37#include "arm_compute/core/utils/misc/ShapeCalculator.h"
38#include "arm_compute/runtime/CL/CLScheduler.h"
39#include "arm_compute/runtime/ITensorAllocator.h"
Georgios Pinitas2b147ee2021-07-08 18:14:45 +010040
SiCongLi579ca842021-10-18 09:38:33 +010041#include "arm_compute/core/experimental/IPostOp.h"
Georgios Pinitas856f66e2021-04-22 21:13:21 +010042#include "src/core/helpers/AutoConfiguration.h"
43#include "src/core/helpers/MemoryHelpers.h"
44#include "src/core/utils/helpers/float_ops.h"
Georgios Pinitas7891a732021-08-20 21:39:25 +010045#include "src/gpu/cl/IClKernel.h"
46#include "src/gpu/cl/utils/ClAuxTensorHandler.h"
Georgios Pinitas856f66e2021-04-22 21:13:21 +010047#include "src/runtime/CL/gemm/CLGEMMKernelSelection.h"
48#include "src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h"
Georgios Pinitas856f66e2021-04-22 21:13:21 +010049
ramelg012e53f172021-09-22 10:48:25 +010050#include "src/common/utils/Log.h"
Georgios Pinitas856f66e2021-04-22 21:13:21 +010051#include "support/Cast.h"
52#include "utils/TypePrinter.h"
53
54namespace arm_compute
55{
56namespace opencl
57{
58using namespace arm_compute::misc::shape_calculator;
59using namespace arm_compute::cl_gemm;
60using namespace arm_compute::experimental;
61using namespace arm_compute::utils::cast;
62using namespace arm_compute::opencl::kernels;
63
64namespace
65{
66inline bool validate_gemm_kernel(CLGEMMKernelType kernel_type)
67{
SiCongLi579ca842021-10-18 09:38:33 +010068 return kernel_type == CLGEMMKernelType::NATIVE ? false : true;
Georgios Pinitas856f66e2021-04-22 21:13:21 +010069}
70//Automatically select between mlgo (prioritized) and default heuristics for gemm kernel type
Giorgio Arena4403ed32021-05-17 13:03:50 +010071inline CLGEMMKernelType auto_select_gemm_kernel(auto_heuristics::CommonQuery query, bool reshape_b_only_on_first_run, bool constant_weights)
Georgios Pinitas856f66e2021-04-22 21:13:21 +010072{
Giorgio Arena4403ed32021-05-17 13:03:50 +010073 if(!constant_weights)
74 {
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +010075 return CLGEMMKernelType::NATIVE;
Giorgio Arena4403ed32021-05-17 13:03:50 +010076 }
77
Georgios Pinitas856f66e2021-04-22 21:13:21 +010078 auto gemm_kernel = auto_heuristics::select_mlgo_gemm_kernel(query, reshape_b_only_on_first_run);
79 if(bool(gemm_kernel))
80 {
81 if(validate_gemm_kernel(gemm_kernel.gemm_type))
82 {
83 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use gemm kernel from mlgo heuristics: %s.", to_string(gemm_kernel.gemm_type).c_str());
84 return gemm_kernel.gemm_type;
85 }
86 }
87 gemm_kernel = auto_heuristics::select_default_gemm_kernel(query, reshape_b_only_on_first_run);
88 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use gemm kernel from default heuristics: %s.", to_string(gemm_kernel.gemm_type).c_str());
89 return gemm_kernel.gemm_type;
90}
91// Validate lhs_info and rhs_info for reshaped only rhs kernel
92inline bool validate_lhs_rhs_info_reshaped_only_rhs(const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c,
93 const ITensorInfo *output, GEMMKernelInfo gemm_kernel_info)
94{
95 // Validate GEMMLHSMatrixInfo and GEMMRHSMatrixInfo for reshaped only rhs kernel
96 TensorInfo tmp_b_info{};
97 // Validate reshape RHS kernel
98 auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
99 if(!bool(ClGemmReshapeRhsMatrixKernel::validate(b, &tmp_b_info, rhs_info)))
100 {
101 return false;
102 }
103 // Validate mm kernel
104 gemm_kernel_info.lhs_info = lhs_info;
105 gemm_kernel_info.rhs_info = rhs_info;
106 gemm_kernel_info.has_pad_y = false;
107 if(!bool(ClGemmMatrixMultiplyReshapedOnlyRhsKernel::validate(a, &tmp_b_info, c, output, 1.f, 0.f, lhs_info, rhs_info, gemm_kernel_info)))
108 {
109 return false;
110 }
111 gemm_kernel_info.has_pad_y = true;
112 if(!bool(ClGemmMatrixMultiplyReshapedOnlyRhsKernel::validate(a, &tmp_b_info, c, output, 1.f, 0.f, lhs_info, rhs_info, gemm_kernel_info)))
113 {
114 return false;
115 }
116 return true;
117}
118
119//Automatically select between mlgo (prioritized) and default heuristics for reshaped only rhs kernel configs
120inline std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> auto_select_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery query, GEMMKernelInfo kernel_info, const ITensorInfo *a,
121 const ITensorInfo *b,
122 const ITensorInfo *c, const ITensorInfo *output)
123{
124 auto config = auto_heuristics::select_mlgo_gemm_config_reshaped_only_rhs(query);
125 if(config)
126 {
127 if(validate_lhs_rhs_info_reshaped_only_rhs(config.lhs_info, config.rhs_info, a, b, c, output, kernel_info))
128 {
129 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use reshaped_only_rhs config from mlgo heuristics: LHS info: %s ; RHS info: %s ", to_string(config.lhs_info).c_str(), to_string(config.rhs_info).c_str());
130 return { config.lhs_info, config.rhs_info };
131 }
132 }
133 config = auto_heuristics::select_default_gemm_config_reshaped_only_rhs(query);
134 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use reshaped_only_rhs config from default heuristics: LHS info: %s ; RHS info: %s ", to_string(config.lhs_info).c_str(), to_string(config.rhs_info).c_str());
135 return { config.lhs_info, config.rhs_info };
136}
137
138// Validate lhs_info and rhs_info for reshaped kernel
139inline bool validate_lhs_rhs_info_reshaped(const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c,
140 const ITensorInfo *output, GEMMKernelInfo gemm_kernel_info, bool reinterpret_input_as_3d)
141{
142 // Validate GEMMLHSMatrixInfo and GEMMRHSMatrixInfo for reshaped kernel
143 TensorInfo tmp_a_info{};
144 TensorInfo tmp_b_info{};
145
146 // Validate reshape LHS kernel
147 auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_lhs_reshaped_shape(*a, lhs_info, reinterpret_input_as_3d)));
148 if(!bool(ClGemmReshapeLhsMatrixKernel::validate(a, &tmp_a_info, lhs_info, reinterpret_input_as_3d)))
149 {
150 return false;
151 }
152
153 // Validate reshape RHS kernel
154 auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
155 if(!bool(ClGemmReshapeRhsMatrixKernel::validate(b, &tmp_b_info, rhs_info)))
156 {
157 return false;
158 }
159 // Validate mm kernel
160 gemm_kernel_info.lhs_info = lhs_info;
161 gemm_kernel_info.rhs_info = rhs_info;
162 if(!bool(ClGemmMatrixMultiplyReshapedKernel::validate(&tmp_a_info, &tmp_b_info, c, output, 1.f, 0.f, lhs_info, rhs_info, gemm_kernel_info)))
163 {
164 return false;
165 }
166 return true;
167}
168
169//Automatically select between mlgo (prioritized) and default heuristics for reshaped kernel configs
170inline std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> auto_select_gemm_config_reshaped(auto_heuristics::CommonQuery query, GEMMKernelInfo kernel_info, const ITensorInfo *a, const ITensorInfo *b,
171 const ITensorInfo *c, const ITensorInfo *output, bool reinterpret_input_as_3d)
172{
173 auto config = auto_heuristics::select_mlgo_gemm_config_reshaped(query);
174 if(config)
175 {
176 if(validate_lhs_rhs_info_reshaped(config.lhs_info, config.rhs_info, a, b, c, output, kernel_info, reinterpret_input_as_3d))
177 {
178 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use reshaped config from mlgo heuristics: LHS info: %s ; RHS info: %s ", to_string(config.lhs_info).c_str(), to_string(config.rhs_info).c_str());
179 return { config.lhs_info, config.rhs_info };
180 }
181 }
182 config = auto_heuristics::select_default_gemm_config_reshaped(query);
183 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use reshaped config from default heuristics: LHS info: %s ; RHS info: %s ", to_string(config.lhs_info).c_str(), to_string(config.rhs_info).c_str());
184 return { config.lhs_info, config.rhs_info };
185}
186} // namespace
187
188ClGemm::ClGemm()
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100189 : _reshape_lhs_kernel(std::make_unique<ClGemmReshapeLhsMatrixKernel>()),
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100190 _reshape_rhs_kernel(std::make_unique<ClGemmReshapeRhsMatrixKernel>()),
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100191 _mm_native_kernel(std::make_unique<ClGemmMatrixMultiplyNativeKernel>()),
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100192 _mm_reshaped_kernel(std::make_unique<ClGemmMatrixMultiplyReshapedKernel>()),
193 _mm_reshaped_only_rhs_kernel(std::make_unique<ClGemmMatrixMultiplyReshapedOnlyRhsKernel>()),
194 _mm_reshaped_only_rhs_fallback_kernel(std::make_unique<ClGemmMatrixMultiplyReshapedOnlyRhsKernel>()),
195 _tmp_a(),
196 _tmp_b(),
197 _reshape_b_only_on_first_run(false),
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100198 _gemm_kernel_type(CLGEMMKernelType::NATIVE),
Manuel Bottinid87aded2021-07-16 10:23:31 +0100199 _is_prepared(false),
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100200 _aux_mem(AuxTensorIdx::Count)
201{
202}
203
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100204void ClGemm::configure_native(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta,
205 const GEMMInfo &gemm_info)
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100206{
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100207 DataType data_type = a->data_type();
208 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
209 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
210 const unsigned int n = b->dimension(0);
211 const unsigned int k = a->dimension(0);
212 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
213 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
214 const GPUTarget gpu_target = CLScheduler::get().target();
215 bool broadcast_bias = gemm_info.broadcast_bias();
216
217 GEMMKernelInfo kernel_info;
218 kernel_info.m = m;
219 kernel_info.n = n;
220 kernel_info.k = k;
221 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
222 kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
223 kernel_info.broadcast_bias = broadcast_bias;
224 kernel_info.activation_info = gemm_info.activation_info();
SiCongLiafa19722021-10-24 19:12:33 +0100225 kernel_info.post_ops = gemm_info.post_ops();
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100226
227 // Set the target for the kernels
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100228 _mm_native_kernel->set_target(gpu_target);
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100229
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100230 auto config = auto_heuristics::select_mlgo_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size });
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100231
232 // Configure and tune matrix multiply kernel
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100233 _mm_native_kernel->configure(compile_context, a, b, c, output, alpha, beta, config.lhs_info, config.rhs_info, kernel_info);
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100234}
235
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100236void ClGemm::configure_reshaped(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta,
237 const GEMMInfo &gemm_info)
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100238{
239 DataType data_type = a->data_type();
240 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
241 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
242 const unsigned int n = b->dimension(0);
243 const unsigned int k = a->dimension(0);
244 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
245 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
246 const GPUTarget gpu_target = CLScheduler::get().target();
247 bool broadcast_bias = gemm_info.broadcast_bias();
248
249 GEMMKernelInfo kernel_info;
250 kernel_info.m = m;
251 kernel_info.n = n;
252 kernel_info.k = k;
253 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
254 kernel_info.reinterpret_input_as_3d = false;
255 kernel_info.broadcast_bias = broadcast_bias;
256 kernel_info.activation_info = gemm_info.activation_info();
SiCongLi579ca842021-10-18 09:38:33 +0100257 kernel_info.post_ops = gemm_info.post_ops();
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100258
259 // Set the target for the kernels
260 _reshape_lhs_kernel->set_target(gpu_target);
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100261 _mm_reshaped_kernel->set_target(gpu_target);
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100262
263 GEMMLHSMatrixInfo lhs_info{};
264 GEMMRHSMatrixInfo rhs_info{};
265
266 // Pick up the GEMM configuration
267 std::tie(lhs_info, rhs_info) = auto_select_gemm_config_reshaped(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size }, kernel_info, a, b,
268 c, output, gemm_info.reinterpret_input_as_3d());
269
270 _reshape_lhs_kernel->configure(compile_context, a, &_tmp_a, lhs_info, gemm_info.reinterpret_input_as_3d());
271 _reshape_rhs_kernel->configure(compile_context, b, &_tmp_b, rhs_info);
272
273 // Configure and tune matrix multiply kernel
274 _mm_reshaped_kernel->configure(compile_context, &_tmp_a, &_tmp_b, c, output, alpha, beta, lhs_info, rhs_info, kernel_info);
275
276 // Request memory for LHS and RHS reshape matrix
277 _aux_mem[LhsReshape] = MemoryInfo(offset_int_vec(LhsReshape), MemoryLifetime::Temporary, _tmp_a.total_size());
278 _aux_mem[RhsReshape] = MemoryInfo(offset_int_vec(RhsReshape), _reshape_b_only_on_first_run ? MemoryLifetime::Persistent : MemoryLifetime::Temporary, _tmp_b.total_size());
279}
280
281void ClGemm::configure_reshaped_only_rhs(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta,
282 const GEMMInfo &gemm_info)
283{
284 DataType data_type = a->data_type();
285 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
286 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
287 const unsigned int n = b->dimension(0);
288 const unsigned int k = a->dimension(0);
289 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
290 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
291 const GPUTarget gpu_target = CLScheduler::get().target();
292 bool broadcast_bias = gemm_info.broadcast_bias();
293
294 GEMMKernelInfo kernel_info;
295 kernel_info.m = m;
296 kernel_info.n = n;
297 kernel_info.k = k;
298 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
299 kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
300 kernel_info.broadcast_bias = broadcast_bias;
301 kernel_info.activation_info = gemm_info.activation_info();
SiCongLiafa19722021-10-24 19:12:33 +0100302 kernel_info.post_ops = gemm_info.post_ops();
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100303
304 // Set the target for the kernels
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100305 _mm_reshaped_only_rhs_kernel->set_target(gpu_target);
306 _mm_reshaped_only_rhs_fallback_kernel->set_target(gpu_target);
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100307
308 GEMMLHSMatrixInfo lhs_info{};
309 GEMMRHSMatrixInfo rhs_info{};
310
311 // Pick up the GEMM configuration
312 std::tie(lhs_info, rhs_info) = auto_select_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size }, kernel_info, a, b, c, output);
313
314 // Transpose matrix
315 _reshape_rhs_kernel->configure(compile_context, b, &_tmp_b, rhs_info);
316
317 // Configure two variants of CLGEMMMatrixMultiplyReshapedOnlyRHSKernel (has_pad_y = false/true)
318 // During the prepare stage we check the padding requirement for the lhs and dst tensors. If they do not have
319 // pad y, we dispatch CLGEMMMatrixMultiplyReshapedOnlyRHSKernel with has_pad_y = false
320
321 // Configure matrix multiply kernel with no y padding support
322 kernel_info.has_pad_y = false;
323 _mm_reshaped_only_rhs_kernel->configure(compile_context, a, &_tmp_b, c, output, alpha, beta, lhs_info, rhs_info, kernel_info);
324
325 // Configure matrix multiply kernel with y padding support
326 kernel_info.has_pad_y = true;
327 _mm_reshaped_only_rhs_fallback_kernel->configure(compile_context, a, &_tmp_b, c, output, alpha, beta, lhs_info, rhs_info, kernel_info);
328
329 // Request memory for RHS reshape matrix
330 _aux_mem[RhsReshape] = MemoryInfo(offset_int_vec(RhsReshape), _reshape_b_only_on_first_run ? MemoryLifetime::Persistent : MemoryLifetime::Temporary, _tmp_b.total_size());
331}
332
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100333Status ClGemm::validate_native(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100334{
335 ARM_COMPUTE_UNUSED(alpha);
336 ARM_COMPUTE_UNUSED(output);
337
338 // Get the GPU target
339 const GPUTarget gpu_target = CLScheduler::get().target();
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100340 DataType data_type = a->data_type();
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100341 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
342 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
343 const unsigned int n = b->dimension(0);
344 const unsigned int k = a->dimension(0);
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100345 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100346 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100347 const bool broadcast_bias = gemm_info.broadcast_bias();
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100348
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100349 GEMMKernelInfo kernel_info;
350 kernel_info.m = m;
351 kernel_info.n = n;
352 kernel_info.k = k;
353 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
354 kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
355 kernel_info.broadcast_bias = broadcast_bias;
356 kernel_info.activation_info = gemm_info.activation_info();
SiCongLiafa19722021-10-24 19:12:33 +0100357 kernel_info.post_ops = gemm_info.post_ops();
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100358
359 auto config = auto_heuristics::select_mlgo_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size });
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100360
361 // Validate matrix multiply
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100362 ARM_COMPUTE_RETURN_ON_ERROR(ClGemmMatrixMultiplyNativeKernel::validate(a, b, c, output, alpha, beta, config.lhs_info, config.rhs_info, kernel_info));
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100363
364 return Status{};
365}
366
367Status ClGemm::validate_reshaped(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
368{
369 ARM_COMPUTE_UNUSED(alpha);
370 ARM_COMPUTE_UNUSED(output);
371
372 TensorInfo tmp_a_info{};
373 TensorInfo tmp_b_info{};
374
375 // Get the GPU target
376 const GPUTarget gpu_target = CLScheduler::get().target();
377 DataType data_type = a->data_type();
378 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
379 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
380 const unsigned int n = b->dimension(0);
381 const unsigned int k = a->dimension(0);
382 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
383 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
384 const bool broadcast_bias = gemm_info.broadcast_bias();
385
386 GEMMKernelInfo kernel_info;
387 kernel_info.m = m;
388 kernel_info.n = n;
389 kernel_info.k = k;
390 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
391 kernel_info.reinterpret_input_as_3d = false;
392 kernel_info.broadcast_bias = broadcast_bias;
393 kernel_info.activation_info = gemm_info.activation_info();
SiCongLi579ca842021-10-18 09:38:33 +0100394 kernel_info.post_ops = gemm_info.post_ops();
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100395
396 GEMMLHSMatrixInfo lhs_info;
397 GEMMRHSMatrixInfo rhs_info;
398
399 // Pick up the GEMM configuration
400 // NOTE: No need to validate mlgo configurations as they automatically fall back to default heuristics if validation fails
401 const auto gemm_config = select_default_gemm_config_reshaped(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size });
402 lhs_info = gemm_config.lhs_info;
403 rhs_info = gemm_config.rhs_info;
404
405 auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_lhs_reshaped_shape(*a, lhs_info, gemm_info.reinterpret_input_as_3d())));
406 ARM_COMPUTE_RETURN_ON_ERROR(ClGemmReshapeLhsMatrixKernel::validate(a, &tmp_a_info, lhs_info, gemm_info.reinterpret_input_as_3d()));
407
408 auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
409 ARM_COMPUTE_RETURN_ON_ERROR(ClGemmReshapeRhsMatrixKernel::validate(b, &tmp_b_info, rhs_info));
410
411 // Validate matrix multiply
412 ARM_COMPUTE_RETURN_ON_ERROR(ClGemmMatrixMultiplyReshapedKernel::validate(&tmp_a_info, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info));
413
414 return Status{};
415}
416
417Status ClGemm::validate_reshaped_only_rhs(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
418{
419 ARM_COMPUTE_UNUSED(alpha);
420 ARM_COMPUTE_UNUSED(output);
421
422 TensorInfo tmp_b_info{};
423
424 // Get the GPU target
425 const GPUTarget gpu_target = CLScheduler::get().target();
426 const DataType data_type = a->data_type();
427 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
428 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
429 const unsigned int n = b->dimension(0);
430 const unsigned int k = a->dimension(0);
431 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
432 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
433 const bool broadcast_bias = gemm_info.broadcast_bias();
434
435 GEMMKernelInfo kernel_info;
436 kernel_info.m = m;
437 kernel_info.n = n;
438 kernel_info.k = k;
439 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
440 kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
441 kernel_info.broadcast_bias = broadcast_bias;
442 kernel_info.activation_info = gemm_info.activation_info();
SiCongLiafa19722021-10-24 19:12:33 +0100443 kernel_info.post_ops = gemm_info.post_ops();
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100444
445 GEMMLHSMatrixInfo lhs_info;
446 GEMMRHSMatrixInfo rhs_info;
447
448 // Pick up the GEMM configuration
449 // NOTE: No need to validate mlgo configurations as they automatically fall back to default heuristics if validation fails
450 const auto gemm_config = select_default_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size });
451 lhs_info = gemm_config.lhs_info;
452 rhs_info = gemm_config.rhs_info;
453
454 auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
455 ARM_COMPUTE_RETURN_ON_ERROR(ClGemmReshapeRhsMatrixKernel::validate(b, &tmp_b_info, rhs_info));
456
457 // Validate matrix multiply
458 kernel_info.has_pad_y = false;
459 ARM_COMPUTE_RETURN_ON_ERROR(ClGemmMatrixMultiplyReshapedOnlyRhsKernel::validate(a, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info));
460
461 kernel_info.has_pad_y = true;
462 ARM_COMPUTE_RETURN_ON_ERROR(ClGemmMatrixMultiplyReshapedOnlyRhsKernel::validate(a, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info));
463
464 return Status{};
465}
466
467void ClGemm::configure(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
468{
469 ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, output);
470
471 // Perform validation step
472 ARM_COMPUTE_ERROR_THROW_ON(validate(a, b, c, output, alpha, beta, gemm_info));
ramelg012e53f172021-09-22 10:48:25 +0100473 ARM_COMPUTE_LOG_PARAMS(a, b, c, output, alpha, beta, gemm_info);
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100474
475 // Check if we need to reshape the matrix B only on the first run
476 _reshape_b_only_on_first_run = gemm_info.reshape_b_only_on_first_run();
Georgios Pinitasf5d51f32021-08-17 16:09:10 +0100477 _is_prepared = gemm_info.retain_internal_weights();
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100478
479 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
480 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
481 const unsigned int n = b->dimension(0);
482 const unsigned int k = a->dimension(0);
483 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
484
485 // Select GEMMType
Giorgio Arena4403ed32021-05-17 13:03:50 +0100486 _gemm_kernel_type = auto_select_gemm_kernel(auto_heuristics::CommonQuery{ CLScheduler::get().target(), a->data_type(), m, n, k, batch_size }, _reshape_b_only_on_first_run,
Giorgio Arena63e0beb2021-09-24 14:04:27 +0100487 b->are_values_constant());
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100488
489 const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr);
490
491 ITensorInfo *c_to_use = fuse_add_c ? c : nullptr;
492
493 switch(_gemm_kernel_type)
494 {
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100495 case CLGEMMKernelType::NATIVE:
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100496 {
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100497 configure_native(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100498 break;
499 }
500 case CLGEMMKernelType::RESHAPED:
501 {
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100502 configure_reshaped(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100503 break;
504 }
505 case CLGEMMKernelType::RESHAPED_ONLY_RHS:
506 {
507 configure_reshaped_only_rhs(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
508 break;
509 }
510 default:
511 {
512 ARM_COMPUTE_ERROR("GEMMType not supported");
513 }
514 }
515}
516
517Status ClGemm::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
518{
519 // Get the GPU target
520 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
521 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
522 const unsigned int n = b->dimension(0);
523 const unsigned int k = a->dimension(0);
524 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
525
526 // Select GEMMType
527 CLGEMMKernelType gemm_kernel_type = auto_select_gemm_kernel(auto_heuristics::CommonQuery
528 {
529 CLScheduler::get().target(), a->data_type(), m, n, k, batch_size,
530 },
Giorgio Arena63e0beb2021-09-24 14:04:27 +0100531 gemm_info.reshape_b_only_on_first_run(), b->are_values_constant());
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100532
533 const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr);
534
535 const ITensorInfo *c_to_use = fuse_add_c ? c : nullptr;
536
537 switch(gemm_kernel_type)
538 {
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100539 case CLGEMMKernelType::NATIVE:
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100540 {
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100541 ARM_COMPUTE_RETURN_ON_ERROR(validate_native(a, b, c_to_use, output, alpha, beta, gemm_info));
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100542 break;
543 }
544 case CLGEMMKernelType::RESHAPED:
545 {
546 ARM_COMPUTE_RETURN_ON_ERROR(validate_reshaped(a, b, c_to_use, output, alpha, beta, gemm_info));
547 break;
548 }
549 case CLGEMMKernelType::RESHAPED_ONLY_RHS:
550 {
551 ARM_COMPUTE_RETURN_ON_ERROR(validate_reshaped_only_rhs(a, b, c_to_use, output, alpha, beta, gemm_info));
552 break;
553 }
554 default:
555 {
556 ARM_COMPUTE_RETURN_ERROR_MSG("GEMMType not supported");
557 }
558 }
559
560 return Status{};
561}
562
563void ClGemm::run(ITensorPack &tensors)
564{
SiCongLiafa19722021-10-24 19:12:33 +0100565 const ITensor *lhs = tensors.get_const_tensor(ACL_SRC_0);
566 const ITensor *rhs = tensors.get_const_tensor(ACL_SRC_1);
567 ITensor *dst = tensors.get_tensor(ACL_DST);
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100568
569 ARM_COMPUTE_ERROR_ON_NULLPTR(lhs, dst);
570
571 CLAuxTensorHandler lhs_reshaped(offset_int_vec(LhsReshape), _tmp_a, tensors, true);
572 CLAuxTensorHandler rhs_reshaped(offset_int_vec(RhsReshape), _tmp_b, tensors, true);
573
574 // Prepare the consts if needed
575 prepare(tensors);
576
577 // Run matrix multiply kernel
578 switch(_gemm_kernel_type)
579 {
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100580 case CLGEMMKernelType::NATIVE:
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100581 {
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100582 CLScheduler::get().enqueue_op(*_mm_native_kernel, tensors, true);
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100583 break;
584 }
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100585 case CLGEMMKernelType::RESHAPED:
586 {
587 // Run interleave kernel
588 ITensorPack reshape_lhs_pack{ { ACL_SRC, lhs }, { ACL_DST, lhs_reshaped.get() } };
589 CLScheduler::get().enqueue_op(*_reshape_lhs_kernel, reshape_lhs_pack, false);
590
591 if(!_reshape_b_only_on_first_run)
592 {
593 // Run transpose kernel
594 ITensorPack reshape_rhs_pack{ { ACL_SRC, rhs }, { ACL_DST, rhs_reshaped.get() } };
595 CLScheduler::get().enqueue_op(*_reshape_rhs_kernel, reshape_rhs_pack, false);
596 }
SiCongLi579ca842021-10-18 09:38:33 +0100597 // Copy original tensor pack and overwrite lhs and rhs with reshaped counterparts
598 ITensorPack gemm_reshaped_pack(tensors);
599 gemm_reshaped_pack.add_const_tensor(ACL_SRC_0, lhs_reshaped.get());
600 gemm_reshaped_pack.add_const_tensor(ACL_SRC_1, rhs_reshaped.get());
Manuel Bottinid87aded2021-07-16 10:23:31 +0100601
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100602 if(_gemm_kernel_type == CLGEMMKernelType::RESHAPED)
603 {
604 CLScheduler::get().enqueue_op(*_mm_reshaped_kernel, gemm_reshaped_pack, true);
605 }
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100606 break;
607 }
608 case CLGEMMKernelType::RESHAPED_ONLY_RHS:
609 {
610 if(!_reshape_b_only_on_first_run)
611 {
612 // Run transpose kernel
613 ITensorPack reshape_rhs_pack{ { ACL_SRC, rhs }, { ACL_DST, rhs_reshaped.get() } };
614 CLScheduler::get().enqueue_op(*_reshape_rhs_kernel, reshape_rhs_pack, false);
615 }
616 // In case of RESHAPED_ONLY_RHS, we need to check the padding requirement
617 // Check if the lhs or dst tensors have padding
618 const unsigned int cross_plane_pad_lhs = lhs->info()->padding().top + lhs->info()->padding().bottom;
619 const unsigned int cross_plane_pad_dst = dst->info()->padding().top + dst->info()->padding().bottom;
620 bool has_pad_y = (cross_plane_pad_lhs != 0) || (cross_plane_pad_dst != 0);
621
SiCongLiafa19722021-10-24 19:12:33 +0100622 // Copy original tensor pack and overwrite rhs with reshaped counterpart
623 ITensorPack gemm_reshaped_onlyrhs_pack(tensors);
624 gemm_reshaped_onlyrhs_pack.add_const_tensor(ACL_SRC_1, rhs_reshaped.get());
625
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100626 if(has_pad_y)
627 {
628 CLScheduler::get().enqueue_op(*_mm_reshaped_only_rhs_fallback_kernel, gemm_reshaped_onlyrhs_pack, true);
629 }
630 else
631 {
632 CLScheduler::get().enqueue_op(*_mm_reshaped_only_rhs_kernel, gemm_reshaped_onlyrhs_pack, true);
633 }
634 break;
635 }
636 default:
637 {
638 ARM_COMPUTE_ERROR("GEMMType not supported");
639 }
640 }
641}
642
643void ClGemm::prepare(ITensorPack &constants)
644{
Manuel Bottinid87aded2021-07-16 10:23:31 +0100645 if(!_is_prepared)
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100646 {
Manuel Bottinid87aded2021-07-16 10:23:31 +0100647 const ITensor *src1 = constants.get_const_tensor(ACL_SRC_1);
648 ICLTensor *rhs_aux = utils::cast::polymorphic_downcast<ICLTensor *>(constants.get_tensor(offset_int_vec(RhsReshape)));
Georgios Pinitas2b147ee2021-07-08 18:14:45 +0100649
Manuel Bottinid87aded2021-07-16 10:23:31 +0100650 // If memory for RHS is persistent and src1 is provided re-transform else assume that RHS is transformed
651 if((_aux_mem[AuxTensorIdx::RhsReshape].lifetime == MemoryLifetime::Persistent) && (src1 != nullptr && rhs_aux != nullptr) && rhs_aux)
652 {
653 ARM_COMPUTE_LOG_INFO_WITH_FUNCNAME_ACL("Transforming RHS Matrix!");
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100654
Manuel Bottinid87aded2021-07-16 10:23:31 +0100655 CLAuxTensorHandler rhs_reshaped(_tmp_b, *rhs_aux);
656 ARM_COMPUTE_ERROR_ON(rhs_reshaped.get()->cl_buffer().get() == nullptr);
657
658 ITensorPack reshape_rhs_pack{ { ACL_SRC, src1 }, { ACL_DST, rhs_reshaped.get() } };
659 CLScheduler::get().enqueue_op(*_reshape_rhs_kernel, reshape_rhs_pack, true);
660 }
661 _is_prepared = true;
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100662 }
663}
664
665experimental::MemoryRequirements ClGemm::workspace() const
666{
667 return _aux_mem;
668}
669} // namespace opencl
670} // namespace arm_compute