blob: 74d59cdad1fd810039acbacd80074adf436e704b [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Giorgio Arena695ad692020-02-07 10:28:45 +00002 * Copyright (c) 2017-2020 ARM Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/runtime/CL/functions/CLGEMM.h"
25
Gian Marco Iodice926afe12019-03-19 11:44:13 +000026#include "arm_compute/core/CL/ICLGEMMKernelConfiguration.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010027#include "arm_compute/core/CL/ICLTensor.h"
Gian Marco Iodice926afe12019-03-19 11:44:13 +000028#include "arm_compute/core/CL/gemm/reshaped/CLGEMMReshapedKernelConfiguration.h"
29#include "arm_compute/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfiguration.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010030#include "arm_compute/core/Error.h"
Gian Marco Iodice750641d2018-05-08 12:01:57 +010031#include "arm_compute/core/GPUTarget.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010032#include "arm_compute/core/Helpers.h"
Gian Marco Iodice7026b302019-06-26 17:18:11 +010033#include "arm_compute/core/KernelDescriptors.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010034#include "arm_compute/core/TensorInfo.h"
35#include "arm_compute/core/Types.h"
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +010036#include "arm_compute/core/Utils.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010037#include "arm_compute/core/Validate.h"
Gian Marco Iodicee16c8902019-06-14 16:11:10 +010038#include "arm_compute/core/utils/helpers/float_ops.h"
Michalis Spyroub27e13a2019-09-27 11:04:27 +010039#include "arm_compute/core/utils/misc/Cast.h"
Gian Marco Iodice750641d2018-05-08 12:01:57 +010040#include "arm_compute/core/utils/misc/ShapeCalculator.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010041#include "arm_compute/runtime/CL/CLScheduler.h"
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +000042#include "arm_compute/runtime/CL/gemm/CLGEMMKernelSelection.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010043#include "arm_compute/runtime/ITensorAllocator.h"
44
giuros011c9efeb2019-01-11 14:04:43 +000045namespace arm_compute
46{
Gian Marco Iodice750641d2018-05-08 12:01:57 +010047using namespace arm_compute::misc::shape_calculator;
Gian Marco Iodice90313eb2019-01-16 15:40:25 +000048using namespace arm_compute::cl_gemm;
Michalis Spyroub27e13a2019-09-27 11:04:27 +010049using namespace arm_compute::utils::cast;
Anthony Barbier6ff3b192017-09-04 18:44:23 +010050
Michalis Spyroub27e13a2019-09-27 11:04:27 +010051CLGEMM::CLGEMM(std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager)
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +010052 : _memory_group(std::move(memory_manager)),
Michalis Spyroub27e13a2019-09-27 11:04:27 +010053 _weights_manager(weights_manager),
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +010054 _mm_kernel(),
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +000055 _reshape_lhs_kernel(),
56 _reshape_rhs_kernel(),
Michalis Spyroub27e13a2019-09-27 11:04:27 +010057 _reshape_rhs_kernel_managed(),
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +000058 _mm_reshaped_kernel(),
Gian Marco Iodice926afe12019-03-19 11:44:13 +000059 _mm_reshaped_only_rhs_kernel(),
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +010060 _tmp_a(),
61 _tmp_b(),
62 _original_b(nullptr),
Georgios Pinitasebf6b8a2018-09-24 16:31:08 +010063 _reshape_b_only_on_first_run(false),
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +000064 _is_prepared(false),
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +000065 _gemm_kernel_type(CLGEMMKernelType::NATIVE_V1)
Anthony Barbier6ff3b192017-09-04 18:44:23 +010066{
67}
68
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +000069CLGEMMKernelType CLGEMM::select_gemm_kernel(unsigned int m, unsigned int n, unsigned int k, DataType data_type, bool reshape_b_only_on_first_run)
Anthony Barbier6ff3b192017-09-04 18:44:23 +010070{
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +000071 std::unique_ptr<ICLGEMMKernelSelection> gemm_kernel = CLGEMMKernelSelectionFactory::create(CLScheduler::get().target());
72 ARM_COMPUTE_ERROR_ON_NULLPTR(gemm_kernel.get());
Anthony Barbier6ff3b192017-09-04 18:44:23 +010073
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +000074 CLGEMMKernelSelectionParams params;
75 params.m = m;
76 params.n = n;
77 params.k = k;
78 params.is_rhs_constant = reshape_b_only_on_first_run;
79 params.data_type = data_type;
Gian Marco Iodice05639f62019-09-24 12:05:06 +010080
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +000081 return gemm_kernel->select_kernel(params);
Gian Marco Iodice926afe12019-03-19 11:44:13 +000082}
83
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +000084void CLGEMM::configure_native_v1(const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info)
Gian Marco Iodice926afe12019-03-19 11:44:13 +000085{
86 const unsigned int m = gemm_info.reinterpret_input_as_3d() ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1);
87 const unsigned int n = b->info()->dimension(0);
88 const unsigned int k = a->info()->dimension(0);
89 const GPUTarget gpu_target = CLScheduler::get().target();
Gian Marco36a0a462018-01-12 10:21:40 +000090
91 // Set the target for the kernels
Gian Marco36a0a462018-01-12 10:21:40 +000092 _mm_kernel.set_target(gpu_target);
93
Gian Marco Iodicef3622be2019-07-29 14:27:16 +010094 GEMMReshapeInfo reshape_info(m, n, k, 1, 1, gemm_info.depth_output_gemm3d(), gemm_info.reinterpret_input_as_3d(), gemm_info.broadcast_bias());
Gian Marco Iodice926afe12019-03-19 11:44:13 +000095
96 // Configure and tune matrix multiply kernel
Gian Marco Iodicef3622be2019-07-29 14:27:16 +010097 _mm_kernel.configure(a, b, c, output, alpha, beta, false, reshape_info, gemm_info.fp_mixed_precision(), gemm_info.activation_info());
Gian Marco Iodice926afe12019-03-19 11:44:13 +000098
99 // Tune kernel statically
100 CLScheduler::get().tune_kernel_static(_mm_kernel);
101}
102
103void CLGEMM::configure_reshaped_v1(const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info)
104{
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000105 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
106 const unsigned int m = reinterpret_input_as_3d ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1);
107 const unsigned int n = b->info()->dimension(0);
108 const unsigned int k = a->info()->dimension(0);
109 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000110 const GPUTarget gpu_target = CLScheduler::get().target();
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000111 int mult_transpose1xW_width = 1;
112 int mult_interleave4x4_height = 1;
Gian Marco36a0a462018-01-12 10:21:40 +0000113
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000114 // Set the target for the kernels
115 _reshape_lhs_kernel.set_target(gpu_target);
116 _mm_kernel.set_target(gpu_target);
117
Gian Marco Iodice750641d2018-05-08 12:01:57 +0100118 if(get_arch_from_target(gpu_target) == GPUTarget::BIFROST)
Gian Marco36a0a462018-01-12 10:21:40 +0000119 {
120 mult_transpose1xW_width = 4;
121 mult_interleave4x4_height = 2;
122 }
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000123
giuros018b6b4a92018-12-18 19:01:33 +0000124 GEMMRHSMatrixInfo rhs_info;
125 rhs_info.n0 = 16 / b->info()->element_size();
126 rhs_info.k0 = 1;
127 rhs_info.h0 = mult_transpose1xW_width;
128 rhs_info.interleave = false;
129 rhs_info.transpose = false;
Gian Marco36a0a462018-01-12 10:21:40 +0000130
giuros011c9efeb2019-01-11 14:04:43 +0000131 GEMMLHSMatrixInfo lhs_info;
132 lhs_info.m0 = 4;
133 lhs_info.k0 = 4;
134 lhs_info.v0 = mult_interleave4x4_height;
135 lhs_info.interleave = true;
136 lhs_info.transpose = true;
137
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100138 GEMMReshapeInfo reshape_info(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d, false, gemm_info.broadcast_bias());
Gian Marcob5311a62017-12-13 12:48:03 +0000139
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100140 const bool use_mm_b = (!_weights_manager || !_weights_manager->are_weights_managed(b));
141
142 // Manage intermediate buffers
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000143 _memory_group.manage(&_tmp_a);
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100144
145 if(!_reshape_b_only_on_first_run && use_mm_b)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +0100146 {
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000147 _memory_group.manage(&_tmp_b);
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100148 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100149
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000150 // Configure interleave kernel
151 _reshape_lhs_kernel.configure(a, &_tmp_a, lhs_info, reinterpret_input_as_3d);
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100152
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000153 // Configure transpose kernel
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100154 ICLTensor *reshaped_rhs = &_tmp_b;
155 if(_weights_manager && _weights_manager->are_weights_managed(b))
156 {
157 _reshape_rhs_kernel_managed.configure(b, rhs_info);
158 reshaped_rhs = utils::cast::polymorphic_downcast<ICLTensor *>(_weights_manager->acquire(b, &_reshape_rhs_kernel_managed));
159 }
160 else
161 {
162 _reshape_rhs_kernel.configure(b, &_tmp_b, rhs_info);
163 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100164
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000165 // Configure and tune matrix multiply kernel
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100166 _mm_kernel.configure(&_tmp_a, reshaped_rhs, c, output, alpha, beta, true, reshape_info, gemm_info.fp_mixed_precision(), gemm_info.activation_info());
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000167
168 CLScheduler::get().tune_kernel_static(_mm_kernel);
169
170 // Allocate intermediate tensors
171 _tmp_a.allocator()->allocate();
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100172
173 if(!_reshape_b_only_on_first_run && use_mm_b)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100174 {
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000175 _tmp_b.allocator()->allocate();
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100176 }
177}
178
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000179void CLGEMM::configure_reshaped(const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000180{
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000181 DataType data_type = a->info()->data_type();
182 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
183 const unsigned int m = reinterpret_input_as_3d ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1);
184 const unsigned int n = b->info()->dimension(0);
185 const unsigned int k = a->info()->dimension(0);
186 const unsigned int batch_size = reinterpret_input_as_3d ? a->info()->dimension(3) : a->info()->dimension(2);
187 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
188 const GPUTarget gpu_target = CLScheduler::get().target();
Gian Marco Iodicee16c8902019-06-14 16:11:10 +0100189 bool broadcast_bias = gemm_info.broadcast_bias();
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100190
191 GEMMKernelInfo kernel_info;
Gian Marco Iodice7026b302019-06-26 17:18:11 +0100192 kernel_info.m = m;
193 kernel_info.n = n;
194 kernel_info.k = k;
195 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
196 kernel_info.reinterpret_input_as_3d = false;
197 kernel_info.broadcast_bias = broadcast_bias;
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100198 kernel_info.activation_info = gemm_info.activation_info();
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000199
200 // Set the target for the kernels
201 _reshape_lhs_kernel.set_target(gpu_target);
202 _mm_kernel.set_target(gpu_target);
203
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100204 const bool use_mm_b = (!_weights_manager || !_weights_manager->are_weights_managed(b));
205
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000206 // Manage intermediate buffers
207 _memory_group.manage(&_tmp_a);
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100208
209 if(!_reshape_b_only_on_first_run && use_mm_b)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000210 {
211 _memory_group.manage(&_tmp_b);
212 }
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100213
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000214 // _tmp_a and _tmp_b will be auto configured in _interleave_kernel and in _transpose_kernel
215
216 GEMMLHSMatrixInfo lhs_info{};
217 GEMMRHSMatrixInfo rhs_info{};
218
219 // Pick up the GEMM configuration
220 std::unique_ptr<ICLGEMMKernelConfiguration> gemm_config = CLGEMMReshapedKernelConfigurationFactory::create(gpu_target);
221 ARM_COMPUTE_ERROR_ON_NULLPTR(gemm_config.get());
222
223 // Configure lhs_info and rhs_info
224 std::tie(lhs_info, rhs_info) = gemm_config->configure(m, n, k, batch_size, data_type);
225
226 _reshape_lhs_kernel.configure(a, &_tmp_a, lhs_info, gemm_info.reinterpret_input_as_3d());
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100227
228 ICLTensor *reshaped_rhs = &_tmp_b;
229 if(_weights_manager && _weights_manager->are_weights_managed(b))
230 {
231 _reshape_rhs_kernel_managed.configure(b, rhs_info);
232 reshaped_rhs = utils::cast::polymorphic_downcast<ICLTensor *>(_weights_manager->acquire(b, &_reshape_rhs_kernel_managed));
233 }
234 else
235 {
236 _reshape_rhs_kernel.configure(b, &_tmp_b, rhs_info);
237 }
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000238
239 // Configure and tune matrix multiply kernel
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100240 _mm_reshaped_kernel.configure(&_tmp_a, reshaped_rhs, c, output, alpha, beta, lhs_info, rhs_info, kernel_info);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000241
242 // Allocate intermediate tensors
243 _tmp_a.allocator()->allocate();
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100244
245 if(!_reshape_b_only_on_first_run && use_mm_b)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000246 {
247 _tmp_b.allocator()->allocate();
248 }
249}
250
251void CLGEMM::configure_reshaped_only_rhs(const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info)
252{
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000253 DataType data_type = a->info()->data_type();
254 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
255 const unsigned int m = reinterpret_input_as_3d ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1);
256 const unsigned int n = b->info()->dimension(0);
257 const unsigned int k = a->info()->dimension(0);
258 const unsigned int batch_size = reinterpret_input_as_3d ? a->info()->dimension(3) : a->info()->dimension(2);
259 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
260 const GPUTarget gpu_target = CLScheduler::get().target();
Georgios Pinitasb0f342e2019-05-21 13:32:43 +0100261 bool broadcast_bias = gemm_info.broadcast_bias();
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100262
263 GEMMKernelInfo kernel_info;
Gian Marco Iodice7026b302019-06-26 17:18:11 +0100264 kernel_info.m = m;
265 kernel_info.n = n;
266 kernel_info.k = k;
267 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
268 kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
269 kernel_info.broadcast_bias = broadcast_bias;
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100270 kernel_info.activation_info = gemm_info.activation_info();
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000271
272 // Set the target for the kernels
273 _mm_kernel.set_target(gpu_target);
274
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100275 const bool use_mm_b = (!_weights_manager || !_weights_manager->are_weights_managed(b));
276
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000277 // Manage intermediate buffers
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100278 if(!_reshape_b_only_on_first_run && use_mm_b)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000279 {
280 _memory_group.manage(&_tmp_b);
281 }
282
283 GEMMLHSMatrixInfo lhs_info{};
284 GEMMRHSMatrixInfo rhs_info{};
285
286 // Pick up the GEMM configuration
287 std::unique_ptr<ICLGEMMKernelConfiguration> gemm_config = CLGEMMReshapedOnlyRHSKernelConfigurationFactory::create(gpu_target);
288 ARM_COMPUTE_ERROR_ON_NULLPTR(gemm_config.get());
289
290 // Configure lhs_info and rhs_info
291 std::tie(lhs_info, rhs_info) = gemm_config->configure(m, n, k, batch_size, data_type);
292
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100293 ICLTensor *reshaped_rhs = &_tmp_b;
294 if(_weights_manager && _weights_manager->are_weights_managed(b))
295 {
296 _reshape_rhs_kernel_managed.configure(b, rhs_info);
297 reshaped_rhs = utils::cast::polymorphic_downcast<ICLTensor *>(_weights_manager->acquire(b, &_reshape_rhs_kernel_managed));
298 }
299 else
300 {
301 _reshape_rhs_kernel.configure(b, &_tmp_b, rhs_info);
302 }
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000303
304 // Configure and tune matrix multiply kernel
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100305 _mm_reshaped_only_rhs_kernel.configure(a, reshaped_rhs, c, output, alpha, beta, lhs_info, rhs_info, kernel_info);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000306
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100307 if(!_reshape_b_only_on_first_run && use_mm_b)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000308 {
309 _tmp_b.allocator()->allocate();
310 }
311}
312
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000313Status CLGEMM::validate_native_v1(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
Georgios Pinitas78c00902018-01-09 17:33:11 +0000314{
Gian Marco Iodice750641d2018-05-08 12:01:57 +0100315 ARM_COMPUTE_UNUSED(alpha);
Gian Marco Iodice215b4ea2018-06-28 16:29:29 +0100316 ARM_COMPUTE_UNUSED(output);
Gian Marco Iodice750641d2018-05-08 12:01:57 +0100317
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000318 // Get the GPU target
319 const GPUTarget gpu_target = CLScheduler::get().target();
320 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
321 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
322 const unsigned int n = b->dimension(0);
323 const unsigned int k = a->dimension(0);
324 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
Gian Marco Iodice750641d2018-05-08 12:01:57 +0100325
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100326 const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d, gemm_info.broadcast_bias());
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000327
328 // Validate matrix multiply
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100329 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyKernel::validate(a, b, c, output, alpha, beta,
330 false, reshape_info, gpu_target, gemm_info.fp_mixed_precision(), gemm_info.activation_info()));
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000331
332 return Status{};
333}
334
335Status CLGEMM::validate_reshaped_v1(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
336{
337 ARM_COMPUTE_UNUSED(alpha);
338 ARM_COMPUTE_UNUSED(output);
Gian Marco Iodice750641d2018-05-08 12:01:57 +0100339
340 TensorInfo tmp_a_info{};
341 TensorInfo tmp_b_info{};
Gian Marco Iodice750641d2018-05-08 12:01:57 +0100342
343 // Get the GPU target
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000344 const GPUTarget gpu_target = CLScheduler::get().target();
345 const unsigned int m = gemm_info.reinterpret_input_as_3d() ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000346 const unsigned int n = b->dimension(0);
347 const unsigned int k = a->dimension(0);
348 int mult_transpose1xW_width = 1;
349 int mult_interleave4x4_height = 1;
350 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
Gian Marco Iodice750641d2018-05-08 12:01:57 +0100351
352 if(get_arch_from_target(gpu_target) == GPUTarget::BIFROST)
353 {
354 mult_transpose1xW_width = 4;
355 mult_interleave4x4_height = 2;
356 }
357
giuros018b6b4a92018-12-18 19:01:33 +0000358 GEMMRHSMatrixInfo rhs_info;
359 rhs_info.n0 = 16 / b->element_size();
360 rhs_info.k0 = 1;
361 rhs_info.h0 = mult_transpose1xW_width;
362 rhs_info.interleave = false;
363 rhs_info.transpose = false;
364
giuros011c9efeb2019-01-11 14:04:43 +0000365 GEMMLHSMatrixInfo lhs_info;
366 lhs_info.m0 = 4;
367 lhs_info.k0 = 4;
368 lhs_info.v0 = mult_interleave4x4_height;
369 lhs_info.interleave = true;
370 lhs_info.transpose = true;
371
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100372 const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, depth_output_gemm3d, false, gemm_info.broadcast_bias());
Gian Marco Iodice750641d2018-05-08 12:01:57 +0100373
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000374 // Validate interleave kernel
375 auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_lhs_reshaped_shape(*a, lhs_info, gemm_info.reinterpret_input_as_3d())));
376 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeLHSMatrixKernel::validate(a, &tmp_a_info, lhs_info, gemm_info.reinterpret_input_as_3d()));
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000377
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000378 // Validate transpose kernel
379 auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
380 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeRHSMatrixKernel::validate(b, &tmp_b_info, rhs_info));
Michele Di Giorgioebc3a902018-11-16 16:04:25 +0000381
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000382 // Validate matrix multiply
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100383 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyKernel::validate(&tmp_a_info, &tmp_b_info, c, output, alpha, beta,
384 true, reshape_info, gpu_target, gemm_info.fp_mixed_precision(), gemm_info.activation_info()));
Gian Marco Iodice750641d2018-05-08 12:01:57 +0100385
Georgios Pinitas78c00902018-01-09 17:33:11 +0000386 return Status{};
387}
388
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000389Status CLGEMM::validate_reshaped(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000390{
391 ARM_COMPUTE_UNUSED(alpha);
392 ARM_COMPUTE_UNUSED(output);
393
394 TensorInfo tmp_a_info{};
395 TensorInfo tmp_b_info{};
396
397 // Get the GPU target
398 const GPUTarget gpu_target = CLScheduler::get().target();
399 DataType data_type = a->data_type();
400 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
401 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
402 const unsigned int n = b->dimension(0);
403 const unsigned int k = a->dimension(0);
404 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
405 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
Gian Marco Iodicee16c8902019-06-14 16:11:10 +0100406 const bool broadcast_bias = gemm_info.broadcast_bias();
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100407
408 GEMMKernelInfo kernel_info;
Gian Marco Iodice7026b302019-06-26 17:18:11 +0100409 kernel_info.m = m;
410 kernel_info.n = n;
411 kernel_info.k = k;
412 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
413 kernel_info.reinterpret_input_as_3d = false;
414 kernel_info.broadcast_bias = broadcast_bias;
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100415 kernel_info.activation_info = gemm_info.activation_info();
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000416
417 GEMMLHSMatrixInfo lhs_info;
418 GEMMRHSMatrixInfo rhs_info;
419
420 // Pick up the GEMM configuration
421 std::unique_ptr<ICLGEMMKernelConfiguration> gemm_config = CLGEMMReshapedKernelConfigurationFactory::create(gpu_target);
422 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(gemm_config.get());
423
424 // Configure lhs_info and rhs_info
425 std::tie(lhs_info, rhs_info) = gemm_config->configure(m, n, k, batch_size, data_type);
426
427 auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_lhs_reshaped_shape(*a, lhs_info, gemm_info.reinterpret_input_as_3d())));
428 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeLHSMatrixKernel::validate(a, &tmp_a_info, lhs_info, gemm_info.reinterpret_input_as_3d()));
429
430 auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
431 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeRHSMatrixKernel::validate(b, &tmp_b_info, rhs_info));
432
433 // Validate matrix multiply
Gian Marco Iodice7026b302019-06-26 17:18:11 +0100434 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyReshapedKernel::validate(&tmp_a_info, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info));
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000435
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000436 return Status{};
437}
438
439Status CLGEMM::validate_reshaped_only_rhs(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
440{
441 ARM_COMPUTE_UNUSED(alpha);
442 ARM_COMPUTE_UNUSED(output);
443
444 TensorInfo tmp_b_info{};
445
446 // Get the GPU target
Gian Marco Iodice7026b302019-06-26 17:18:11 +0100447 const GPUTarget gpu_target = CLScheduler::get().target();
448 const DataType data_type = a->data_type();
449 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
450 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
451 const unsigned int n = b->dimension(0);
452 const unsigned int k = a->dimension(0);
453 const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2);
454 const int depth_output_gemm3d = gemm_info.depth_output_gemm3d();
455 const bool broadcast_bias = gemm_info.broadcast_bias();
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100456
457 GEMMKernelInfo kernel_info;
Gian Marco Iodice7026b302019-06-26 17:18:11 +0100458 kernel_info.m = m;
459 kernel_info.n = n;
460 kernel_info.k = k;
461 kernel_info.depth_output_gemm3d = depth_output_gemm3d;
462 kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d;
463 kernel_info.broadcast_bias = broadcast_bias;
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100464 kernel_info.activation_info = gemm_info.activation_info();
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000465
466 GEMMLHSMatrixInfo lhs_info;
467 GEMMRHSMatrixInfo rhs_info;
468
469 // Pick up the GEMM configuration
470 std::unique_ptr<ICLGEMMKernelConfiguration> gemm_config = CLGEMMReshapedOnlyRHSKernelConfigurationFactory::create(gpu_target);
471 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(gemm_config.get());
472
473 // Configure lhs_info and rhs_info
474 std::tie(lhs_info, rhs_info) = gemm_config->configure(m, n, k, batch_size, data_type);
475
476 auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info)));
477 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMReshapeRHSMatrixKernel::validate(b, &tmp_b_info, rhs_info));
478
479 // Validate matrix multiply
Gian Marco Iodice7026b302019-06-26 17:18:11 +0100480 ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMMatrixMultiplyReshapedOnlyRHSKernel::validate(a, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info));
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000481
482 return Status{};
483}
484
485void CLGEMM::configure(const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info)
486{
487 ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, output);
488
489 // Perform validation step
490 ARM_COMPUTE_ERROR_THROW_ON(validate(a->info(), b->info(), c != nullptr ? c->info() : nullptr, output->info(), alpha, beta, gemm_info));
491
492 // Check if we need to reshape the matrix B only on the first run
493 _reshape_b_only_on_first_run = gemm_info.reshape_b_only_on_first_run();
494 _is_prepared = gemm_info.retain_internal_weights();
495 _original_b = b;
496
497 // Get the GPU target
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000498 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
499 const unsigned int m = reinterpret_input_as_3d ? (a->info()->dimension(1) * a->info()->dimension(2)) : a->info()->dimension(1);
500 const unsigned int n = b->info()->dimension(0);
501 const unsigned int k = a->info()->dimension(0);
502
503 // Select GEMMType
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000504 _gemm_kernel_type = select_gemm_kernel(m, n, k, a->info()->data_type(), _reshape_b_only_on_first_run);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000505
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100506 const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr);
Gian Marco Iodicee16c8902019-06-14 16:11:10 +0100507
508 const ICLTensor *c_to_use = fuse_add_c ? c : nullptr;
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000509
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000510 switch(_gemm_kernel_type)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000511 {
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000512 case CLGEMMKernelType::NATIVE_V1:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000513 {
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000514 configure_native_v1(a, b, c_to_use, output, alpha, beta, gemm_info);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000515 break;
516 }
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000517 case CLGEMMKernelType::RESHAPED_V1:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000518 {
Gian Marco Iodicee16c8902019-06-14 16:11:10 +0100519 configure_reshaped_v1(a, b, c_to_use, output, alpha, beta, gemm_info);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000520 break;
521 }
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000522 case CLGEMMKernelType::RESHAPED:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000523 {
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000524 configure_reshaped(a, b, c_to_use, output, alpha, beta, gemm_info);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000525 break;
526 }
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000527 case CLGEMMKernelType::RESHAPED_ONLY_RHS:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000528 {
Gian Marco Iodicee16c8902019-06-14 16:11:10 +0100529 configure_reshaped_only_rhs(a, b, c_to_use, output, alpha, beta, gemm_info);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000530 break;
531 }
532 default:
533 {
534 ARM_COMPUTE_ERROR("GEMMType not supported");
535 }
536 }
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000537}
538
539Status CLGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info)
540{
541 // Get the GPU target
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000542 bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
543 const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
544 const unsigned int n = b->dimension(0);
545 const unsigned int k = a->dimension(0);
546
547 // Select GEMMType
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000548 CLGEMMKernelType gemm_kernel_type = select_gemm_kernel(m, n, k, a->data_type(), gemm_info.reshape_b_only_on_first_run());
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000549
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100550 const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr);
Gian Marco Iodicee16c8902019-06-14 16:11:10 +0100551
552 const ITensorInfo *c_to_use = fuse_add_c ? c : nullptr;
553
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000554 switch(gemm_kernel_type)
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000555 {
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000556 case CLGEMMKernelType::NATIVE_V1:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000557 {
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000558 ARM_COMPUTE_RETURN_ON_ERROR(validate_native_v1(a, b, c_to_use, output, alpha, beta, gemm_info));
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000559 break;
560 }
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000561 case CLGEMMKernelType::RESHAPED_V1:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000562 {
Gian Marco Iodicee16c8902019-06-14 16:11:10 +0100563 ARM_COMPUTE_RETURN_ON_ERROR(validate_reshaped_v1(a, b, c_to_use, output, alpha, beta, gemm_info));
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000564 break;
565 }
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000566 case CLGEMMKernelType::RESHAPED:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000567 {
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000568 ARM_COMPUTE_RETURN_ON_ERROR(validate_reshaped(a, b, c_to_use, output, alpha, beta, gemm_info));
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000569 break;
570 }
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000571 case CLGEMMKernelType::RESHAPED_ONLY_RHS:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000572 {
Gian Marco Iodicee16c8902019-06-14 16:11:10 +0100573 ARM_COMPUTE_RETURN_ON_ERROR(validate_reshaped_only_rhs(a, b, c_to_use, output, alpha, beta, gemm_info));
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000574 break;
575 }
576 default:
577 {
578 ARM_COMPUTE_RETURN_ERROR_MSG("GEMMType not supported");
579 }
580 }
581
582 return Status{};
583}
584
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100585void CLGEMM::run()
586{
Georgios Pinitase0437672018-05-02 14:07:55 +0100587 prepare();
588
Georgios Pinitasda953f22019-04-02 17:27:03 +0100589 MemoryGroupResourceScope scope_mg(_memory_group);
Georgios Pinitas8a94e7c2017-09-15 19:06:47 +0100590
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100591 // Run matrix multiply kernel
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000592 switch(_gemm_kernel_type)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000593 {
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000594 case CLGEMMKernelType::NATIVE_V1:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000595 {
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100596 CLScheduler::get().enqueue(_mm_kernel, true);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000597 break;
598 }
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000599 case CLGEMMKernelType::RESHAPED_V1:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000600 {
601 // Run interleave kernel
602 CLScheduler::get().enqueue(_reshape_lhs_kernel, false);
603
604 if(!_reshape_b_only_on_first_run)
605 {
606 // Run transpose kernel
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100607 if(_weights_manager && _weights_manager->are_weights_managed(_original_b))
608 {
609 _weights_manager->run(_original_b, &_reshape_rhs_kernel_managed);
610 }
611 else
612 {
613 CLScheduler::get().enqueue(_reshape_rhs_kernel, false);
614 }
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000615 }
616
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100617 CLScheduler::get().enqueue(_mm_kernel, true);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000618 break;
619 }
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000620 case CLGEMMKernelType::RESHAPED:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000621 {
622 // Run interleave kernel
623 CLScheduler::get().enqueue(_reshape_lhs_kernel, false);
624
625 if(!_reshape_b_only_on_first_run)
626 {
627 // Run transpose kernel
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100628 if(_weights_manager && _weights_manager->are_weights_managed(_original_b))
629 {
630 _weights_manager->run(_original_b, &_reshape_rhs_kernel_managed);
631 }
632 else
633 {
634 CLScheduler::get().enqueue(_reshape_rhs_kernel, false);
635 }
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000636 }
637
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100638 CLScheduler::get().enqueue(_mm_reshaped_kernel, true);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000639 break;
640 }
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000641 case CLGEMMKernelType::RESHAPED_ONLY_RHS:
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000642 {
643 if(!_reshape_b_only_on_first_run)
644 {
645 // Run transpose kernel
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100646 if(_weights_manager && _weights_manager->are_weights_managed(_original_b))
647 {
648 _weights_manager->run(_original_b, &_reshape_rhs_kernel_managed);
649 }
650 else
651 {
652 CLScheduler::get().enqueue(_reshape_rhs_kernel, false);
653 }
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000654 }
655
Gian Marco Iodicef3622be2019-07-29 14:27:16 +0100656 CLScheduler::get().enqueue(_mm_reshaped_only_rhs_kernel, true);
Gian Marco Iodice926afe12019-03-19 11:44:13 +0000657 break;
658 }
659 default:
660 {
661 ARM_COMPUTE_ERROR("GEMMType not supported");
662 }
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000663 }
Georgios Pinitase0437672018-05-02 14:07:55 +0100664}
Georgios Pinitas82b51482018-04-24 15:14:12 +0100665
Georgios Pinitase0437672018-05-02 14:07:55 +0100666void CLGEMM::prepare()
667{
668 if(!_is_prepared)
669 {
Gian Marco Iodice5a4fe192020-03-16 12:22:37 +0000670 if(_gemm_kernel_type != CLGEMMKernelType::NATIVE_V1 && _reshape_b_only_on_first_run)
Georgios Pinitase0437672018-05-02 14:07:55 +0100671 {
Michalis Spyroub27e13a2019-09-27 11:04:27 +0100672 if(_weights_manager && _weights_manager->are_weights_managed(_original_b))
673 {
674 _weights_manager->run(_original_b, &_reshape_rhs_kernel_managed);
675 }
676 else
677 {
678 // Run transpose kernel and mark original weights tensor as unused
679 _tmp_b.allocator()->allocate();
680 CLScheduler::get().enqueue(_reshape_rhs_kernel, false);
681 _original_b->mark_as_unused();
682 }
Georgios Pinitase0437672018-05-02 14:07:55 +0100683 }
684 CLScheduler::get().queue().finish();
685 _is_prepared = true;
686 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100687}
giuros011c9efeb2019-01-11 14:04:43 +0000688} // namespace arm_compute