blob: cf1a82bc5a310c23432d3bd9975d58d7e2f7daa8 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
SiCong Libd8b1e22021-02-04 13:07:09 +00002 * Copyright (c) 2017-2021 Arm Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/runtime/CL/functions/CLGEMM.h"
25
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +010026#include "arm_compute/core/CL/CLKernelLibrary.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010027#include "arm_compute/core/CL/ICLTensor.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010028#include "arm_compute/core/Error.h"
Gian Marco Iodice750641d2018-05-08 12:01:57 +010029#include "arm_compute/core/GPUTarget.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010030#include "arm_compute/core/Helpers.h"
Gian Marco Iodice7026b302019-06-26 17:18:11 +010031#include "arm_compute/core/KernelDescriptors.h"
SiCong Libd8b1e22021-02-04 13:07:09 +000032#include "arm_compute/core/Log.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010033#include "arm_compute/core/TensorInfo.h"
34#include "arm_compute/core/Types.h"
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +010035#include "arm_compute/core/Utils.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010036#include "arm_compute/core/Validate.h"
Gian Marco Iodice750641d2018-05-08 12:01:57 +010037#include "arm_compute/core/utils/misc/ShapeCalculator.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010038#include "arm_compute/runtime/CL/CLScheduler.h"
39#include "arm_compute/runtime/ITensorAllocator.h"
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +010040#include "src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.h"
41#include "src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.h"
42#include "src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.h"
43#include "src/core/CL/kernels/CLGEMMReshapeLHSMatrixKernel.h"
44#include "src/core/CL/kernels/CLGEMMReshapeRHSMatrixKernel.h"
Sang-Hoon Park68dd25f2020-10-19 16:00:11 +010045#include "src/core/helpers/AutoConfiguration.h"
46#include "src/core/utils/helpers/float_ops.h"
47#include "src/runtime/CL/gemm/CLGEMMKernelSelection.h"
SiCong Libd8b1e22021-02-04 13:07:09 +000048#include "src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h"
Sang-Hoon Park68dd25f2020-10-19 16:00:11 +010049#include "support/Cast.h"
SiCong Libd8b1e22021-02-04 13:07:09 +000050#include "utils/TypePrinter.h"
Sang-Hoon Park68dd25f2020-10-19 16:00:11 +010051
giuros011c9efeb2019-01-11 14:04:43 +000052namespace arm_compute
53{
Gian Marco Iodice750641d2018-05-08 12:01:57 +010054using namespace arm_compute::misc::shape_calculator;
Gian Marco Iodice90313eb2019-01-16 15:40:25 +000055using namespace arm_compute::cl_gemm;
Michalis Spyroub27e13a2019-09-27 11:04:27 +010056using namespace arm_compute::utils::cast;
Anthony Barbier6ff3b192017-09-04 18:44:23 +010057
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +010058namespace weights_transformations
59{
60CLGEMMReshapeRHSMatrixKernelManaged::CLGEMMReshapeRHSMatrixKernelManaged()
Georgios Pinitas40f51a62020-11-21 03:04:18 +000061 : _kernel(std::make_unique<CLGEMMReshapeRHSMatrixKernel>())
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +010062{
63}
64
65CLGEMMReshapeRHSMatrixKernelManaged::~CLGEMMReshapeRHSMatrixKernelManaged() = default;
66
67void CLGEMMReshapeRHSMatrixKernelManaged::run()
68{
69 _output.allocator()->allocate();
70 CLScheduler::get().enqueue(*_kernel, false);
71 _reshape_run = true;
72}
73
74void CLGEMMReshapeRHSMatrixKernelManaged::release()
75{
76 _output.allocator()->free();
77}
78
79ICLTensor *CLGEMMReshapeRHSMatrixKernelManaged::get_weights()
80{
81 return &_output;
82}
83
84uint32_t CLGEMMReshapeRHSMatrixKernelManaged::uid()
85{
86 return _uid;
87}
88
89void CLGEMMReshapeRHSMatrixKernelManaged::configure(const ICLTensor *input, GEMMRHSMatrixInfo info)
90{
91 configure(CLKernelLibrary::get().get_compile_context(), input, info);
92}
93
94void CLGEMMReshapeRHSMatrixKernelManaged::configure(const CLCompileContext &compile_context, const ICLTensor *input, GEMMRHSMatrixInfo info)
95{
96 _kernel->configure(compile_context, input, &_output, info);
97}
98} // namespace weights_transformations
99
SiCong Libd8b1e22021-02-04 13:07:09 +0000100namespace
101{
SiCong Li1a28e732021-02-10 16:57:33 +0000102inline bool validate_gemm_kernel(CLGEMMKernelType kernel_type)
103{
104 switch(kernel_type)
105 {
106 case CLGEMMKernelType::NATIVE_V1:
107 case CLGEMMKernelType::RESHAPED_ONLY_RHS:
108 case CLGEMMKernelType::RESHAPED_V1:
109 case CLGEMMKernelType::RESHAPED:
110 {
111 return true;
112 }
113 default:
114 {
115 return false;
116 }
117 }
118}
119//Automatically select between mlgo (prioritized) and default heuristics for gemm kernel type
120inline CLGEMMKernelType auto_select_gemm_kernel(auto_heuristics::CommonQuery query, bool reshape_b_only_on_first_run)
121{
122 auto gemm_kernel = auto_heuristics::select_mlgo_gemm_kernel(query, reshape_b_only_on_first_run);
123 if(bool(gemm_kernel))
124 {
125 if(validate_gemm_kernel(gemm_kernel.gemm_type))
126 {
127 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use gemm kernel from mlgo heuristics: %s.", to_string(gemm_kernel.gemm_type).c_str());
128 return gemm_kernel.gemm_type;
129 }
130 }
131 gemm_kernel = auto_heuristics::select_default_gemm_kernel(query, reshape_b_only_on_first_run);
132 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use gemm kernel from default heuristics: %s.", to_string(gemm_kernel.gemm_type).c_str());
133 return gemm_kernel.gemm_type;
134}
SiCong Li8c23ba12021-02-08 14:19:23 +0000135// Validate lhs_info and rhs_info for reshaped only rhs kernel
SiCong Libd8b1e22021-02-04 13:07:09 +0000136inline bool validate_lhs_rhs_info_reshaped_only_rhs(const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c,
137 const ITensorInfo *output, GEMMKernelInfo gemm_kernel_info)
138{
139 // Validate GEMMLHSMatrixInfo and GEMMRHSMatrixInfo for reshaped only rhs kernel
140 TensorInfo tmp_b_info{};
141 // Validate reshape RHS kernel
142 auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
143 if(!bool(CLGEMMReshapeRHSMatrixKernel::validate(b, &tmp_b_info, rhs_info)))
144 {
145 return false;
146 }
147 // Validate mm kernel
148 gemm_kernel_info.lhs_info = lhs_info;
149 gemm_kernel_info.rhs_info = rhs_info;
150 gemm_kernel_info.has_pad_y = false;
151 if(!bool(CLGEMMMatrixMultiplyReshapedOnlyRHSKernel::validate(a, &tmp_b_info, c, output, 1.f, 0.f, lhs_info, rhs_info, gemm_kernel_info)))
152 {
153 return false;
154 }
155 gemm_kernel_info.has_pad_y = true;
156 if(!bool(CLGEMMMatrixMultiplyReshapedOnlyRHSKernel::validate(a, &tmp_b_info, c, output, 1.f, 0.f, lhs_info, rhs_info, gemm_kernel_info)))
157 {
158 return false;
159 }
160 return true;
161}
162
SiCong Li8c23ba12021-02-08 14:19:23 +0000163//Automatically select between mlgo (prioritized) and default heuristics for reshaped only rhs kernel configs
SiCong Libd8b1e22021-02-04 13:07:09 +0000164inline std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> auto_select_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery query, GEMMKernelInfo kernel_info, const ITensorInfo *a,
165 const ITensorInfo *b,
166 const ITensorInfo *c, const ITensorInfo *output)
167{
168 auto config = auto_heuristics::select_mlgo_gemm_config_reshaped_only_rhs(query);
169 if(config)
170 {
171 if(validate_lhs_rhs_info_reshaped_only_rhs(config.lhs_info, config.rhs_info, a, b, c, output, kernel_info))
172 {
173 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use reshaped_only_rhs config from mlgo heuristics: LHS info: %s ; RHS info: %s ", to_string(config.lhs_info).c_str(), to_string(config.rhs_info).c_str());
174 return { config.lhs_info, config.rhs_info };
175 }
176 }
SiCong Lidb353452021-02-08 15:16:13 +0000177 config = auto_heuristics::select_default_gemm_config_reshaped_only_rhs(query);
SiCong Libd8b1e22021-02-04 13:07:09 +0000178 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use reshaped_only_rhs config from default heuristics: LHS info: %s ; RHS info: %s ", to_string(config.lhs_info).c_str(), to_string(config.rhs_info).c_str());
179 return { config.lhs_info, config.rhs_info };
180}
181
SiCong Li8c23ba12021-02-08 14:19:23 +0000182// Validate lhs_info and rhs_info for reshaped kernel
183inline bool validate_lhs_rhs_info_reshaped(const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c,
184 const ITensorInfo *output, GEMMKernelInfo gemm_kernel_info, bool reinterpret_input_as_3d)
185{
186 // Validate GEMMLHSMatrixInfo and GEMMRHSMatrixInfo for reshaped kernel
187 TensorInfo tmp_a_info{};
188 TensorInfo tmp_b_info{};
189
190 // Validate reshape LHS kernel
191 auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_lhs_reshaped_shape(*a, lhs_info, reinterpret_input_as_3d)));
192 if(!bool(CLGEMMReshapeLHSMatrixKernel::validate(a, &tmp_a_info, lhs_info, reinterpret_input_as_3d)))
193 {
194 return false;
195 }
196
197 // Validate reshape RHS kernel
198 auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
199 if(!bool(CLGEMMReshapeRHSMatrixKernel::validate(b, &tmp_b_info, rhs_info)))
200 {
201 return false;
202 }
203 // Validate mm kernel
204 gemm_kernel_info.lhs_info = lhs_info;
205 gemm_kernel_info.rhs_info = rhs_info;
206 if(!bool(CLGEMMMatrixMultiplyReshapedKernel::validate(&tmp_a_info, &tmp_b_info, c, output, 1.f, 0.f, lhs_info, rhs_info, gemm_kernel_info)))
207 {
208 return false;
209 }
210 return true;
211}
212
213//Automatically select between mlgo (prioritized) and default heuristics for reshaped kernel configs
214inline std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> auto_select_gemm_config_reshaped(auto_heuristics::CommonQuery query, GEMMKernelInfo kernel_info, const ITensorInfo *a, const ITensorInfo *b,
215 const ITensorInfo *c, const ITensorInfo *output, bool reinterpret_input_as_3d)
216{
217 auto config = auto_heuristics::select_mlgo_gemm_config_reshaped(query);
218 if(config)
219 {
220 if(validate_lhs_rhs_info_reshaped(config.lhs_info, config.rhs_info, a, b, c, output, kernel_info, reinterpret_input_as_3d))
221 {
222 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use reshaped config from mlgo heuristics: LHS info: %s ; RHS info: %s ", to_string(config.lhs_info).c_str(), to_string(config.rhs_info).c_str());
223 return { config.lhs_info, config.rhs_info };
224 }
225 }
SiCong Lidb353452021-02-08 15:16:13 +0000226 config = auto_heuristics::select_default_gemm_config_reshaped(query);
SiCong Li8c23ba12021-02-08 14:19:23 +0000227 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use reshaped config from default heuristics: LHS info: %s ; RHS info: %s ", to_string(config.lhs_info).c_str(), to_string(config.rhs_info).c_str());
228 return { config.lhs_info, config.rhs_info };
229}
230
SiCong Libd8b1e22021-02-04 13:07:09 +0000231} // namespace
232
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100233CLGEMM::CLGEMM(std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager)
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100234 : _memory_group(std::move(memory_manager)),
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100235 _weights_manager(weights_manager),
Georgios Pinitas40f51a62020-11-21 03:04:18 +0000236 _mm_kernel(std::make_unique<CLGEMMMatrixMultiplyKernel>()),
237 _reshape_lhs_kernel(std::make_unique<CLGEMMReshapeLHSMatrixKernel>()),
238 _reshape_rhs_kernel(std::make_unique<CLGEMMReshapeRHSMatrixKernel>()),
239 _reshape_rhs_kernel_managed(std::make_unique<weights_transformations::CLGEMMReshapeRHSMatrixKernelManaged>()),
240 _mm_reshaped_kernel(std::make_unique<CLGEMMMatrixMultiplyReshapedKernel>()),
241 _mm_reshaped_only_rhs_kernel(std::make_unique<CLGEMMMatrixMultiplyReshapedOnlyRHSKernel>()),
242 _mm_reshaped_only_rhs_fallback_kernel(std::make_unique<CLGEMMMatrixMultiplyReshapedOnlyRHSKernel>()),
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100243 _tmp_a(),
244 _tmp_b(),
245 _original_b(nullptr),
Gian Marco Iodice9ae06d42020-10-22 16:37:12 +0100246 _lhs(nullptr),
247 _dst(nullptr),
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100248 _reshape_b_only_on_first_run(false),
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000249 _is_prepared(false),
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000250 _gemm_kernel_type(CLGEMMKernelType::NATIVE_V1)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100251{
252}
253
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100254CLGEMM::~CLGEMM() = default;
255
Manuel Bottini2b84be52020-04-08 10:15:51 +0100256void CLGEMM::configure_native_v1(const CLCompileContext &compile_context, const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta,
257 const GEMMInfo &gemm_info)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000258{
259 const unsigned int m = gemm_info.reinterpret_input_as_3d() ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1);
260 const unsigned int n = b->info()->dimension(0);
261 const unsigned int k = a->info()->dimension(0);
262 const GPUTarget gpu_target = CLScheduler::get().target();
Gian Marco36a0a462018-01-12 10:21:40 +0000263
264 // Set the target for the kernels
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100265 _mm_kernel->set_target(gpu_target);
Gian Marco36a0a462018-01-12 10:21:40 +0000266
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100267 GEMMReshapeInfo reshape_info(m, n, k, 1, 1, gemm_info.depth_output_gemm3d(), gemm_info.reinterpret_input_as_3d(), gemm_info.broadcast_bias());
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000268
269 // Configure and tune matrix multiply kernel
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100270 _mm_kernel->configure(compile_context, a, b, c, output, alpha, beta, false, reshape_info, gemm_info.fp_mixed_precision(), gemm_info.activation_info());
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000271
272 // Tune kernel statically
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100273 CLScheduler::get().tune_kernel_static(*_mm_kernel);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000274}
275
Manuel Bottini2b84be52020-04-08 10:15:51 +0100276void CLGEMM::configure_reshaped_v1(const CLCompileContext &compile_context, const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta,
277 const GEMMInfo &gemm_info)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000278{
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000279 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
280 const unsigned int m = reinterpret_input_as_3d ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1);
281 const unsigned int n = b->info()->dimension(0);
282 const unsigned int k = a->info()->dimension(0);
283 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000284 const GPUTarget gpu_target = CLScheduler::get().target();
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000285 int mult_transpose1xW_width = 1;
286 int mult_interleave4x4_height = 1;
Gian Marco36a0a462018-01-12 10:21:40 +0000287
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000288 // Set the target for the kernels
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100289 _reshape_lhs_kernel->set_target(gpu_target);
290 _mm_kernel->set_target(gpu_target);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000291
Gian Marco Iodice750641d2018-05-08 12:01:57 +0100292 if(get_arch_from_target(gpu_target) == GPUTarget::BIFROST)
Gian Marco36a0a462018-01-12 10:21:40 +0000293 {
294 mult_transpose1xW_width = 4;
295 mult_interleave4x4_height = 2;
296 }
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000297
giuros018b6b4a92018-12-18 19:01:33 +0000298 GEMMRHSMatrixInfo rhs_info;
299 rhs_info.n0 = 16 / b->info()->element_size();
300 rhs_info.k0 = 1;
301 rhs_info.h0 = mult_transpose1xW_width;
302 rhs_info.interleave = false;
303 rhs_info.transpose = false;
Gian Marco36a0a462018-01-12 10:21:40 +0000304
giuros011c9efeb2019-01-11 14:04:43 +0000305 GEMMLHSMatrixInfo lhs_info;
306 lhs_info.m0 = 4;
307 lhs_info.k0 = 4;
308 lhs_info.v0 = mult_interleave4x4_height;
309 lhs_info.interleave = true;
310 lhs_info.transpose = true;
311
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100312 GEMMReshapeInfo reshape_info(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d, false, gemm_info.broadcast_bias());
Gian Marcob5311a62017-12-13 12:48:03 +0000313
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100314 const bool use_mm_b = (!_weights_manager || !_weights_manager->are_weights_managed(b));
315
316 // Manage intermediate buffers
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000317 _memory_group.manage(&_tmp_a);
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100318
319 if(!_reshape_b_only_on_first_run && use_mm_b)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +0100320 {
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000321 _memory_group.manage(&_tmp_b);
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100322 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100323
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000324 // Configure interleave kernel
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100325 _reshape_lhs_kernel->configure(compile_context, a, &_tmp_a, lhs_info, reinterpret_input_as_3d);
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100326
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000327 // Configure transpose kernel
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100328 ICLTensor *reshaped_rhs = &_tmp_b;
329 if(_weights_manager && _weights_manager->are_weights_managed(b))
330 {
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100331 _reshape_rhs_kernel_managed->configure(compile_context, b, rhs_info);
332 reshaped_rhs = utils::cast::polymorphic_downcast<ICLTensor *>(_weights_manager->acquire(b, _reshape_rhs_kernel_managed.get()));
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100333 }
334 else
335 {
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100336 _reshape_rhs_kernel->configure(compile_context, b, &_tmp_b, rhs_info);
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100337 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100338
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000339 // Configure and tune matrix multiply kernel
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100340 _mm_kernel->configure(compile_context, &_tmp_a, reshaped_rhs, c, output, alpha, beta, true, reshape_info, gemm_info.fp_mixed_precision(), gemm_info.activation_info());
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000341
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100342 CLScheduler::get().tune_kernel_static(*_mm_kernel);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000343
344 // Allocate intermediate tensors
345 _tmp_a.allocator()->allocate();
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100346
347 if(!_reshape_b_only_on_first_run && use_mm_b)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100348 {
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000349 _tmp_b.allocator()->allocate();
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100350 }
351}
352
Manuel Bottini2b84be52020-04-08 10:15:51 +0100353void CLGEMM::configure_reshaped_v2(const CLCompileContext &compile_context, const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta,
354 const GEMMInfo &gemm_info)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000355{
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000356 DataType data_type = a->info()->data_type();
357 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
358 const unsigned int m = reinterpret_input_as_3d ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1);
359 const unsigned int n = b->info()->dimension(0);
360 const unsigned int k = a->info()->dimension(0);
361 const unsigned int batch_size = reinterpret_input_as_3d ? a->info()->dimension(3) : a->info()->dimension(2);
362 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
363 const GPUTarget gpu_target = CLScheduler::get().target();
Gian Marco Iodicee16c8902019-06-14 16:11:10 +0100364 bool broadcast_bias = gemm_info.broadcast_bias();
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100365
366 GEMMKernelInfo kernel_info;
Gian Marco Iodice7026b302019-06-26 17:18:11 +0100367 kernel_info.m = m;
368 kernel_info.n = n;
369 kernel_info.k = k;
370 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
371 kernel_info.reinterpret_input_as_3d = false;
372 kernel_info.broadcast_bias = broadcast_bias;
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100373 kernel_info.activation_info = gemm_info.activation_info();
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000374
375 // Set the target for the kernels
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100376 _reshape_lhs_kernel->set_target(gpu_target);
377 _mm_kernel->set_target(gpu_target);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000378
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100379 const bool use_mm_b = (!_weights_manager || !_weights_manager->are_weights_managed(b));
380
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000381 // Manage intermediate buffers
382 _memory_group.manage(&_tmp_a);
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100383
384 if(!_reshape_b_only_on_first_run && use_mm_b)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000385 {
386 _memory_group.manage(&_tmp_b);
387 }
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100388
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000389 // _tmp_a and _tmp_b will be auto configured in _interleave_kernel and in _transpose_kernel
390
391 GEMMLHSMatrixInfo lhs_info{};
392 GEMMRHSMatrixInfo rhs_info{};
393
394 // Pick up the GEMM configuration
SiCong Li8c23ba12021-02-08 14:19:23 +0000395 std::tie(lhs_info, rhs_info) = auto_select_gemm_config_reshaped(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size }, kernel_info, a->info(), b->info(),
396 c == nullptr ? nullptr : c->info(), output->info(), gemm_info.reinterpret_input_as_3d());
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000397
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100398 _reshape_lhs_kernel->configure(compile_context, a, &_tmp_a, lhs_info, gemm_info.reinterpret_input_as_3d());
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100399
400 ICLTensor *reshaped_rhs = &_tmp_b;
401 if(_weights_manager && _weights_manager->are_weights_managed(b))
402 {
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100403 _reshape_rhs_kernel_managed->configure(compile_context, b, rhs_info);
404 reshaped_rhs = utils::cast::polymorphic_downcast<ICLTensor *>(_weights_manager->acquire(b, _reshape_rhs_kernel_managed.get()));
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100405 }
406 else
407 {
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100408 _reshape_rhs_kernel->configure(compile_context, b, &_tmp_b, rhs_info);
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100409 }
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000410
411 // Configure and tune matrix multiply kernel
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100412 _mm_reshaped_kernel->configure(compile_context, &_tmp_a, reshaped_rhs, c, output, alpha, beta, lhs_info, rhs_info, kernel_info);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000413
414 // Allocate intermediate tensors
415 _tmp_a.allocator()->allocate();
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100416
417 if(!_reshape_b_only_on_first_run && use_mm_b)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000418 {
419 _tmp_b.allocator()->allocate();
420 }
421}
422
Manuel Bottini2b84be52020-04-08 10:15:51 +0100423void CLGEMM::configure_reshaped_only_rhs(const CLCompileContext &compile_context, const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta,
424 const GEMMInfo &gemm_info)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000425{
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000426 DataType data_type = a->info()->data_type();
427 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
428 const unsigned int m = reinterpret_input_as_3d ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1);
429 const unsigned int n = b->info()->dimension(0);
430 const unsigned int k = a->info()->dimension(0);
431 const unsigned int batch_size = reinterpret_input_as_3d ? a->info()->dimension(3) : a->info()->dimension(2);
432 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
433 const GPUTarget gpu_target = CLScheduler::get().target();
Georgios Pinitasb0f342e2019-05-21 13:32:43 +0100434 bool broadcast_bias = gemm_info.broadcast_bias();
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100435
436 GEMMKernelInfo kernel_info;
Gian Marco Iodice7026b302019-06-26 17:18:11 +0100437 kernel_info.m = m;
438 kernel_info.n = n;
439 kernel_info.k = k;
440 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
441 kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
442 kernel_info.broadcast_bias = broadcast_bias;
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100443 kernel_info.activation_info = gemm_info.activation_info();
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000444
445 // Set the target for the kernels
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100446 _mm_kernel->set_target(gpu_target);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000447
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100448 const bool use_mm_b = (!_weights_manager || !_weights_manager->are_weights_managed(b));
449
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000450 // Manage intermediate buffers
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100451 if(!_reshape_b_only_on_first_run && use_mm_b)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000452 {
453 _memory_group.manage(&_tmp_b);
454 }
455
456 GEMMLHSMatrixInfo lhs_info{};
457 GEMMRHSMatrixInfo rhs_info{};
458
459 // Pick up the GEMM configuration
SiCong Libd8b1e22021-02-04 13:07:09 +0000460 std::tie(lhs_info, rhs_info) = auto_select_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size }, kernel_info, a->info(), b->info(),
461 c == nullptr ? nullptr : c->info(), output->info());
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000462
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100463 ICLTensor *reshaped_rhs = &_tmp_b;
464 if(_weights_manager && _weights_manager->are_weights_managed(b))
465 {
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100466 _reshape_rhs_kernel_managed->configure(compile_context, b, rhs_info);
467 reshaped_rhs = utils::cast::polymorphic_downcast<ICLTensor *>(_weights_manager->acquire(b, _reshape_rhs_kernel_managed.get()));
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100468 }
469 else
470 {
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100471 _reshape_rhs_kernel->configure(compile_context, b, &_tmp_b, rhs_info);
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100472 }
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000473
Gian Marco Iodice9ae06d42020-10-22 16:37:12 +0100474 // Configure two variants of CLGEMMMatrixMultiplyReshapedOnlyRHSKernel (has_pad_y = false/true)
475 // During the prepare stage we check the padding requirement for the lhs and dst tensors. If they do not have
476 // pad y, we dispatch CLGEMMMatrixMultiplyReshapedOnlyRHSKernel with has_pad_y = false
477
478 // Configure matrix multiply kernel with no y padding support
479 kernel_info.has_pad_y = false;
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100480 _mm_reshaped_only_rhs_kernel->configure(compile_context, a, reshaped_rhs, c, output, alpha, beta, lhs_info, rhs_info, kernel_info);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000481
Gian Marco Iodice9ae06d42020-10-22 16:37:12 +0100482 // Configure matrix multiply kernel with y padding support
483 kernel_info.has_pad_y = true;
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100484 _mm_reshaped_only_rhs_fallback_kernel->configure(compile_context, a, reshaped_rhs, c, output, alpha, beta, lhs_info, rhs_info, kernel_info);
Gian Marco Iodice9ae06d42020-10-22 16:37:12 +0100485
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100486 if(!_reshape_b_only_on_first_run && use_mm_b)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000487 {
488 _tmp_b.allocator()->allocate();
489 }
490}
491
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000492Status CLGEMM::validate_native_v1(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
Georgios Pinitas78c00902018-01-09 17:33:11 +0000493{
Gian Marco Iodice750641d2018-05-08 12:01:57 +0100494 ARM_COMPUTE_UNUSED(alpha);
Gian Marco Iodice215b4ea2018-06-28 16:29:29 +0100495 ARM_COMPUTE_UNUSED(output);
Gian Marco Iodice750641d2018-05-08 12:01:57 +0100496
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000497 // Get the GPU target
498 const GPUTarget gpu_target = CLScheduler::get().target();
499 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
500 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
501 const unsigned int n = b->dimension(0);
502 const unsigned int k = a->dimension(0);
503 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
Gian Marco Iodice750641d2018-05-08 12:01:57 +0100504
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100505 const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d, gemm_info.broadcast_bias());
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000506
507 // Validate matrix multiply
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100508 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyKernel::validate(a, b, c, output, alpha, beta,
509 false, reshape_info, gpu_target, gemm_info.fp_mixed_precision(), gemm_info.activation_info()));
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000510
511 return Status{};
512}
513
514Status CLGEMM::validate_reshaped_v1(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
515{
516 ARM_COMPUTE_UNUSED(alpha);
517 ARM_COMPUTE_UNUSED(output);
Gian Marco Iodice750641d2018-05-08 12:01:57 +0100518
519 TensorInfo tmp_a_info{};
520 TensorInfo tmp_b_info{};
Gian Marco Iodice750641d2018-05-08 12:01:57 +0100521
522 // Get the GPU target
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000523 const GPUTarget gpu_target = CLScheduler::get().target();
524 const unsigned int m = gemm_info.reinterpret_input_as_3d() ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000525 const unsigned int n = b->dimension(0);
526 const unsigned int k = a->dimension(0);
527 int mult_transpose1xW_width = 1;
528 int mult_interleave4x4_height = 1;
529 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
Gian Marco Iodice750641d2018-05-08 12:01:57 +0100530
531 if(get_arch_from_target(gpu_target) == GPUTarget::BIFROST)
532 {
533 mult_transpose1xW_width = 4;
534 mult_interleave4x4_height = 2;
535 }
536
giuros018b6b4a92018-12-18 19:01:33 +0000537 GEMMRHSMatrixInfo rhs_info;
538 rhs_info.n0 = 16 / b->element_size();
539 rhs_info.k0 = 1;
540 rhs_info.h0 = mult_transpose1xW_width;
541 rhs_info.interleave = false;
542 rhs_info.transpose = false;
543
giuros011c9efeb2019-01-11 14:04:43 +0000544 GEMMLHSMatrixInfo lhs_info;
545 lhs_info.m0 = 4;
546 lhs_info.k0 = 4;
547 lhs_info.v0 = mult_interleave4x4_height;
548 lhs_info.interleave = true;
549 lhs_info.transpose = true;
550
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100551 const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d, false, gemm_info.broadcast_bias());
Gian Marco Iodice750641d2018-05-08 12:01:57 +0100552
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000553 // Validate interleave kernel
554 auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_lhs_reshaped_shape(*a, lhs_info, gemm_info.reinterpret_input_as_3d())));
555 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeLHSMatrixKernel::validate(a, &tmp_a_info, lhs_info, gemm_info.reinterpret_input_as_3d()));
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000556
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000557 // Validate transpose kernel
558 auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
559 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeRHSMatrixKernel::validate(b, &tmp_b_info, rhs_info));
Michele Di Giorgioebc3a902018-11-16 16:04:25 +0000560
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000561 // Validate matrix multiply
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100562 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyKernel::validate(&tmp_a_info, &tmp_b_info, c, output, alpha, beta,
563 true, reshape_info, gpu_target, gemm_info.fp_mixed_precision(), gemm_info.activation_info()));
Gian Marco Iodice750641d2018-05-08 12:01:57 +0100564
Georgios Pinitas78c00902018-01-09 17:33:11 +0000565 return Status{};
566}
567
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000568Status CLGEMM::validate_reshaped(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000569{
570 ARM_COMPUTE_UNUSED(alpha);
571 ARM_COMPUTE_UNUSED(output);
572
573 TensorInfo tmp_a_info{};
574 TensorInfo tmp_b_info{};
575
576 // Get the GPU target
577 const GPUTarget gpu_target = CLScheduler::get().target();
578 DataType data_type = a->data_type();
579 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
580 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
581 const unsigned int n = b->dimension(0);
582 const unsigned int k = a->dimension(0);
583 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
584 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
Gian Marco Iodicee16c8902019-06-14 16:11:10 +0100585 const bool broadcast_bias = gemm_info.broadcast_bias();
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100586
587 GEMMKernelInfo kernel_info;
Gian Marco Iodice7026b302019-06-26 17:18:11 +0100588 kernel_info.m = m;
589 kernel_info.n = n;
590 kernel_info.k = k;
591 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
592 kernel_info.reinterpret_input_as_3d = false;
593 kernel_info.broadcast_bias = broadcast_bias;
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100594 kernel_info.activation_info = gemm_info.activation_info();
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000595
596 GEMMLHSMatrixInfo lhs_info;
597 GEMMRHSMatrixInfo rhs_info;
598
599 // Pick up the GEMM configuration
SiCong Li8c23ba12021-02-08 14:19:23 +0000600 // NOTE: No need to validate mlgo configurations as they automatically fall back to default heuristics if validation fails
601 const auto gemm_config = select_default_gemm_config_reshaped(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size });
602 lhs_info = gemm_config.lhs_info;
603 rhs_info = gemm_config.rhs_info;
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000604
605 auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_lhs_reshaped_shape(*a, lhs_info, gemm_info.reinterpret_input_as_3d())));
606 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeLHSMatrixKernel::validate(a, &tmp_a_info, lhs_info, gemm_info.reinterpret_input_as_3d()));
607
608 auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
609 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeRHSMatrixKernel::validate(b, &tmp_b_info, rhs_info));
610
611 // Validate matrix multiply
Gian Marco Iodice7026b302019-06-26 17:18:11 +0100612 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyReshapedKernel::validate(&tmp_a_info, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info));
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000613
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000614 return Status{};
615}
616
617Status CLGEMM::validate_reshaped_only_rhs(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
618{
619 ARM_COMPUTE_UNUSED(alpha);
620 ARM_COMPUTE_UNUSED(output);
621
622 TensorInfo tmp_b_info{};
623
624 // Get the GPU target
Gian Marco Iodice7026b302019-06-26 17:18:11 +0100625 const GPUTarget gpu_target = CLScheduler::get().target();
626 const DataType data_type = a->data_type();
627 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
628 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
629 const unsigned int n = b->dimension(0);
630 const unsigned int k = a->dimension(0);
631 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
632 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
633 const bool broadcast_bias = gemm_info.broadcast_bias();
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100634
635 GEMMKernelInfo kernel_info;
Gian Marco Iodice7026b302019-06-26 17:18:11 +0100636 kernel_info.m = m;
637 kernel_info.n = n;
638 kernel_info.k = k;
639 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
640 kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
641 kernel_info.broadcast_bias = broadcast_bias;
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100642 kernel_info.activation_info = gemm_info.activation_info();
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000643
644 GEMMLHSMatrixInfo lhs_info;
645 GEMMRHSMatrixInfo rhs_info;
646
647 // Pick up the GEMM configuration
SiCong Libd8b1e22021-02-04 13:07:09 +0000648 // NOTE: No need to validate mlgo configurations as they automatically fall back to default heuristics if validation fails
649 const auto gemm_config = select_default_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size });
650 lhs_info = gemm_config.lhs_info;
651 rhs_info = gemm_config.rhs_info;
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000652
653 auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
654 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeRHSMatrixKernel::validate(b, &tmp_b_info, rhs_info));
655
656 // Validate matrix multiply
Gian Marco Iodice9ae06d42020-10-22 16:37:12 +0100657 kernel_info.has_pad_y = false;
658 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyReshapedOnlyRHSKernel::validate(a, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info));
659
660 kernel_info.has_pad_y = true;
Gian Marco Iodice7026b302019-06-26 17:18:11 +0100661 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyReshapedOnlyRHSKernel::validate(a, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info));
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000662
663 return Status{};
664}
665
666void CLGEMM::configure(const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info)
667{
Manuel Bottini2b84be52020-04-08 10:15:51 +0100668 configure(CLKernelLibrary::get().get_compile_context(), a, b, c, output, alpha, beta, gemm_info);
669}
670
671void CLGEMM::configure(const CLCompileContext &compile_context, const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info)
672{
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000673 ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, output);
674
675 // Perform validation step
676 ARM_COMPUTE_ERROR_THROW_ON(validate(a->info(), b->info(), c != nullptr ? c->info() : nullptr, output->info(), alpha, beta, gemm_info));
677
678 // Check if we need to reshape the matrix B only on the first run
679 _reshape_b_only_on_first_run = gemm_info.reshape_b_only_on_first_run();
680 _is_prepared = gemm_info.retain_internal_weights();
681 _original_b = b;
Gian Marco Iodice9ae06d42020-10-22 16:37:12 +0100682 _lhs = a;
683 _dst = output;
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000684
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000685 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
686 const unsigned int m = reinterpret_input_as_3d ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1);
687 const unsigned int n = b->info()->dimension(0);
688 const unsigned int k = a->info()->dimension(0);
Gian Marco Iodice026d0452020-08-28 13:52:12 +0100689 const unsigned int batch_size = reinterpret_input_as_3d ? a->info()->dimension(3) : a->info()->dimension(2);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000690
691 // Select GEMMType
SiCong Libd8b1e22021-02-04 13:07:09 +0000692 _gemm_kernel_type = auto_select_gemm_kernel(auto_heuristics::CommonQuery{ CLScheduler::get().target(), a->info()->data_type(), m, n, k, batch_size }, _reshape_b_only_on_first_run);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000693
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100694 const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr);
Gian Marco Iodicee16c8902019-06-14 16:11:10 +0100695
696 const ICLTensor *c_to_use = fuse_add_c ? c : nullptr;
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000697
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000698 switch(_gemm_kernel_type)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000699 {
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000700 case CLGEMMKernelType::NATIVE_V1:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000701 {
Manuel Bottini2b84be52020-04-08 10:15:51 +0100702 configure_native_v1(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000703 break;
704 }
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000705 case CLGEMMKernelType::RESHAPED_V1:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000706 {
Manuel Bottini2b84be52020-04-08 10:15:51 +0100707 configure_reshaped_v1(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000708 break;
709 }
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000710 case CLGEMMKernelType::RESHAPED:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000711 {
Manuel Bottini2b84be52020-04-08 10:15:51 +0100712 configure_reshaped_v2(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000713 break;
714 }
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000715 case CLGEMMKernelType::RESHAPED_ONLY_RHS:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000716 {
Manuel Bottini2b84be52020-04-08 10:15:51 +0100717 configure_reshaped_only_rhs(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000718 break;
719 }
720 default:
721 {
722 ARM_COMPUTE_ERROR("GEMMType not supported");
723 }
724 }
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000725}
726
727Status CLGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
728{
729 // Get the GPU target
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000730 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
731 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
732 const unsigned int n = b->dimension(0);
733 const unsigned int k = a->dimension(0);
Gian Marco Iodice026d0452020-08-28 13:52:12 +0100734 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000735
736 // Select GEMMType
SiCong Libd8b1e22021-02-04 13:07:09 +0000737 CLGEMMKernelType gemm_kernel_type = auto_select_gemm_kernel(auto_heuristics::CommonQuery
738 {
739 CLScheduler::get().target(), a->data_type(), m, n, k, batch_size,
740 },
741 gemm_info.reshape_b_only_on_first_run());
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000742
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100743 const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr);
Gian Marco Iodicee16c8902019-06-14 16:11:10 +0100744
745 const ITensorInfo *c_to_use = fuse_add_c ? c : nullptr;
746
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000747 switch(gemm_kernel_type)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000748 {
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000749 case CLGEMMKernelType::NATIVE_V1:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000750 {
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000751 ARM_COMPUTE_RETURN_ON_ERROR(validate_native_v1(a, b, c_to_use, output, alpha, beta, gemm_info));
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000752 break;
753 }
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000754 case CLGEMMKernelType::RESHAPED_V1:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000755 {
Gian Marco Iodicee16c8902019-06-14 16:11:10 +0100756 ARM_COMPUTE_RETURN_ON_ERROR(validate_reshaped_v1(a, b, c_to_use, output, alpha, beta, gemm_info));
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000757 break;
758 }
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000759 case CLGEMMKernelType::RESHAPED:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000760 {
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000761 ARM_COMPUTE_RETURN_ON_ERROR(validate_reshaped(a, b, c_to_use, output, alpha, beta, gemm_info));
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000762 break;
763 }
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000764 case CLGEMMKernelType::RESHAPED_ONLY_RHS:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000765 {
Gian Marco Iodicee16c8902019-06-14 16:11:10 +0100766 ARM_COMPUTE_RETURN_ON_ERROR(validate_reshaped_only_rhs(a, b, c_to_use, output, alpha, beta, gemm_info));
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000767 break;
768 }
769 default:
770 {
771 ARM_COMPUTE_RETURN_ERROR_MSG("GEMMType not supported");
772 }
773 }
774
775 return Status{};
776}
777
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100778void CLGEMM::run()
779{
Georgios Pinitase0437672018-05-02 14:07:55 +0100780 prepare();
Georgios Pinitasda953f22019-04-02 17:27:03 +0100781 MemoryGroupResourceScope scope_mg(_memory_group);
Georgios Pinitas8a94e7c2017-09-15 19:06:47 +0100782
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100783 // Run matrix multiply kernel
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000784 switch(_gemm_kernel_type)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000785 {
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000786 case CLGEMMKernelType::NATIVE_V1:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000787 {
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100788 CLScheduler::get().enqueue(*_mm_kernel, true);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000789 break;
790 }
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000791 case CLGEMMKernelType::RESHAPED_V1:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000792 {
793 // Run interleave kernel
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100794 CLScheduler::get().enqueue(*_reshape_lhs_kernel, false);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000795
796 if(!_reshape_b_only_on_first_run)
797 {
798 // Run transpose kernel
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100799 if(_weights_manager && _weights_manager->are_weights_managed(_original_b))
800 {
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100801 _weights_manager->run(_original_b, _reshape_rhs_kernel_managed.get());
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100802 }
803 else
804 {
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100805 CLScheduler::get().enqueue(*_reshape_rhs_kernel, false);
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100806 }
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000807 }
808
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100809 CLScheduler::get().enqueue(*_mm_kernel, true);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000810 break;
811 }
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000812 case CLGEMMKernelType::RESHAPED:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000813 {
814 // Run interleave kernel
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100815 CLScheduler::get().enqueue(*_reshape_lhs_kernel, false);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000816
817 if(!_reshape_b_only_on_first_run)
818 {
819 // Run transpose kernel
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100820 if(_weights_manager && _weights_manager->are_weights_managed(_original_b))
821 {
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100822 _weights_manager->run(_original_b, _reshape_rhs_kernel_managed.get());
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100823 }
824 else
825 {
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100826 CLScheduler::get().enqueue(*_reshape_rhs_kernel, false);
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100827 }
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000828 }
829
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100830 CLScheduler::get().enqueue(*_mm_reshaped_kernel, true);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000831 break;
832 }
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000833 case CLGEMMKernelType::RESHAPED_ONLY_RHS:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000834 {
835 if(!_reshape_b_only_on_first_run)
836 {
837 // Run transpose kernel
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100838 if(_weights_manager && _weights_manager->are_weights_managed(_original_b))
839 {
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100840 _weights_manager->run(_original_b, _reshape_rhs_kernel_managed.get());
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100841 }
842 else
843 {
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100844 CLScheduler::get().enqueue(*_reshape_rhs_kernel, false);
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100845 }
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000846 }
SiCong Li4d2365d2020-11-08 22:11:01 +0000847 // In case of RESHAPED_ONLY_RHS, we need to check the padding requirement
848 // Check if the lhs or dst tensors have padding
849 const unsigned int cross_plane_pad_lhs = _lhs->info()->padding().top + _lhs->info()->padding().bottom;
850 const unsigned int cross_plane_pad_dst = _dst->info()->padding().top + _dst->info()->padding().bottom;
851
852 bool has_pad_y = (cross_plane_pad_lhs != 0) || (cross_plane_pad_dst != 0);
853 if(has_pad_y)
Gian Marco Iodice9ae06d42020-10-22 16:37:12 +0100854 {
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100855 CLScheduler::get().enqueue(*_mm_reshaped_only_rhs_fallback_kernel, true);
Gian Marco Iodice9ae06d42020-10-22 16:37:12 +0100856 }
857 else
858 {
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100859 CLScheduler::get().enqueue(*_mm_reshaped_only_rhs_kernel, true);
Gian Marco Iodice9ae06d42020-10-22 16:37:12 +0100860 }
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000861 break;
862 }
863 default:
864 {
865 ARM_COMPUTE_ERROR("GEMMType not supported");
866 }
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000867 }
Georgios Pinitase0437672018-05-02 14:07:55 +0100868}
Georgios Pinitas82b51482018-04-24 15:14:12 +0100869
Georgios Pinitase0437672018-05-02 14:07:55 +0100870void CLGEMM::prepare()
871{
872 if(!_is_prepared)
873 {
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000874 if(_gemm_kernel_type != CLGEMMKernelType::NATIVE_V1 && _reshape_b_only_on_first_run)
Georgios Pinitase0437672018-05-02 14:07:55 +0100875 {
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100876 if(_weights_manager && _weights_manager->are_weights_managed(_original_b))
877 {
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100878 _weights_manager->run(_original_b, _reshape_rhs_kernel_managed.get());
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100879 }
880 else
881 {
882 // Run transpose kernel and mark original weights tensor as unused
883 _tmp_b.allocator()->allocate();
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100884 CLScheduler::get().enqueue(*_reshape_rhs_kernel, false);
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100885 _original_b->mark_as_unused();
886 }
Georgios Pinitase0437672018-05-02 14:07:55 +0100887 }
888 CLScheduler::get().queue().finish();
889 _is_prepared = true;
890 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100891}
giuros011c9efeb2019-01-11 14:04:43 +0000892} // namespace arm_compute