blob: 80c5496ede66dda7e367c57a6434e9cdbb6a982d [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Michele Di Giorgiod9eaf612020-07-08 11:12:57 +01002 * Copyright (c) 2017-2020 Arm Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/runtime/CL/functions/CLGEMM.h"
25
26#include "arm_compute/core/CL/ICLTensor.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010027#include "arm_compute/core/Error.h"
Gian Marco Iodice750641d2018-05-08 12:01:57 +010028#include "arm_compute/core/GPUTarget.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010029#include "arm_compute/core/Helpers.h"
Gian Marco Iodice7026b302019-06-26 17:18:11 +010030#include "arm_compute/core/KernelDescriptors.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010031#include "arm_compute/core/TensorInfo.h"
32#include "arm_compute/core/Types.h"
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +010033#include "arm_compute/core/Utils.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010034#include "arm_compute/core/Validate.h"
Gian Marco Iodice750641d2018-05-08 12:01:57 +010035#include "arm_compute/core/utils/misc/ShapeCalculator.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010036#include "arm_compute/runtime/CL/CLScheduler.h"
37#include "arm_compute/runtime/ITensorAllocator.h"
Sang-Hoon Park68dd25f2020-10-19 16:00:11 +010038#include "src/core/CL/ICLGEMMKernelConfiguration.h"
39#include "src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfiguration.h"
40#include "src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfiguration.h"
41#include "src/core/helpers/AutoConfiguration.h"
42#include "src/core/utils/helpers/float_ops.h"
43#include "src/runtime/CL/gemm/CLGEMMKernelSelection.h"
44#include "support/Cast.h"
45
46#include "support/MemorySupport.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010047
giuros011c9efeb2019-01-11 14:04:43 +000048namespace arm_compute
49{
Gian Marco Iodice750641d2018-05-08 12:01:57 +010050using namespace arm_compute::misc::shape_calculator;
Gian Marco Iodice90313eb2019-01-16 15:40:25 +000051using namespace arm_compute::cl_gemm;
Michalis Spyroub27e13a2019-09-27 11:04:27 +010052using namespace arm_compute::utils::cast;
Anthony Barbier6ff3b192017-09-04 18:44:23 +010053
Michalis Spyroub27e13a2019-09-27 11:04:27 +010054CLGEMM::CLGEMM(std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager)
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +010055 : _memory_group(std::move(memory_manager)),
Michalis Spyroub27e13a2019-09-27 11:04:27 +010056 _weights_manager(weights_manager),
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +010057 _mm_kernel(),
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +000058 _reshape_lhs_kernel(),
59 _reshape_rhs_kernel(),
Michalis Spyroub27e13a2019-09-27 11:04:27 +010060 _reshape_rhs_kernel_managed(),
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +000061 _mm_reshaped_kernel(),
Gian Marco Iodice926afe12019-03-19 11:44:13 +000062 _mm_reshaped_only_rhs_kernel(),
Gian Marco Iodice9ae06d42020-10-22 16:37:12 +010063 _mm_reshaped_only_rhs_fallback_kernel(),
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +010064 _tmp_a(),
65 _tmp_b(),
66 _original_b(nullptr),
Gian Marco Iodice9ae06d42020-10-22 16:37:12 +010067 _lhs(nullptr),
68 _dst(nullptr),
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +010069 _reshape_b_only_on_first_run(false),
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +000070 _is_prepared(false),
Gian Marco Iodice9ae06d42020-10-22 16:37:12 +010071 _has_pad_y(false),
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +000072 _gemm_kernel_type(CLGEMMKernelType::NATIVE_V1)
Anthony Barbier6ff3b192017-09-04 18:44:23 +010073{
74}
75
Gian Marco Iodice026d0452020-08-28 13:52:12 +010076CLGEMMKernelType CLGEMM::select_gemm_kernel(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type, bool reshape_b_only_on_first_run)
Anthony Barbier6ff3b192017-09-04 18:44:23 +010077{
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +000078 std::unique_ptr<ICLGEMMKernelSelection> gemm_kernel = CLGEMMKernelSelectionFactory::create(CLScheduler::get().target());
79 ARM_COMPUTE_ERROR_ON_NULLPTR(gemm_kernel.get());
Anthony Barbier6ff3b192017-09-04 18:44:23 +010080
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +000081 CLGEMMKernelSelectionParams params;
82 params.m = m;
83 params.n = n;
84 params.k = k;
Gian Marco Iodice026d0452020-08-28 13:52:12 +010085 params.b = b;
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +000086 params.is_rhs_constant = reshape_b_only_on_first_run;
87 params.data_type = data_type;
Gian Marco Iodice05639f62019-09-24 12:05:06 +010088
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +000089 return gemm_kernel->select_kernel(params);
Gian Marco Iodice926afe12019-03-19 11:44:13 +000090}
91
Manuel Bottini2b84be52020-04-08 10:15:51 +010092void CLGEMM::configure_native_v1(const CLCompileContext &compile_context, const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta,
93 const GEMMInfo &gemm_info)
Gian Marco Iodice926afe12019-03-19 11:44:13 +000094{
95 const unsigned int m = gemm_info.reinterpret_input_as_3d() ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1);
96 const unsigned int n = b->info()->dimension(0);
97 const unsigned int k = a->info()->dimension(0);
98 const GPUTarget gpu_target = CLScheduler::get().target();
Gian Marco36a0a462018-01-12 10:21:40 +000099
100 // Set the target for the kernels
Gian Marco36a0a462018-01-12 10:21:40 +0000101 _mm_kernel.set_target(gpu_target);
102
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100103 GEMMReshapeInfo reshape_info(m, n, k, 1, 1, gemm_info.depth_output_gemm3d(), gemm_info.reinterpret_input_as_3d(), gemm_info.broadcast_bias());
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000104
105 // Configure and tune matrix multiply kernel
Manuel Bottini2b84be52020-04-08 10:15:51 +0100106 _mm_kernel.configure(compile_context, a, b, c, output, alpha, beta, false, reshape_info, gemm_info.fp_mixed_precision(), gemm_info.activation_info());
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000107
108 // Tune kernel statically
109 CLScheduler::get().tune_kernel_static(_mm_kernel);
110}
111
Manuel Bottini2b84be52020-04-08 10:15:51 +0100112void CLGEMM::configure_reshaped_v1(const CLCompileContext &compile_context, const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta,
113 const GEMMInfo &gemm_info)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000114{
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000115 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
116 const unsigned int m = reinterpret_input_as_3d ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1);
117 const unsigned int n = b->info()->dimension(0);
118 const unsigned int k = a->info()->dimension(0);
119 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000120 const GPUTarget gpu_target = CLScheduler::get().target();
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000121 int mult_transpose1xW_width = 1;
122 int mult_interleave4x4_height = 1;
Gian Marco36a0a462018-01-12 10:21:40 +0000123
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000124 // Set the target for the kernels
125 _reshape_lhs_kernel.set_target(gpu_target);
126 _mm_kernel.set_target(gpu_target);
127
Gian Marco Iodice750641d2018-05-08 12:01:57 +0100128 if(get_arch_from_target(gpu_target) == GPUTarget::BIFROST)
Gian Marco36a0a462018-01-12 10:21:40 +0000129 {
130 mult_transpose1xW_width = 4;
131 mult_interleave4x4_height = 2;
132 }
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000133
giuros018b6b4a92018-12-18 19:01:33 +0000134 GEMMRHSMatrixInfo rhs_info;
135 rhs_info.n0 = 16 / b->info()->element_size();
136 rhs_info.k0 = 1;
137 rhs_info.h0 = mult_transpose1xW_width;
138 rhs_info.interleave = false;
139 rhs_info.transpose = false;
Gian Marco36a0a462018-01-12 10:21:40 +0000140
giuros011c9efeb2019-01-11 14:04:43 +0000141 GEMMLHSMatrixInfo lhs_info;
142 lhs_info.m0 = 4;
143 lhs_info.k0 = 4;
144 lhs_info.v0 = mult_interleave4x4_height;
145 lhs_info.interleave = true;
146 lhs_info.transpose = true;
147
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100148 GEMMReshapeInfo reshape_info(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d, false, gemm_info.broadcast_bias());
Gian Marcob5311a62017-12-13 12:48:03 +0000149
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100150 const bool use_mm_b = (!_weights_manager || !_weights_manager->are_weights_managed(b));
151
152 // Manage intermediate buffers
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000153 _memory_group.manage(&_tmp_a);
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100154
155 if(!_reshape_b_only_on_first_run && use_mm_b)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +0100156 {
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000157 _memory_group.manage(&_tmp_b);
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100158 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100159
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000160 // Configure interleave kernel
Manuel Bottini2b84be52020-04-08 10:15:51 +0100161 _reshape_lhs_kernel.configure(compile_context, a, &_tmp_a, lhs_info, reinterpret_input_as_3d);
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100162
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000163 // Configure transpose kernel
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100164 ICLTensor *reshaped_rhs = &_tmp_b;
165 if(_weights_manager && _weights_manager->are_weights_managed(b))
166 {
Manuel Bottini2b84be52020-04-08 10:15:51 +0100167 _reshape_rhs_kernel_managed.configure(compile_context, b, rhs_info);
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100168 reshaped_rhs = utils::cast::polymorphic_downcast<ICLTensor *>(_weights_manager->acquire(b, &_reshape_rhs_kernel_managed));
169 }
170 else
171 {
Manuel Bottini2b84be52020-04-08 10:15:51 +0100172 _reshape_rhs_kernel.configure(compile_context, b, &_tmp_b, rhs_info);
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100173 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100174
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000175 // Configure and tune matrix multiply kernel
Manuel Bottini2b84be52020-04-08 10:15:51 +0100176 _mm_kernel.configure(compile_context, &_tmp_a, reshaped_rhs, c, output, alpha, beta, true, reshape_info, gemm_info.fp_mixed_precision(), gemm_info.activation_info());
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000177
178 CLScheduler::get().tune_kernel_static(_mm_kernel);
179
180 // Allocate intermediate tensors
181 _tmp_a.allocator()->allocate();
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100182
183 if(!_reshape_b_only_on_first_run && use_mm_b)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100184 {
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000185 _tmp_b.allocator()->allocate();
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100186 }
187}
188
Manuel Bottini2b84be52020-04-08 10:15:51 +0100189void CLGEMM::configure_reshaped_v2(const CLCompileContext &compile_context, const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta,
190 const GEMMInfo &gemm_info)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000191{
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000192 DataType data_type = a->info()->data_type();
193 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
194 const unsigned int m = reinterpret_input_as_3d ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1);
195 const unsigned int n = b->info()->dimension(0);
196 const unsigned int k = a->info()->dimension(0);
197 const unsigned int batch_size = reinterpret_input_as_3d ? a->info()->dimension(3) : a->info()->dimension(2);
198 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
199 const GPUTarget gpu_target = CLScheduler::get().target();
Gian Marco Iodicee16c8902019-06-14 16:11:10 +0100200 bool broadcast_bias = gemm_info.broadcast_bias();
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100201
202 GEMMKernelInfo kernel_info;
Gian Marco Iodice7026b302019-06-26 17:18:11 +0100203 kernel_info.m = m;
204 kernel_info.n = n;
205 kernel_info.k = k;
206 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
207 kernel_info.reinterpret_input_as_3d = false;
208 kernel_info.broadcast_bias = broadcast_bias;
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100209 kernel_info.activation_info = gemm_info.activation_info();
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000210
211 // Set the target for the kernels
212 _reshape_lhs_kernel.set_target(gpu_target);
213 _mm_kernel.set_target(gpu_target);
214
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100215 const bool use_mm_b = (!_weights_manager || !_weights_manager->are_weights_managed(b));
216
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000217 // Manage intermediate buffers
218 _memory_group.manage(&_tmp_a);
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100219
220 if(!_reshape_b_only_on_first_run && use_mm_b)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000221 {
222 _memory_group.manage(&_tmp_b);
223 }
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100224
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000225 // _tmp_a and _tmp_b will be auto configured in _interleave_kernel and in _transpose_kernel
226
227 GEMMLHSMatrixInfo lhs_info{};
228 GEMMRHSMatrixInfo rhs_info{};
229
230 // Pick up the GEMM configuration
231 std::unique_ptr<ICLGEMMKernelConfiguration> gemm_config = CLGEMMReshapedKernelConfigurationFactory::create(gpu_target);
232 ARM_COMPUTE_ERROR_ON_NULLPTR(gemm_config.get());
233
234 // Configure lhs_info and rhs_info
235 std::tie(lhs_info, rhs_info) = gemm_config->configure(m, n, k, batch_size, data_type);
236
Manuel Bottini2b84be52020-04-08 10:15:51 +0100237 _reshape_lhs_kernel.configure(compile_context, a, &_tmp_a, lhs_info, gemm_info.reinterpret_input_as_3d());
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100238
239 ICLTensor *reshaped_rhs = &_tmp_b;
240 if(_weights_manager && _weights_manager->are_weights_managed(b))
241 {
Manuel Bottini2b84be52020-04-08 10:15:51 +0100242 _reshape_rhs_kernel_managed.configure(compile_context, b, rhs_info);
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100243 reshaped_rhs = utils::cast::polymorphic_downcast<ICLTensor *>(_weights_manager->acquire(b, &_reshape_rhs_kernel_managed));
244 }
245 else
246 {
Manuel Bottini2b84be52020-04-08 10:15:51 +0100247 _reshape_rhs_kernel.configure(compile_context, b, &_tmp_b, rhs_info);
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100248 }
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000249
250 // Configure and tune matrix multiply kernel
Manuel Bottini2b84be52020-04-08 10:15:51 +0100251 _mm_reshaped_kernel.configure(compile_context, &_tmp_a, reshaped_rhs, c, output, alpha, beta, lhs_info, rhs_info, kernel_info);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000252
253 // Allocate intermediate tensors
254 _tmp_a.allocator()->allocate();
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100255
256 if(!_reshape_b_only_on_first_run && use_mm_b)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000257 {
258 _tmp_b.allocator()->allocate();
259 }
260}
261
Manuel Bottini2b84be52020-04-08 10:15:51 +0100262void CLGEMM::configure_reshaped_only_rhs(const CLCompileContext &compile_context, const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta,
263 const GEMMInfo &gemm_info)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000264{
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000265 DataType data_type = a->info()->data_type();
266 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
267 const unsigned int m = reinterpret_input_as_3d ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1);
268 const unsigned int n = b->info()->dimension(0);
269 const unsigned int k = a->info()->dimension(0);
270 const unsigned int batch_size = reinterpret_input_as_3d ? a->info()->dimension(3) : a->info()->dimension(2);
271 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
272 const GPUTarget gpu_target = CLScheduler::get().target();
Georgios Pinitasb0f342e2019-05-21 13:32:43 +0100273 bool broadcast_bias = gemm_info.broadcast_bias();
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100274
275 GEMMKernelInfo kernel_info;
Gian Marco Iodice7026b302019-06-26 17:18:11 +0100276 kernel_info.m = m;
277 kernel_info.n = n;
278 kernel_info.k = k;
279 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
280 kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
281 kernel_info.broadcast_bias = broadcast_bias;
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100282 kernel_info.activation_info = gemm_info.activation_info();
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000283
284 // Set the target for the kernels
285 _mm_kernel.set_target(gpu_target);
286
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100287 const bool use_mm_b = (!_weights_manager || !_weights_manager->are_weights_managed(b));
288
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000289 // Manage intermediate buffers
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100290 if(!_reshape_b_only_on_first_run && use_mm_b)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000291 {
292 _memory_group.manage(&_tmp_b);
293 }
294
295 GEMMLHSMatrixInfo lhs_info{};
296 GEMMRHSMatrixInfo rhs_info{};
297
298 // Pick up the GEMM configuration
299 std::unique_ptr<ICLGEMMKernelConfiguration> gemm_config = CLGEMMReshapedOnlyRHSKernelConfigurationFactory::create(gpu_target);
300 ARM_COMPUTE_ERROR_ON_NULLPTR(gemm_config.get());
301
302 // Configure lhs_info and rhs_info
Gian Marco Iodice9ae06d42020-10-22 16:37:12 +0100303 std::tie(lhs_info, rhs_info) = gemm_config->configure(m, n, k, batch_size, data_type);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000304
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100305 ICLTensor *reshaped_rhs = &_tmp_b;
306 if(_weights_manager && _weights_manager->are_weights_managed(b))
307 {
Manuel Bottini2b84be52020-04-08 10:15:51 +0100308 _reshape_rhs_kernel_managed.configure(compile_context, b, rhs_info);
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100309 reshaped_rhs = utils::cast::polymorphic_downcast<ICLTensor *>(_weights_manager->acquire(b, &_reshape_rhs_kernel_managed));
310 }
311 else
312 {
Manuel Bottini2b84be52020-04-08 10:15:51 +0100313 _reshape_rhs_kernel.configure(compile_context, b, &_tmp_b, rhs_info);
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100314 }
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000315
Gian Marco Iodice9ae06d42020-10-22 16:37:12 +0100316 // Configure two variants of CLGEMMMatrixMultiplyReshapedOnlyRHSKernel (has_pad_y = false/true)
317 // During the prepare stage we check the padding requirement for the lhs and dst tensors. If they do not have
318 // pad y, we dispatch CLGEMMMatrixMultiplyReshapedOnlyRHSKernel with has_pad_y = false
319
320 // Configure matrix multiply kernel with no y padding support
321 kernel_info.has_pad_y = false;
Manuel Bottini2b84be52020-04-08 10:15:51 +0100322 _mm_reshaped_only_rhs_kernel.configure(compile_context, a, reshaped_rhs, c, output, alpha, beta, lhs_info, rhs_info, kernel_info);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000323
Gian Marco Iodice9ae06d42020-10-22 16:37:12 +0100324 // Configure matrix multiply kernel with y padding support
325 kernel_info.has_pad_y = true;
326 _mm_reshaped_only_rhs_fallback_kernel.configure(compile_context, a, reshaped_rhs, c, output, alpha, beta, lhs_info, rhs_info, kernel_info);
327
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100328 if(!_reshape_b_only_on_first_run && use_mm_b)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000329 {
330 _tmp_b.allocator()->allocate();
331 }
332}
333
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000334Status CLGEMM::validate_native_v1(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
Georgios Pinitas78c00902018-01-09 17:33:11 +0000335{
Gian Marco Iodice750641d2018-05-08 12:01:57 +0100336 ARM_COMPUTE_UNUSED(alpha);
Gian Marco Iodice215b4ea2018-06-28 16:29:29 +0100337 ARM_COMPUTE_UNUSED(output);
Gian Marco Iodice750641d2018-05-08 12:01:57 +0100338
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000339 // Get the GPU target
340 const GPUTarget gpu_target = CLScheduler::get().target();
341 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
342 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
343 const unsigned int n = b->dimension(0);
344 const unsigned int k = a->dimension(0);
345 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
Gian Marco Iodice750641d2018-05-08 12:01:57 +0100346
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100347 const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d, gemm_info.broadcast_bias());
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000348
349 // Validate matrix multiply
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100350 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyKernel::validate(a, b, c, output, alpha, beta,
351 false, reshape_info, gpu_target, gemm_info.fp_mixed_precision(), gemm_info.activation_info()));
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000352
353 return Status{};
354}
355
356Status CLGEMM::validate_reshaped_v1(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
357{
358 ARM_COMPUTE_UNUSED(alpha);
359 ARM_COMPUTE_UNUSED(output);
Gian Marco Iodice750641d2018-05-08 12:01:57 +0100360
361 TensorInfo tmp_a_info{};
362 TensorInfo tmp_b_info{};
Gian Marco Iodice750641d2018-05-08 12:01:57 +0100363
364 // Get the GPU target
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000365 const GPUTarget gpu_target = CLScheduler::get().target();
366 const unsigned int m = gemm_info.reinterpret_input_as_3d() ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000367 const unsigned int n = b->dimension(0);
368 const unsigned int k = a->dimension(0);
369 int mult_transpose1xW_width = 1;
370 int mult_interleave4x4_height = 1;
371 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
Gian Marco Iodice750641d2018-05-08 12:01:57 +0100372
373 if(get_arch_from_target(gpu_target) == GPUTarget::BIFROST)
374 {
375 mult_transpose1xW_width = 4;
376 mult_interleave4x4_height = 2;
377 }
378
giuros018b6b4a92018-12-18 19:01:33 +0000379 GEMMRHSMatrixInfo rhs_info;
380 rhs_info.n0 = 16 / b->element_size();
381 rhs_info.k0 = 1;
382 rhs_info.h0 = mult_transpose1xW_width;
383 rhs_info.interleave = false;
384 rhs_info.transpose = false;
385
giuros011c9efeb2019-01-11 14:04:43 +0000386 GEMMLHSMatrixInfo lhs_info;
387 lhs_info.m0 = 4;
388 lhs_info.k0 = 4;
389 lhs_info.v0 = mult_interleave4x4_height;
390 lhs_info.interleave = true;
391 lhs_info.transpose = true;
392
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100393 const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d, false, gemm_info.broadcast_bias());
Gian Marco Iodice750641d2018-05-08 12:01:57 +0100394
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000395 // Validate interleave kernel
396 auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_lhs_reshaped_shape(*a, lhs_info, gemm_info.reinterpret_input_as_3d())));
397 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeLHSMatrixKernel::validate(a, &tmp_a_info, lhs_info, gemm_info.reinterpret_input_as_3d()));
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000398
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000399 // Validate transpose kernel
400 auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
401 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeRHSMatrixKernel::validate(b, &tmp_b_info, rhs_info));
Michele Di Giorgioebc3a902018-11-16 16:04:25 +0000402
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000403 // Validate matrix multiply
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100404 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyKernel::validate(&tmp_a_info, &tmp_b_info, c, output, alpha, beta,
405 true, reshape_info, gpu_target, gemm_info.fp_mixed_precision(), gemm_info.activation_info()));
Gian Marco Iodice750641d2018-05-08 12:01:57 +0100406
Georgios Pinitas78c00902018-01-09 17:33:11 +0000407 return Status{};
408}
409
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000410Status CLGEMM::validate_reshaped(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000411{
412 ARM_COMPUTE_UNUSED(alpha);
413 ARM_COMPUTE_UNUSED(output);
414
415 TensorInfo tmp_a_info{};
416 TensorInfo tmp_b_info{};
417
418 // Get the GPU target
419 const GPUTarget gpu_target = CLScheduler::get().target();
420 DataType data_type = a->data_type();
421 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
422 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
423 const unsigned int n = b->dimension(0);
424 const unsigned int k = a->dimension(0);
425 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
426 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
Gian Marco Iodicee16c8902019-06-14 16:11:10 +0100427 const bool broadcast_bias = gemm_info.broadcast_bias();
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100428
429 GEMMKernelInfo kernel_info;
Gian Marco Iodice7026b302019-06-26 17:18:11 +0100430 kernel_info.m = m;
431 kernel_info.n = n;
432 kernel_info.k = k;
433 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
434 kernel_info.reinterpret_input_as_3d = false;
435 kernel_info.broadcast_bias = broadcast_bias;
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100436 kernel_info.activation_info = gemm_info.activation_info();
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000437
438 GEMMLHSMatrixInfo lhs_info;
439 GEMMRHSMatrixInfo rhs_info;
440
441 // Pick up the GEMM configuration
442 std::unique_ptr<ICLGEMMKernelConfiguration> gemm_config = CLGEMMReshapedKernelConfigurationFactory::create(gpu_target);
443 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(gemm_config.get());
444
445 // Configure lhs_info and rhs_info
446 std::tie(lhs_info, rhs_info) = gemm_config->configure(m, n, k, batch_size, data_type);
447
448 auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_lhs_reshaped_shape(*a, lhs_info, gemm_info.reinterpret_input_as_3d())));
449 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeLHSMatrixKernel::validate(a, &tmp_a_info, lhs_info, gemm_info.reinterpret_input_as_3d()));
450
451 auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
452 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeRHSMatrixKernel::validate(b, &tmp_b_info, rhs_info));
453
454 // Validate matrix multiply
Gian Marco Iodice7026b302019-06-26 17:18:11 +0100455 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyReshapedKernel::validate(&tmp_a_info, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info));
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000456
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000457 return Status{};
458}
459
460Status CLGEMM::validate_reshaped_only_rhs(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
461{
462 ARM_COMPUTE_UNUSED(alpha);
463 ARM_COMPUTE_UNUSED(output);
464
465 TensorInfo tmp_b_info{};
466
467 // Get the GPU target
Gian Marco Iodice7026b302019-06-26 17:18:11 +0100468 const GPUTarget gpu_target = CLScheduler::get().target();
469 const DataType data_type = a->data_type();
470 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
471 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
472 const unsigned int n = b->dimension(0);
473 const unsigned int k = a->dimension(0);
474 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
475 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
476 const bool broadcast_bias = gemm_info.broadcast_bias();
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100477
478 GEMMKernelInfo kernel_info;
Gian Marco Iodice7026b302019-06-26 17:18:11 +0100479 kernel_info.m = m;
480 kernel_info.n = n;
481 kernel_info.k = k;
482 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
483 kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
484 kernel_info.broadcast_bias = broadcast_bias;
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100485 kernel_info.activation_info = gemm_info.activation_info();
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000486
487 GEMMLHSMatrixInfo lhs_info;
488 GEMMRHSMatrixInfo rhs_info;
489
490 // Pick up the GEMM configuration
491 std::unique_ptr<ICLGEMMKernelConfiguration> gemm_config = CLGEMMReshapedOnlyRHSKernelConfigurationFactory::create(gpu_target);
492 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(gemm_config.get());
493
494 // Configure lhs_info and rhs_info
495 std::tie(lhs_info, rhs_info) = gemm_config->configure(m, n, k, batch_size, data_type);
496
497 auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
498 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeRHSMatrixKernel::validate(b, &tmp_b_info, rhs_info));
499
500 // Validate matrix multiply
Gian Marco Iodice9ae06d42020-10-22 16:37:12 +0100501 kernel_info.has_pad_y = false;
502 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyReshapedOnlyRHSKernel::validate(a, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info));
503
504 kernel_info.has_pad_y = true;
Gian Marco Iodice7026b302019-06-26 17:18:11 +0100505 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyReshapedOnlyRHSKernel::validate(a, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info));
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000506
507 return Status{};
508}
509
510void CLGEMM::configure(const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info)
511{
Manuel Bottini2b84be52020-04-08 10:15:51 +0100512 configure(CLKernelLibrary::get().get_compile_context(), a, b, c, output, alpha, beta, gemm_info);
513}
514
515void CLGEMM::configure(const CLCompileContext &compile_context, const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info)
516{
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000517 ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, output);
518
519 // Perform validation step
520 ARM_COMPUTE_ERROR_THROW_ON(validate(a->info(), b->info(), c != nullptr ? c->info() : nullptr, output->info(), alpha, beta, gemm_info));
521
522 // Check if we need to reshape the matrix B only on the first run
523 _reshape_b_only_on_first_run = gemm_info.reshape_b_only_on_first_run();
524 _is_prepared = gemm_info.retain_internal_weights();
525 _original_b = b;
Gian Marco Iodice9ae06d42020-10-22 16:37:12 +0100526 _lhs = a;
527 _dst = output;
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000528
529 // Get the GPU target
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000530 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
531 const unsigned int m = reinterpret_input_as_3d ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1);
532 const unsigned int n = b->info()->dimension(0);
533 const unsigned int k = a->info()->dimension(0);
Gian Marco Iodice026d0452020-08-28 13:52:12 +0100534 const unsigned int batch_size = reinterpret_input_as_3d ? a->info()->dimension(3) : a->info()->dimension(2);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000535
536 // Select GEMMType
Gian Marco Iodice026d0452020-08-28 13:52:12 +0100537 _gemm_kernel_type = select_gemm_kernel(m, n, k, batch_size, a->info()->data_type(), _reshape_b_only_on_first_run);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000538
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100539 const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr);
Gian Marco Iodicee16c8902019-06-14 16:11:10 +0100540
541 const ICLTensor *c_to_use = fuse_add_c ? c : nullptr;
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000542
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000543 switch(_gemm_kernel_type)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000544 {
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000545 case CLGEMMKernelType::NATIVE_V1:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000546 {
Manuel Bottini2b84be52020-04-08 10:15:51 +0100547 configure_native_v1(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000548 break;
549 }
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000550 case CLGEMMKernelType::RESHAPED_V1:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000551 {
Manuel Bottini2b84be52020-04-08 10:15:51 +0100552 configure_reshaped_v1(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000553 break;
554 }
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000555 case CLGEMMKernelType::RESHAPED:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000556 {
Manuel Bottini2b84be52020-04-08 10:15:51 +0100557 configure_reshaped_v2(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000558 break;
559 }
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000560 case CLGEMMKernelType::RESHAPED_ONLY_RHS:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000561 {
Manuel Bottini2b84be52020-04-08 10:15:51 +0100562 configure_reshaped_only_rhs(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000563 break;
564 }
565 default:
566 {
567 ARM_COMPUTE_ERROR("GEMMType not supported");
568 }
569 }
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000570}
571
572Status CLGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
573{
574 // Get the GPU target
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000575 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
576 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
577 const unsigned int n = b->dimension(0);
578 const unsigned int k = a->dimension(0);
Gian Marco Iodice026d0452020-08-28 13:52:12 +0100579 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000580
581 // Select GEMMType
Gian Marco Iodice026d0452020-08-28 13:52:12 +0100582 CLGEMMKernelType gemm_kernel_type = select_gemm_kernel(m, n, k, batch_size, a->data_type(), gemm_info.reshape_b_only_on_first_run());
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000583
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100584 const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr);
Gian Marco Iodicee16c8902019-06-14 16:11:10 +0100585
586 const ITensorInfo *c_to_use = fuse_add_c ? c : nullptr;
587
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000588 switch(gemm_kernel_type)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000589 {
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000590 case CLGEMMKernelType::NATIVE_V1:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000591 {
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000592 ARM_COMPUTE_RETURN_ON_ERROR(validate_native_v1(a, b, c_to_use, output, alpha, beta, gemm_info));
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000593 break;
594 }
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000595 case CLGEMMKernelType::RESHAPED_V1:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000596 {
Gian Marco Iodicee16c8902019-06-14 16:11:10 +0100597 ARM_COMPUTE_RETURN_ON_ERROR(validate_reshaped_v1(a, b, c_to_use, output, alpha, beta, gemm_info));
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000598 break;
599 }
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000600 case CLGEMMKernelType::RESHAPED:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000601 {
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000602 ARM_COMPUTE_RETURN_ON_ERROR(validate_reshaped(a, b, c_to_use, output, alpha, beta, gemm_info));
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000603 break;
604 }
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000605 case CLGEMMKernelType::RESHAPED_ONLY_RHS:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000606 {
Gian Marco Iodicee16c8902019-06-14 16:11:10 +0100607 ARM_COMPUTE_RETURN_ON_ERROR(validate_reshaped_only_rhs(a, b, c_to_use, output, alpha, beta, gemm_info));
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000608 break;
609 }
610 default:
611 {
612 ARM_COMPUTE_RETURN_ERROR_MSG("GEMMType not supported");
613 }
614 }
615
616 return Status{};
617}
618
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100619void CLGEMM::run()
620{
Georgios Pinitase0437672018-05-02 14:07:55 +0100621 prepare();
Georgios Pinitasda953f22019-04-02 17:27:03 +0100622 MemoryGroupResourceScope scope_mg(_memory_group);
Georgios Pinitas8a94e7c2017-09-15 19:06:47 +0100623
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100624 // Run matrix multiply kernel
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000625 switch(_gemm_kernel_type)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000626 {
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000627 case CLGEMMKernelType::NATIVE_V1:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000628 {
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100629 CLScheduler::get().enqueue(_mm_kernel, true);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000630 break;
631 }
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000632 case CLGEMMKernelType::RESHAPED_V1:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000633 {
634 // Run interleave kernel
635 CLScheduler::get().enqueue(_reshape_lhs_kernel, false);
636
637 if(!_reshape_b_only_on_first_run)
638 {
639 // Run transpose kernel
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100640 if(_weights_manager && _weights_manager->are_weights_managed(_original_b))
641 {
642 _weights_manager->run(_original_b, &_reshape_rhs_kernel_managed);
643 }
644 else
645 {
646 CLScheduler::get().enqueue(_reshape_rhs_kernel, false);
647 }
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000648 }
649
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100650 CLScheduler::get().enqueue(_mm_kernel, true);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000651 break;
652 }
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000653 case CLGEMMKernelType::RESHAPED:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000654 {
655 // Run interleave kernel
656 CLScheduler::get().enqueue(_reshape_lhs_kernel, false);
657
658 if(!_reshape_b_only_on_first_run)
659 {
660 // Run transpose kernel
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100661 if(_weights_manager && _weights_manager->are_weights_managed(_original_b))
662 {
663 _weights_manager->run(_original_b, &_reshape_rhs_kernel_managed);
664 }
665 else
666 {
667 CLScheduler::get().enqueue(_reshape_rhs_kernel, false);
668 }
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000669 }
670
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100671 CLScheduler::get().enqueue(_mm_reshaped_kernel, true);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000672 break;
673 }
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000674 case CLGEMMKernelType::RESHAPED_ONLY_RHS:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000675 {
676 if(!_reshape_b_only_on_first_run)
677 {
678 // Run transpose kernel
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100679 if(_weights_manager && _weights_manager->are_weights_managed(_original_b))
680 {
681 _weights_manager->run(_original_b, &_reshape_rhs_kernel_managed);
682 }
683 else
684 {
685 CLScheduler::get().enqueue(_reshape_rhs_kernel, false);
686 }
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000687 }
Gian Marco Iodice9ae06d42020-10-22 16:37:12 +0100688 if(_has_pad_y)
689 {
690 CLScheduler::get().enqueue(_mm_reshaped_only_rhs_fallback_kernel, true);
691 }
692 else
693 {
694 CLScheduler::get().enqueue(_mm_reshaped_only_rhs_kernel, true);
695 }
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000696 break;
697 }
698 default:
699 {
700 ARM_COMPUTE_ERROR("GEMMType not supported");
701 }
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000702 }
Georgios Pinitase0437672018-05-02 14:07:55 +0100703}
Georgios Pinitas82b51482018-04-24 15:14:12 +0100704
Georgios Pinitase0437672018-05-02 14:07:55 +0100705void CLGEMM::prepare()
706{
707 if(!_is_prepared)
708 {
Gian Marco Iodice9ae06d42020-10-22 16:37:12 +0100709 // In case of RESHAPED_ONLY_RHS, we need to check the padding requirement
710 if(_gemm_kernel_type == CLGEMMKernelType::RESHAPED_ONLY_RHS)
711 {
712 // Check if the lhs or dst tensors have padding
713 const unsigned int cross_plane_pad_lhs = _lhs->info()->padding().top + _lhs->info()->padding().bottom;
714 const unsigned int cross_plane_pad_dst = _dst->info()->padding().top + _dst->info()->padding().bottom;
715
716 _has_pad_y = (cross_plane_pad_lhs != 0) || (cross_plane_pad_dst != 0);
717 }
718
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000719 if(_gemm_kernel_type != CLGEMMKernelType::NATIVE_V1 && _reshape_b_only_on_first_run)
Georgios Pinitase0437672018-05-02 14:07:55 +0100720 {
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100721 if(_weights_manager && _weights_manager->are_weights_managed(_original_b))
722 {
723 _weights_manager->run(_original_b, &_reshape_rhs_kernel_managed);
724 }
725 else
726 {
727 // Run transpose kernel and mark original weights tensor as unused
728 _tmp_b.allocator()->allocate();
729 CLScheduler::get().enqueue(_reshape_rhs_kernel, false);
730 _original_b->mark_as_unused();
731 }
Georgios Pinitase0437672018-05-02 14:07:55 +0100732 }
733 CLScheduler::get().queue().finish();
734 _is_prepared = true;
735 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100736}
giuros011c9efeb2019-01-11 14:04:43 +0000737} // namespace arm_compute