blob: 88f6b79b56b59970632931c90b5f9630a9f88a04 [file] [log] [blame]
Georgios Pinitas856f66e2021-04-22 21:13:21 +01001/*
Gian Marco Iodice10e88a72021-11-29 12:49:19 +00002 * Copyright (c) 2017-2022 Arm Limited.
Georgios Pinitas856f66e2021-04-22 21:13:21 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Georgios Pinitas7891a732021-08-20 21:39:25 +010024#include "src/gpu/cl/operators/ClGemm.h"
Georgios Pinitas856f66e2021-04-22 21:13:21 +010025
26#include "arm_compute/core/CL/CLKernelLibrary.h"
27#include "arm_compute/core/CL/ICLTensor.h"
28#include "arm_compute/core/Error.h"
29#include "arm_compute/core/GPUTarget.h"
30#include "arm_compute/core/Helpers.h"
31#include "arm_compute/core/KernelDescriptors.h"
32#include "arm_compute/core/Log.h"
33#include "arm_compute/core/TensorInfo.h"
34#include "arm_compute/core/Types.h"
35#include "arm_compute/core/Utils.h"
36#include "arm_compute/core/Validate.h"
37#include "arm_compute/core/utils/misc/ShapeCalculator.h"
38#include "arm_compute/runtime/CL/CLScheduler.h"
39#include "arm_compute/runtime/ITensorAllocator.h"
Georgios Pinitas2b147ee2021-07-08 18:14:45 +010040
SiCongLi579ca842021-10-18 09:38:33 +010041#include "arm_compute/core/experimental/IPostOp.h"
Georgios Pinitas856f66e2021-04-22 21:13:21 +010042#include "src/core/helpers/AutoConfiguration.h"
43#include "src/core/helpers/MemoryHelpers.h"
44#include "src/core/utils/helpers/float_ops.h"
Georgios Pinitas7891a732021-08-20 21:39:25 +010045#include "src/gpu/cl/IClKernel.h"
46#include "src/gpu/cl/utils/ClAuxTensorHandler.h"
Georgios Pinitas856f66e2021-04-22 21:13:21 +010047#include "src/runtime/CL/gemm/CLGEMMKernelSelection.h"
48#include "src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h"
Georgios Pinitas856f66e2021-04-22 21:13:21 +010049
ramelg012e53f172021-09-22 10:48:25 +010050#include "src/common/utils/Log.h"
Georgios Pinitas856f66e2021-04-22 21:13:21 +010051#include "support/Cast.h"
52#include "utils/TypePrinter.h"
53
54namespace arm_compute
55{
56namespace opencl
57{
58using namespace arm_compute::misc::shape_calculator;
59using namespace arm_compute::cl_gemm;
60using namespace arm_compute::experimental;
61using namespace arm_compute::utils::cast;
62using namespace arm_compute::opencl::kernels;
63
64namespace
65{
66inline bool validate_gemm_kernel(CLGEMMKernelType kernel_type)
67{
SiCongLi579ca842021-10-18 09:38:33 +010068 return kernel_type == CLGEMMKernelType::NATIVE ? false : true;
Georgios Pinitas856f66e2021-04-22 21:13:21 +010069}
70//Automatically select between mlgo (prioritized) and default heuristics for gemm kernel type
Giorgio Arena4403ed32021-05-17 13:03:50 +010071inline CLGEMMKernelType auto_select_gemm_kernel(auto_heuristics::CommonQuery query, bool reshape_b_only_on_first_run, bool constant_weights)
Georgios Pinitas856f66e2021-04-22 21:13:21 +010072{
Giorgio Arena4403ed32021-05-17 13:03:50 +010073 if(!constant_weights)
74 {
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +010075 return CLGEMMKernelType::NATIVE;
Giorgio Arena4403ed32021-05-17 13:03:50 +010076 }
77
Georgios Pinitas856f66e2021-04-22 21:13:21 +010078 auto gemm_kernel = auto_heuristics::select_mlgo_gemm_kernel(query, reshape_b_only_on_first_run);
79 if(bool(gemm_kernel))
80 {
81 if(validate_gemm_kernel(gemm_kernel.gemm_type))
82 {
83 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use gemm kernel from mlgo heuristics: %s.", to_string(gemm_kernel.gemm_type).c_str());
84 return gemm_kernel.gemm_type;
85 }
86 }
87 gemm_kernel = auto_heuristics::select_default_gemm_kernel(query, reshape_b_only_on_first_run);
88 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use gemm kernel from default heuristics: %s.", to_string(gemm_kernel.gemm_type).c_str());
89 return gemm_kernel.gemm_type;
90}
91// Validate lhs_info and rhs_info for reshaped only rhs kernel
92inline bool validate_lhs_rhs_info_reshaped_only_rhs(const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c,
93 const ITensorInfo *output, GEMMKernelInfo gemm_kernel_info)
94{
95 // Validate GEMMLHSMatrixInfo and GEMMRHSMatrixInfo for reshaped only rhs kernel
96 TensorInfo tmp_b_info{};
97 // Validate reshape RHS kernel
98 auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
99 if(!bool(ClGemmReshapeRhsMatrixKernel::validate(b, &tmp_b_info, rhs_info)))
100 {
101 return false;
102 }
103 // Validate mm kernel
104 gemm_kernel_info.lhs_info = lhs_info;
105 gemm_kernel_info.rhs_info = rhs_info;
106 gemm_kernel_info.has_pad_y = false;
107 if(!bool(ClGemmMatrixMultiplyReshapedOnlyRhsKernel::validate(a, &tmp_b_info, c, output, 1.f, 0.f, lhs_info, rhs_info, gemm_kernel_info)))
108 {
109 return false;
110 }
111 gemm_kernel_info.has_pad_y = true;
112 if(!bool(ClGemmMatrixMultiplyReshapedOnlyRhsKernel::validate(a, &tmp_b_info, c, output, 1.f, 0.f, lhs_info, rhs_info, gemm_kernel_info)))
113 {
114 return false;
115 }
116 return true;
117}
118
119//Automatically select between mlgo (prioritized) and default heuristics for reshaped only rhs kernel configs
120inline std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> auto_select_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery query, GEMMKernelInfo kernel_info, const ITensorInfo *a,
121 const ITensorInfo *b,
122 const ITensorInfo *c, const ITensorInfo *output)
123{
124 auto config = auto_heuristics::select_mlgo_gemm_config_reshaped_only_rhs(query);
125 if(config)
126 {
127 if(validate_lhs_rhs_info_reshaped_only_rhs(config.lhs_info, config.rhs_info, a, b, c, output, kernel_info))
128 {
129 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use reshaped_only_rhs config from mlgo heuristics: LHS info: %s ; RHS info: %s ", to_string(config.lhs_info).c_str(), to_string(config.rhs_info).c_str());
130 return { config.lhs_info, config.rhs_info };
131 }
132 }
133 config = auto_heuristics::select_default_gemm_config_reshaped_only_rhs(query);
134 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use reshaped_only_rhs config from default heuristics: LHS info: %s ; RHS info: %s ", to_string(config.lhs_info).c_str(), to_string(config.rhs_info).c_str());
135 return { config.lhs_info, config.rhs_info };
136}
137
138// Validate lhs_info and rhs_info for reshaped kernel
139inline bool validate_lhs_rhs_info_reshaped(const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c,
140 const ITensorInfo *output, GEMMKernelInfo gemm_kernel_info, bool reinterpret_input_as_3d)
141{
142 // Validate GEMMLHSMatrixInfo and GEMMRHSMatrixInfo for reshaped kernel
143 TensorInfo tmp_a_info{};
144 TensorInfo tmp_b_info{};
145
146 // Validate reshape LHS kernel
147 auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_lhs_reshaped_shape(*a, lhs_info, reinterpret_input_as_3d)));
148 if(!bool(ClGemmReshapeLhsMatrixKernel::validate(a, &tmp_a_info, lhs_info, reinterpret_input_as_3d)))
149 {
150 return false;
151 }
152
153 // Validate reshape RHS kernel
154 auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
155 if(!bool(ClGemmReshapeRhsMatrixKernel::validate(b, &tmp_b_info, rhs_info)))
156 {
157 return false;
158 }
159 // Validate mm kernel
160 gemm_kernel_info.lhs_info = lhs_info;
161 gemm_kernel_info.rhs_info = rhs_info;
162 if(!bool(ClGemmMatrixMultiplyReshapedKernel::validate(&tmp_a_info, &tmp_b_info, c, output, 1.f, 0.f, lhs_info, rhs_info, gemm_kernel_info)))
163 {
164 return false;
165 }
166 return true;
167}
168
169//Automatically select between mlgo (prioritized) and default heuristics for reshaped kernel configs
170inline std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> auto_select_gemm_config_reshaped(auto_heuristics::CommonQuery query, GEMMKernelInfo kernel_info, const ITensorInfo *a, const ITensorInfo *b,
171 const ITensorInfo *c, const ITensorInfo *output, bool reinterpret_input_as_3d)
172{
173 auto config = auto_heuristics::select_mlgo_gemm_config_reshaped(query);
174 if(config)
175 {
176 if(validate_lhs_rhs_info_reshaped(config.lhs_info, config.rhs_info, a, b, c, output, kernel_info, reinterpret_input_as_3d))
177 {
178 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use reshaped config from mlgo heuristics: LHS info: %s ; RHS info: %s ", to_string(config.lhs_info).c_str(), to_string(config.rhs_info).c_str());
179 return { config.lhs_info, config.rhs_info };
180 }
181 }
182 config = auto_heuristics::select_default_gemm_config_reshaped(query);
183 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use reshaped config from default heuristics: LHS info: %s ; RHS info: %s ", to_string(config.lhs_info).c_str(), to_string(config.rhs_info).c_str());
184 return { config.lhs_info, config.rhs_info };
185}
186} // namespace
187
188ClGemm::ClGemm()
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100189 : _reshape_lhs_kernel(std::make_unique<ClGemmReshapeLhsMatrixKernel>()),
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100190 _reshape_rhs_kernel(std::make_unique<ClGemmReshapeRhsMatrixKernel>()),
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100191 _mm_native_kernel(std::make_unique<ClGemmMatrixMultiplyNativeKernel>()),
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100192 _mm_reshaped_kernel(std::make_unique<ClGemmMatrixMultiplyReshapedKernel>()),
193 _mm_reshaped_only_rhs_kernel(std::make_unique<ClGemmMatrixMultiplyReshapedOnlyRhsKernel>()),
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100194 _tmp_a(),
195 _tmp_b(),
196 _reshape_b_only_on_first_run(false),
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100197 _gemm_kernel_type(CLGEMMKernelType::NATIVE),
Manuel Bottinid87aded2021-07-16 10:23:31 +0100198 _is_prepared(false),
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100199 _aux_mem(AuxTensorIdx::Count)
200{
201}
202
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100203void ClGemm::configure_native(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta,
204 const GEMMInfo &gemm_info)
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100205{
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100206 DataType data_type = a->data_type();
207 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
208 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
209 const unsigned int n = b->dimension(0);
210 const unsigned int k = a->dimension(0);
211 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
212 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
213 const GPUTarget gpu_target = CLScheduler::get().target();
214 bool broadcast_bias = gemm_info.broadcast_bias();
215
216 GEMMKernelInfo kernel_info;
217 kernel_info.m = m;
218 kernel_info.n = n;
219 kernel_info.k = k;
220 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
221 kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
222 kernel_info.broadcast_bias = broadcast_bias;
223 kernel_info.activation_info = gemm_info.activation_info();
SiCongLiafa19722021-10-24 19:12:33 +0100224 kernel_info.post_ops = gemm_info.post_ops();
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100225
226 // Set the target for the kernels
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100227 _mm_native_kernel->set_target(gpu_target);
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100228
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100229 auto config = auto_heuristics::select_mlgo_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size });
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100230
231 // Configure and tune matrix multiply kernel
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100232 _mm_native_kernel->configure(compile_context, a, b, c, output, alpha, beta, config.lhs_info, config.rhs_info, kernel_info);
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100233}
234
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100235void ClGemm::configure_reshaped(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta,
236 const GEMMInfo &gemm_info)
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100237{
238 DataType data_type = a->data_type();
239 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
240 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
241 const unsigned int n = b->dimension(0);
242 const unsigned int k = a->dimension(0);
243 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
244 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
245 const GPUTarget gpu_target = CLScheduler::get().target();
246 bool broadcast_bias = gemm_info.broadcast_bias();
247
248 GEMMKernelInfo kernel_info;
249 kernel_info.m = m;
250 kernel_info.n = n;
251 kernel_info.k = k;
252 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
253 kernel_info.reinterpret_input_as_3d = false;
254 kernel_info.broadcast_bias = broadcast_bias;
255 kernel_info.activation_info = gemm_info.activation_info();
SiCongLi579ca842021-10-18 09:38:33 +0100256 kernel_info.post_ops = gemm_info.post_ops();
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100257
258 // Set the target for the kernels
259 _reshape_lhs_kernel->set_target(gpu_target);
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100260 _mm_reshaped_kernel->set_target(gpu_target);
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100261
262 GEMMLHSMatrixInfo lhs_info{};
263 GEMMRHSMatrixInfo rhs_info{};
264
265 // Pick up the GEMM configuration
266 std::tie(lhs_info, rhs_info) = auto_select_gemm_config_reshaped(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size }, kernel_info, a, b,
267 c, output, gemm_info.reinterpret_input_as_3d());
268
269 _reshape_lhs_kernel->configure(compile_context, a, &_tmp_a, lhs_info, gemm_info.reinterpret_input_as_3d());
270 _reshape_rhs_kernel->configure(compile_context, b, &_tmp_b, rhs_info);
271
272 // Configure and tune matrix multiply kernel
273 _mm_reshaped_kernel->configure(compile_context, &_tmp_a, &_tmp_b, c, output, alpha, beta, lhs_info, rhs_info, kernel_info);
274
275 // Request memory for LHS and RHS reshape matrix
276 _aux_mem[LhsReshape] = MemoryInfo(offset_int_vec(LhsReshape), MemoryLifetime::Temporary, _tmp_a.total_size());
277 _aux_mem[RhsReshape] = MemoryInfo(offset_int_vec(RhsReshape), _reshape_b_only_on_first_run ? MemoryLifetime::Persistent : MemoryLifetime::Temporary, _tmp_b.total_size());
278}
279
280void ClGemm::configure_reshaped_only_rhs(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta,
281 const GEMMInfo &gemm_info)
282{
283 DataType data_type = a->data_type();
284 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
285 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
286 const unsigned int n = b->dimension(0);
287 const unsigned int k = a->dimension(0);
288 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
289 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
290 const GPUTarget gpu_target = CLScheduler::get().target();
291 bool broadcast_bias = gemm_info.broadcast_bias();
292
293 GEMMKernelInfo kernel_info;
294 kernel_info.m = m;
295 kernel_info.n = n;
296 kernel_info.k = k;
297 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
298 kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
299 kernel_info.broadcast_bias = broadcast_bias;
300 kernel_info.activation_info = gemm_info.activation_info();
SiCongLiafa19722021-10-24 19:12:33 +0100301 kernel_info.post_ops = gemm_info.post_ops();
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100302
303 // Set the target for the kernels
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100304 _mm_reshaped_only_rhs_kernel->set_target(gpu_target);
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100305
306 GEMMLHSMatrixInfo lhs_info{};
307 GEMMRHSMatrixInfo rhs_info{};
308
309 // Pick up the GEMM configuration
310 std::tie(lhs_info, rhs_info) = auto_select_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size }, kernel_info, a, b, c, output);
311
312 // Transpose matrix
313 _reshape_rhs_kernel->configure(compile_context, b, &_tmp_b, rhs_info);
314
315 // Configure two variants of CLGEMMMatrixMultiplyReshapedOnlyRHSKernel (has_pad_y = false/true)
316 // During the prepare stage we check the padding requirement for the lhs and dst tensors. If they do not have
317 // pad y, we dispatch CLGEMMMatrixMultiplyReshapedOnlyRHSKernel with has_pad_y = false
318
319 // Configure matrix multiply kernel with no y padding support
320 kernel_info.has_pad_y = false;
321 _mm_reshaped_only_rhs_kernel->configure(compile_context, a, &_tmp_b, c, output, alpha, beta, lhs_info, rhs_info, kernel_info);
322
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100323 // Request memory for RHS reshape matrix
324 _aux_mem[RhsReshape] = MemoryInfo(offset_int_vec(RhsReshape), _reshape_b_only_on_first_run ? MemoryLifetime::Persistent : MemoryLifetime::Temporary, _tmp_b.total_size());
325}
326
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100327Status ClGemm::validate_native(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100328{
329 ARM_COMPUTE_UNUSED(alpha);
330 ARM_COMPUTE_UNUSED(output);
331
332 // Get the GPU target
333 const GPUTarget gpu_target = CLScheduler::get().target();
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100334 DataType data_type = a->data_type();
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100335 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
336 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
337 const unsigned int n = b->dimension(0);
338 const unsigned int k = a->dimension(0);
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100339 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100340 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100341 const bool broadcast_bias = gemm_info.broadcast_bias();
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100342
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100343 GEMMKernelInfo kernel_info;
344 kernel_info.m = m;
345 kernel_info.n = n;
346 kernel_info.k = k;
347 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
348 kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
349 kernel_info.broadcast_bias = broadcast_bias;
350 kernel_info.activation_info = gemm_info.activation_info();
SiCongLiafa19722021-10-24 19:12:33 +0100351 kernel_info.post_ops = gemm_info.post_ops();
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100352
353 auto config = auto_heuristics::select_mlgo_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size });
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100354
355 // Validate matrix multiply
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100356 ARM_COMPUTE_RETURN_ON_ERROR(ClGemmMatrixMultiplyNativeKernel::validate(a, b, c, output, alpha, beta, config.lhs_info, config.rhs_info, kernel_info));
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100357
358 return Status{};
359}
360
361Status ClGemm::validate_reshaped(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
362{
363 ARM_COMPUTE_UNUSED(alpha);
364 ARM_COMPUTE_UNUSED(output);
365
366 TensorInfo tmp_a_info{};
367 TensorInfo tmp_b_info{};
368
369 // Get the GPU target
370 const GPUTarget gpu_target = CLScheduler::get().target();
371 DataType data_type = a->data_type();
372 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
373 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
374 const unsigned int n = b->dimension(0);
375 const unsigned int k = a->dimension(0);
376 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
377 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
378 const bool broadcast_bias = gemm_info.broadcast_bias();
379
380 GEMMKernelInfo kernel_info;
381 kernel_info.m = m;
382 kernel_info.n = n;
383 kernel_info.k = k;
384 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
385 kernel_info.reinterpret_input_as_3d = false;
386 kernel_info.broadcast_bias = broadcast_bias;
387 kernel_info.activation_info = gemm_info.activation_info();
SiCongLi579ca842021-10-18 09:38:33 +0100388 kernel_info.post_ops = gemm_info.post_ops();
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100389
390 GEMMLHSMatrixInfo lhs_info;
391 GEMMRHSMatrixInfo rhs_info;
392
393 // Pick up the GEMM configuration
394 // NOTE: No need to validate mlgo configurations as they automatically fall back to default heuristics if validation fails
395 const auto gemm_config = select_default_gemm_config_reshaped(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size });
396 lhs_info = gemm_config.lhs_info;
397 rhs_info = gemm_config.rhs_info;
398
399 auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_lhs_reshaped_shape(*a, lhs_info, gemm_info.reinterpret_input_as_3d())));
400 ARM_COMPUTE_RETURN_ON_ERROR(ClGemmReshapeLhsMatrixKernel::validate(a, &tmp_a_info, lhs_info, gemm_info.reinterpret_input_as_3d()));
401
402 auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
403 ARM_COMPUTE_RETURN_ON_ERROR(ClGemmReshapeRhsMatrixKernel::validate(b, &tmp_b_info, rhs_info));
404
405 // Validate matrix multiply
406 ARM_COMPUTE_RETURN_ON_ERROR(ClGemmMatrixMultiplyReshapedKernel::validate(&tmp_a_info, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info));
407
408 return Status{};
409}
410
411Status ClGemm::validate_reshaped_only_rhs(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
412{
413 ARM_COMPUTE_UNUSED(alpha);
414 ARM_COMPUTE_UNUSED(output);
415
416 TensorInfo tmp_b_info{};
417
418 // Get the GPU target
419 const GPUTarget gpu_target = CLScheduler::get().target();
420 const DataType data_type = a->data_type();
421 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
422 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
423 const unsigned int n = b->dimension(0);
424 const unsigned int k = a->dimension(0);
425 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
426 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
427 const bool broadcast_bias = gemm_info.broadcast_bias();
428
429 GEMMKernelInfo kernel_info;
430 kernel_info.m = m;
431 kernel_info.n = n;
432 kernel_info.k = k;
433 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
434 kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
435 kernel_info.broadcast_bias = broadcast_bias;
436 kernel_info.activation_info = gemm_info.activation_info();
SiCongLiafa19722021-10-24 19:12:33 +0100437 kernel_info.post_ops = gemm_info.post_ops();
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100438
439 GEMMLHSMatrixInfo lhs_info;
440 GEMMRHSMatrixInfo rhs_info;
441
442 // Pick up the GEMM configuration
443 // NOTE: No need to validate mlgo configurations as they automatically fall back to default heuristics if validation fails
444 const auto gemm_config = select_default_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size });
445 lhs_info = gemm_config.lhs_info;
446 rhs_info = gemm_config.rhs_info;
447
448 auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
449 ARM_COMPUTE_RETURN_ON_ERROR(ClGemmReshapeRhsMatrixKernel::validate(b, &tmp_b_info, rhs_info));
450
451 // Validate matrix multiply
452 kernel_info.has_pad_y = false;
453 ARM_COMPUTE_RETURN_ON_ERROR(ClGemmMatrixMultiplyReshapedOnlyRhsKernel::validate(a, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info));
454
Ramy Elgammal451c3092022-02-01 23:01:27 +0000455 kernel_info.has_pad_y = true;
456 ARM_COMPUTE_RETURN_ON_ERROR(ClGemmMatrixMultiplyReshapedOnlyRhsKernel::validate(a, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info));
457
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100458 return Status{};
459}
460
461void ClGemm::configure(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
462{
463 ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, output);
464
465 // Perform validation step
466 ARM_COMPUTE_ERROR_THROW_ON(validate(a, b, c, output, alpha, beta, gemm_info));
ramelg012e53f172021-09-22 10:48:25 +0100467 ARM_COMPUTE_LOG_PARAMS(a, b, c, output, alpha, beta, gemm_info);
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100468
469 // Check if we need to reshape the matrix B only on the first run
470 _reshape_b_only_on_first_run = gemm_info.reshape_b_only_on_first_run();
Georgios Pinitasf5d51f32021-08-17 16:09:10 +0100471 _is_prepared = gemm_info.retain_internal_weights();
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100472
473 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
474 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
475 const unsigned int n = b->dimension(0);
476 const unsigned int k = a->dimension(0);
477 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
478
479 // Select GEMMType
Giorgio Arena4403ed32021-05-17 13:03:50 +0100480 _gemm_kernel_type = auto_select_gemm_kernel(auto_heuristics::CommonQuery{ CLScheduler::get().target(), a->data_type(), m, n, k, batch_size }, _reshape_b_only_on_first_run,
Giorgio Arena63e0beb2021-09-24 14:04:27 +0100481 b->are_values_constant());
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100482
483 const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr);
484
485 ITensorInfo *c_to_use = fuse_add_c ? c : nullptr;
486
487 switch(_gemm_kernel_type)
488 {
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100489 case CLGEMMKernelType::NATIVE:
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100490 {
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100491 configure_native(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100492 break;
493 }
494 case CLGEMMKernelType::RESHAPED:
495 {
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100496 configure_reshaped(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100497 break;
498 }
499 case CLGEMMKernelType::RESHAPED_ONLY_RHS:
500 {
501 configure_reshaped_only_rhs(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
502 break;
503 }
504 default:
505 {
506 ARM_COMPUTE_ERROR("GEMMType not supported");
507 }
508 }
509}
510
511Status ClGemm::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
512{
513 // Get the GPU target
514 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
515 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
516 const unsigned int n = b->dimension(0);
517 const unsigned int k = a->dimension(0);
518 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
519
520 // Select GEMMType
521 CLGEMMKernelType gemm_kernel_type = auto_select_gemm_kernel(auto_heuristics::CommonQuery
522 {
523 CLScheduler::get().target(), a->data_type(), m, n, k, batch_size,
524 },
Giorgio Arena63e0beb2021-09-24 14:04:27 +0100525 gemm_info.reshape_b_only_on_first_run(), b->are_values_constant());
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100526
527 const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr);
528
529 const ITensorInfo *c_to_use = fuse_add_c ? c : nullptr;
530
531 switch(gemm_kernel_type)
532 {
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100533 case CLGEMMKernelType::NATIVE:
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100534 {
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100535 ARM_COMPUTE_RETURN_ON_ERROR(validate_native(a, b, c_to_use, output, alpha, beta, gemm_info));
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100536 break;
537 }
538 case CLGEMMKernelType::RESHAPED:
539 {
540 ARM_COMPUTE_RETURN_ON_ERROR(validate_reshaped(a, b, c_to_use, output, alpha, beta, gemm_info));
541 break;
542 }
543 case CLGEMMKernelType::RESHAPED_ONLY_RHS:
544 {
545 ARM_COMPUTE_RETURN_ON_ERROR(validate_reshaped_only_rhs(a, b, c_to_use, output, alpha, beta, gemm_info));
546 break;
547 }
548 default:
549 {
550 ARM_COMPUTE_RETURN_ERROR_MSG("GEMMType not supported");
551 }
552 }
553
554 return Status{};
555}
556
557void ClGemm::run(ITensorPack &tensors)
558{
SiCongLiafa19722021-10-24 19:12:33 +0100559 const ITensor *lhs = tensors.get_const_tensor(ACL_SRC_0);
560 const ITensor *rhs = tensors.get_const_tensor(ACL_SRC_1);
561 ITensor *dst = tensors.get_tensor(ACL_DST);
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100562
563 ARM_COMPUTE_ERROR_ON_NULLPTR(lhs, dst);
564
565 CLAuxTensorHandler lhs_reshaped(offset_int_vec(LhsReshape), _tmp_a, tensors, true);
566 CLAuxTensorHandler rhs_reshaped(offset_int_vec(RhsReshape), _tmp_b, tensors, true);
567
568 // Prepare the consts if needed
569 prepare(tensors);
570
571 // Run matrix multiply kernel
572 switch(_gemm_kernel_type)
573 {
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100574 case CLGEMMKernelType::NATIVE:
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100575 {
Gian Marco Iodicec9cecc02021-10-15 10:23:24 +0100576 CLScheduler::get().enqueue_op(*_mm_native_kernel, tensors, true);
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100577 break;
578 }
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100579 case CLGEMMKernelType::RESHAPED:
580 {
581 // Run interleave kernel
582 ITensorPack reshape_lhs_pack{ { ACL_SRC, lhs }, { ACL_DST, lhs_reshaped.get() } };
583 CLScheduler::get().enqueue_op(*_reshape_lhs_kernel, reshape_lhs_pack, false);
584
585 if(!_reshape_b_only_on_first_run)
586 {
587 // Run transpose kernel
588 ITensorPack reshape_rhs_pack{ { ACL_SRC, rhs }, { ACL_DST, rhs_reshaped.get() } };
589 CLScheduler::get().enqueue_op(*_reshape_rhs_kernel, reshape_rhs_pack, false);
590 }
SiCongLi579ca842021-10-18 09:38:33 +0100591 // Copy original tensor pack and overwrite lhs and rhs with reshaped counterparts
592 ITensorPack gemm_reshaped_pack(tensors);
593 gemm_reshaped_pack.add_const_tensor(ACL_SRC_0, lhs_reshaped.get());
594 gemm_reshaped_pack.add_const_tensor(ACL_SRC_1, rhs_reshaped.get());
Manuel Bottinid87aded2021-07-16 10:23:31 +0100595
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100596 if(_gemm_kernel_type == CLGEMMKernelType::RESHAPED)
597 {
598 CLScheduler::get().enqueue_op(*_mm_reshaped_kernel, gemm_reshaped_pack, true);
599 }
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100600 break;
601 }
602 case CLGEMMKernelType::RESHAPED_ONLY_RHS:
603 {
604 if(!_reshape_b_only_on_first_run)
605 {
606 // Run transpose kernel
607 ITensorPack reshape_rhs_pack{ { ACL_SRC, rhs }, { ACL_DST, rhs_reshaped.get() } };
608 CLScheduler::get().enqueue_op(*_reshape_rhs_kernel, reshape_rhs_pack, false);
609 }
610 // In case of RESHAPED_ONLY_RHS, we need to check the padding requirement
611 // Check if the lhs or dst tensors have padding
612 const unsigned int cross_plane_pad_lhs = lhs->info()->padding().top + lhs->info()->padding().bottom;
613 const unsigned int cross_plane_pad_dst = dst->info()->padding().top + dst->info()->padding().bottom;
614 bool has_pad_y = (cross_plane_pad_lhs != 0) || (cross_plane_pad_dst != 0);
615
SiCongLiafa19722021-10-24 19:12:33 +0100616 // Copy original tensor pack and overwrite rhs with reshaped counterpart
617 ITensorPack gemm_reshaped_onlyrhs_pack(tensors);
618 gemm_reshaped_onlyrhs_pack.add_const_tensor(ACL_SRC_1, rhs_reshaped.get());
619
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100620 if(has_pad_y)
621 {
ramelg019cca5922021-11-11 10:05:00 +0000622 ARM_COMPUTE_ERROR_ON(has_pad_y);
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100623 }
624 else
625 {
626 CLScheduler::get().enqueue_op(*_mm_reshaped_only_rhs_kernel, gemm_reshaped_onlyrhs_pack, true);
627 }
628 break;
629 }
630 default:
631 {
632 ARM_COMPUTE_ERROR("GEMMType not supported");
633 }
634 }
635}
636
637void ClGemm::prepare(ITensorPack &constants)
638{
Manuel Bottinid87aded2021-07-16 10:23:31 +0100639 if(!_is_prepared)
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100640 {
Manuel Bottinid87aded2021-07-16 10:23:31 +0100641 const ITensor *src1 = constants.get_const_tensor(ACL_SRC_1);
642 ICLTensor *rhs_aux = utils::cast::polymorphic_downcast<ICLTensor *>(constants.get_tensor(offset_int_vec(RhsReshape)));
Georgios Pinitas2b147ee2021-07-08 18:14:45 +0100643
Manuel Bottinid87aded2021-07-16 10:23:31 +0100644 // If memory for RHS is persistent and src1 is provided re-transform else assume that RHS is transformed
645 if((_aux_mem[AuxTensorIdx::RhsReshape].lifetime == MemoryLifetime::Persistent) && (src1 != nullptr && rhs_aux != nullptr) && rhs_aux)
646 {
647 ARM_COMPUTE_LOG_INFO_WITH_FUNCNAME_ACL("Transforming RHS Matrix!");
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100648
Manuel Bottinid87aded2021-07-16 10:23:31 +0100649 CLAuxTensorHandler rhs_reshaped(_tmp_b, *rhs_aux);
650 ARM_COMPUTE_ERROR_ON(rhs_reshaped.get()->cl_buffer().get() == nullptr);
651
652 ITensorPack reshape_rhs_pack{ { ACL_SRC, src1 }, { ACL_DST, rhs_reshaped.get() } };
653 CLScheduler::get().enqueue_op(*_reshape_rhs_kernel, reshape_rhs_pack, true);
654 }
655 _is_prepared = true;
Georgios Pinitas856f66e2021-04-22 21:13:21 +0100656 }
657}
658
659experimental::MemoryRequirements ClGemm::workspace() const
660{
661 return _aux_mem;
662}
663} // namespace opencl
664} // namespace arm_compute