blob: 181ae2843be278c0a2ca1a51da99adce385145d3 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Michele Di Giorgiod9eaf612020-07-08 11:12:57 +01002 * Copyright (c) 2017-2020 Arm Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/runtime/CL/functions/CLGEMM.h"
25
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +010026#include "arm_compute/core/CL/CLKernelLibrary.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010027#include "arm_compute/core/CL/ICLTensor.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010028#include "arm_compute/core/Error.h"
Gian Marco Iodice750641d2018-05-08 12:01:57 +010029#include "arm_compute/core/GPUTarget.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010030#include "arm_compute/core/Helpers.h"
Gian Marco Iodice7026b302019-06-26 17:18:11 +010031#include "arm_compute/core/KernelDescriptors.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010032#include "arm_compute/core/TensorInfo.h"
33#include "arm_compute/core/Types.h"
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +010034#include "arm_compute/core/Utils.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010035#include "arm_compute/core/Validate.h"
Gian Marco Iodice750641d2018-05-08 12:01:57 +010036#include "arm_compute/core/utils/misc/ShapeCalculator.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010037#include "arm_compute/runtime/CL/CLScheduler.h"
38#include "arm_compute/runtime/ITensorAllocator.h"
Sang-Hoon Park68dd25f2020-10-19 16:00:11 +010039#include "src/core/CL/ICLGEMMKernelConfiguration.h"
40#include "src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfiguration.h"
41#include "src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfiguration.h"
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +010042#include "src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.h"
43#include "src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.h"
44#include "src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.h"
45#include "src/core/CL/kernels/CLGEMMReshapeLHSMatrixKernel.h"
46#include "src/core/CL/kernels/CLGEMMReshapeRHSMatrixKernel.h"
Sang-Hoon Park68dd25f2020-10-19 16:00:11 +010047#include "src/core/helpers/AutoConfiguration.h"
48#include "src/core/utils/helpers/float_ops.h"
49#include "src/runtime/CL/gemm/CLGEMMKernelSelection.h"
50#include "support/Cast.h"
51
giuros011c9efeb2019-01-11 14:04:43 +000052namespace arm_compute
53{
Gian Marco Iodice750641d2018-05-08 12:01:57 +010054using namespace arm_compute::misc::shape_calculator;
Gian Marco Iodice90313eb2019-01-16 15:40:25 +000055using namespace arm_compute::cl_gemm;
Michalis Spyroub27e13a2019-09-27 11:04:27 +010056using namespace arm_compute::utils::cast;
Anthony Barbier6ff3b192017-09-04 18:44:23 +010057
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +010058namespace weights_transformations
59{
60CLGEMMReshapeRHSMatrixKernelManaged::CLGEMMReshapeRHSMatrixKernelManaged()
Georgios Pinitas40f51a62020-11-21 03:04:18 +000061 : _kernel(std::make_unique<CLGEMMReshapeRHSMatrixKernel>())
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +010062{
63}
64
65CLGEMMReshapeRHSMatrixKernelManaged::~CLGEMMReshapeRHSMatrixKernelManaged() = default;
66
67void CLGEMMReshapeRHSMatrixKernelManaged::run()
68{
69 _output.allocator()->allocate();
70 CLScheduler::get().enqueue(*_kernel, false);
71 _reshape_run = true;
72}
73
74void CLGEMMReshapeRHSMatrixKernelManaged::release()
75{
76 _output.allocator()->free();
77}
78
79ICLTensor *CLGEMMReshapeRHSMatrixKernelManaged::get_weights()
80{
81 return &_output;
82}
83
84uint32_t CLGEMMReshapeRHSMatrixKernelManaged::uid()
85{
86 return _uid;
87}
88
89void CLGEMMReshapeRHSMatrixKernelManaged::configure(const ICLTensor *input, GEMMRHSMatrixInfo info)
90{
91 configure(CLKernelLibrary::get().get_compile_context(), input, info);
92}
93
94void CLGEMMReshapeRHSMatrixKernelManaged::configure(const CLCompileContext &compile_context, const ICLTensor *input, GEMMRHSMatrixInfo info)
95{
96 _kernel->configure(compile_context, input, &_output, info);
97}
98} // namespace weights_transformations
99
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100100CLGEMM::CLGEMM(std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager)
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100101 : _memory_group(std::move(memory_manager)),
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100102 _weights_manager(weights_manager),
Georgios Pinitas40f51a62020-11-21 03:04:18 +0000103 _mm_kernel(std::make_unique<CLGEMMMatrixMultiplyKernel>()),
104 _reshape_lhs_kernel(std::make_unique<CLGEMMReshapeLHSMatrixKernel>()),
105 _reshape_rhs_kernel(std::make_unique<CLGEMMReshapeRHSMatrixKernel>()),
106 _reshape_rhs_kernel_managed(std::make_unique<weights_transformations::CLGEMMReshapeRHSMatrixKernelManaged>()),
107 _mm_reshaped_kernel(std::make_unique<CLGEMMMatrixMultiplyReshapedKernel>()),
108 _mm_reshaped_only_rhs_kernel(std::make_unique<CLGEMMMatrixMultiplyReshapedOnlyRHSKernel>()),
109 _mm_reshaped_only_rhs_fallback_kernel(std::make_unique<CLGEMMMatrixMultiplyReshapedOnlyRHSKernel>()),
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100110 _tmp_a(),
111 _tmp_b(),
112 _original_b(nullptr),
Gian Marco Iodice9ae06d42020-10-22 16:37:12 +0100113 _lhs(nullptr),
114 _dst(nullptr),
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100115 _reshape_b_only_on_first_run(false),
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000116 _is_prepared(false),
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000117 _gemm_kernel_type(CLGEMMKernelType::NATIVE_V1)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100118{
119}
120
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100121CLGEMM::~CLGEMM() = default;
122
Gian Marco Iodice026d0452020-08-28 13:52:12 +0100123CLGEMMKernelType CLGEMM::select_gemm_kernel(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type, bool reshape_b_only_on_first_run)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100124{
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000125 std::unique_ptr<ICLGEMMKernelSelection> gemm_kernel = CLGEMMKernelSelectionFactory::create(CLScheduler::get().target());
126 ARM_COMPUTE_ERROR_ON_NULLPTR(gemm_kernel.get());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100127
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000128 CLGEMMKernelSelectionParams params;
129 params.m = m;
130 params.n = n;
131 params.k = k;
Gian Marco Iodice026d0452020-08-28 13:52:12 +0100132 params.b = b;
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000133 params.is_rhs_constant = reshape_b_only_on_first_run;
134 params.data_type = data_type;
Gian Marco Iodice05639f62019-09-24 12:05:06 +0100135
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000136 return gemm_kernel->select_kernel(params);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000137}
138
Manuel Bottini2b84be52020-04-08 10:15:51 +0100139void CLGEMM::configure_native_v1(const CLCompileContext &compile_context, const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta,
140 const GEMMInfo &gemm_info)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000141{
142 const unsigned int m = gemm_info.reinterpret_input_as_3d() ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1);
143 const unsigned int n = b->info()->dimension(0);
144 const unsigned int k = a->info()->dimension(0);
145 const GPUTarget gpu_target = CLScheduler::get().target();
Gian Marco36a0a462018-01-12 10:21:40 +0000146
147 // Set the target for the kernels
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100148 _mm_kernel->set_target(gpu_target);
Gian Marco36a0a462018-01-12 10:21:40 +0000149
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100150 GEMMReshapeInfo reshape_info(m, n, k, 1, 1, gemm_info.depth_output_gemm3d(), gemm_info.reinterpret_input_as_3d(), gemm_info.broadcast_bias());
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000151
152 // Configure and tune matrix multiply kernel
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100153 _mm_kernel->configure(compile_context, a, b, c, output, alpha, beta, false, reshape_info, gemm_info.fp_mixed_precision(), gemm_info.activation_info());
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000154
155 // Tune kernel statically
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100156 CLScheduler::get().tune_kernel_static(*_mm_kernel);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000157}
158
Manuel Bottini2b84be52020-04-08 10:15:51 +0100159void CLGEMM::configure_reshaped_v1(const CLCompileContext &compile_context, const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta,
160 const GEMMInfo &gemm_info)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000161{
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000162 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
163 const unsigned int m = reinterpret_input_as_3d ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1);
164 const unsigned int n = b->info()->dimension(0);
165 const unsigned int k = a->info()->dimension(0);
166 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000167 const GPUTarget gpu_target = CLScheduler::get().target();
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000168 int mult_transpose1xW_width = 1;
169 int mult_interleave4x4_height = 1;
Gian Marco36a0a462018-01-12 10:21:40 +0000170
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000171 // Set the target for the kernels
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100172 _reshape_lhs_kernel->set_target(gpu_target);
173 _mm_kernel->set_target(gpu_target);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000174
Gian Marco Iodice750641d2018-05-08 12:01:57 +0100175 if(get_arch_from_target(gpu_target) == GPUTarget::BIFROST)
Gian Marco36a0a462018-01-12 10:21:40 +0000176 {
177 mult_transpose1xW_width = 4;
178 mult_interleave4x4_height = 2;
179 }
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000180
giuros018b6b4a92018-12-18 19:01:33 +0000181 GEMMRHSMatrixInfo rhs_info;
182 rhs_info.n0 = 16 / b->info()->element_size();
183 rhs_info.k0 = 1;
184 rhs_info.h0 = mult_transpose1xW_width;
185 rhs_info.interleave = false;
186 rhs_info.transpose = false;
Gian Marco36a0a462018-01-12 10:21:40 +0000187
giuros011c9efeb2019-01-11 14:04:43 +0000188 GEMMLHSMatrixInfo lhs_info;
189 lhs_info.m0 = 4;
190 lhs_info.k0 = 4;
191 lhs_info.v0 = mult_interleave4x4_height;
192 lhs_info.interleave = true;
193 lhs_info.transpose = true;
194
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100195 GEMMReshapeInfo reshape_info(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d, false, gemm_info.broadcast_bias());
Gian Marcob5311a62017-12-13 12:48:03 +0000196
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100197 const bool use_mm_b = (!_weights_manager || !_weights_manager->are_weights_managed(b));
198
199 // Manage intermediate buffers
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000200 _memory_group.manage(&_tmp_a);
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100201
202 if(!_reshape_b_only_on_first_run && use_mm_b)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +0100203 {
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000204 _memory_group.manage(&_tmp_b);
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100205 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100206
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000207 // Configure interleave kernel
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100208 _reshape_lhs_kernel->configure(compile_context, a, &_tmp_a, lhs_info, reinterpret_input_as_3d);
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100209
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000210 // Configure transpose kernel
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100211 ICLTensor *reshaped_rhs = &_tmp_b;
212 if(_weights_manager && _weights_manager->are_weights_managed(b))
213 {
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100214 _reshape_rhs_kernel_managed->configure(compile_context, b, rhs_info);
215 reshaped_rhs = utils::cast::polymorphic_downcast<ICLTensor *>(_weights_manager->acquire(b, _reshape_rhs_kernel_managed.get()));
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100216 }
217 else
218 {
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100219 _reshape_rhs_kernel->configure(compile_context, b, &_tmp_b, rhs_info);
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100220 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100221
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000222 // Configure and tune matrix multiply kernel
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100223 _mm_kernel->configure(compile_context, &_tmp_a, reshaped_rhs, c, output, alpha, beta, true, reshape_info, gemm_info.fp_mixed_precision(), gemm_info.activation_info());
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000224
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100225 CLScheduler::get().tune_kernel_static(*_mm_kernel);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000226
227 // Allocate intermediate tensors
228 _tmp_a.allocator()->allocate();
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100229
230 if(!_reshape_b_only_on_first_run && use_mm_b)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100231 {
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000232 _tmp_b.allocator()->allocate();
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100233 }
234}
235
Manuel Bottini2b84be52020-04-08 10:15:51 +0100236void CLGEMM::configure_reshaped_v2(const CLCompileContext &compile_context, const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta,
237 const GEMMInfo &gemm_info)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000238{
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000239 DataType data_type = a->info()->data_type();
240 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
241 const unsigned int m = reinterpret_input_as_3d ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1);
242 const unsigned int n = b->info()->dimension(0);
243 const unsigned int k = a->info()->dimension(0);
244 const unsigned int batch_size = reinterpret_input_as_3d ? a->info()->dimension(3) : a->info()->dimension(2);
245 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
246 const GPUTarget gpu_target = CLScheduler::get().target();
Gian Marco Iodicee16c8902019-06-14 16:11:10 +0100247 bool broadcast_bias = gemm_info.broadcast_bias();
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100248
249 GEMMKernelInfo kernel_info;
Gian Marco Iodice7026b302019-06-26 17:18:11 +0100250 kernel_info.m = m;
251 kernel_info.n = n;
252 kernel_info.k = k;
253 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
254 kernel_info.reinterpret_input_as_3d = false;
255 kernel_info.broadcast_bias = broadcast_bias;
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100256 kernel_info.activation_info = gemm_info.activation_info();
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000257
258 // Set the target for the kernels
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100259 _reshape_lhs_kernel->set_target(gpu_target);
260 _mm_kernel->set_target(gpu_target);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000261
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100262 const bool use_mm_b = (!_weights_manager || !_weights_manager->are_weights_managed(b));
263
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000264 // Manage intermediate buffers
265 _memory_group.manage(&_tmp_a);
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100266
267 if(!_reshape_b_only_on_first_run && use_mm_b)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000268 {
269 _memory_group.manage(&_tmp_b);
270 }
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100271
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000272 // _tmp_a and _tmp_b will be auto configured in _interleave_kernel and in _transpose_kernel
273
274 GEMMLHSMatrixInfo lhs_info{};
275 GEMMRHSMatrixInfo rhs_info{};
276
277 // Pick up the GEMM configuration
278 std::unique_ptr<ICLGEMMKernelConfiguration> gemm_config = CLGEMMReshapedKernelConfigurationFactory::create(gpu_target);
279 ARM_COMPUTE_ERROR_ON_NULLPTR(gemm_config.get());
280
281 // Configure lhs_info and rhs_info
282 std::tie(lhs_info, rhs_info) = gemm_config->configure(m, n, k, batch_size, data_type);
283
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100284 _reshape_lhs_kernel->configure(compile_context, a, &_tmp_a, lhs_info, gemm_info.reinterpret_input_as_3d());
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100285
286 ICLTensor *reshaped_rhs = &_tmp_b;
287 if(_weights_manager && _weights_manager->are_weights_managed(b))
288 {
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100289 _reshape_rhs_kernel_managed->configure(compile_context, b, rhs_info);
290 reshaped_rhs = utils::cast::polymorphic_downcast<ICLTensor *>(_weights_manager->acquire(b, _reshape_rhs_kernel_managed.get()));
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100291 }
292 else
293 {
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100294 _reshape_rhs_kernel->configure(compile_context, b, &_tmp_b, rhs_info);
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100295 }
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000296
297 // Configure and tune matrix multiply kernel
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100298 _mm_reshaped_kernel->configure(compile_context, &_tmp_a, reshaped_rhs, c, output, alpha, beta, lhs_info, rhs_info, kernel_info);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000299
300 // Allocate intermediate tensors
301 _tmp_a.allocator()->allocate();
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100302
303 if(!_reshape_b_only_on_first_run && use_mm_b)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000304 {
305 _tmp_b.allocator()->allocate();
306 }
307}
308
Manuel Bottini2b84be52020-04-08 10:15:51 +0100309void CLGEMM::configure_reshaped_only_rhs(const CLCompileContext &compile_context, const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta,
310 const GEMMInfo &gemm_info)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000311{
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000312 DataType data_type = a->info()->data_type();
313 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
314 const unsigned int m = reinterpret_input_as_3d ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1);
315 const unsigned int n = b->info()->dimension(0);
316 const unsigned int k = a->info()->dimension(0);
317 const unsigned int batch_size = reinterpret_input_as_3d ? a->info()->dimension(3) : a->info()->dimension(2);
318 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
319 const GPUTarget gpu_target = CLScheduler::get().target();
Georgios Pinitasb0f342e2019-05-21 13:32:43 +0100320 bool broadcast_bias = gemm_info.broadcast_bias();
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100321
322 GEMMKernelInfo kernel_info;
Gian Marco Iodice7026b302019-06-26 17:18:11 +0100323 kernel_info.m = m;
324 kernel_info.n = n;
325 kernel_info.k = k;
326 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
327 kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
328 kernel_info.broadcast_bias = broadcast_bias;
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100329 kernel_info.activation_info = gemm_info.activation_info();
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000330
331 // Set the target for the kernels
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100332 _mm_kernel->set_target(gpu_target);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000333
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100334 const bool use_mm_b = (!_weights_manager || !_weights_manager->are_weights_managed(b));
335
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000336 // Manage intermediate buffers
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100337 if(!_reshape_b_only_on_first_run && use_mm_b)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000338 {
339 _memory_group.manage(&_tmp_b);
340 }
341
342 GEMMLHSMatrixInfo lhs_info{};
343 GEMMRHSMatrixInfo rhs_info{};
344
345 // Pick up the GEMM configuration
346 std::unique_ptr<ICLGEMMKernelConfiguration> gemm_config = CLGEMMReshapedOnlyRHSKernelConfigurationFactory::create(gpu_target);
347 ARM_COMPUTE_ERROR_ON_NULLPTR(gemm_config.get());
348
349 // Configure lhs_info and rhs_info
Gian Marco Iodice9ae06d42020-10-22 16:37:12 +0100350 std::tie(lhs_info, rhs_info) = gemm_config->configure(m, n, k, batch_size, data_type);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000351
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100352 ICLTensor *reshaped_rhs = &_tmp_b;
353 if(_weights_manager && _weights_manager->are_weights_managed(b))
354 {
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100355 _reshape_rhs_kernel_managed->configure(compile_context, b, rhs_info);
356 reshaped_rhs = utils::cast::polymorphic_downcast<ICLTensor *>(_weights_manager->acquire(b, _reshape_rhs_kernel_managed.get()));
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100357 }
358 else
359 {
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100360 _reshape_rhs_kernel->configure(compile_context, b, &_tmp_b, rhs_info);
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100361 }
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000362
Gian Marco Iodice9ae06d42020-10-22 16:37:12 +0100363 // Configure two variants of CLGEMMMatrixMultiplyReshapedOnlyRHSKernel (has_pad_y = false/true)
364 // During the prepare stage we check the padding requirement for the lhs and dst tensors. If they do not have
365 // pad y, we dispatch CLGEMMMatrixMultiplyReshapedOnlyRHSKernel with has_pad_y = false
366
367 // Configure matrix multiply kernel with no y padding support
368 kernel_info.has_pad_y = false;
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100369 _mm_reshaped_only_rhs_kernel->configure(compile_context, a, reshaped_rhs, c, output, alpha, beta, lhs_info, rhs_info, kernel_info);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000370
Gian Marco Iodice9ae06d42020-10-22 16:37:12 +0100371 // Configure matrix multiply kernel with y padding support
372 kernel_info.has_pad_y = true;
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100373 _mm_reshaped_only_rhs_fallback_kernel->configure(compile_context, a, reshaped_rhs, c, output, alpha, beta, lhs_info, rhs_info, kernel_info);
Gian Marco Iodice9ae06d42020-10-22 16:37:12 +0100374
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100375 if(!_reshape_b_only_on_first_run && use_mm_b)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000376 {
377 _tmp_b.allocator()->allocate();
378 }
379}
380
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000381Status CLGEMM::validate_native_v1(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
Georgios Pinitas78c00902018-01-09 17:33:11 +0000382{
Gian Marco Iodice750641d2018-05-08 12:01:57 +0100383 ARM_COMPUTE_UNUSED(alpha);
Gian Marco Iodice215b4ea2018-06-28 16:29:29 +0100384 ARM_COMPUTE_UNUSED(output);
Gian Marco Iodice750641d2018-05-08 12:01:57 +0100385
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000386 // Get the GPU target
387 const GPUTarget gpu_target = CLScheduler::get().target();
388 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
389 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
390 const unsigned int n = b->dimension(0);
391 const unsigned int k = a->dimension(0);
392 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
Gian Marco Iodice750641d2018-05-08 12:01:57 +0100393
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100394 const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d, gemm_info.broadcast_bias());
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000395
396 // Validate matrix multiply
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100397 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyKernel::validate(a, b, c, output, alpha, beta,
398 false, reshape_info, gpu_target, gemm_info.fp_mixed_precision(), gemm_info.activation_info()));
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000399
400 return Status{};
401}
402
403Status CLGEMM::validate_reshaped_v1(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
404{
405 ARM_COMPUTE_UNUSED(alpha);
406 ARM_COMPUTE_UNUSED(output);
Gian Marco Iodice750641d2018-05-08 12:01:57 +0100407
408 TensorInfo tmp_a_info{};
409 TensorInfo tmp_b_info{};
Gian Marco Iodice750641d2018-05-08 12:01:57 +0100410
411 // Get the GPU target
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000412 const GPUTarget gpu_target = CLScheduler::get().target();
413 const unsigned int m = gemm_info.reinterpret_input_as_3d() ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000414 const unsigned int n = b->dimension(0);
415 const unsigned int k = a->dimension(0);
416 int mult_transpose1xW_width = 1;
417 int mult_interleave4x4_height = 1;
418 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
Gian Marco Iodice750641d2018-05-08 12:01:57 +0100419
420 if(get_arch_from_target(gpu_target) == GPUTarget::BIFROST)
421 {
422 mult_transpose1xW_width = 4;
423 mult_interleave4x4_height = 2;
424 }
425
giuros018b6b4a92018-12-18 19:01:33 +0000426 GEMMRHSMatrixInfo rhs_info;
427 rhs_info.n0 = 16 / b->element_size();
428 rhs_info.k0 = 1;
429 rhs_info.h0 = mult_transpose1xW_width;
430 rhs_info.interleave = false;
431 rhs_info.transpose = false;
432
giuros011c9efeb2019-01-11 14:04:43 +0000433 GEMMLHSMatrixInfo lhs_info;
434 lhs_info.m0 = 4;
435 lhs_info.k0 = 4;
436 lhs_info.v0 = mult_interleave4x4_height;
437 lhs_info.interleave = true;
438 lhs_info.transpose = true;
439
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100440 const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d, false, gemm_info.broadcast_bias());
Gian Marco Iodice750641d2018-05-08 12:01:57 +0100441
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000442 // Validate interleave kernel
443 auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_lhs_reshaped_shape(*a, lhs_info, gemm_info.reinterpret_input_as_3d())));
444 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeLHSMatrixKernel::validate(a, &tmp_a_info, lhs_info, gemm_info.reinterpret_input_as_3d()));
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000445
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000446 // Validate transpose kernel
447 auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
448 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeRHSMatrixKernel::validate(b, &tmp_b_info, rhs_info));
Michele Di Giorgioebc3a902018-11-16 16:04:25 +0000449
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000450 // Validate matrix multiply
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100451 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyKernel::validate(&tmp_a_info, &tmp_b_info, c, output, alpha, beta,
452 true, reshape_info, gpu_target, gemm_info.fp_mixed_precision(), gemm_info.activation_info()));
Gian Marco Iodice750641d2018-05-08 12:01:57 +0100453
Georgios Pinitas78c00902018-01-09 17:33:11 +0000454 return Status{};
455}
456
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000457Status CLGEMM::validate_reshaped(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000458{
459 ARM_COMPUTE_UNUSED(alpha);
460 ARM_COMPUTE_UNUSED(output);
461
462 TensorInfo tmp_a_info{};
463 TensorInfo tmp_b_info{};
464
465 // Get the GPU target
466 const GPUTarget gpu_target = CLScheduler::get().target();
467 DataType data_type = a->data_type();
468 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
469 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
470 const unsigned int n = b->dimension(0);
471 const unsigned int k = a->dimension(0);
472 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
473 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
Gian Marco Iodicee16c8902019-06-14 16:11:10 +0100474 const bool broadcast_bias = gemm_info.broadcast_bias();
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100475
476 GEMMKernelInfo kernel_info;
Gian Marco Iodice7026b302019-06-26 17:18:11 +0100477 kernel_info.m = m;
478 kernel_info.n = n;
479 kernel_info.k = k;
480 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
481 kernel_info.reinterpret_input_as_3d = false;
482 kernel_info.broadcast_bias = broadcast_bias;
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100483 kernel_info.activation_info = gemm_info.activation_info();
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000484
485 GEMMLHSMatrixInfo lhs_info;
486 GEMMRHSMatrixInfo rhs_info;
487
488 // Pick up the GEMM configuration
489 std::unique_ptr<ICLGEMMKernelConfiguration> gemm_config = CLGEMMReshapedKernelConfigurationFactory::create(gpu_target);
490 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(gemm_config.get());
491
492 // Configure lhs_info and rhs_info
493 std::tie(lhs_info, rhs_info) = gemm_config->configure(m, n, k, batch_size, data_type);
494
495 auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_lhs_reshaped_shape(*a, lhs_info, gemm_info.reinterpret_input_as_3d())));
496 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeLHSMatrixKernel::validate(a, &tmp_a_info, lhs_info, gemm_info.reinterpret_input_as_3d()));
497
498 auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
499 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeRHSMatrixKernel::validate(b, &tmp_b_info, rhs_info));
500
501 // Validate matrix multiply
Gian Marco Iodice7026b302019-06-26 17:18:11 +0100502 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyReshapedKernel::validate(&tmp_a_info, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info));
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000503
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000504 return Status{};
505}
506
507Status CLGEMM::validate_reshaped_only_rhs(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
508{
509 ARM_COMPUTE_UNUSED(alpha);
510 ARM_COMPUTE_UNUSED(output);
511
512 TensorInfo tmp_b_info{};
513
514 // Get the GPU target
Gian Marco Iodice7026b302019-06-26 17:18:11 +0100515 const GPUTarget gpu_target = CLScheduler::get().target();
516 const DataType data_type = a->data_type();
517 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
518 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
519 const unsigned int n = b->dimension(0);
520 const unsigned int k = a->dimension(0);
521 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
522 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
523 const bool broadcast_bias = gemm_info.broadcast_bias();
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100524
525 GEMMKernelInfo kernel_info;
Gian Marco Iodice7026b302019-06-26 17:18:11 +0100526 kernel_info.m = m;
527 kernel_info.n = n;
528 kernel_info.k = k;
529 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
530 kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
531 kernel_info.broadcast_bias = broadcast_bias;
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100532 kernel_info.activation_info = gemm_info.activation_info();
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000533
534 GEMMLHSMatrixInfo lhs_info;
535 GEMMRHSMatrixInfo rhs_info;
536
537 // Pick up the GEMM configuration
538 std::unique_ptr<ICLGEMMKernelConfiguration> gemm_config = CLGEMMReshapedOnlyRHSKernelConfigurationFactory::create(gpu_target);
539 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(gemm_config.get());
540
541 // Configure lhs_info and rhs_info
542 std::tie(lhs_info, rhs_info) = gemm_config->configure(m, n, k, batch_size, data_type);
543
544 auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
545 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeRHSMatrixKernel::validate(b, &tmp_b_info, rhs_info));
546
547 // Validate matrix multiply
Gian Marco Iodice9ae06d42020-10-22 16:37:12 +0100548 kernel_info.has_pad_y = false;
549 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyReshapedOnlyRHSKernel::validate(a, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info));
550
551 kernel_info.has_pad_y = true;
Gian Marco Iodice7026b302019-06-26 17:18:11 +0100552 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyReshapedOnlyRHSKernel::validate(a, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info));
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000553
554 return Status{};
555}
556
557void CLGEMM::configure(const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info)
558{
Manuel Bottini2b84be52020-04-08 10:15:51 +0100559 configure(CLKernelLibrary::get().get_compile_context(), a, b, c, output, alpha, beta, gemm_info);
560}
561
562void CLGEMM::configure(const CLCompileContext &compile_context, const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info)
563{
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000564 ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, output);
565
566 // Perform validation step
567 ARM_COMPUTE_ERROR_THROW_ON(validate(a->info(), b->info(), c != nullptr ? c->info() : nullptr, output->info(), alpha, beta, gemm_info));
568
569 // Check if we need to reshape the matrix B only on the first run
570 _reshape_b_only_on_first_run = gemm_info.reshape_b_only_on_first_run();
571 _is_prepared = gemm_info.retain_internal_weights();
572 _original_b = b;
Gian Marco Iodice9ae06d42020-10-22 16:37:12 +0100573 _lhs = a;
574 _dst = output;
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000575
576 // Get the GPU target
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000577 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
578 const unsigned int m = reinterpret_input_as_3d ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1);
579 const unsigned int n = b->info()->dimension(0);
580 const unsigned int k = a->info()->dimension(0);
Gian Marco Iodice026d0452020-08-28 13:52:12 +0100581 const unsigned int batch_size = reinterpret_input_as_3d ? a->info()->dimension(3) : a->info()->dimension(2);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000582
583 // Select GEMMType
Gian Marco Iodice026d0452020-08-28 13:52:12 +0100584 _gemm_kernel_type = select_gemm_kernel(m, n, k, batch_size, a->info()->data_type(), _reshape_b_only_on_first_run);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000585
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100586 const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr);
Gian Marco Iodicee16c8902019-06-14 16:11:10 +0100587
588 const ICLTensor *c_to_use = fuse_add_c ? c : nullptr;
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000589
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000590 switch(_gemm_kernel_type)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000591 {
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000592 case CLGEMMKernelType::NATIVE_V1:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000593 {
Manuel Bottini2b84be52020-04-08 10:15:51 +0100594 configure_native_v1(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000595 break;
596 }
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000597 case CLGEMMKernelType::RESHAPED_V1:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000598 {
Manuel Bottini2b84be52020-04-08 10:15:51 +0100599 configure_reshaped_v1(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000600 break;
601 }
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000602 case CLGEMMKernelType::RESHAPED:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000603 {
Manuel Bottini2b84be52020-04-08 10:15:51 +0100604 configure_reshaped_v2(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000605 break;
606 }
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000607 case CLGEMMKernelType::RESHAPED_ONLY_RHS:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000608 {
Manuel Bottini2b84be52020-04-08 10:15:51 +0100609 configure_reshaped_only_rhs(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000610 break;
611 }
612 default:
613 {
614 ARM_COMPUTE_ERROR("GEMMType not supported");
615 }
616 }
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000617}
618
619Status CLGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
620{
621 // Get the GPU target
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000622 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
623 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
624 const unsigned int n = b->dimension(0);
625 const unsigned int k = a->dimension(0);
Gian Marco Iodice026d0452020-08-28 13:52:12 +0100626 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000627
628 // Select GEMMType
Gian Marco Iodice026d0452020-08-28 13:52:12 +0100629 CLGEMMKernelType gemm_kernel_type = select_gemm_kernel(m, n, k, batch_size, a->data_type(), gemm_info.reshape_b_only_on_first_run());
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000630
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100631 const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr);
Gian Marco Iodicee16c8902019-06-14 16:11:10 +0100632
633 const ITensorInfo *c_to_use = fuse_add_c ? c : nullptr;
634
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000635 switch(gemm_kernel_type)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000636 {
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000637 case CLGEMMKernelType::NATIVE_V1:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000638 {
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000639 ARM_COMPUTE_RETURN_ON_ERROR(validate_native_v1(a, b, c_to_use, output, alpha, beta, gemm_info));
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000640 break;
641 }
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000642 case CLGEMMKernelType::RESHAPED_V1:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000643 {
Gian Marco Iodicee16c8902019-06-14 16:11:10 +0100644 ARM_COMPUTE_RETURN_ON_ERROR(validate_reshaped_v1(a, b, c_to_use, output, alpha, beta, gemm_info));
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000645 break;
646 }
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000647 case CLGEMMKernelType::RESHAPED:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000648 {
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000649 ARM_COMPUTE_RETURN_ON_ERROR(validate_reshaped(a, b, c_to_use, output, alpha, beta, gemm_info));
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000650 break;
651 }
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000652 case CLGEMMKernelType::RESHAPED_ONLY_RHS:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000653 {
Gian Marco Iodicee16c8902019-06-14 16:11:10 +0100654 ARM_COMPUTE_RETURN_ON_ERROR(validate_reshaped_only_rhs(a, b, c_to_use, output, alpha, beta, gemm_info));
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000655 break;
656 }
657 default:
658 {
659 ARM_COMPUTE_RETURN_ERROR_MSG("GEMMType not supported");
660 }
661 }
662
663 return Status{};
664}
665
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100666void CLGEMM::run()
667{
Georgios Pinitase0437672018-05-02 14:07:55 +0100668 prepare();
Georgios Pinitasda953f22019-04-02 17:27:03 +0100669 MemoryGroupResourceScope scope_mg(_memory_group);
Georgios Pinitas8a94e7c2017-09-15 19:06:47 +0100670
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100671 // Run matrix multiply kernel
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000672 switch(_gemm_kernel_type)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000673 {
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000674 case CLGEMMKernelType::NATIVE_V1:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000675 {
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100676 CLScheduler::get().enqueue(*_mm_kernel, true);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000677 break;
678 }
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000679 case CLGEMMKernelType::RESHAPED_V1:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000680 {
681 // Run interleave kernel
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100682 CLScheduler::get().enqueue(*_reshape_lhs_kernel, false);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000683
684 if(!_reshape_b_only_on_first_run)
685 {
686 // Run transpose kernel
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100687 if(_weights_manager && _weights_manager->are_weights_managed(_original_b))
688 {
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100689 _weights_manager->run(_original_b, _reshape_rhs_kernel_managed.get());
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100690 }
691 else
692 {
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100693 CLScheduler::get().enqueue(*_reshape_rhs_kernel, false);
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100694 }
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000695 }
696
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100697 CLScheduler::get().enqueue(*_mm_kernel, true);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000698 break;
699 }
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000700 case CLGEMMKernelType::RESHAPED:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000701 {
702 // Run interleave kernel
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100703 CLScheduler::get().enqueue(*_reshape_lhs_kernel, false);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000704
705 if(!_reshape_b_only_on_first_run)
706 {
707 // Run transpose kernel
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100708 if(_weights_manager && _weights_manager->are_weights_managed(_original_b))
709 {
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100710 _weights_manager->run(_original_b, _reshape_rhs_kernel_managed.get());
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100711 }
712 else
713 {
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100714 CLScheduler::get().enqueue(*_reshape_rhs_kernel, false);
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100715 }
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000716 }
717
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100718 CLScheduler::get().enqueue(*_mm_reshaped_kernel, true);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000719 break;
720 }
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000721 case CLGEMMKernelType::RESHAPED_ONLY_RHS:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000722 {
723 if(!_reshape_b_only_on_first_run)
724 {
725 // Run transpose kernel
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100726 if(_weights_manager && _weights_manager->are_weights_managed(_original_b))
727 {
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100728 _weights_manager->run(_original_b, _reshape_rhs_kernel_managed.get());
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100729 }
730 else
731 {
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100732 CLScheduler::get().enqueue(*_reshape_rhs_kernel, false);
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100733 }
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000734 }
SiCong Li4d2365d2020-11-08 22:11:01 +0000735 // In case of RESHAPED_ONLY_RHS, we need to check the padding requirement
736 // Check if the lhs or dst tensors have padding
737 const unsigned int cross_plane_pad_lhs = _lhs->info()->padding().top + _lhs->info()->padding().bottom;
738 const unsigned int cross_plane_pad_dst = _dst->info()->padding().top + _dst->info()->padding().bottom;
739
740 bool has_pad_y = (cross_plane_pad_lhs != 0) || (cross_plane_pad_dst != 0);
741 if(has_pad_y)
Gian Marco Iodice9ae06d42020-10-22 16:37:12 +0100742 {
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100743 CLScheduler::get().enqueue(*_mm_reshaped_only_rhs_fallback_kernel, true);
Gian Marco Iodice9ae06d42020-10-22 16:37:12 +0100744 }
745 else
746 {
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100747 CLScheduler::get().enqueue(*_mm_reshaped_only_rhs_kernel, true);
Gian Marco Iodice9ae06d42020-10-22 16:37:12 +0100748 }
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000749 break;
750 }
751 default:
752 {
753 ARM_COMPUTE_ERROR("GEMMType not supported");
754 }
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000755 }
Georgios Pinitase0437672018-05-02 14:07:55 +0100756}
Georgios Pinitas82b51482018-04-24 15:14:12 +0100757
Georgios Pinitase0437672018-05-02 14:07:55 +0100758void CLGEMM::prepare()
759{
760 if(!_is_prepared)
761 {
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000762 if(_gemm_kernel_type != CLGEMMKernelType::NATIVE_V1 && _reshape_b_only_on_first_run)
Georgios Pinitase0437672018-05-02 14:07:55 +0100763 {
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100764 if(_weights_manager && _weights_manager->are_weights_managed(_original_b))
765 {
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100766 _weights_manager->run(_original_b, _reshape_rhs_kernel_managed.get());
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100767 }
768 else
769 {
770 // Run transpose kernel and mark original weights tensor as unused
771 _tmp_b.allocator()->allocate();
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100772 CLScheduler::get().enqueue(*_reshape_rhs_kernel, false);
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100773 _original_b->mark_as_unused();
774 }
Georgios Pinitase0437672018-05-02 14:07:55 +0100775 }
776 CLScheduler::get().queue().finish();
777 _is_prepared = true;
778 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100779}
giuros011c9efeb2019-01-11 14:04:43 +0000780} // namespace arm_compute