blob: 0151485849dccfebe94ccee53d2fc5756fe3dc9c [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Michele Di Giorgiod9eaf612020-07-08 11:12:57 +01002 * Copyright (c) 2017-2020 Arm Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/runtime/CL/functions/CLGEMM.h"
25
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +010026#include "arm_compute/core/CL/CLKernelLibrary.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010027#include "arm_compute/core/CL/ICLTensor.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010028#include "arm_compute/core/Error.h"
Gian Marco Iodice750641d2018-05-08 12:01:57 +010029#include "arm_compute/core/GPUTarget.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010030#include "arm_compute/core/Helpers.h"
Gian Marco Iodice7026b302019-06-26 17:18:11 +010031#include "arm_compute/core/KernelDescriptors.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010032#include "arm_compute/core/TensorInfo.h"
33#include "arm_compute/core/Types.h"
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +010034#include "arm_compute/core/Utils.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010035#include "arm_compute/core/Validate.h"
Gian Marco Iodice750641d2018-05-08 12:01:57 +010036#include "arm_compute/core/utils/misc/ShapeCalculator.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010037#include "arm_compute/runtime/CL/CLScheduler.h"
38#include "arm_compute/runtime/ITensorAllocator.h"
Sang-Hoon Park68dd25f2020-10-19 16:00:11 +010039#include "src/core/CL/ICLGEMMKernelConfiguration.h"
40#include "src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfiguration.h"
41#include "src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfiguration.h"
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +010042#include "src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.h"
43#include "src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.h"
44#include "src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.h"
45#include "src/core/CL/kernels/CLGEMMReshapeLHSMatrixKernel.h"
46#include "src/core/CL/kernels/CLGEMMReshapeRHSMatrixKernel.h"
Sang-Hoon Park68dd25f2020-10-19 16:00:11 +010047#include "src/core/helpers/AutoConfiguration.h"
48#include "src/core/utils/helpers/float_ops.h"
49#include "src/runtime/CL/gemm/CLGEMMKernelSelection.h"
50#include "support/Cast.h"
51
52#include "support/MemorySupport.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010053
giuros011c9efeb2019-01-11 14:04:43 +000054namespace arm_compute
55{
Gian Marco Iodice750641d2018-05-08 12:01:57 +010056using namespace arm_compute::misc::shape_calculator;
Gian Marco Iodice90313eb2019-01-16 15:40:25 +000057using namespace arm_compute::cl_gemm;
Michalis Spyroub27e13a2019-09-27 11:04:27 +010058using namespace arm_compute::utils::cast;
Anthony Barbier6ff3b192017-09-04 18:44:23 +010059
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +010060namespace weights_transformations
61{
62CLGEMMReshapeRHSMatrixKernelManaged::CLGEMMReshapeRHSMatrixKernelManaged()
63 : _kernel(support::cpp14::make_unique<CLGEMMReshapeRHSMatrixKernel>())
64{
65}
66
67CLGEMMReshapeRHSMatrixKernelManaged::~CLGEMMReshapeRHSMatrixKernelManaged() = default;
68
69void CLGEMMReshapeRHSMatrixKernelManaged::run()
70{
71 _output.allocator()->allocate();
72 CLScheduler::get().enqueue(*_kernel, false);
73 _reshape_run = true;
74}
75
76void CLGEMMReshapeRHSMatrixKernelManaged::release()
77{
78 _output.allocator()->free();
79}
80
81ICLTensor *CLGEMMReshapeRHSMatrixKernelManaged::get_weights()
82{
83 return &_output;
84}
85
86uint32_t CLGEMMReshapeRHSMatrixKernelManaged::uid()
87{
88 return _uid;
89}
90
91void CLGEMMReshapeRHSMatrixKernelManaged::configure(const ICLTensor *input, GEMMRHSMatrixInfo info)
92{
93 configure(CLKernelLibrary::get().get_compile_context(), input, info);
94}
95
96void CLGEMMReshapeRHSMatrixKernelManaged::configure(const CLCompileContext &compile_context, const ICLTensor *input, GEMMRHSMatrixInfo info)
97{
98 _kernel->configure(compile_context, input, &_output, info);
99}
100} // namespace weights_transformations
101
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100102CLGEMM::CLGEMM(std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager)
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100103 : _memory_group(std::move(memory_manager)),
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100104 _weights_manager(weights_manager),
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100105 _mm_kernel(support::cpp14::make_unique<CLGEMMMatrixMultiplyKernel>()),
106 _reshape_lhs_kernel(support::cpp14::make_unique<CLGEMMReshapeLHSMatrixKernel>()),
107 _reshape_rhs_kernel(support::cpp14::make_unique<CLGEMMReshapeRHSMatrixKernel>()),
108 _reshape_rhs_kernel_managed(support::cpp14::make_unique<weights_transformations::CLGEMMReshapeRHSMatrixKernelManaged>()),
109 _mm_reshaped_kernel(support::cpp14::make_unique<CLGEMMMatrixMultiplyReshapedKernel>()),
110 _mm_reshaped_only_rhs_kernel(support::cpp14::make_unique<CLGEMMMatrixMultiplyReshapedOnlyRHSKernel>()),
111 _mm_reshaped_only_rhs_fallback_kernel(support::cpp14::make_unique<CLGEMMMatrixMultiplyReshapedOnlyRHSKernel>()),
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100112 _tmp_a(),
113 _tmp_b(),
114 _original_b(nullptr),
Gian Marco Iodice9ae06d42020-10-22 16:37:12 +0100115 _lhs(nullptr),
116 _dst(nullptr),
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +0100117 _reshape_b_only_on_first_run(false),
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000118 _is_prepared(false),
Gian Marco Iodice9ae06d42020-10-22 16:37:12 +0100119 _has_pad_y(false),
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000120 _gemm_kernel_type(CLGEMMKernelType::NATIVE_V1)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100121{
122}
123
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100124CLGEMM::~CLGEMM() = default;
125
Gian Marco Iodice026d0452020-08-28 13:52:12 +0100126CLGEMMKernelType CLGEMM::select_gemm_kernel(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type, bool reshape_b_only_on_first_run)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100127{
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000128 std::unique_ptr<ICLGEMMKernelSelection> gemm_kernel = CLGEMMKernelSelectionFactory::create(CLScheduler::get().target());
129 ARM_COMPUTE_ERROR_ON_NULLPTR(gemm_kernel.get());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100130
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000131 CLGEMMKernelSelectionParams params;
132 params.m = m;
133 params.n = n;
134 params.k = k;
Gian Marco Iodice026d0452020-08-28 13:52:12 +0100135 params.b = b;
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000136 params.is_rhs_constant = reshape_b_only_on_first_run;
137 params.data_type = data_type;
Gian Marco Iodice05639f62019-09-24 12:05:06 +0100138
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000139 return gemm_kernel->select_kernel(params);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000140}
141
Manuel Bottini2b84be52020-04-08 10:15:51 +0100142void CLGEMM::configure_native_v1(const CLCompileContext &compile_context, const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta,
143 const GEMMInfo &gemm_info)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000144{
145 const unsigned int m = gemm_info.reinterpret_input_as_3d() ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1);
146 const unsigned int n = b->info()->dimension(0);
147 const unsigned int k = a->info()->dimension(0);
148 const GPUTarget gpu_target = CLScheduler::get().target();
Gian Marco36a0a462018-01-12 10:21:40 +0000149
150 // Set the target for the kernels
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100151 _mm_kernel->set_target(gpu_target);
Gian Marco36a0a462018-01-12 10:21:40 +0000152
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100153 GEMMReshapeInfo reshape_info(m, n, k, 1, 1, gemm_info.depth_output_gemm3d(), gemm_info.reinterpret_input_as_3d(), gemm_info.broadcast_bias());
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000154
155 // Configure and tune matrix multiply kernel
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100156 _mm_kernel->configure(compile_context, a, b, c, output, alpha, beta, false, reshape_info, gemm_info.fp_mixed_precision(), gemm_info.activation_info());
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000157
158 // Tune kernel statically
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100159 CLScheduler::get().tune_kernel_static(*_mm_kernel);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000160}
161
Manuel Bottini2b84be52020-04-08 10:15:51 +0100162void CLGEMM::configure_reshaped_v1(const CLCompileContext &compile_context, const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta,
163 const GEMMInfo &gemm_info)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000164{
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000165 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
166 const unsigned int m = reinterpret_input_as_3d ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1);
167 const unsigned int n = b->info()->dimension(0);
168 const unsigned int k = a->info()->dimension(0);
169 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000170 const GPUTarget gpu_target = CLScheduler::get().target();
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000171 int mult_transpose1xW_width = 1;
172 int mult_interleave4x4_height = 1;
Gian Marco36a0a462018-01-12 10:21:40 +0000173
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000174 // Set the target for the kernels
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100175 _reshape_lhs_kernel->set_target(gpu_target);
176 _mm_kernel->set_target(gpu_target);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000177
Gian Marco Iodice750641d2018-05-08 12:01:57 +0100178 if(get_arch_from_target(gpu_target) == GPUTarget::BIFROST)
Gian Marco36a0a462018-01-12 10:21:40 +0000179 {
180 mult_transpose1xW_width = 4;
181 mult_interleave4x4_height = 2;
182 }
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000183
giuros018b6b4a92018-12-18 19:01:33 +0000184 GEMMRHSMatrixInfo rhs_info;
185 rhs_info.n0 = 16 / b->info()->element_size();
186 rhs_info.k0 = 1;
187 rhs_info.h0 = mult_transpose1xW_width;
188 rhs_info.interleave = false;
189 rhs_info.transpose = false;
Gian Marco36a0a462018-01-12 10:21:40 +0000190
giuros011c9efeb2019-01-11 14:04:43 +0000191 GEMMLHSMatrixInfo lhs_info;
192 lhs_info.m0 = 4;
193 lhs_info.k0 = 4;
194 lhs_info.v0 = mult_interleave4x4_height;
195 lhs_info.interleave = true;
196 lhs_info.transpose = true;
197
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100198 GEMMReshapeInfo reshape_info(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d, false, gemm_info.broadcast_bias());
Gian Marcob5311a62017-12-13 12:48:03 +0000199
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100200 const bool use_mm_b = (!_weights_manager || !_weights_manager->are_weights_managed(b));
201
202 // Manage intermediate buffers
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000203 _memory_group.manage(&_tmp_a);
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100204
205 if(!_reshape_b_only_on_first_run && use_mm_b)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +0100206 {
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000207 _memory_group.manage(&_tmp_b);
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100208 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100209
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000210 // Configure interleave kernel
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100211 _reshape_lhs_kernel->configure(compile_context, a, &_tmp_a, lhs_info, reinterpret_input_as_3d);
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100212
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000213 // Configure transpose kernel
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100214 ICLTensor *reshaped_rhs = &_tmp_b;
215 if(_weights_manager && _weights_manager->are_weights_managed(b))
216 {
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100217 _reshape_rhs_kernel_managed->configure(compile_context, b, rhs_info);
218 reshaped_rhs = utils::cast::polymorphic_downcast<ICLTensor *>(_weights_manager->acquire(b, _reshape_rhs_kernel_managed.get()));
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100219 }
220 else
221 {
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100222 _reshape_rhs_kernel->configure(compile_context, b, &_tmp_b, rhs_info);
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100223 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100224
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000225 // Configure and tune matrix multiply kernel
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100226 _mm_kernel->configure(compile_context, &_tmp_a, reshaped_rhs, c, output, alpha, beta, true, reshape_info, gemm_info.fp_mixed_precision(), gemm_info.activation_info());
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000227
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100228 CLScheduler::get().tune_kernel_static(*_mm_kernel);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000229
230 // Allocate intermediate tensors
231 _tmp_a.allocator()->allocate();
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100232
233 if(!_reshape_b_only_on_first_run && use_mm_b)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100234 {
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000235 _tmp_b.allocator()->allocate();
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100236 }
237}
238
Manuel Bottini2b84be52020-04-08 10:15:51 +0100239void CLGEMM::configure_reshaped_v2(const CLCompileContext &compile_context, const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta,
240 const GEMMInfo &gemm_info)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000241{
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000242 DataType data_type = a->info()->data_type();
243 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
244 const unsigned int m = reinterpret_input_as_3d ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1);
245 const unsigned int n = b->info()->dimension(0);
246 const unsigned int k = a->info()->dimension(0);
247 const unsigned int batch_size = reinterpret_input_as_3d ? a->info()->dimension(3) : a->info()->dimension(2);
248 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
249 const GPUTarget gpu_target = CLScheduler::get().target();
Gian Marco Iodicee16c8902019-06-14 16:11:10 +0100250 bool broadcast_bias = gemm_info.broadcast_bias();
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100251
252 GEMMKernelInfo kernel_info;
Gian Marco Iodice7026b302019-06-26 17:18:11 +0100253 kernel_info.m = m;
254 kernel_info.n = n;
255 kernel_info.k = k;
256 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
257 kernel_info.reinterpret_input_as_3d = false;
258 kernel_info.broadcast_bias = broadcast_bias;
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100259 kernel_info.activation_info = gemm_info.activation_info();
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000260
261 // Set the target for the kernels
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100262 _reshape_lhs_kernel->set_target(gpu_target);
263 _mm_kernel->set_target(gpu_target);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000264
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100265 const bool use_mm_b = (!_weights_manager || !_weights_manager->are_weights_managed(b));
266
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000267 // Manage intermediate buffers
268 _memory_group.manage(&_tmp_a);
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100269
270 if(!_reshape_b_only_on_first_run && use_mm_b)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000271 {
272 _memory_group.manage(&_tmp_b);
273 }
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100274
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000275 // _tmp_a and _tmp_b will be auto configured in _interleave_kernel and in _transpose_kernel
276
277 GEMMLHSMatrixInfo lhs_info{};
278 GEMMRHSMatrixInfo rhs_info{};
279
280 // Pick up the GEMM configuration
281 std::unique_ptr<ICLGEMMKernelConfiguration> gemm_config = CLGEMMReshapedKernelConfigurationFactory::create(gpu_target);
282 ARM_COMPUTE_ERROR_ON_NULLPTR(gemm_config.get());
283
284 // Configure lhs_info and rhs_info
285 std::tie(lhs_info, rhs_info) = gemm_config->configure(m, n, k, batch_size, data_type);
286
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100287 _reshape_lhs_kernel->configure(compile_context, a, &_tmp_a, lhs_info, gemm_info.reinterpret_input_as_3d());
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100288
289 ICLTensor *reshaped_rhs = &_tmp_b;
290 if(_weights_manager && _weights_manager->are_weights_managed(b))
291 {
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100292 _reshape_rhs_kernel_managed->configure(compile_context, b, rhs_info);
293 reshaped_rhs = utils::cast::polymorphic_downcast<ICLTensor *>(_weights_manager->acquire(b, _reshape_rhs_kernel_managed.get()));
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100294 }
295 else
296 {
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100297 _reshape_rhs_kernel->configure(compile_context, b, &_tmp_b, rhs_info);
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100298 }
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000299
300 // Configure and tune matrix multiply kernel
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100301 _mm_reshaped_kernel->configure(compile_context, &_tmp_a, reshaped_rhs, c, output, alpha, beta, lhs_info, rhs_info, kernel_info);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000302
303 // Allocate intermediate tensors
304 _tmp_a.allocator()->allocate();
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100305
306 if(!_reshape_b_only_on_first_run && use_mm_b)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000307 {
308 _tmp_b.allocator()->allocate();
309 }
310}
311
Manuel Bottini2b84be52020-04-08 10:15:51 +0100312void CLGEMM::configure_reshaped_only_rhs(const CLCompileContext &compile_context, const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta,
313 const GEMMInfo &gemm_info)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000314{
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000315 DataType data_type = a->info()->data_type();
316 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
317 const unsigned int m = reinterpret_input_as_3d ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1);
318 const unsigned int n = b->info()->dimension(0);
319 const unsigned int k = a->info()->dimension(0);
320 const unsigned int batch_size = reinterpret_input_as_3d ? a->info()->dimension(3) : a->info()->dimension(2);
321 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
322 const GPUTarget gpu_target = CLScheduler::get().target();
Georgios Pinitasb0f342e2019-05-21 13:32:43 +0100323 bool broadcast_bias = gemm_info.broadcast_bias();
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100324
325 GEMMKernelInfo kernel_info;
Gian Marco Iodice7026b302019-06-26 17:18:11 +0100326 kernel_info.m = m;
327 kernel_info.n = n;
328 kernel_info.k = k;
329 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
330 kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
331 kernel_info.broadcast_bias = broadcast_bias;
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100332 kernel_info.activation_info = gemm_info.activation_info();
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000333
334 // Set the target for the kernels
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100335 _mm_kernel->set_target(gpu_target);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000336
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100337 const bool use_mm_b = (!_weights_manager || !_weights_manager->are_weights_managed(b));
338
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000339 // Manage intermediate buffers
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100340 if(!_reshape_b_only_on_first_run && use_mm_b)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000341 {
342 _memory_group.manage(&_tmp_b);
343 }
344
345 GEMMLHSMatrixInfo lhs_info{};
346 GEMMRHSMatrixInfo rhs_info{};
347
348 // Pick up the GEMM configuration
349 std::unique_ptr<ICLGEMMKernelConfiguration> gemm_config = CLGEMMReshapedOnlyRHSKernelConfigurationFactory::create(gpu_target);
350 ARM_COMPUTE_ERROR_ON_NULLPTR(gemm_config.get());
351
352 // Configure lhs_info and rhs_info
Gian Marco Iodice9ae06d42020-10-22 16:37:12 +0100353 std::tie(lhs_info, rhs_info) = gemm_config->configure(m, n, k, batch_size, data_type);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000354
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100355 ICLTensor *reshaped_rhs = &_tmp_b;
356 if(_weights_manager && _weights_manager->are_weights_managed(b))
357 {
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100358 _reshape_rhs_kernel_managed->configure(compile_context, b, rhs_info);
359 reshaped_rhs = utils::cast::polymorphic_downcast<ICLTensor *>(_weights_manager->acquire(b, _reshape_rhs_kernel_managed.get()));
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100360 }
361 else
362 {
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100363 _reshape_rhs_kernel->configure(compile_context, b, &_tmp_b, rhs_info);
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100364 }
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000365
Gian Marco Iodice9ae06d42020-10-22 16:37:12 +0100366 // Configure two variants of CLGEMMMatrixMultiplyReshapedOnlyRHSKernel (has_pad_y = false/true)
367 // During the prepare stage we check the padding requirement for the lhs and dst tensors. If they do not have
368 // pad y, we dispatch CLGEMMMatrixMultiplyReshapedOnlyRHSKernel with has_pad_y = false
369
370 // Configure matrix multiply kernel with no y padding support
371 kernel_info.has_pad_y = false;
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100372 _mm_reshaped_only_rhs_kernel->configure(compile_context, a, reshaped_rhs, c, output, alpha, beta, lhs_info, rhs_info, kernel_info);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000373
Gian Marco Iodice9ae06d42020-10-22 16:37:12 +0100374 // Configure matrix multiply kernel with y padding support
375 kernel_info.has_pad_y = true;
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100376 _mm_reshaped_only_rhs_fallback_kernel->configure(compile_context, a, reshaped_rhs, c, output, alpha, beta, lhs_info, rhs_info, kernel_info);
Gian Marco Iodice9ae06d42020-10-22 16:37:12 +0100377
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100378 if(!_reshape_b_only_on_first_run && use_mm_b)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000379 {
380 _tmp_b.allocator()->allocate();
381 }
382}
383
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000384Status CLGEMM::validate_native_v1(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
Georgios Pinitas78c00902018-01-09 17:33:11 +0000385{
Gian Marco Iodice750641d2018-05-08 12:01:57 +0100386 ARM_COMPUTE_UNUSED(alpha);
Gian Marco Iodice215b4ea2018-06-28 16:29:29 +0100387 ARM_COMPUTE_UNUSED(output);
Gian Marco Iodice750641d2018-05-08 12:01:57 +0100388
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000389 // Get the GPU target
390 const GPUTarget gpu_target = CLScheduler::get().target();
391 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
392 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
393 const unsigned int n = b->dimension(0);
394 const unsigned int k = a->dimension(0);
395 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
Gian Marco Iodice750641d2018-05-08 12:01:57 +0100396
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100397 const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d, gemm_info.broadcast_bias());
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000398
399 // Validate matrix multiply
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100400 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyKernel::validate(a, b, c, output, alpha, beta,
401 false, reshape_info, gpu_target, gemm_info.fp_mixed_precision(), gemm_info.activation_info()));
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000402
403 return Status{};
404}
405
406Status CLGEMM::validate_reshaped_v1(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
407{
408 ARM_COMPUTE_UNUSED(alpha);
409 ARM_COMPUTE_UNUSED(output);
Gian Marco Iodice750641d2018-05-08 12:01:57 +0100410
411 TensorInfo tmp_a_info{};
412 TensorInfo tmp_b_info{};
Gian Marco Iodice750641d2018-05-08 12:01:57 +0100413
414 // Get the GPU target
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000415 const GPUTarget gpu_target = CLScheduler::get().target();
416 const unsigned int m = gemm_info.reinterpret_input_as_3d() ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000417 const unsigned int n = b->dimension(0);
418 const unsigned int k = a->dimension(0);
419 int mult_transpose1xW_width = 1;
420 int mult_interleave4x4_height = 1;
421 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
Gian Marco Iodice750641d2018-05-08 12:01:57 +0100422
423 if(get_arch_from_target(gpu_target) == GPUTarget::BIFROST)
424 {
425 mult_transpose1xW_width = 4;
426 mult_interleave4x4_height = 2;
427 }
428
giuros018b6b4a92018-12-18 19:01:33 +0000429 GEMMRHSMatrixInfo rhs_info;
430 rhs_info.n0 = 16 / b->element_size();
431 rhs_info.k0 = 1;
432 rhs_info.h0 = mult_transpose1xW_width;
433 rhs_info.interleave = false;
434 rhs_info.transpose = false;
435
giuros011c9efeb2019-01-11 14:04:43 +0000436 GEMMLHSMatrixInfo lhs_info;
437 lhs_info.m0 = 4;
438 lhs_info.k0 = 4;
439 lhs_info.v0 = mult_interleave4x4_height;
440 lhs_info.interleave = true;
441 lhs_info.transpose = true;
442
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100443 const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d, false, gemm_info.broadcast_bias());
Gian Marco Iodice750641d2018-05-08 12:01:57 +0100444
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000445 // Validate interleave kernel
446 auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_lhs_reshaped_shape(*a, lhs_info, gemm_info.reinterpret_input_as_3d())));
447 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeLHSMatrixKernel::validate(a, &tmp_a_info, lhs_info, gemm_info.reinterpret_input_as_3d()));
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000448
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000449 // Validate transpose kernel
450 auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
451 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeRHSMatrixKernel::validate(b, &tmp_b_info, rhs_info));
Michele Di Giorgioebc3a902018-11-16 16:04:25 +0000452
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000453 // Validate matrix multiply
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100454 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyKernel::validate(&tmp_a_info, &tmp_b_info, c, output, alpha, beta,
455 true, reshape_info, gpu_target, gemm_info.fp_mixed_precision(), gemm_info.activation_info()));
Gian Marco Iodice750641d2018-05-08 12:01:57 +0100456
Georgios Pinitas78c00902018-01-09 17:33:11 +0000457 return Status{};
458}
459
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000460Status CLGEMM::validate_reshaped(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000461{
462 ARM_COMPUTE_UNUSED(alpha);
463 ARM_COMPUTE_UNUSED(output);
464
465 TensorInfo tmp_a_info{};
466 TensorInfo tmp_b_info{};
467
468 // Get the GPU target
469 const GPUTarget gpu_target = CLScheduler::get().target();
470 DataType data_type = a->data_type();
471 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
472 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
473 const unsigned int n = b->dimension(0);
474 const unsigned int k = a->dimension(0);
475 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
476 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
Gian Marco Iodicee16c8902019-06-14 16:11:10 +0100477 const bool broadcast_bias = gemm_info.broadcast_bias();
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100478
479 GEMMKernelInfo kernel_info;
Gian Marco Iodice7026b302019-06-26 17:18:11 +0100480 kernel_info.m = m;
481 kernel_info.n = n;
482 kernel_info.k = k;
483 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
484 kernel_info.reinterpret_input_as_3d = false;
485 kernel_info.broadcast_bias = broadcast_bias;
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100486 kernel_info.activation_info = gemm_info.activation_info();
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000487
488 GEMMLHSMatrixInfo lhs_info;
489 GEMMRHSMatrixInfo rhs_info;
490
491 // Pick up the GEMM configuration
492 std::unique_ptr<ICLGEMMKernelConfiguration> gemm_config = CLGEMMReshapedKernelConfigurationFactory::create(gpu_target);
493 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(gemm_config.get());
494
495 // Configure lhs_info and rhs_info
496 std::tie(lhs_info, rhs_info) = gemm_config->configure(m, n, k, batch_size, data_type);
497
498 auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_lhs_reshaped_shape(*a, lhs_info, gemm_info.reinterpret_input_as_3d())));
499 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeLHSMatrixKernel::validate(a, &tmp_a_info, lhs_info, gemm_info.reinterpret_input_as_3d()));
500
501 auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
502 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeRHSMatrixKernel::validate(b, &tmp_b_info, rhs_info));
503
504 // Validate matrix multiply
Gian Marco Iodice7026b302019-06-26 17:18:11 +0100505 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyReshapedKernel::validate(&tmp_a_info, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info));
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000506
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000507 return Status{};
508}
509
510Status CLGEMM::validate_reshaped_only_rhs(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
511{
512 ARM_COMPUTE_UNUSED(alpha);
513 ARM_COMPUTE_UNUSED(output);
514
515 TensorInfo tmp_b_info{};
516
517 // Get the GPU target
Gian Marco Iodice7026b302019-06-26 17:18:11 +0100518 const GPUTarget gpu_target = CLScheduler::get().target();
519 const DataType data_type = a->data_type();
520 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
521 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
522 const unsigned int n = b->dimension(0);
523 const unsigned int k = a->dimension(0);
524 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
525 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
526 const bool broadcast_bias = gemm_info.broadcast_bias();
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100527
528 GEMMKernelInfo kernel_info;
Gian Marco Iodice7026b302019-06-26 17:18:11 +0100529 kernel_info.m = m;
530 kernel_info.n = n;
531 kernel_info.k = k;
532 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
533 kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
534 kernel_info.broadcast_bias = broadcast_bias;
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100535 kernel_info.activation_info = gemm_info.activation_info();
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000536
537 GEMMLHSMatrixInfo lhs_info;
538 GEMMRHSMatrixInfo rhs_info;
539
540 // Pick up the GEMM configuration
541 std::unique_ptr<ICLGEMMKernelConfiguration> gemm_config = CLGEMMReshapedOnlyRHSKernelConfigurationFactory::create(gpu_target);
542 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(gemm_config.get());
543
544 // Configure lhs_info and rhs_info
545 std::tie(lhs_info, rhs_info) = gemm_config->configure(m, n, k, batch_size, data_type);
546
547 auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
548 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeRHSMatrixKernel::validate(b, &tmp_b_info, rhs_info));
549
550 // Validate matrix multiply
Gian Marco Iodice9ae06d42020-10-22 16:37:12 +0100551 kernel_info.has_pad_y = false;
552 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyReshapedOnlyRHSKernel::validate(a, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info));
553
554 kernel_info.has_pad_y = true;
Gian Marco Iodice7026b302019-06-26 17:18:11 +0100555 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyReshapedOnlyRHSKernel::validate(a, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info));
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000556
557 return Status{};
558}
559
560void CLGEMM::configure(const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info)
561{
Manuel Bottini2b84be52020-04-08 10:15:51 +0100562 configure(CLKernelLibrary::get().get_compile_context(), a, b, c, output, alpha, beta, gemm_info);
563}
564
565void CLGEMM::configure(const CLCompileContext &compile_context, const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info)
566{
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000567 ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, output);
568
569 // Perform validation step
570 ARM_COMPUTE_ERROR_THROW_ON(validate(a->info(), b->info(), c != nullptr ? c->info() : nullptr, output->info(), alpha, beta, gemm_info));
571
572 // Check if we need to reshape the matrix B only on the first run
573 _reshape_b_only_on_first_run = gemm_info.reshape_b_only_on_first_run();
574 _is_prepared = gemm_info.retain_internal_weights();
575 _original_b = b;
Gian Marco Iodice9ae06d42020-10-22 16:37:12 +0100576 _lhs = a;
577 _dst = output;
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000578
579 // Get the GPU target
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000580 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
581 const unsigned int m = reinterpret_input_as_3d ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1);
582 const unsigned int n = b->info()->dimension(0);
583 const unsigned int k = a->info()->dimension(0);
Gian Marco Iodice026d0452020-08-28 13:52:12 +0100584 const unsigned int batch_size = reinterpret_input_as_3d ? a->info()->dimension(3) : a->info()->dimension(2);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000585
586 // Select GEMMType
Gian Marco Iodice026d0452020-08-28 13:52:12 +0100587 _gemm_kernel_type = select_gemm_kernel(m, n, k, batch_size, a->info()->data_type(), _reshape_b_only_on_first_run);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000588
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100589 const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr);
Gian Marco Iodicee16c8902019-06-14 16:11:10 +0100590
591 const ICLTensor *c_to_use = fuse_add_c ? c : nullptr;
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000592
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000593 switch(_gemm_kernel_type)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000594 {
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000595 case CLGEMMKernelType::NATIVE_V1:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000596 {
Manuel Bottini2b84be52020-04-08 10:15:51 +0100597 configure_native_v1(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000598 break;
599 }
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000600 case CLGEMMKernelType::RESHAPED_V1:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000601 {
Manuel Bottini2b84be52020-04-08 10:15:51 +0100602 configure_reshaped_v1(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000603 break;
604 }
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000605 case CLGEMMKernelType::RESHAPED:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000606 {
Manuel Bottini2b84be52020-04-08 10:15:51 +0100607 configure_reshaped_v2(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000608 break;
609 }
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000610 case CLGEMMKernelType::RESHAPED_ONLY_RHS:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000611 {
Manuel Bottini2b84be52020-04-08 10:15:51 +0100612 configure_reshaped_only_rhs(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000613 break;
614 }
615 default:
616 {
617 ARM_COMPUTE_ERROR("GEMMType not supported");
618 }
619 }
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000620}
621
622Status CLGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
623{
624 // Get the GPU target
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000625 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
626 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
627 const unsigned int n = b->dimension(0);
628 const unsigned int k = a->dimension(0);
Gian Marco Iodice026d0452020-08-28 13:52:12 +0100629 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000630
631 // Select GEMMType
Gian Marco Iodice026d0452020-08-28 13:52:12 +0100632 CLGEMMKernelType gemm_kernel_type = select_gemm_kernel(m, n, k, batch_size, a->data_type(), gemm_info.reshape_b_only_on_first_run());
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000633
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100634 const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr);
Gian Marco Iodicee16c8902019-06-14 16:11:10 +0100635
636 const ITensorInfo *c_to_use = fuse_add_c ? c : nullptr;
637
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000638 switch(gemm_kernel_type)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000639 {
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000640 case CLGEMMKernelType::NATIVE_V1:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000641 {
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000642 ARM_COMPUTE_RETURN_ON_ERROR(validate_native_v1(a, b, c_to_use, output, alpha, beta, gemm_info));
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000643 break;
644 }
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000645 case CLGEMMKernelType::RESHAPED_V1:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000646 {
Gian Marco Iodicee16c8902019-06-14 16:11:10 +0100647 ARM_COMPUTE_RETURN_ON_ERROR(validate_reshaped_v1(a, b, c_to_use, output, alpha, beta, gemm_info));
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000648 break;
649 }
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000650 case CLGEMMKernelType::RESHAPED:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000651 {
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000652 ARM_COMPUTE_RETURN_ON_ERROR(validate_reshaped(a, b, c_to_use, output, alpha, beta, gemm_info));
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000653 break;
654 }
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000655 case CLGEMMKernelType::RESHAPED_ONLY_RHS:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000656 {
Gian Marco Iodicee16c8902019-06-14 16:11:10 +0100657 ARM_COMPUTE_RETURN_ON_ERROR(validate_reshaped_only_rhs(a, b, c_to_use, output, alpha, beta, gemm_info));
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000658 break;
659 }
660 default:
661 {
662 ARM_COMPUTE_RETURN_ERROR_MSG("GEMMType not supported");
663 }
664 }
665
666 return Status{};
667}
668
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100669void CLGEMM::run()
670{
Georgios Pinitase0437672018-05-02 14:07:55 +0100671 prepare();
Georgios Pinitasda953f22019-04-02 17:27:03 +0100672 MemoryGroupResourceScope scope_mg(_memory_group);
Georgios Pinitas8a94e7c2017-09-15 19:06:47 +0100673
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100674 // Run matrix multiply kernel
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000675 switch(_gemm_kernel_type)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000676 {
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000677 case CLGEMMKernelType::NATIVE_V1:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000678 {
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100679 CLScheduler::get().enqueue(*_mm_kernel, true);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000680 break;
681 }
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000682 case CLGEMMKernelType::RESHAPED_V1:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000683 {
684 // Run interleave kernel
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100685 CLScheduler::get().enqueue(*_reshape_lhs_kernel, false);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000686
687 if(!_reshape_b_only_on_first_run)
688 {
689 // Run transpose kernel
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100690 if(_weights_manager && _weights_manager->are_weights_managed(_original_b))
691 {
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100692 _weights_manager->run(_original_b, _reshape_rhs_kernel_managed.get());
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100693 }
694 else
695 {
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100696 CLScheduler::get().enqueue(*_reshape_rhs_kernel, false);
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100697 }
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000698 }
699
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100700 CLScheduler::get().enqueue(*_mm_kernel, true);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000701 break;
702 }
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000703 case CLGEMMKernelType::RESHAPED:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000704 {
705 // Run interleave kernel
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100706 CLScheduler::get().enqueue(*_reshape_lhs_kernel, false);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000707
708 if(!_reshape_b_only_on_first_run)
709 {
710 // Run transpose kernel
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100711 if(_weights_manager && _weights_manager->are_weights_managed(_original_b))
712 {
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100713 _weights_manager->run(_original_b, _reshape_rhs_kernel_managed.get());
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100714 }
715 else
716 {
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100717 CLScheduler::get().enqueue(*_reshape_rhs_kernel, false);
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100718 }
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000719 }
720
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100721 CLScheduler::get().enqueue(*_mm_reshaped_kernel, true);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000722 break;
723 }
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000724 case CLGEMMKernelType::RESHAPED_ONLY_RHS:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000725 {
726 if(!_reshape_b_only_on_first_run)
727 {
728 // Run transpose kernel
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100729 if(_weights_manager && _weights_manager->are_weights_managed(_original_b))
730 {
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100731 _weights_manager->run(_original_b, _reshape_rhs_kernel_managed.get());
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100732 }
733 else
734 {
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100735 CLScheduler::get().enqueue(*_reshape_rhs_kernel, false);
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100736 }
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000737 }
Gian Marco Iodice9ae06d42020-10-22 16:37:12 +0100738 if(_has_pad_y)
739 {
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100740 CLScheduler::get().enqueue(*_mm_reshaped_only_rhs_fallback_kernel, true);
Gian Marco Iodice9ae06d42020-10-22 16:37:12 +0100741 }
742 else
743 {
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100744 CLScheduler::get().enqueue(*_mm_reshaped_only_rhs_kernel, true);
Gian Marco Iodice9ae06d42020-10-22 16:37:12 +0100745 }
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000746 break;
747 }
748 default:
749 {
750 ARM_COMPUTE_ERROR("GEMMType not supported");
751 }
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000752 }
Georgios Pinitase0437672018-05-02 14:07:55 +0100753}
Georgios Pinitas82b51482018-04-24 15:14:12 +0100754
Georgios Pinitase0437672018-05-02 14:07:55 +0100755void CLGEMM::prepare()
756{
757 if(!_is_prepared)
758 {
Gian Marco Iodice9ae06d42020-10-22 16:37:12 +0100759 // In case of RESHAPED_ONLY_RHS, we need to check the padding requirement
760 if(_gemm_kernel_type == CLGEMMKernelType::RESHAPED_ONLY_RHS)
761 {
762 // Check if the lhs or dst tensors have padding
763 const unsigned int cross_plane_pad_lhs = _lhs->info()->padding().top + _lhs->info()->padding().bottom;
764 const unsigned int cross_plane_pad_dst = _dst->info()->padding().top + _dst->info()->padding().bottom;
765
766 _has_pad_y = (cross_plane_pad_lhs != 0) || (cross_plane_pad_dst != 0);
767 }
768
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000769 if(_gemm_kernel_type != CLGEMMKernelType::NATIVE_V1 && _reshape_b_only_on_first_run)
Georgios Pinitase0437672018-05-02 14:07:55 +0100770 {
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100771 if(_weights_manager && _weights_manager->are_weights_managed(_original_b))
772 {
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100773 _weights_manager->run(_original_b, _reshape_rhs_kernel_managed.get());
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100774 }
775 else
776 {
777 // Run transpose kernel and mark original weights tensor as unused
778 _tmp_b.allocator()->allocate();
Sang-Hoon Parkbef7fa22020-10-21 15:58:54 +0100779 CLScheduler::get().enqueue(*_reshape_rhs_kernel, false);
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100780 _original_b->mark_as_unused();
781 }
Georgios Pinitase0437672018-05-02 14:07:55 +0100782 }
783 CLScheduler::get().queue().finish();
784 _is_prepared = true;
785 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100786}
giuros011c9efeb2019-01-11 14:04:43 +0000787} // namespace arm_compute