blob: ccae6713a61cbdaa3fb59eb3c8281389a885546b [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Michele Di Giorgiod9eaf612020-07-08 11:12:57 +01002 * Copyright (c) 2017-2020 Arm Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/runtime/CL/functions/CLGEMM.h"
25
26#include "arm_compute/core/CL/ICLTensor.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010027#include "arm_compute/core/Error.h"
Gian Marco Iodice750641d2018-05-08 12:01:57 +010028#include "arm_compute/core/GPUTarget.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010029#include "arm_compute/core/Helpers.h"
Gian Marco Iodice7026b302019-06-26 17:18:11 +010030#include "arm_compute/core/KernelDescriptors.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010031#include "arm_compute/core/TensorInfo.h"
32#include "arm_compute/core/Types.h"
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +010033#include "arm_compute/core/Utils.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010034#include "arm_compute/core/Validate.h"
Gian Marco Iodice750641d2018-05-08 12:01:57 +010035#include "arm_compute/core/utils/misc/ShapeCalculator.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010036#include "arm_compute/runtime/CL/CLScheduler.h"
37#include "arm_compute/runtime/ITensorAllocator.h"
Sang-Hoon Park68dd25f2020-10-19 16:00:11 +010038#include "src/core/CL/ICLGEMMKernelConfiguration.h"
39#include "src/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfiguration.h"
40#include "src/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfiguration.h"
41#include "src/core/helpers/AutoConfiguration.h"
42#include "src/core/utils/helpers/float_ops.h"
43#include "src/runtime/CL/gemm/CLGEMMKernelSelection.h"
44#include "support/Cast.h"
45
46#include "support/MemorySupport.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010047
giuros011c9efeb2019-01-11 14:04:43 +000048namespace arm_compute
49{
Gian Marco Iodice750641d2018-05-08 12:01:57 +010050using namespace arm_compute::misc::shape_calculator;
Gian Marco Iodice90313eb2019-01-16 15:40:25 +000051using namespace arm_compute::cl_gemm;
Michalis Spyroub27e13a2019-09-27 11:04:27 +010052using namespace arm_compute::utils::cast;
Anthony Barbier6ff3b192017-09-04 18:44:23 +010053
Michalis Spyroub27e13a2019-09-27 11:04:27 +010054CLGEMM::CLGEMM(std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager)
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +010055 : _memory_group(std::move(memory_manager)),
Michalis Spyroub27e13a2019-09-27 11:04:27 +010056 _weights_manager(weights_manager),
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +010057 _mm_kernel(),
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +000058 _reshape_lhs_kernel(),
59 _reshape_rhs_kernel(),
Michalis Spyroub27e13a2019-09-27 11:04:27 +010060 _reshape_rhs_kernel_managed(),
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +000061 _mm_reshaped_kernel(),
Gian Marco Iodice926afe12019-03-19 11:44:13 +000062 _mm_reshaped_only_rhs_kernel(),
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +010063 _tmp_a(),
64 _tmp_b(),
65 _original_b(nullptr),
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +010066 _reshape_b_only_on_first_run(false),
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +000067 _is_prepared(false),
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +000068 _gemm_kernel_type(CLGEMMKernelType::NATIVE_V1)
Anthony Barbier6ff3b192017-09-04 18:44:23 +010069{
70}
71
Gian Marco Iodice026d0452020-08-28 13:52:12 +010072CLGEMMKernelType CLGEMM::select_gemm_kernel(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type, bool reshape_b_only_on_first_run)
Anthony Barbier6ff3b192017-09-04 18:44:23 +010073{
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +000074 std::unique_ptr<ICLGEMMKernelSelection> gemm_kernel = CLGEMMKernelSelectionFactory::create(CLScheduler::get().target());
75 ARM_COMPUTE_ERROR_ON_NULLPTR(gemm_kernel.get());
Anthony Barbier6ff3b192017-09-04 18:44:23 +010076
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +000077 CLGEMMKernelSelectionParams params;
78 params.m = m;
79 params.n = n;
80 params.k = k;
Gian Marco Iodice026d0452020-08-28 13:52:12 +010081 params.b = b;
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +000082 params.is_rhs_constant = reshape_b_only_on_first_run;
83 params.data_type = data_type;
Gian Marco Iodice05639f62019-09-24 12:05:06 +010084
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +000085 return gemm_kernel->select_kernel(params);
Gian Marco Iodice926afe12019-03-19 11:44:13 +000086}
87
Manuel Bottini2b84be52020-04-08 10:15:51 +010088void CLGEMM::configure_native_v1(const CLCompileContext &compile_context, const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta,
89 const GEMMInfo &gemm_info)
Gian Marco Iodice926afe12019-03-19 11:44:13 +000090{
91 const unsigned int m = gemm_info.reinterpret_input_as_3d() ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1);
92 const unsigned int n = b->info()->dimension(0);
93 const unsigned int k = a->info()->dimension(0);
94 const GPUTarget gpu_target = CLScheduler::get().target();
Gian Marco36a0a462018-01-12 10:21:40 +000095
96 // Set the target for the kernels
Gian Marco36a0a462018-01-12 10:21:40 +000097 _mm_kernel.set_target(gpu_target);
98
Gian Marco Iodicef3622be2019-07-29 14:27:16 +010099 GEMMReshapeInfo reshape_info(m, n, k, 1, 1, gemm_info.depth_output_gemm3d(), gemm_info.reinterpret_input_as_3d(), gemm_info.broadcast_bias());
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000100
101 // Configure and tune matrix multiply kernel
Manuel Bottini2b84be52020-04-08 10:15:51 +0100102 _mm_kernel.configure(compile_context, a, b, c, output, alpha, beta, false, reshape_info, gemm_info.fp_mixed_precision(), gemm_info.activation_info());
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000103
104 // Tune kernel statically
105 CLScheduler::get().tune_kernel_static(_mm_kernel);
106}
107
Manuel Bottini2b84be52020-04-08 10:15:51 +0100108void CLGEMM::configure_reshaped_v1(const CLCompileContext &compile_context, const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta,
109 const GEMMInfo &gemm_info)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000110{
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000111 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
112 const unsigned int m = reinterpret_input_as_3d ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1);
113 const unsigned int n = b->info()->dimension(0);
114 const unsigned int k = a->info()->dimension(0);
115 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000116 const GPUTarget gpu_target = CLScheduler::get().target();
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000117 int mult_transpose1xW_width = 1;
118 int mult_interleave4x4_height = 1;
Gian Marco36a0a462018-01-12 10:21:40 +0000119
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000120 // Set the target for the kernels
121 _reshape_lhs_kernel.set_target(gpu_target);
122 _mm_kernel.set_target(gpu_target);
123
Gian Marco Iodice750641d2018-05-08 12:01:57 +0100124 if(get_arch_from_target(gpu_target) == GPUTarget::BIFROST)
Gian Marco36a0a462018-01-12 10:21:40 +0000125 {
126 mult_transpose1xW_width = 4;
127 mult_interleave4x4_height = 2;
128 }
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000129
giuros018b6b4a92018-12-18 19:01:33 +0000130 GEMMRHSMatrixInfo rhs_info;
131 rhs_info.n0 = 16 / b->info()->element_size();
132 rhs_info.k0 = 1;
133 rhs_info.h0 = mult_transpose1xW_width;
134 rhs_info.interleave = false;
135 rhs_info.transpose = false;
Gian Marco36a0a462018-01-12 10:21:40 +0000136
giuros011c9efeb2019-01-11 14:04:43 +0000137 GEMMLHSMatrixInfo lhs_info;
138 lhs_info.m0 = 4;
139 lhs_info.k0 = 4;
140 lhs_info.v0 = mult_interleave4x4_height;
141 lhs_info.interleave = true;
142 lhs_info.transpose = true;
143
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100144 GEMMReshapeInfo reshape_info(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d, false, gemm_info.broadcast_bias());
Gian Marcob5311a62017-12-13 12:48:03 +0000145
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100146 const bool use_mm_b = (!_weights_manager || !_weights_manager->are_weights_managed(b));
147
148 // Manage intermediate buffers
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000149 _memory_group.manage(&_tmp_a);
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100150
151 if(!_reshape_b_only_on_first_run && use_mm_b)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +0100152 {
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000153 _memory_group.manage(&_tmp_b);
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100154 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100155
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000156 // Configure interleave kernel
Manuel Bottini2b84be52020-04-08 10:15:51 +0100157 _reshape_lhs_kernel.configure(compile_context, a, &_tmp_a, lhs_info, reinterpret_input_as_3d);
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100158
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000159 // Configure transpose kernel
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100160 ICLTensor *reshaped_rhs = &_tmp_b;
161 if(_weights_manager && _weights_manager->are_weights_managed(b))
162 {
Manuel Bottini2b84be52020-04-08 10:15:51 +0100163 _reshape_rhs_kernel_managed.configure(compile_context, b, rhs_info);
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100164 reshaped_rhs = utils::cast::polymorphic_downcast<ICLTensor *>(_weights_manager->acquire(b, &_reshape_rhs_kernel_managed));
165 }
166 else
167 {
Manuel Bottini2b84be52020-04-08 10:15:51 +0100168 _reshape_rhs_kernel.configure(compile_context, b, &_tmp_b, rhs_info);
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100169 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100170
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000171 // Configure and tune matrix multiply kernel
Manuel Bottini2b84be52020-04-08 10:15:51 +0100172 _mm_kernel.configure(compile_context, &_tmp_a, reshaped_rhs, c, output, alpha, beta, true, reshape_info, gemm_info.fp_mixed_precision(), gemm_info.activation_info());
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000173
174 CLScheduler::get().tune_kernel_static(_mm_kernel);
175
176 // Allocate intermediate tensors
177 _tmp_a.allocator()->allocate();
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100178
179 if(!_reshape_b_only_on_first_run && use_mm_b)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100180 {
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000181 _tmp_b.allocator()->allocate();
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100182 }
183}
184
Manuel Bottini2b84be52020-04-08 10:15:51 +0100185void CLGEMM::configure_reshaped_v2(const CLCompileContext &compile_context, const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta,
186 const GEMMInfo &gemm_info)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000187{
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000188 DataType data_type = a->info()->data_type();
189 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
190 const unsigned int m = reinterpret_input_as_3d ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1);
191 const unsigned int n = b->info()->dimension(0);
192 const unsigned int k = a->info()->dimension(0);
193 const unsigned int batch_size = reinterpret_input_as_3d ? a->info()->dimension(3) : a->info()->dimension(2);
194 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
195 const GPUTarget gpu_target = CLScheduler::get().target();
Gian Marco Iodicee16c8902019-06-14 16:11:10 +0100196 bool broadcast_bias = gemm_info.broadcast_bias();
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100197
198 GEMMKernelInfo kernel_info;
Gian Marco Iodice7026b302019-06-26 17:18:11 +0100199 kernel_info.m = m;
200 kernel_info.n = n;
201 kernel_info.k = k;
202 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
203 kernel_info.reinterpret_input_as_3d = false;
204 kernel_info.broadcast_bias = broadcast_bias;
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100205 kernel_info.activation_info = gemm_info.activation_info();
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000206
207 // Set the target for the kernels
208 _reshape_lhs_kernel.set_target(gpu_target);
209 _mm_kernel.set_target(gpu_target);
210
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100211 const bool use_mm_b = (!_weights_manager || !_weights_manager->are_weights_managed(b));
212
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000213 // Manage intermediate buffers
214 _memory_group.manage(&_tmp_a);
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100215
216 if(!_reshape_b_only_on_first_run && use_mm_b)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000217 {
218 _memory_group.manage(&_tmp_b);
219 }
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100220
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000221 // _tmp_a and _tmp_b will be auto configured in _interleave_kernel and in _transpose_kernel
222
223 GEMMLHSMatrixInfo lhs_info{};
224 GEMMRHSMatrixInfo rhs_info{};
225
226 // Pick up the GEMM configuration
227 std::unique_ptr<ICLGEMMKernelConfiguration> gemm_config = CLGEMMReshapedKernelConfigurationFactory::create(gpu_target);
228 ARM_COMPUTE_ERROR_ON_NULLPTR(gemm_config.get());
229
230 // Configure lhs_info and rhs_info
231 std::tie(lhs_info, rhs_info) = gemm_config->configure(m, n, k, batch_size, data_type);
232
Manuel Bottini2b84be52020-04-08 10:15:51 +0100233 _reshape_lhs_kernel.configure(compile_context, a, &_tmp_a, lhs_info, gemm_info.reinterpret_input_as_3d());
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100234
235 ICLTensor *reshaped_rhs = &_tmp_b;
236 if(_weights_manager && _weights_manager->are_weights_managed(b))
237 {
Manuel Bottini2b84be52020-04-08 10:15:51 +0100238 _reshape_rhs_kernel_managed.configure(compile_context, b, rhs_info);
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100239 reshaped_rhs = utils::cast::polymorphic_downcast<ICLTensor *>(_weights_manager->acquire(b, &_reshape_rhs_kernel_managed));
240 }
241 else
242 {
Manuel Bottini2b84be52020-04-08 10:15:51 +0100243 _reshape_rhs_kernel.configure(compile_context, b, &_tmp_b, rhs_info);
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100244 }
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000245
246 // Configure and tune matrix multiply kernel
Manuel Bottini2b84be52020-04-08 10:15:51 +0100247 _mm_reshaped_kernel.configure(compile_context, &_tmp_a, reshaped_rhs, c, output, alpha, beta, lhs_info, rhs_info, kernel_info);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000248
249 // Allocate intermediate tensors
250 _tmp_a.allocator()->allocate();
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100251
252 if(!_reshape_b_only_on_first_run && use_mm_b)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000253 {
254 _tmp_b.allocator()->allocate();
255 }
256}
257
Manuel Bottini2b84be52020-04-08 10:15:51 +0100258void CLGEMM::configure_reshaped_only_rhs(const CLCompileContext &compile_context, const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta,
259 const GEMMInfo &gemm_info)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000260{
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000261 DataType data_type = a->info()->data_type();
262 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
263 const unsigned int m = reinterpret_input_as_3d ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1);
264 const unsigned int n = b->info()->dimension(0);
265 const unsigned int k = a->info()->dimension(0);
266 const unsigned int batch_size = reinterpret_input_as_3d ? a->info()->dimension(3) : a->info()->dimension(2);
267 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
268 const GPUTarget gpu_target = CLScheduler::get().target();
Georgios Pinitasb0f342e2019-05-21 13:32:43 +0100269 bool broadcast_bias = gemm_info.broadcast_bias();
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100270
271 GEMMKernelInfo kernel_info;
Gian Marco Iodice7026b302019-06-26 17:18:11 +0100272 kernel_info.m = m;
273 kernel_info.n = n;
274 kernel_info.k = k;
275 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
276 kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
277 kernel_info.broadcast_bias = broadcast_bias;
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100278 kernel_info.activation_info = gemm_info.activation_info();
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000279
280 // Set the target for the kernels
281 _mm_kernel.set_target(gpu_target);
282
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100283 const bool use_mm_b = (!_weights_manager || !_weights_manager->are_weights_managed(b));
284
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000285 // Manage intermediate buffers
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100286 if(!_reshape_b_only_on_first_run && use_mm_b)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000287 {
288 _memory_group.manage(&_tmp_b);
289 }
290
291 GEMMLHSMatrixInfo lhs_info{};
292 GEMMRHSMatrixInfo rhs_info{};
293
294 // Pick up the GEMM configuration
295 std::unique_ptr<ICLGEMMKernelConfiguration> gemm_config = CLGEMMReshapedOnlyRHSKernelConfigurationFactory::create(gpu_target);
296 ARM_COMPUTE_ERROR_ON_NULLPTR(gemm_config.get());
297
Gian Marco Iodicec6eaec32020-07-20 13:31:05 +0100298 unsigned int m_internal = m;
299 unsigned int b_internal = batch_size;
300 if(reinterpret_input_as_3d)
301 {
302 m_internal = a->info()->dimension(1);
303 b_internal = a->info()->dimension(2);
304 }
305
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000306 // Configure lhs_info and rhs_info
Gian Marco Iodicec6eaec32020-07-20 13:31:05 +0100307 std::tie(lhs_info, rhs_info) = gemm_config->configure(m_internal, n, k, b_internal, data_type);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000308
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100309 ICLTensor *reshaped_rhs = &_tmp_b;
310 if(_weights_manager && _weights_manager->are_weights_managed(b))
311 {
Manuel Bottini2b84be52020-04-08 10:15:51 +0100312 _reshape_rhs_kernel_managed.configure(compile_context, b, rhs_info);
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100313 reshaped_rhs = utils::cast::polymorphic_downcast<ICLTensor *>(_weights_manager->acquire(b, &_reshape_rhs_kernel_managed));
314 }
315 else
316 {
Manuel Bottini2b84be52020-04-08 10:15:51 +0100317 _reshape_rhs_kernel.configure(compile_context, b, &_tmp_b, rhs_info);
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100318 }
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000319
320 // Configure and tune matrix multiply kernel
Manuel Bottini2b84be52020-04-08 10:15:51 +0100321 _mm_reshaped_only_rhs_kernel.configure(compile_context, a, reshaped_rhs, c, output, alpha, beta, lhs_info, rhs_info, kernel_info);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000322
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100323 if(!_reshape_b_only_on_first_run && use_mm_b)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000324 {
325 _tmp_b.allocator()->allocate();
326 }
327}
328
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000329Status CLGEMM::validate_native_v1(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
Georgios Pinitas78c00902018-01-09 17:33:11 +0000330{
Gian Marco Iodice750641d2018-05-08 12:01:57 +0100331 ARM_COMPUTE_UNUSED(alpha);
Gian Marco Iodice215b4ea2018-06-28 16:29:29 +0100332 ARM_COMPUTE_UNUSED(output);
Gian Marco Iodice750641d2018-05-08 12:01:57 +0100333
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000334 // Get the GPU target
335 const GPUTarget gpu_target = CLScheduler::get().target();
336 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
337 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
338 const unsigned int n = b->dimension(0);
339 const unsigned int k = a->dimension(0);
340 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
Gian Marco Iodice750641d2018-05-08 12:01:57 +0100341
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100342 const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d, gemm_info.broadcast_bias());
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000343
344 // Validate matrix multiply
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100345 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyKernel::validate(a, b, c, output, alpha, beta,
346 false, reshape_info, gpu_target, gemm_info.fp_mixed_precision(), gemm_info.activation_info()));
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000347
348 return Status{};
349}
350
351Status CLGEMM::validate_reshaped_v1(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
352{
353 ARM_COMPUTE_UNUSED(alpha);
354 ARM_COMPUTE_UNUSED(output);
Gian Marco Iodice750641d2018-05-08 12:01:57 +0100355
356 TensorInfo tmp_a_info{};
357 TensorInfo tmp_b_info{};
Gian Marco Iodice750641d2018-05-08 12:01:57 +0100358
359 // Get the GPU target
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000360 const GPUTarget gpu_target = CLScheduler::get().target();
361 const unsigned int m = gemm_info.reinterpret_input_as_3d() ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000362 const unsigned int n = b->dimension(0);
363 const unsigned int k = a->dimension(0);
364 int mult_transpose1xW_width = 1;
365 int mult_interleave4x4_height = 1;
366 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
Gian Marco Iodice750641d2018-05-08 12:01:57 +0100367
368 if(get_arch_from_target(gpu_target) == GPUTarget::BIFROST)
369 {
370 mult_transpose1xW_width = 4;
371 mult_interleave4x4_height = 2;
372 }
373
giuros018b6b4a92018-12-18 19:01:33 +0000374 GEMMRHSMatrixInfo rhs_info;
375 rhs_info.n0 = 16 / b->element_size();
376 rhs_info.k0 = 1;
377 rhs_info.h0 = mult_transpose1xW_width;
378 rhs_info.interleave = false;
379 rhs_info.transpose = false;
380
giuros011c9efeb2019-01-11 14:04:43 +0000381 GEMMLHSMatrixInfo lhs_info;
382 lhs_info.m0 = 4;
383 lhs_info.k0 = 4;
384 lhs_info.v0 = mult_interleave4x4_height;
385 lhs_info.interleave = true;
386 lhs_info.transpose = true;
387
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100388 const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d, false, gemm_info.broadcast_bias());
Gian Marco Iodice750641d2018-05-08 12:01:57 +0100389
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000390 // Validate interleave kernel
391 auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_lhs_reshaped_shape(*a, lhs_info, gemm_info.reinterpret_input_as_3d())));
392 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeLHSMatrixKernel::validate(a, &tmp_a_info, lhs_info, gemm_info.reinterpret_input_as_3d()));
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000393
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000394 // Validate transpose kernel
395 auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
396 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeRHSMatrixKernel::validate(b, &tmp_b_info, rhs_info));
Michele Di Giorgioebc3a902018-11-16 16:04:25 +0000397
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000398 // Validate matrix multiply
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100399 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyKernel::validate(&tmp_a_info, &tmp_b_info, c, output, alpha, beta,
400 true, reshape_info, gpu_target, gemm_info.fp_mixed_precision(), gemm_info.activation_info()));
Gian Marco Iodice750641d2018-05-08 12:01:57 +0100401
Georgios Pinitas78c00902018-01-09 17:33:11 +0000402 return Status{};
403}
404
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000405Status CLGEMM::validate_reshaped(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000406{
407 ARM_COMPUTE_UNUSED(alpha);
408 ARM_COMPUTE_UNUSED(output);
409
410 TensorInfo tmp_a_info{};
411 TensorInfo tmp_b_info{};
412
413 // Get the GPU target
414 const GPUTarget gpu_target = CLScheduler::get().target();
415 DataType data_type = a->data_type();
416 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
417 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
418 const unsigned int n = b->dimension(0);
419 const unsigned int k = a->dimension(0);
420 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
421 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
Gian Marco Iodicee16c8902019-06-14 16:11:10 +0100422 const bool broadcast_bias = gemm_info.broadcast_bias();
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100423
424 GEMMKernelInfo kernel_info;
Gian Marco Iodice7026b302019-06-26 17:18:11 +0100425 kernel_info.m = m;
426 kernel_info.n = n;
427 kernel_info.k = k;
428 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
429 kernel_info.reinterpret_input_as_3d = false;
430 kernel_info.broadcast_bias = broadcast_bias;
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100431 kernel_info.activation_info = gemm_info.activation_info();
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000432
433 GEMMLHSMatrixInfo lhs_info;
434 GEMMRHSMatrixInfo rhs_info;
435
436 // Pick up the GEMM configuration
437 std::unique_ptr<ICLGEMMKernelConfiguration> gemm_config = CLGEMMReshapedKernelConfigurationFactory::create(gpu_target);
438 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(gemm_config.get());
439
440 // Configure lhs_info and rhs_info
441 std::tie(lhs_info, rhs_info) = gemm_config->configure(m, n, k, batch_size, data_type);
442
443 auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_lhs_reshaped_shape(*a, lhs_info, gemm_info.reinterpret_input_as_3d())));
444 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeLHSMatrixKernel::validate(a, &tmp_a_info, lhs_info, gemm_info.reinterpret_input_as_3d()));
445
446 auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
447 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeRHSMatrixKernel::validate(b, &tmp_b_info, rhs_info));
448
449 // Validate matrix multiply
Gian Marco Iodice7026b302019-06-26 17:18:11 +0100450 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyReshapedKernel::validate(&tmp_a_info, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info));
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000451
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000452 return Status{};
453}
454
455Status CLGEMM::validate_reshaped_only_rhs(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
456{
457 ARM_COMPUTE_UNUSED(alpha);
458 ARM_COMPUTE_UNUSED(output);
459
460 TensorInfo tmp_b_info{};
461
462 // Get the GPU target
Gian Marco Iodice7026b302019-06-26 17:18:11 +0100463 const GPUTarget gpu_target = CLScheduler::get().target();
464 const DataType data_type = a->data_type();
465 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
466 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
467 const unsigned int n = b->dimension(0);
468 const unsigned int k = a->dimension(0);
469 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
470 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
471 const bool broadcast_bias = gemm_info.broadcast_bias();
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100472
473 GEMMKernelInfo kernel_info;
Gian Marco Iodice7026b302019-06-26 17:18:11 +0100474 kernel_info.m = m;
475 kernel_info.n = n;
476 kernel_info.k = k;
477 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
478 kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
479 kernel_info.broadcast_bias = broadcast_bias;
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100480 kernel_info.activation_info = gemm_info.activation_info();
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000481
482 GEMMLHSMatrixInfo lhs_info;
483 GEMMRHSMatrixInfo rhs_info;
484
485 // Pick up the GEMM configuration
486 std::unique_ptr<ICLGEMMKernelConfiguration> gemm_config = CLGEMMReshapedOnlyRHSKernelConfigurationFactory::create(gpu_target);
487 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(gemm_config.get());
488
489 // Configure lhs_info and rhs_info
490 std::tie(lhs_info, rhs_info) = gemm_config->configure(m, n, k, batch_size, data_type);
491
492 auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
493 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeRHSMatrixKernel::validate(b, &tmp_b_info, rhs_info));
494
495 // Validate matrix multiply
Gian Marco Iodice7026b302019-06-26 17:18:11 +0100496 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyReshapedOnlyRHSKernel::validate(a, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info));
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000497
498 return Status{};
499}
500
501void CLGEMM::configure(const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info)
502{
Manuel Bottini2b84be52020-04-08 10:15:51 +0100503 configure(CLKernelLibrary::get().get_compile_context(), a, b, c, output, alpha, beta, gemm_info);
504}
505
506void CLGEMM::configure(const CLCompileContext &compile_context, const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info)
507{
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000508 ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, output);
509
510 // Perform validation step
511 ARM_COMPUTE_ERROR_THROW_ON(validate(a->info(), b->info(), c != nullptr ? c->info() : nullptr, output->info(), alpha, beta, gemm_info));
512
513 // Check if we need to reshape the matrix B only on the first run
514 _reshape_b_only_on_first_run = gemm_info.reshape_b_only_on_first_run();
515 _is_prepared = gemm_info.retain_internal_weights();
516 _original_b = b;
517
518 // Get the GPU target
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000519 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
520 const unsigned int m = reinterpret_input_as_3d ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1);
521 const unsigned int n = b->info()->dimension(0);
522 const unsigned int k = a->info()->dimension(0);
Gian Marco Iodice026d0452020-08-28 13:52:12 +0100523 const unsigned int batch_size = reinterpret_input_as_3d ? a->info()->dimension(3) : a->info()->dimension(2);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000524
525 // Select GEMMType
Gian Marco Iodice026d0452020-08-28 13:52:12 +0100526 _gemm_kernel_type = select_gemm_kernel(m, n, k, batch_size, a->info()->data_type(), _reshape_b_only_on_first_run);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000527
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100528 const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr);
Gian Marco Iodicee16c8902019-06-14 16:11:10 +0100529
530 const ICLTensor *c_to_use = fuse_add_c ? c : nullptr;
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000531
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000532 switch(_gemm_kernel_type)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000533 {
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000534 case CLGEMMKernelType::NATIVE_V1:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000535 {
Manuel Bottini2b84be52020-04-08 10:15:51 +0100536 configure_native_v1(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000537 break;
538 }
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000539 case CLGEMMKernelType::RESHAPED_V1:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000540 {
Manuel Bottini2b84be52020-04-08 10:15:51 +0100541 configure_reshaped_v1(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000542 break;
543 }
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000544 case CLGEMMKernelType::RESHAPED:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000545 {
Manuel Bottini2b84be52020-04-08 10:15:51 +0100546 configure_reshaped_v2(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000547 break;
548 }
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000549 case CLGEMMKernelType::RESHAPED_ONLY_RHS:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000550 {
Manuel Bottini2b84be52020-04-08 10:15:51 +0100551 configure_reshaped_only_rhs(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000552 break;
553 }
554 default:
555 {
556 ARM_COMPUTE_ERROR("GEMMType not supported");
557 }
558 }
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000559}
560
561Status CLGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
562{
563 // Get the GPU target
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000564 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
565 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
566 const unsigned int n = b->dimension(0);
567 const unsigned int k = a->dimension(0);
Gian Marco Iodice026d0452020-08-28 13:52:12 +0100568 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000569
570 // Select GEMMType
Gian Marco Iodice026d0452020-08-28 13:52:12 +0100571 CLGEMMKernelType gemm_kernel_type = select_gemm_kernel(m, n, k, batch_size, a->data_type(), gemm_info.reshape_b_only_on_first_run());
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000572
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100573 const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr);
Gian Marco Iodicee16c8902019-06-14 16:11:10 +0100574
575 const ITensorInfo *c_to_use = fuse_add_c ? c : nullptr;
576
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000577 switch(gemm_kernel_type)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000578 {
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000579 case CLGEMMKernelType::NATIVE_V1:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000580 {
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000581 ARM_COMPUTE_RETURN_ON_ERROR(validate_native_v1(a, b, c_to_use, output, alpha, beta, gemm_info));
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000582 break;
583 }
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000584 case CLGEMMKernelType::RESHAPED_V1:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000585 {
Gian Marco Iodicee16c8902019-06-14 16:11:10 +0100586 ARM_COMPUTE_RETURN_ON_ERROR(validate_reshaped_v1(a, b, c_to_use, output, alpha, beta, gemm_info));
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000587 break;
588 }
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000589 case CLGEMMKernelType::RESHAPED:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000590 {
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000591 ARM_COMPUTE_RETURN_ON_ERROR(validate_reshaped(a, b, c_to_use, output, alpha, beta, gemm_info));
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000592 break;
593 }
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000594 case CLGEMMKernelType::RESHAPED_ONLY_RHS:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000595 {
Gian Marco Iodicee16c8902019-06-14 16:11:10 +0100596 ARM_COMPUTE_RETURN_ON_ERROR(validate_reshaped_only_rhs(a, b, c_to_use, output, alpha, beta, gemm_info));
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000597 break;
598 }
599 default:
600 {
601 ARM_COMPUTE_RETURN_ERROR_MSG("GEMMType not supported");
602 }
603 }
604
605 return Status{};
606}
607
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100608void CLGEMM::run()
609{
Georgios Pinitase0437672018-05-02 14:07:55 +0100610 prepare();
611
Georgios Pinitasda953f22019-04-02 17:27:03 +0100612 MemoryGroupResourceScope scope_mg(_memory_group);
Georgios Pinitas8a94e7c2017-09-15 19:06:47 +0100613
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100614 // Run matrix multiply kernel
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000615 switch(_gemm_kernel_type)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000616 {
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000617 case CLGEMMKernelType::NATIVE_V1:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000618 {
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100619 CLScheduler::get().enqueue(_mm_kernel, true);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000620 break;
621 }
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000622 case CLGEMMKernelType::RESHAPED_V1:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000623 {
624 // Run interleave kernel
625 CLScheduler::get().enqueue(_reshape_lhs_kernel, false);
626
627 if(!_reshape_b_only_on_first_run)
628 {
629 // Run transpose kernel
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100630 if(_weights_manager && _weights_manager->are_weights_managed(_original_b))
631 {
632 _weights_manager->run(_original_b, &_reshape_rhs_kernel_managed);
633 }
634 else
635 {
636 CLScheduler::get().enqueue(_reshape_rhs_kernel, false);
637 }
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000638 }
639
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100640 CLScheduler::get().enqueue(_mm_kernel, true);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000641 break;
642 }
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000643 case CLGEMMKernelType::RESHAPED:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000644 {
645 // Run interleave kernel
646 CLScheduler::get().enqueue(_reshape_lhs_kernel, false);
647
648 if(!_reshape_b_only_on_first_run)
649 {
650 // Run transpose kernel
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100651 if(_weights_manager && _weights_manager->are_weights_managed(_original_b))
652 {
653 _weights_manager->run(_original_b, &_reshape_rhs_kernel_managed);
654 }
655 else
656 {
657 CLScheduler::get().enqueue(_reshape_rhs_kernel, false);
658 }
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000659 }
660
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100661 CLScheduler::get().enqueue(_mm_reshaped_kernel, true);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000662 break;
663 }
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000664 case CLGEMMKernelType::RESHAPED_ONLY_RHS:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000665 {
666 if(!_reshape_b_only_on_first_run)
667 {
668 // Run transpose kernel
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100669 if(_weights_manager && _weights_manager->are_weights_managed(_original_b))
670 {
671 _weights_manager->run(_original_b, &_reshape_rhs_kernel_managed);
672 }
673 else
674 {
675 CLScheduler::get().enqueue(_reshape_rhs_kernel, false);
676 }
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000677 }
678
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100679 CLScheduler::get().enqueue(_mm_reshaped_only_rhs_kernel, true);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000680 break;
681 }
682 default:
683 {
684 ARM_COMPUTE_ERROR("GEMMType not supported");
685 }
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000686 }
Georgios Pinitase0437672018-05-02 14:07:55 +0100687}
Georgios Pinitas82b51482018-04-24 15:14:12 +0100688
Georgios Pinitase0437672018-05-02 14:07:55 +0100689void CLGEMM::prepare()
690{
691 if(!_is_prepared)
692 {
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000693 if(_gemm_kernel_type != CLGEMMKernelType::NATIVE_V1 && _reshape_b_only_on_first_run)
Georgios Pinitase0437672018-05-02 14:07:55 +0100694 {
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100695 if(_weights_manager && _weights_manager->are_weights_managed(_original_b))
696 {
697 _weights_manager->run(_original_b, &_reshape_rhs_kernel_managed);
698 }
699 else
700 {
701 // Run transpose kernel and mark original weights tensor as unused
702 _tmp_b.allocator()->allocate();
703 CLScheduler::get().enqueue(_reshape_rhs_kernel, false);
704 _original_b->mark_as_unused();
705 }
Georgios Pinitase0437672018-05-02 14:07:55 +0100706 }
707 CLScheduler::get().queue().finish();
708 _is_prepared = true;
709 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100710}
giuros011c9efeb2019-01-11 14:04:43 +0000711} // namespace arm_compute