blob: 956d37a9d8f4b18e340e792d63236816d0623b17 [file] [log] [blame]
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +00001/*
2 * Copyright (c) 2023 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "helpers.h"
25#include "tile_helpers.h"
26
27#if defined(MAT_MUL_NATIVE_NT_NT)
28/** This OpenCL kernel performs the batch matrix multiplication (BatchMatMul): LHS non-transposed, RHS non-transposed - buffer only
29 *
30 * @note the "batch" here expresses the number of matrix multiplications to run in parallel. However, it
31 * should NOT be confused with the batch size of the model. For NHWC the "batch" is the "H" dimension
Gunes Bayir8918b232023-03-17 13:52:21 +000032 * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +000033 * @note The block's dimensions used for the LHS and RHS matrices (M0, N0 and K0) must be passed at compile time using -DN0, -DM0 and -DK0 (e.g. -DN0=8, -DM0=4, -DK0=4).
Gunes Bayir8918b232023-03-17 13:52:21 +000034 * @note The number of leftover outputs rows/columns must be passed using -DPARTIAL_STORE_N0 and -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_N0=2, -DPARTIAL_STORE_M0=3)
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +000035 * @note The dimension K must be passed at compile time using -DK (e.g. -DK=6)
Gunes Bayir8918b232023-03-17 13:52:21 +000036 * @note The kernel name in uppercase must be passed at compile time (e.g. -DMAT_MUL_NATIVE_NT_NT)
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +000037 * @note Only the following configurations of M0, N0 and K0 are currently supported:
38 * - M0 > 0
39 * - N0 = 1, 2, 3, 4, 8, 16
40 * - K0 = 1, 2, 3, 4, 8, 16
41 * @note Values > 8 for M0 are not expected to be efficient
42 *
43 * @param[in] lhs_ptr Pointer to the lhs matrix. Supported data types: F32/F16
44 * @param[in] lhs_stride_y Stride of the lhs matrix in Y (2nd) dimension (in bytes)
45 * @param[in] lhs_stride_z Stride of the lhs tensor in Z (3rd) dimension (in bytes)
46 * @param[in] lhs_w The width of the lhs tensor
47 * @param[in] lhs_h The height of the lhs tensor
48 * @param[in] lhs_n Number of the matrices (buffers) in the batch
49 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the lhs matrix
Gunes Bayir8918b232023-03-17 13:52:21 +000050 * @param[in] rhs_ptr Pointer to the rhs matrix. Supported data types: same as @p lhs_ptr
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +000051 * @param[in] rhs_stride_y Stride of the rhs matrix in Y (2nd) dimension (in bytes)
52 * @param[in] rhs_stride_z Stride of the rhs tensor in Z (3rd) dimension (in bytes)
53 * @param[in] rhs_w The width of the rhs tensor
54 * @param[in] rhs_h The height of the rhs tensor
55 * @param[in] rhs_n Number of the matrices (buffers) in the batch
56 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the rhs matrix
Gunes Bayir8918b232023-03-17 13:52:21 +000057 * @param[out] dst_ptr Pointer to the dst matrix. Supported data types: same as @p lhs_ptr
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +000058 * @param[in] dst_stride_y Stride of the dst matrix in Y (2nd) dimension (in bytes)
59 * @param[in] dst_stride_z Stride of the dst tensor in Z (3rd) dimension (in bytes)
60 * @param[in] dst_w The width of the dst tensor
61 * @param[in] dst_h The height of the dst tensor
62 * @param[in] dst_n Number of the matrices (buffers) in the batch
63 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the dst matrix
64 */
65__kernel void mat_mul_native_nt_nt(
66 TENSOR3D_T(lhs, BUFFER),
67 TENSOR3D_T(rhs, BUFFER),
68 TENSOR3D_T(dst, BUFFER))
69{
70 const uint x = GET_SPATIAL_IDX(0, N0, PARTIAL_STORE_N0);
71 const uint y = GET_SPATIAL_IDX(1, M0, PARTIAL_STORE_M0);
72 const uint z = GET_SPATIAL_IDX(2, 1, 0);
73
74 // Compute LHS/RHS/DST matrix address
75 lhs_offset_first_element_in_bytes += y * lhs_stride_y + z * lhs_stride_z;
76 rhs_offset_first_element_in_bytes += x * sizeof(DATA_TYPE) + z * rhs_stride_z;
77 dst_offset_first_element_in_bytes += x * sizeof(DATA_TYPE) + y * dst_stride_y + z * dst_stride_z;
78
79 // Initialize the accumulators
80 TILE(DATA_TYPE, M0, N0, acc);
81
82 LOOP_UNROLLING(int, i, 0, 1, M0,
83 {
84 acc[i].v = 0.f;
85 })
86
87 int k;
88 for(k = 0; k <= K - K0; k += K0)
89 {
90 TILE(DATA_TYPE, M0, K0, a);
91 TILE(DATA_TYPE, K0, N0, b);
92
93 LOOP_UNROLLING(int, i, 0, 1, M0,
94 {
95 a[i].v = 0.f;
96 })
97
98 LOOP_UNROLLING(int, i, 0, 1, K0,
99 {
100 b[i].v = 0.f;
101 })
102
103 // Load tile from the lhs/rhs tensors
104 T_LOAD(DATA_TYPE, M0, K0, BUFFER, lhs, 0, 0, 1, lhs_stride_y, a);
105 T_LOAD(DATA_TYPE, K0, N0, BUFFER, rhs, 0, 0, 1, rhs_stride_y, b);
106
107 T_MMUL(DATA_TYPE, DATA_TYPE, DATA_TYPE, M0, N0, K0, NT, NT, a, b, acc);
108
109 lhs_offset_first_element_in_bytes += K0 * sizeof(DATA_TYPE);
110 rhs_offset_first_element_in_bytes += K0 * rhs_stride_y;
111 }
112
113#ifdef K % K0 != 0
Gunes Bayir8918b232023-03-17 13:52:21 +0000114 /* Leftover Loop */
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +0000115 for(; k < K; ++k)
116 {
117 TILE(DATA_TYPE, M0, 1, a);
118 TILE(DATA_TYPE, 1, N0, b);
119
120 LOOP_UNROLLING(int, i, 0, 1, M0,
121 {
122 a[i].v = 0.f;
123 })
124
125 LOOP_UNROLLING(int, i, 0, 1, 1,
126 {
127 b[i].v = 0.f;
128 })
129
130 // Load tile from the lhs/rhs tensors
131 T_LOAD(DATA_TYPE, M0, 1, BUFFER, lhs, 0, 0, 1, lhs_stride_y, a);
132 T_LOAD(DATA_TYPE, 1, N0, BUFFER, rhs, 0, 0, 1, rhs_stride_y, b);
133
134 T_MMUL(DATA_TYPE, DATA_TYPE, DATA_TYPE, M0, N0, 1, NT, NT, a, b, acc);
135
136 lhs_offset_first_element_in_bytes += 1 * sizeof(DATA_TYPE);
137 rhs_offset_first_element_in_bytes += 1 * rhs_stride_y;
138 }
139#endif // K % K0 != 0
140
141 const bool x_cond = PARTIAL_STORE_N0 != 0 && get_global_id(0) == 0;
142 const bool y_cond = PARTIAL_STORE_M0 != 0 && get_global_id(1) == 0;
143
144 TILE(int, M0, 1, indirect_buffer);
145 LOOP_UNROLLING(int, _i, 0, 1, M0,
146 {
147 indirect_buffer[_i].v = min(_i, select(M0 - 1, PARTIAL_STORE_M0 - 1, y_cond));
148 });
149
150 T_STORE_INDIRECT_WIDTH_SELECT(DATA_TYPE, M0, N0, PARTIAL_STORE_N0, BUFFER, dst, 0, dst_stride_y, x_cond, acc, indirect_buffer);
151}
152#endif // defined(MAT_MUL_NATIVE_NT_NT)
153
154#if defined(MAT_MUL_NATIVE_NT_T)
155/** This OpenCL kernel performs the batch matrix multiplication (BatchMatMul): LHS non-transposed, RHS transposed - buffer only
156 *
157 * @note the "batch" here expresses the number of matrix multiplications to run in parallel. However, it
158 * should NOT be confused with the batch size of the model. For NHWC the "batch" is the "H" dimension
Gunes Bayir8918b232023-03-17 13:52:21 +0000159 * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +0000160 * @note The block's dimensions used for the LHS and RHS matrices (M0, N0 and K0) must be passed at compile time using -DN0, -DM0 and -DK0 (e.g. -DN0=8, -DM0=4, -DK0=4).
Gunes Bayir8918b232023-03-17 13:52:21 +0000161 * @note The number of leftover outputs rows/columns must be passed using -DPARTIAL_STORE_N0 and -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_N0=2, -DPARTIAL_STORE_M0=3)
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +0000162 * @note The dimension K must be passed at compile time using -DK (e.g. -DK=6)
Gunes Bayir8918b232023-03-17 13:52:21 +0000163 * @note The kernel name in uppercase must be passed at compile time (e.g. -DMAT_MUL_NATIVE_NT_T)
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +0000164 * @note Only the following configurations of M0, N0 and K0 are currently supported:
165 * - M0 > 0
166 * - N0 = 1, 2, 3, 4, 8, 16
167 * - K0 = 1, 2, 3, 4, 8, 16
168 * @note Values > 8 for M0, N0 and K0 are not expected to be efficient
169 *
170 * @param[in] lhs_ptr Pointer to the lhs matrix. Supported data types: F32/F16
171 * @param[in] lhs_stride_y Stride of the lhs matrix in Y (2nd) dimension (in bytes)
172 * @param[in] lhs_stride_z Stride of the lhs tensor in Z (3rd) dimension (in bytes)
173 * @param[in] lhs_w The width of the lhs tensor
174 * @param[in] lhs_h The height of the lhs tensor
175 * @param[in] lhs_n Number of the matrices (buffers) in the batch
176 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the lhs matrix
Gunes Bayir8918b232023-03-17 13:52:21 +0000177 * @param[in] rhs_ptr Pointer to the rhs matrix. Supported data types: same as @p lhs_ptr
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +0000178 * @param[in] rhs_stride_y Stride of the rhs matrix in Y (2nd) dimension (in bytes)
179 * @param[in] rhs_stride_z Stride of the rhs tensor in Z (3rd) dimension (in bytes)
180 * @param[in] rhs_w The width of the rhs tensor
181 * @param[in] rhs_h The height of the rhs tensor
182 * @param[in] rhs_n Number of the matrices (buffers) in the batch
183 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the rhs matrix
Gunes Bayir8918b232023-03-17 13:52:21 +0000184 * @param[out] dst_ptr Pointer to the dst matrix. Supported data types: same as @p lhs_ptr
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +0000185 * @param[in] dst_stride_y Stride of the dst matrix in Y (2nd) dimension (in bytes)
186 * @param[in] dst_stride_z Stride of the dst tensor in Z (3rd) dimension (in bytes)
187 * @param[in] dst_w The width of the dst tensor
188 * @param[in] dst_h The height of the dst tensor
189 * @param[in] dst_n Number of the matrices (buffers) in the batch
190 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the dst matrix
191 */
192__kernel void mat_mul_native_nt_t(TENSOR3D_T(lhs, BUFFER),
193 TENSOR3D_T(rhs, BUFFER),
194 TENSOR3D_T(dst, BUFFER))
195
196{
197 const uint x = GET_SPATIAL_IDX(0, N0, PARTIAL_STORE_N0);
198 const uint y = GET_SPATIAL_IDX(1, M0, PARTIAL_STORE_M0);
199 const uint z = GET_SPATIAL_IDX(2, 1, 0);
200
201 // Compute LHS/RHS/DST matrix address
202 lhs_offset_first_element_in_bytes += y * lhs_stride_y + z * lhs_stride_z;
203 rhs_offset_first_element_in_bytes += x * rhs_stride_y + z * rhs_stride_z;
204 dst_offset_first_element_in_bytes += x * sizeof(DATA_TYPE) + y * dst_stride_y + z * dst_stride_z;
205
206 // Initialize the accumulators
207 TILE(DATA_TYPE, M0, N0, acc);
208
209 LOOP_UNROLLING(int, i, 0, 1, M0,
210 {
211 acc[i].v = 0.f;
212 })
213
214 int k;
215 for(k = 0; k <= K - K0; k += K0)
216 {
217 TILE(DATA_TYPE, M0, K0, a);
218 TILE(DATA_TYPE, N0, K0, b);
219
220 LOOP_UNROLLING(int, i, 0, 1, M0,
221 {
222 a[i].v = 0.f;
223 })
224
225 LOOP_UNROLLING(int, i, 0, 1, N0,
226 {
227 b[i].v = 0.f;
228 })
229
230 // Load tile from the lhs/rhs tensors
231 T_LOAD(DATA_TYPE, M0, K0, BUFFER, lhs, 0, 0, 1, lhs_stride_y, a);
232 T_LOAD(DATA_TYPE, N0, K0, BUFFER, rhs, 0, 0, 1, rhs_stride_y, b);
233
234#if GPU_ARCH == GPU_ARCH_MIDGARD
235 // This part is written to decrease the number of loop unrollings caused
236 // by T_MMUL. The NT/NT version is partly vectorized and uses less number
237 // of loop unrollings, and code behaves as expected. Although this is not
238 // a performant solution for the specified architecture, it is necessary
239 // to overcome some limitations.
240 TILE(DATA_TYPE, K0, N0, bt);
241 LOOP_UNROLLING(int, i, 0, 1, N0,
242 {
243 LOOP_UNROLLING(int, j, 0, 1, K0,
244 {
245 bt[j].s[i] = b[i].s[j];
246 })
247 })
248 T_MMUL(DATA_TYPE, DATA_TYPE, DATA_TYPE, M0, N0, K0, NT, NT, a, bt, acc);
Gunes Bayir8918b232023-03-17 13:52:21 +0000249#else // GPU_ARCH == GPU_ARCH_MIDGARD
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +0000250 T_MMUL(DATA_TYPE, DATA_TYPE, DATA_TYPE, M0, N0, K0, NT, T, a, b, acc);
251#endif // GPU_ARCH == GPU_ARCH_MIDGARD
252
253 lhs_offset_first_element_in_bytes += K0 * sizeof(DATA_TYPE);
254 rhs_offset_first_element_in_bytes += K0 * sizeof(DATA_TYPE);
255 }
256
257#if K % K0 != 0
258 /* Leftover Loop */
259 for(; k < K; ++k)
260 {
261 TILE(DATA_TYPE, M0, 1, a);
262 TILE(DATA_TYPE, N0, 1, b);
263
264 LOOP_UNROLLING(int, i, 0, 1, M0,
265 {
266 a[i].v = 0.f;
267 })
268
269 LOOP_UNROLLING(int, i, 0, 1, N0,
270 {
271 b[i].v = 0.f;
272 })
273
274 // Load tile from the lhs/rhs tensors
275 T_LOAD(DATA_TYPE, M0, 1, BUFFER, lhs, 0, 0, 1, lhs_stride_y, a);
276 T_LOAD(DATA_TYPE, N0, 1, BUFFER, rhs, 0, 0, 1, rhs_stride_y, b);
277
278#if GPU_ARCH == GPU_ARCH_MIDGARD
279 // See the main loop for the explanation of this part
280 TILE(DATA_TYPE, 1, N0, bt);
281 LOOP_UNROLLING(int, i, 0, 1, N0,
282 {
283 bt[0].s[i] = b[i].s[0];
284 })
285 T_MMUL(DATA_TYPE, DATA_TYPE, DATA_TYPE, M0, N0, 1, NT, NT, a, bt, acc);
Gunes Bayir8918b232023-03-17 13:52:21 +0000286#else // GPU_ARCH == GPU_ARCH_MIDGARD
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +0000287 T_MMUL(DATA_TYPE, DATA_TYPE, DATA_TYPE, M0, N0, 1, NT, T, a, b, acc);
288#endif // GPU_ARCH == GPU_ARCH_MIDGARD
289
290 lhs_offset_first_element_in_bytes += 1 * sizeof(DATA_TYPE);
291 rhs_offset_first_element_in_bytes += 1 * sizeof(DATA_TYPE);
292 }
293#endif // K % K0 != 0
294
295 const bool x_cond = PARTIAL_STORE_N0 != 0 && get_global_id(0) == 0;
296 const bool y_cond = PARTIAL_STORE_M0 != 0 && get_global_id(1) == 0;
297
298 TILE(int, M0, 1, indirect_buffer);
299 LOOP_UNROLLING(int, _i, 0, 1, M0,
300 {
301 indirect_buffer[_i].v = min(_i, select(M0 - 1, PARTIAL_STORE_M0 - 1, y_cond));
302 });
303
304 T_STORE_INDIRECT_WIDTH_SELECT(DATA_TYPE, M0, N0, PARTIAL_STORE_N0, BUFFER, dst, 0, dst_stride_y, x_cond, acc, indirect_buffer);
305}
Gunes Bayir8918b232023-03-17 13:52:21 +0000306#endif // defined(MAT_MUL_NATIVE_NT_T)
307
308#if defined(MAT_MUL_NATIVE_T_NT)
309/** This OpenCL kernel performs the batch matrix multiplication (BatchMatMul): LHS transposed, RHS non-transposed - buffer only
310 *
311 * @note the "batch" here expresses the number of matrix multiplications to run in parallel. However, it
312 * should NOT be confused with the batch size of the model. For NHWC the "batch" is the "H" dimension
313 * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
314 * @note The block's dimensions used for the LHS and RHS matrices (M0, N0 and K0) must be passed at compile time using -DN0, -DM0 and -DK0 (e.g. -DN0=8, -DM0=4, -DK0=4).
315 * @note The number of leftover outputs rows/columns must be passed using -DPARTIAL_STORE_N0 and -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_N0=2, -DPARTIAL_STORE_M0=3)
316 * @note The dimension K must be passed at compile time using -DK (e.g. -DK=6)
317 * @note The kernel name in uppercase must be passed at compile time (e.g. -DMAT_MUL_NATIVE_T_NT)
318 * @note Only the following configurations of M0, N0 and K0 are currently supported:
319 * - M0 = 1, 2, 3, 4, 8, 16
320 * - N0 = 1, 2, 3, 4, 8, 16
321 * - K0 > 0
322 * * @note Values > 8 for M0, and K0 are not expected to be efficient
323 *
324 * @param[in] lhs_ptr Pointer to the lhs matrix. Supported data types: F32/F16
325 * @param[in] lhs_stride_y Stride of the lhs matrix in Y (2nd) dimension (in bytes)
326 * @param[in] lhs_stride_z Stride of the lhs tensor in Z (3rd) dimension (in bytes)
327 * @param[in] lhs_w The width of the lhs tensor
328 * @param[in] lhs_h The height of the lhs tensor
329 * @param[in] lhs_n Number of the matrices (buffers) in the batch
330 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the lhs matrix
331 * @param[in] rhs_ptr Pointer to the rhs matrix. Supported data types: same as @p lhs_ptr
332 * @param[in] rhs_stride_y Stride of the rhs matrix in Y (2nd) dimension (in bytes)
333 * @param[in] rhs_stride_z Stride of the rhs tensor in Z (3rd) dimension (in bytes)
334 * @param[in] rhs_w The width of the rhs tensor
335 * @param[in] rhs_h The height of the rhs tensor
336 * @param[in] rhs_n Number of the matrices (buffers) in the batch
337 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the rhs matrix
338 * @param[out] dst_ptr Pointer to the dst matrix. Supported data types: same as @p lhs_ptr
339 * @param[in] dst_stride_y Stride of the dst matrix in Y (2nd) dimension (in bytes)
340 * @param[in] dst_stride_z Stride of the dst tensor in Z (3rd) dimension (in bytes)
341 * @param[in] dst_w The width of the dst tensor
342 * @param[in] dst_h The height of the dst tensor
343 * @param[in] dst_n Number of the matrices (buffers) in the batch
344 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the dst matrix
345 */
346__kernel void mat_mul_native_t_nt(
347 TENSOR3D_T(lhs, BUFFER),
348 TENSOR3D_T(rhs, BUFFER),
349 TENSOR3D_T(dst, BUFFER))
350{
351 const uint x = GET_SPATIAL_IDX(0, N0, PARTIAL_STORE_N0);
352 const uint y = GET_SPATIAL_IDX(1, M0, PARTIAL_STORE_M0);
353 const uint z = GET_SPATIAL_IDX(2, 1, 0);
354
355 // Compute LHS/RHS/DST matrix address
356 lhs_offset_first_element_in_bytes += y * sizeof(DATA_TYPE) + z * lhs_stride_z;
357 rhs_offset_first_element_in_bytes += x * sizeof(DATA_TYPE) + z * rhs_stride_z;
358 dst_offset_first_element_in_bytes += x * sizeof(DATA_TYPE) + y * dst_stride_y + z * dst_stride_z;
359
360 // Initialize the accumulators
361 TILE(DATA_TYPE, M0, N0, acc);
362
363 LOOP_UNROLLING(int, i, 0, 1, M0,
364 {
365 acc[i].v = 0.f;
366 })
367
368 int k;
369 for(k = 0; k <= K - K0; k += K0)
370 {
371 TILE(DATA_TYPE, K0, M0, a);
372 TILE(DATA_TYPE, K0, N0, b);
373
374 LOOP_UNROLLING(int, i, 0, 1, K0,
375 {
376 a[i].v = 0.f;
377 })
378
379 LOOP_UNROLLING(int, i, 0, 1, K0,
380 {
381 b[i].v = 0.f;
382 })
383
384 // Load tile from the lhs/rhs tensors
385 T_LOAD(DATA_TYPE, K0, M0, BUFFER, lhs, 0, 0, 1, lhs_stride_y, a);
386 T_LOAD(DATA_TYPE, K0, N0, BUFFER, rhs, 0, 0, 1, rhs_stride_y, b);
387
388#if GPU_ARCH == GPU_ARCH_MIDGARD
389 // For explanation, see mat_mul_native_nt_t
390 TILE(DATA_TYPE, M0, K0, at);
391 LOOP_UNROLLING(int, i, 0, 1, K0,
392 {
393 LOOP_UNROLLING(int, j, 0, 1, M0,
394 {
395 at[j].s[i] = a[i].s[j];
396 })
397 })
398 T_MMUL(DATA_TYPE, DATA_TYPE, DATA_TYPE, M0, N0, K0, NT, NT, at, b, acc);
399#else // GPU_ARCH == GPU_ARCH_MIDGARD
400 T_MMUL(DATA_TYPE, DATA_TYPE, DATA_TYPE, M0, N0, K0, T, NT, a, b, acc);
401#endif // GPU_ARCH == GPU_ARCH_MIDGARD
402
403 lhs_offset_first_element_in_bytes += K0 * lhs_stride_y;
404 rhs_offset_first_element_in_bytes += K0 * rhs_stride_y;
405 }
406
407#ifdef K % K0 != 0
408 /* Leftover Loop */
409 for(; k < K; ++k)
410 {
411 TILE(DATA_TYPE, 1, M0, a);
412 TILE(DATA_TYPE, 1, N0, b);
413
414 LOOP_UNROLLING(int, i, 0, 1, 1,
415 {
416 a[i].v = 0.f;
417 })
418
419 LOOP_UNROLLING(int, i, 0, 1, 1,
420 {
421 b[i].v = 0.f;
422 })
423
424 // Load tile from the lhs/rhs tensors
425 T_LOAD(DATA_TYPE, 1, M0, BUFFER, lhs, 0, 0, 1, lhs_stride_y, a);
426 T_LOAD(DATA_TYPE, 1, N0, BUFFER, rhs, 0, 0, 1, rhs_stride_y, b);
427
428#if GPU_ARCH == GPU_ARCH_MIDGARD
429 // For explanation, see mat_mul_native_nt_t
430 TILE(DATA_TYPE, M0, 1, at);
431 LOOP_UNROLLING(int, j, 0, 1, M0,
432 {
433 at[j].s[0] = a[0].s[j];
434 })
435 T_MMUL(DATA_TYPE, DATA_TYPE, DATA_TYPE, M0, N0, 1, NT, NT, at, b, acc);
436#else // GPU_ARCH == GPU_ARCH_MIDGARD
437 T_MMUL(DATA_TYPE, DATA_TYPE, DATA_TYPE, M0, N0, 1, T, NT, a, b, acc);
438#endif // GPU_ARCH == GPU_ARCH_MIDGARD
439
440 lhs_offset_first_element_in_bytes += 1 * lhs_stride_y;
441 rhs_offset_first_element_in_bytes += 1 * rhs_stride_y;
442 }
443#endif // K % K0 != 0
444
445 const bool x_cond = PARTIAL_STORE_N0 != 0 && get_global_id(0) == 0;
446 const bool y_cond = PARTIAL_STORE_M0 != 0 && get_global_id(1) == 0;
447
448 TILE(int, M0, 1, indirect_buffer);
449 LOOP_UNROLLING(int, _i, 0, 1, M0,
450 {
451 indirect_buffer[_i].v = min(_i, select(M0 - 1, PARTIAL_STORE_M0 - 1, y_cond));
452 });
453
454 T_STORE_INDIRECT_WIDTH_SELECT(DATA_TYPE, M0, N0, PARTIAL_STORE_N0, BUFFER, dst, 0, dst_stride_y, x_cond, acc, indirect_buffer);
455}
456#endif // defined(MAT_MUL_NATIVE_T_NT)
457
458#if defined(MAT_MUL_NATIVE_T_T)
459/** This OpenCL kernel performs the batch matrix multiplication (BatchMatMul): LHS transposed, RHS transposed - buffer only
460 *
461 * @note the "batch" here expresses the number of matrix multiplications to run in parallel. However, it
462 * should NOT be confused with the batch size of the model. For NHWC the "batch" is the "H" dimension
463 * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
464 * @note The block's dimensions used for the LHS and RHS matrices (M0, N0 and K0) must be passed at compile time using -DN0, -DM0 and -DK0 (e.g. -DN0=8, -DM0=4, -DK0=4).
465 * @note The number of leftover outputs rows/columns must be passed using -DPARTIAL_STORE_N0 and -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_N0=2, -DPARTIAL_STORE_M0=3)
466 * @note The dimension K must be passed at compile time using -DK (e.g. -DK=6)
467 * @note The kernel name in uppercase must be passed at compile time (e.g. -DMAT_MUL_NATIVE_T_NT)
468 * @note Only the following configurations of M0, N0 and K0 are currently supported:
469 * - M0 = 1, 2, 3, 4, 8, 16
470 * - N0 = 1, 2, 3, 4, 8, 16
471 * - K0 = 1, 2, 3, 4, 8, 16
472 * @note Values > 8 for M0, N0 and K0 are not expected to be efficient
473 *
474 * @param[in] lhs_ptr Pointer to the lhs matrix. Supported data types: F32/F16
475 * @param[in] lhs_stride_y Stride of the lhs matrix in Y (2nd) dimension (in bytes)
476 * @param[in] lhs_stride_z Stride of the lhs tensor in Z (3rd) dimension (in bytes)
477 * @param[in] lhs_w The width of the lhs tensor
478 * @param[in] lhs_h The height of the lhs tensor
479 * @param[in] lhs_n Number of the matrices (buffers) in the batch
480 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the lhs matrix
481 * @param[in] rhs_ptr Pointer to the rhs matrix. Supported data types: same as @p lhs_ptr
482 * @param[in] rhs_stride_y Stride of the rhs matrix in Y (2nd) dimension (in bytes)
483 * @param[in] rhs_stride_z Stride of the rhs tensor in Z (3rd) dimension (in bytes)
484 * @param[in] rhs_w The width of the rhs tensor
485 * @param[in] rhs_h The height of the rhs tensor
486 * @param[in] rhs_n Number of the matrices (buffers) in the batch
487 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the rhs matrix
488 * @param[out] dst_ptr Pointer to the dst matrix. Supported data types: same as @p lhs_ptr
489 * @param[in] dst_stride_y Stride of the dst matrix in Y (2nd) dimension (in bytes)
490 * @param[in] dst_stride_z Stride of the dst tensor in Z (3rd) dimension (in bytes)
491 * @param[in] dst_w The width of the dst tensor
492 * @param[in] dst_h The height of the dst tensor
493 * @param[in] dst_n Number of the matrices (buffers) in the batch
494 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the dst matrix
495 */
496__kernel void mat_mul_native_t_t(
497 TENSOR3D_T(lhs, BUFFER),
498 TENSOR3D_T(rhs, BUFFER),
499 TENSOR3D_T(dst, BUFFER))
500{
501 const uint x = GET_SPATIAL_IDX(0, N0, PARTIAL_STORE_N0);
502 const uint y = GET_SPATIAL_IDX(1, M0, PARTIAL_STORE_M0);
503 const uint z = GET_SPATIAL_IDX(2, 1, 0);
504
505 // Compute LHS/RHS/DST matrix address
506 lhs_offset_first_element_in_bytes += y * sizeof(DATA_TYPE) + z * lhs_stride_z;
507 rhs_offset_first_element_in_bytes += x * rhs_stride_y + z * rhs_stride_z;
508 dst_offset_first_element_in_bytes += x * sizeof(DATA_TYPE) + y * dst_stride_y + z * dst_stride_z;
509
510 // Initialize the accumulators
511 TILE(DATA_TYPE, M0, N0, acc);
512
513 LOOP_UNROLLING(int, i, 0, 1, M0,
514 {
515 acc[i].v = 0.f;
516 })
517
518 int k;
519 for(k = 0; k <= K - K0; k += K0)
520 {
521 TILE(DATA_TYPE, K0, M0, a);
522 TILE(DATA_TYPE, N0, K0, b);
523
524 LOOP_UNROLLING(int, i, 0, 1, K0,
525 {
526 a[i].v = 0.f;
527 })
528
529 LOOP_UNROLLING(int, i, 0, 1, N0,
530 {
531 b[i].v = 0.f;
532 })
533
534 // Load tile from the lhs/rhs tensors
535 T_LOAD(DATA_TYPE, K0, M0, BUFFER, lhs, 0, 0, 1, lhs_stride_y, a);
536 T_LOAD(DATA_TYPE, N0, K0, BUFFER, rhs, 0, 0, 1, rhs_stride_y, b);
537
538#if GPU_ARCH == GPU_ARCH_MIDGARD
539 // For explanation, see mat_mul_native_nt_t
540 TILE(DATA_TYPE, M0, K0, at);
541 TILE(DATA_TYPE, K0, N0, bt);
542
543 LOOP_UNROLLING(int, i, 0, 1, K0,
544 {
545 LOOP_UNROLLING(int, j, 0, 1, M0,
546 {
547 at[j].s[i] = a[i].s[j];
548 })
549 })
550
551 LOOP_UNROLLING(int, i, 0, 1, N0,
552 {
553 LOOP_UNROLLING(int, j, 0, 1, K0,
554 {
555 bt[j].s[i] = b[i].s[j];
556 })
557 })
558
559 T_MMUL(DATA_TYPE, DATA_TYPE, DATA_TYPE, M0, N0, K0, NT, NT, at, bt, acc);
560#else // GPU_ARCH == GPU_ARCH_MIDGARD
561 T_MMUL(DATA_TYPE, DATA_TYPE, DATA_TYPE, M0, N0, K0, T, T, a, b, acc);
562#endif // GPU_ARCH == GPU_ARCH_MIDGARD
563
564 lhs_offset_first_element_in_bytes += K0 * lhs_stride_y;
565 rhs_offset_first_element_in_bytes += K0 * sizeof(DATA_TYPE);
566 }
567
568#ifdef K % K0 != 0
569 /* Leftover Loop */
570 for(; k < K; ++k)
571 {
572 TILE(DATA_TYPE, 1, M0, a);
573 TILE(DATA_TYPE, N0, 1, b);
574
575 LOOP_UNROLLING(int, i, 0, 1, 1,
576 {
577 a[i].v = 0.f;
578 })
579
580 LOOP_UNROLLING(int, i, 0, 1, N0,
581 {
582 b[i].v = 0.f;
583 })
584
585 // Load tile from the lhs/rhs tensors
586 T_LOAD(DATA_TYPE, 1, M0, BUFFER, lhs, 0, 0, 1, lhs_stride_y, a);
587 T_LOAD(DATA_TYPE, N0, 1, BUFFER, rhs, 0, 0, 1, rhs_stride_y, b);
588
589#if GPU_ARCH == GPU_ARCH_MIDGARD
590 // For explanation, see mat_mul_native_nt_t
591 TILE(DATA_TYPE, M0, 1, at);
592 TILE(DATA_TYPE, 1, N0, bt);
593
594 LOOP_UNROLLING(int, j, 0, 1, M0,
595 {
596 at[j].s[0] = a[0].s[j];
597 })
598
599 LOOP_UNROLLING(int, i, 0, 1, N0,
600 {
601 bt[0].s[i] = b[i].s[0];
602 })
603
604 T_MMUL(DATA_TYPE, DATA_TYPE, DATA_TYPE, M0, N0, 1, NT, NT, at, bt, acc);
605#else // GPU_ARCH == GPU_ARCH_MIDGARD
606 T_MMUL(DATA_TYPE, DATA_TYPE, DATA_TYPE, M0, N0, 1, T, T, a, b, acc);
607#endif // GPU_ARCH == GPU_ARCH_MIDGARD
608
609 lhs_offset_first_element_in_bytes += 1 * lhs_stride_y;
610 rhs_offset_first_element_in_bytes += 1 * sizeof(DATA_TYPE);
611 }
612#endif // K % K0 != 0
613
614 const bool x_cond = PARTIAL_STORE_N0 != 0 && get_global_id(0) == 0;
615 const bool y_cond = PARTIAL_STORE_M0 != 0 && get_global_id(1) == 0;
616
617 TILE(int, M0, 1, indirect_buffer);
618 LOOP_UNROLLING(int, _i, 0, 1, M0,
619 {
620 indirect_buffer[_i].v = min(_i, select(M0 - 1, PARTIAL_STORE_M0 - 1, y_cond));
621 });
622
623 T_STORE_INDIRECT_WIDTH_SELECT(DATA_TYPE, M0, N0, PARTIAL_STORE_N0, BUFFER, dst, 0, dst_stride_y, x_cond, acc, indirect_buffer);
624}
625#endif // defined(MAT_MUL_NATIVE_T_T)