blob: 01102b3d602da5d8bb27d71c9bf2be4b9ed59040 [file] [log] [blame]
Gian Marco Iodice352c07d2023-05-03 12:21:38 +01001/*
2 * Copyright (c) 2023 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "src/runtime/heuristics/matmul_native/ClMatMulNativeDefaultConfigValhall.h"
25
26#include "arm_compute/core/CL/CLHelpers.h"
27#include "arm_compute/core/CL/CLKernelLibrary.h"
28#include "arm_compute/core/GPUTarget.h"
29#include "arm_compute/core/KernelDescriptors.h"
30#include "arm_compute/core/TensorInfo.h"
31#include "src/gpu/cl/kernels/ClMatMulNativeKernel.h"
32#include <utility>
33
34#include "src/runtime/heuristics/matmul_native/ClMatMulNativeHelpers.h"
35
36namespace arm_compute
37{
38namespace cl_matmul
39{
40ClMatMulNativeDefaultConfigValhall::ClMatMulNativeDefaultConfigValhall(GPUTarget gpu)
41 : IClMatMulNativeKernelConfig(gpu)
42{
43}
44
45MatMulKernelInfo ClMatMulNativeDefaultConfigValhall::configure(const ITensorInfo *lhs, const ITensorInfo *rhs, const MatMulInfo &info)
46{
47 using ConfigurationFunctionExecutorPtr = MatMulKernelInfo (ClMatMulNativeDefaultConfigValhall::*)(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool rhs_lock_padding, const MatMulInfo & info);
48
49 ClMatMulNativeConfigArray<ConfigurationFunctionExecutorPtr> configs_G710(&ClMatMulNativeDefaultConfigValhall::configure_G710_f32,
50 &ClMatMulNativeDefaultConfigValhall::configure_G710_f16,
51 &ClMatMulNativeDefaultConfigValhall::configure_G710_u8);
52
53 ConfigurationFunctionExecutorPtr func = nullptr;
54 switch(_target)
55 {
56 case GPUTarget::G710:
57 default:
58 func = configs_G710.get_function(lhs->data_type());
59 break;
60 }
61
62 const bool adj_lhs = info.adj_lhs();
63 const bool adj_rhs = info.adj_rhs();
64
65 TensorShape lhs_shape = lhs->tensor_shape();
66 TensorShape rhs_shape = rhs->tensor_shape();
67
68 const bool is_batched = lhs_shape.num_dimensions() > 2;
69
70 if(is_batched == true)
71 {
72 lhs_shape.collapse_from(2);
73 }
74
75 const unsigned int m = adj_lhs ? lhs_shape.x() : lhs_shape.y();
76 const unsigned int n = adj_rhs ? rhs_shape.y() : rhs_shape.x();
77 const unsigned int k = adj_lhs ? lhs_shape.y() : lhs_shape.x();
78 const unsigned int b = lhs_shape.z();
79
80 ARM_COMPUTE_ERROR_ON_MSG(func == nullptr, "Data type not supported for matmul native");
81 return (this->*func)(m, n, k, b, rhs->lock_paddings(), info);
82}
83
84MatMulKernelInfo ClMatMulNativeDefaultConfigValhall::configure_G710_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool rhs_lock_padding, const MatMulInfo &info)
85{
86 const MatMulNativeConfigsMatrix configs_mnkb_best_nt_nt =
87 {
88 { 3136, 64, 64, 36, 4, 4, 16, 1 },
89 { 4096, 48, 32, 36, 4, 4, 4, 1 },
90 { 688, 92, 68, 32, 2, 8, 4, 1 },
91 { 24, 464, 412, 24, 2, 8, 4, 1 },
92 { 112, 184, 144, 28, 4, 4, 16, 1 },
93 { 5776, 64, 32, 36, 2, 4, 16, 1 },
94 { 1568, 64, 40, 36, 2, 8, 8, 1 },
95 { 2920, 64, 64, 24, 4, 4, 16, 1 }
96 };
97
98 const MatMulNativeConfigsMatrix configs_mnkb_fallback_nt_nt =
99 {
100 { 3136, 64, 64, 36, 4, 4, 8, 0 },
101 { 4096, 48, 32, 36, 4, 4, 8, 0 },
102 { 688, 92, 68, 32, 5, 4, 4, 0 },
103 { 24, 464, 412, 24, 6, 2, 8, 0 },
104 { 112, 184, 144, 28, 6, 4, 4, 0 },
105 { 5776, 64, 32, 36, 5, 4, 4, 0 },
106 { 1568, 64, 40, 36, 4, 4, 8, 0 },
107 { 2920, 64, 64, 24, 4, 4, 8, 0 }
108 };
109
110 const MatMulNativeConfigsMatrix configs_mnkb_best_nt_t =
111 {
112 { 3136, 64, 64, 36, 4, 4, 4, 1 },
113 { 4096, 48, 32, 36, 2, 2, 16, 1 },
114 { 688, 92, 68, 32, 4, 4, 4, 1 },
115 { 24, 464, 412, 24, 6, 2, 8, 1 },
116 { 112, 184, 144, 28, 4, 2, 16, 1 },
117 { 5776, 64, 32, 36, 4, 4, 4, 1 },
118 { 1568, 64, 40, 36, 4, 4, 8, 1 },
119 { 2920, 64, 64, 24, 4, 4, 4, 1 }
120 };
121
122 const MatMulNativeConfigsMatrix configs_mnkb_fallback_nt_t =
123 {
124 { 3136, 64, 64, 36, 5, 4, 4, 0 },
125 { 4096, 48, 32, 36, 5, 4, 4, 0 },
126 { 688, 92, 68, 32, 5, 4, 4, 0 },
127 { 24, 464, 412, 24, 6, 2, 4, 0 },
128 { 112, 184, 144, 28, 5, 4, 4, 0 },
129 { 5776, 64, 32, 36, 5, 4, 4, 0 },
130 { 1568, 64, 40, 36, 5, 4, 4, 0 },
131 { 2920, 64, 64, 24, 6, 2, 4, 0 }
132 };
133
134 const MatMulNativeConfigsMatrix configs_mnkb_best_t_nt =
135 {
136 { 3136, 64, 64, 36, 4, 4, 16, 1 },
137 { 4096, 48, 32, 36, 4, 4, 4, 1 },
138 { 688, 92, 68, 32, 2, 8, 4, 1 },
139 { 24, 464, 412, 24, 2, 8, 4, 1 },
140 { 112, 184, 144, 28, 4, 4, 16, 1 },
141 { 5776, 64, 32, 36, 2, 8, 8, 1 },
142 { 1568, 64, 40, 36, 4, 4, 8, 1 },
143 { 2920, 64, 64, 24, 4, 4, 16, 1 }
144 };
145
146 const MatMulNativeConfigsMatrix configs_mnkb_fallback_t_nt =
147 {
148 { 3136, 64, 64, 36, 4, 4, 4, 0 },
149 { 4096, 48, 32, 36, 4, 4, 4, 0 },
150 { 688, 92, 68, 32, 4, 4, 4, 0 },
151 { 24, 464, 412, 24, 4, 4, 4, 0 },
152 { 112, 184, 144, 28, 4, 4, 4, 0 },
153 { 5776, 64, 32, 36, 4, 4, 8, 0 },
154 { 1568, 64, 40, 36, 4, 4, 4, 0 },
155 { 2920, 64, 64, 24, 4, 4, 4, 0 }
156 };
157
158 const MatMulNativeConfigsMatrix configs_mnkb_best_t_t =
159 {
160 { 3136, 64, 64, 36, 4, 4, 4, 1 },
161 { 4096, 48, 32, 36, 4, 4, 4, 1 },
162 { 688, 92, 68, 32, 4, 4, 4, 1 },
163 { 24, 464, 412, 24, 2, 2, 16, 1 },
164 { 112, 184, 144, 28, 4, 4, 4, 1 },
165 { 5776, 64, 32, 36, 4, 4, 4, 1 },
166 { 1568, 64, 40, 36, 4, 4, 4, 1 },
167 { 2920, 64, 64, 24, 4, 4, 4, 1 }
168 };
169
170 const MatMulNativeConfigsMatrix configs_mnkb_fallback_t_t =
171 {
172 { 3136, 64, 64, 36, 4, 4, 4, 0 },
173 { 4096, 48, 32, 36, 4, 4, 4, 0 },
174 { 688, 92, 68, 32, 4, 4, 4, 0 },
175 { 24, 464, 412, 24, 4, 2, 8, 0 },
176 { 112, 184, 144, 28, 4, 4, 4, 0 },
177 { 5776, 64, 32, 36, 4, 4, 4, 0 },
178 { 1568, 64, 40, 36, 4, 4, 4, 0 },
179 { 2920, 64, 64, 24, 4, 4, 4, 0 }
180 };
181
182 const bool adj_lhs = info.adj_lhs();
183 const bool adj_rhs = info.adj_rhs();
184
185 const MatMulNativeConfigsMatrix *configs_best_to_use = nullptr;
186 const MatMulNativeConfigsMatrix *configs_fallback_to_use = nullptr;
187
188 if((adj_lhs == false) && (adj_rhs == false))
189 {
190 configs_best_to_use = &configs_mnkb_best_nt_nt;
191 configs_fallback_to_use = &configs_mnkb_fallback_nt_nt;
192 }
193 else if((adj_lhs == false) && (adj_rhs == true))
194 {
195 configs_best_to_use = &configs_mnkb_best_nt_t;
196 configs_fallback_to_use = &configs_mnkb_fallback_nt_t;
197 }
198 else if((adj_lhs == true) && (adj_rhs == false))
199 {
200 configs_best_to_use = &configs_mnkb_best_t_nt;
201 configs_fallback_to_use = &configs_mnkb_fallback_t_nt;
202 }
203 else
204 {
205 configs_best_to_use = &configs_mnkb_best_t_t;
206 configs_fallback_to_use = &configs_mnkb_fallback_t_t;
207 }
208
209 MatMulKernelInfo desc0 = find_info(*configs_best_to_use, adj_lhs, adj_rhs, m, n, k, b);
210 MatMulKernelInfo desc1 = find_info(*configs_fallback_to_use, adj_lhs, adj_rhs, m, n, k, b);
211
212 return select_info(desc0,
213 desc1,
214 m, n, k, b, DataType::F32, rhs_lock_padding);
215}
216
217MatMulKernelInfo ClMatMulNativeDefaultConfigValhall::configure_G710_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool rhs_lock_padding, const MatMulInfo &info)
218{
219 const MatMulNativeConfigsMatrix configs_mnkb_best_nt_nt =
220 {
221 { 3136, 64, 64, 36, 4, 4, 16, 1 },
222 { 4096, 48, 32, 36, 4, 4, 8, 1 },
223 { 688, 92, 68, 32, 4, 4, 16, 1 },
224 { 24, 464, 412, 24, 4, 4, 4, 1 },
225 { 112, 184, 144, 28, 4, 4, 16, 1 },
226 { 5776, 64, 32, 36, 4, 4, 8, 1 },
227 { 1568, 64, 40, 36, 4, 4, 8, 1 },
228 { 2920, 64, 64, 24, 4, 4, 16, 1 }
229 };
230
231 const MatMulNativeConfigsMatrix configs_mnkb_fallback_nt_nt =
232 {
233 { 3136, 64, 64, 36, 6, 4, 8, 0 },
234 { 4096, 48, 32, 36, 6, 4, 8, 0 },
235 { 688, 92, 68, 32, 6, 4, 8, 0 },
236 { 24, 464, 412, 24, 4, 4, 8, 0 },
237 { 112, 184, 144, 28, 6, 4, 8, 0 },
238 { 5776, 64, 32, 36, 6, 4, 8, 0 },
239 { 1568, 64, 40, 36, 6, 4, 8, 0 },
240 { 2920, 64, 64, 24, 6, 4, 8, 0 }
241 };
242
243 const MatMulNativeConfigsMatrix configs_mnkb_best_nt_t =
244 {
245 { 3136, 64, 64, 36, 6, 4, 8, 1 },
246 { 4096, 48, 32, 36, 6, 4, 8, 1 },
247 { 688, 92, 68, 32, 4, 4, 4, 1 },
248 { 24, 464, 412, 24, 6, 2, 4, 1 },
249 { 112, 184, 144, 28, 4, 2, 16, 1 },
250 { 5776, 64, 32, 36, 6, 4, 8, 1 },
251 { 1568, 64, 40, 36, 6, 4, 8, 1 },
252 { 2920, 64, 64, 24, 6, 4, 8, 1 }
253 };
254
255 const MatMulNativeConfigsMatrix configs_mnkb_fallback_nt_t =
256 {
257 { 3136, 64, 64, 36, 6, 2, 16, 0 },
258 { 4096, 48, 32, 36, 5, 4, 8, 0 },
259 { 688, 92, 68, 32, 6, 2, 16, 0 },
260 { 24, 464, 412, 24, 6, 2, 16, 0 },
261 { 112, 184, 144, 28, 6, 2, 16, 0 },
262 { 5776, 64, 32, 36, 5, 4, 8, 0 },
263 { 1568, 64, 40, 36, 5, 4, 8, 0 },
264 { 2920, 64, 64, 24, 6, 2, 16, 0 }
265 };
266
267 const MatMulNativeConfigsMatrix configs_mnkb_best_t_nt =
268 {
269 { 3136, 64, 64, 36, 4, 4, 16, 1 },
270 { 4096, 48, 32, 36, 4, 4, 4, 1 },
271 { 688, 92, 68, 32, 4, 4, 4, 1 },
272 { 24, 464, 412, 24, 4, 4, 4, 1 },
273 { 112, 184, 144, 28, 4, 4, 4, 1 },
274 { 5776, 64, 32, 36, 4, 4, 4, 1 },
275 { 1568, 64, 40, 36, 4, 4, 4, 1 },
276 { 2920, 64, 64, 24, 4, 4, 4, 1 }
277 };
278
279 const MatMulNativeConfigsMatrix configs_mnkb_fallback_t_nt =
280 {
281 { 3136, 64, 64, 36, 4, 4, 4, 0 },
282 { 4096, 48, 32, 36, 4, 4, 4, 0 },
283 { 688, 92, 68, 32, 4, 4, 4, 0 },
284 { 24, 464, 412, 24, 4, 4, 4, 0 },
285 { 112, 184, 144, 28, 4, 4, 4, 0 },
286 { 5776, 64, 32, 36, 4, 4, 4, 0 },
287 { 1568, 64, 40, 36, 4, 4, 4, 0 },
288 { 2920, 64, 64, 24, 4, 4, 4, 0 }
289 };
290
291 const MatMulNativeConfigsMatrix configs_mnkb_best_t_t =
292 {
293 { 3136, 64, 64, 36, 4, 4, 16, 1 },
294 { 4096, 48, 32, 36, 4, 4, 8, 1 },
295 { 688, 92, 68, 32, 4, 4, 4, 1 },
296 { 24, 464, 412, 24, 4, 2, 8, 1 },
297 { 112, 184, 144, 28, 4, 2, 16, 1 },
298 { 5776, 64, 32, 36, 4, 4, 16, 1 },
299 { 1568, 64, 40, 36, 4, 4, 8, 1 },
300 { 2920, 64, 64, 24, 4, 4, 16, 1 }
301 };
302
303 const MatMulNativeConfigsMatrix configs_mnkb_fallback_t_t =
304 {
305 { 3136, 64, 64, 36, 4, 4, 8, 0 },
306 { 4096, 48, 32, 36, 4, 4, 8, 0 },
307 { 688, 92, 68, 32, 4, 4, 8, 0 },
308 { 24, 464, 412, 24, 4, 4, 8, 0 },
309 { 112, 184, 144, 28, 4, 4, 8, 0 },
310 { 5776, 64, 32, 36, 4, 4, 8, 0 },
311 { 1568, 64, 40, 36, 4, 4, 8, 0 },
312 { 2920, 64, 64, 24, 4, 4, 8, 0 }
313 };
314
315 const bool adj_lhs = info.adj_lhs();
316 const bool adj_rhs = info.adj_rhs();
317
318 const MatMulNativeConfigsMatrix *configs_best_to_use = nullptr;
319 const MatMulNativeConfigsMatrix *configs_fallback_to_use = nullptr;
320
321 if((adj_lhs == false) && (adj_rhs == false))
322 {
323 configs_best_to_use = &configs_mnkb_best_nt_nt;
324 configs_fallback_to_use = &configs_mnkb_fallback_nt_nt;
325 }
326 else if((adj_lhs == false) && (adj_rhs == true))
327 {
328 configs_best_to_use = &configs_mnkb_best_nt_t;
329 configs_fallback_to_use = &configs_mnkb_fallback_nt_t;
330 }
331 else if((adj_lhs == true) && (adj_rhs == false))
332 {
333 configs_best_to_use = &configs_mnkb_best_t_nt;
334 configs_fallback_to_use = &configs_mnkb_fallback_t_nt;
335 }
336 else
337 {
338 configs_best_to_use = &configs_mnkb_best_t_t;
339 configs_fallback_to_use = &configs_mnkb_fallback_t_t;
340 }
341
342 MatMulKernelInfo desc0 = find_info(*configs_best_to_use, adj_lhs, adj_rhs, m, n, k, b);
343 MatMulKernelInfo desc1 = find_info(*configs_fallback_to_use, adj_lhs, adj_rhs, m, n, k, b);
344
345 return select_info(desc0,
346 desc1,
347 m, n, k, b, DataType::F16, rhs_lock_padding);
348}
349
350MatMulKernelInfo ClMatMulNativeDefaultConfigValhall::configure_G710_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool rhs_lock_padding, const MatMulInfo &info)
351{
352 ARM_COMPUTE_UNUSED(rhs_lock_padding);
353
354 const MatMulNativeConfigsMatrix configs_mnkb_best_nt_nt =
355 {
356 { 3136, 64, 64, 36, 6, 4, 4, 0 },
357 { 4096, 48, 32, 36, 6, 4, 4, 0 },
358 { 688, 92, 68, 32, 2, 8, 4, 0 },
359 { 24, 464, 412, 24, 4, 4, 4, 0 },
360 { 112, 184, 144, 28, 6, 4, 4, 0 },
361 { 5776, 64, 32, 36, 6, 4, 4, 0 },
362 { 1568, 64, 40, 36, 6, 4, 4, 0 },
363 { 2920, 64, 64, 24, 5, 4, 4, 0 }
364 };
365
366 const MatMulNativeConfigsMatrix configs_mnkb_best_nt_t =
367 {
368 { 3136, 64, 64, 36, 4, 4, 16, 0 },
369 { 4096, 48, 32, 36, 4, 4, 16, 0 },
370 { 688, 92, 68, 32, 4, 4, 16, 0 },
371 { 24, 464, 412, 24, 6, 2, 16, 0 },
372 { 112, 184, 144, 28, 4, 4, 16, 0 },
373 { 5776, 64, 32, 36, 4, 4, 16, 0 },
374 { 1568, 64, 40, 36, 6, 4, 4, 0 },
375 { 2920, 64, 64, 24, 4, 4, 16, 0 }
376 };
377
378 const MatMulNativeConfigsMatrix configs_mnkb_best_t_nt =
379 {
380 { 3136, 64, 64, 36, 4, 4, 8, 0 },
381 { 4096, 48, 32, 36, 4, 4, 8, 0 },
382 { 688, 92, 68, 32, 4, 4, 4, 0 },
383 { 24, 464, 412, 24, 4, 4, 4, 0 },
384 { 112, 184, 144, 28, 4, 4, 8, 0 },
385 { 5776, 64, 32, 36, 4, 4, 8, 0 },
386 { 1568, 64, 40, 36, 4, 4, 8, 0 },
387 { 2920, 64, 64, 24, 4, 4, 8, 0 }
388 };
389
390 const MatMulNativeConfigsMatrix configs_mnkb_best_t_t =
391 {
392 { 3136, 64, 64, 36, 4, 2, 16, 0 },
393 { 4096, 48, 32, 36, 4, 4, 4, 0 },
394 { 688, 92, 68, 32, 4, 4, 8, 0 },
395 { 24, 464, 412, 24, 4, 2, 16, 0 },
396 { 112, 184, 144, 28, 4, 2, 16, 0 },
397 { 5776, 64, 32, 36, 4, 4, 4, 0 },
398 { 1568, 64, 40, 36, 4, 4, 8, 0 },
399 { 2920, 64, 64, 24, 4, 2, 16, 0 }
400 };
401
402 const bool adj_lhs = info.adj_lhs();
403 const bool adj_rhs = info.adj_rhs();
404
405 if((adj_lhs == false) && (adj_rhs == false))
406 {
407 return find_info(configs_mnkb_best_nt_nt, adj_lhs, adj_rhs, m, n, k, b);
408 }
409 else if((adj_lhs == false) && (adj_rhs == true))
410 {
411 return find_info(configs_mnkb_best_nt_t, adj_lhs, adj_rhs, m, n, k, b);
412 }
413 else if((adj_lhs == true) && (adj_rhs == false))
414 {
415 return find_info(configs_mnkb_best_t_nt, adj_lhs, adj_rhs, m, n, k, b);
416 }
417 else
418 {
419 return find_info(configs_mnkb_best_t_t, adj_lhs, adj_rhs, m, n, k, b);
420 }
421}
422} // namespace opencl
423} // namespace arm_compute