blob: 90ebf80a6a92669911d6b417870d258306e1d26e [file] [log] [blame]
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +00001/*
2 * Copyright (c) 2023 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "helpers.h"
25#include "tile_helpers.h"
26
27#if defined(MAT_MUL_NATIVE_NT_NT)
28/** This OpenCL kernel performs the batch matrix multiplication (BatchMatMul): LHS non-transposed, RHS non-transposed - buffer only
29 *
30 * @note the "batch" here expresses the number of matrix multiplications to run in parallel. However, it
31 * should NOT be confused with the batch size of the model. For NHWC the "batch" is the "H" dimension
Gunes Bayir8918b232023-03-17 13:52:21 +000032 * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +000033 * @note The block's dimensions used for the LHS and RHS matrices (M0, N0 and K0) must be passed at compile time using -DN0, -DM0 and -DK0 (e.g. -DN0=8, -DM0=4, -DK0=4).
Gunes Bayir8918b232023-03-17 13:52:21 +000034 * @note The number of leftover outputs rows/columns must be passed using -DPARTIAL_STORE_N0 and -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_N0=2, -DPARTIAL_STORE_M0=3)
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +000035 * @note The dimension K must be passed at compile time using -DK (e.g. -DK=6)
Gunes Bayirbbeef722023-03-20 10:19:10 +000036 * @note The tensor type ("BUFFER" or "IMAGE") of the rhs tensor must be passed at compile time using -DRHS_TENSOR_TYPE (e.g. -DRHS_TENSOR_TYPE=BUFFER)
Gunes Bayir8918b232023-03-17 13:52:21 +000037 * @note The kernel name in uppercase must be passed at compile time (e.g. -DMAT_MUL_NATIVE_NT_NT)
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +000038 * @note Only the following configurations of M0, N0 and K0 are currently supported:
39 * - M0 > 0
Gunes Bayirbbeef722023-03-20 10:19:10 +000040 * - N0 = 1, 2, 3, 4, 8, 16 (only 4, 8, 16 if RHS_TENSOR_TYPE=IMAGE)
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +000041 * - K0 = 1, 2, 3, 4, 8, 16
42 * @note Values > 8 for M0 are not expected to be efficient
43 *
44 * @param[in] lhs_ptr Pointer to the lhs matrix. Supported data types: F32/F16
45 * @param[in] lhs_stride_y Stride of the lhs matrix in Y (2nd) dimension (in bytes)
46 * @param[in] lhs_stride_z Stride of the lhs tensor in Z (3rd) dimension (in bytes)
47 * @param[in] lhs_w The width of the lhs tensor
48 * @param[in] lhs_h The height of the lhs tensor
49 * @param[in] lhs_n Number of the matrices (buffers) in the batch
50 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the lhs matrix
Gunes Bayirbbeef722023-03-20 10:19:10 +000051 * @param[in] rhs_img (Optional) Read only cl_image object for the rhs tensor. Included when RHS_TENSOR_TYPE=IMAGE
Gunes Bayir8918b232023-03-17 13:52:21 +000052 * @param[in] rhs_ptr Pointer to the rhs matrix. Supported data types: same as @p lhs_ptr
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +000053 * @param[in] rhs_stride_y Stride of the rhs matrix in Y (2nd) dimension (in bytes)
54 * @param[in] rhs_stride_z Stride of the rhs tensor in Z (3rd) dimension (in bytes)
55 * @param[in] rhs_w The width of the rhs tensor
56 * @param[in] rhs_h The height of the rhs tensor
57 * @param[in] rhs_n Number of the matrices (buffers) in the batch
58 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the rhs matrix
Gunes Bayir8918b232023-03-17 13:52:21 +000059 * @param[out] dst_ptr Pointer to the dst matrix. Supported data types: same as @p lhs_ptr
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +000060 * @param[in] dst_stride_y Stride of the dst matrix in Y (2nd) dimension (in bytes)
61 * @param[in] dst_stride_z Stride of the dst tensor in Z (3rd) dimension (in bytes)
62 * @param[in] dst_w The width of the dst tensor
63 * @param[in] dst_h The height of the dst tensor
64 * @param[in] dst_n Number of the matrices (buffers) in the batch
65 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the dst matrix
66 */
67__kernel void mat_mul_native_nt_nt(
68 TENSOR3D_T(lhs, BUFFER),
Gunes Bayirbbeef722023-03-20 10:19:10 +000069 TENSOR3D_T(rhs, RHS_TENSOR_TYPE),
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +000070 TENSOR3D_T(dst, BUFFER))
71{
72 const uint x = GET_SPATIAL_IDX(0, N0, PARTIAL_STORE_N0);
73 const uint y = GET_SPATIAL_IDX(1, M0, PARTIAL_STORE_M0);
74 const uint z = GET_SPATIAL_IDX(2, 1, 0);
75
76 // Compute LHS/RHS/DST matrix address
77 lhs_offset_first_element_in_bytes += y * lhs_stride_y + z * lhs_stride_z;
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +000078 dst_offset_first_element_in_bytes += x * sizeof(DATA_TYPE) + y * dst_stride_y + z * dst_stride_z;
79
80 // Initialize the accumulators
81 TILE(DATA_TYPE, M0, N0, acc);
82
83 LOOP_UNROLLING(int, i, 0, 1, M0,
84 {
85 acc[i].v = 0.f;
86 })
87
Gunes Bayirbbeef722023-03-20 10:19:10 +000088 const int rhs_z = z * rhs_h;
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +000089 int k;
90 for(k = 0; k <= K - K0; k += K0)
91 {
92 TILE(DATA_TYPE, M0, K0, a);
93 TILE(DATA_TYPE, K0, N0, b);
94
95 LOOP_UNROLLING(int, i, 0, 1, M0,
96 {
97 a[i].v = 0.f;
98 })
99
100 LOOP_UNROLLING(int, i, 0, 1, K0,
101 {
102 b[i].v = 0.f;
103 })
104
105 // Load tile from the lhs/rhs tensors
106 T_LOAD(DATA_TYPE, M0, K0, BUFFER, lhs, 0, 0, 1, lhs_stride_y, a);
Gunes Bayirbbeef722023-03-20 10:19:10 +0000107 T_LOAD(DATA_TYPE, K0, N0, RHS_TENSOR_TYPE, rhs, x, k + rhs_z, 1, rhs_stride_y, b);
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +0000108
109 T_MMUL(DATA_TYPE, DATA_TYPE, DATA_TYPE, M0, N0, K0, NT, NT, a, b, acc);
110
111 lhs_offset_first_element_in_bytes += K0 * sizeof(DATA_TYPE);
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +0000112 }
113
114#ifdef K % K0 != 0
Gunes Bayir8918b232023-03-17 13:52:21 +0000115 /* Leftover Loop */
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +0000116 for(; k < K; ++k)
117 {
118 TILE(DATA_TYPE, M0, 1, a);
119 TILE(DATA_TYPE, 1, N0, b);
120
121 LOOP_UNROLLING(int, i, 0, 1, M0,
122 {
123 a[i].v = 0.f;
124 })
125
126 LOOP_UNROLLING(int, i, 0, 1, 1,
127 {
128 b[i].v = 0.f;
129 })
130
131 // Load tile from the lhs/rhs tensors
132 T_LOAD(DATA_TYPE, M0, 1, BUFFER, lhs, 0, 0, 1, lhs_stride_y, a);
Gunes Bayirbbeef722023-03-20 10:19:10 +0000133 T_LOAD(DATA_TYPE, 1, N0, BUFFER, rhs, x, k + rhs_z, 1, rhs_stride_y, b);
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +0000134
135 T_MMUL(DATA_TYPE, DATA_TYPE, DATA_TYPE, M0, N0, 1, NT, NT, a, b, acc);
136
137 lhs_offset_first_element_in_bytes += 1 * sizeof(DATA_TYPE);
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +0000138 }
139#endif // K % K0 != 0
140
141 const bool x_cond = PARTIAL_STORE_N0 != 0 && get_global_id(0) == 0;
142 const bool y_cond = PARTIAL_STORE_M0 != 0 && get_global_id(1) == 0;
143
144 TILE(int, M0, 1, indirect_buffer);
145 LOOP_UNROLLING(int, _i, 0, 1, M0,
146 {
147 indirect_buffer[_i].v = min(_i, select(M0 - 1, PARTIAL_STORE_M0 - 1, y_cond));
148 });
149
150 T_STORE_INDIRECT_WIDTH_SELECT(DATA_TYPE, M0, N0, PARTIAL_STORE_N0, BUFFER, dst, 0, dst_stride_y, x_cond, acc, indirect_buffer);
151}
152#endif // defined(MAT_MUL_NATIVE_NT_NT)
153
154#if defined(MAT_MUL_NATIVE_NT_T)
155/** This OpenCL kernel performs the batch matrix multiplication (BatchMatMul): LHS non-transposed, RHS transposed - buffer only
156 *
157 * @note the "batch" here expresses the number of matrix multiplications to run in parallel. However, it
158 * should NOT be confused with the batch size of the model. For NHWC the "batch" is the "H" dimension
Gunes Bayir8918b232023-03-17 13:52:21 +0000159 * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +0000160 * @note The block's dimensions used for the LHS and RHS matrices (M0, N0 and K0) must be passed at compile time using -DN0, -DM0 and -DK0 (e.g. -DN0=8, -DM0=4, -DK0=4).
Gunes Bayir8918b232023-03-17 13:52:21 +0000161 * @note The number of leftover outputs rows/columns must be passed using -DPARTIAL_STORE_N0 and -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_N0=2, -DPARTIAL_STORE_M0=3)
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +0000162 * @note The dimension K must be passed at compile time using -DK (e.g. -DK=6)
Gunes Bayir8918b232023-03-17 13:52:21 +0000163 * @note The kernel name in uppercase must be passed at compile time (e.g. -DMAT_MUL_NATIVE_NT_T)
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +0000164 * @note Only the following configurations of M0, N0 and K0 are currently supported:
165 * - M0 > 0
166 * - N0 = 1, 2, 3, 4, 8, 16
167 * - K0 = 1, 2, 3, 4, 8, 16
168 * @note Values > 8 for M0, N0 and K0 are not expected to be efficient
169 *
170 * @param[in] lhs_ptr Pointer to the lhs matrix. Supported data types: F32/F16
171 * @param[in] lhs_stride_y Stride of the lhs matrix in Y (2nd) dimension (in bytes)
172 * @param[in] lhs_stride_z Stride of the lhs tensor in Z (3rd) dimension (in bytes)
173 * @param[in] lhs_w The width of the lhs tensor
174 * @param[in] lhs_h The height of the lhs tensor
175 * @param[in] lhs_n Number of the matrices (buffers) in the batch
176 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the lhs matrix
Gunes Bayir8918b232023-03-17 13:52:21 +0000177 * @param[in] rhs_ptr Pointer to the rhs matrix. Supported data types: same as @p lhs_ptr
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +0000178 * @param[in] rhs_stride_y Stride of the rhs matrix in Y (2nd) dimension (in bytes)
179 * @param[in] rhs_stride_z Stride of the rhs tensor in Z (3rd) dimension (in bytes)
180 * @param[in] rhs_w The width of the rhs tensor
181 * @param[in] rhs_h The height of the rhs tensor
182 * @param[in] rhs_n Number of the matrices (buffers) in the batch
183 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the rhs matrix
Gunes Bayir8918b232023-03-17 13:52:21 +0000184 * @param[out] dst_ptr Pointer to the dst matrix. Supported data types: same as @p lhs_ptr
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +0000185 * @param[in] dst_stride_y Stride of the dst matrix in Y (2nd) dimension (in bytes)
186 * @param[in] dst_stride_z Stride of the dst tensor in Z (3rd) dimension (in bytes)
187 * @param[in] dst_w The width of the dst tensor
188 * @param[in] dst_h The height of the dst tensor
189 * @param[in] dst_n Number of the matrices (buffers) in the batch
190 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the dst matrix
191 */
192__kernel void mat_mul_native_nt_t(TENSOR3D_T(lhs, BUFFER),
193 TENSOR3D_T(rhs, BUFFER),
194 TENSOR3D_T(dst, BUFFER))
195
196{
197 const uint x = GET_SPATIAL_IDX(0, N0, PARTIAL_STORE_N0);
198 const uint y = GET_SPATIAL_IDX(1, M0, PARTIAL_STORE_M0);
199 const uint z = GET_SPATIAL_IDX(2, 1, 0);
200
201 // Compute LHS/RHS/DST matrix address
202 lhs_offset_first_element_in_bytes += y * lhs_stride_y + z * lhs_stride_z;
203 rhs_offset_first_element_in_bytes += x * rhs_stride_y + z * rhs_stride_z;
204 dst_offset_first_element_in_bytes += x * sizeof(DATA_TYPE) + y * dst_stride_y + z * dst_stride_z;
205
206 // Initialize the accumulators
207 TILE(DATA_TYPE, M0, N0, acc);
208
209 LOOP_UNROLLING(int, i, 0, 1, M0,
210 {
211 acc[i].v = 0.f;
212 })
213
214 int k;
215 for(k = 0; k <= K - K0; k += K0)
216 {
217 TILE(DATA_TYPE, M0, K0, a);
218 TILE(DATA_TYPE, N0, K0, b);
219
220 LOOP_UNROLLING(int, i, 0, 1, M0,
221 {
222 a[i].v = 0.f;
223 })
224
225 LOOP_UNROLLING(int, i, 0, 1, N0,
226 {
227 b[i].v = 0.f;
228 })
229
230 // Load tile from the lhs/rhs tensors
231 T_LOAD(DATA_TYPE, M0, K0, BUFFER, lhs, 0, 0, 1, lhs_stride_y, a);
232 T_LOAD(DATA_TYPE, N0, K0, BUFFER, rhs, 0, 0, 1, rhs_stride_y, b);
233
234#if GPU_ARCH == GPU_ARCH_MIDGARD
235 // This part is written to decrease the number of loop unrollings caused
236 // by T_MMUL. The NT/NT version is partly vectorized and uses less number
237 // of loop unrollings, and code behaves as expected. Although this is not
238 // a performant solution for the specified architecture, it is necessary
239 // to overcome some limitations.
240 TILE(DATA_TYPE, K0, N0, bt);
241 LOOP_UNROLLING(int, i, 0, 1, N0,
242 {
243 LOOP_UNROLLING(int, j, 0, 1, K0,
244 {
245 bt[j].s[i] = b[i].s[j];
246 })
247 })
248 T_MMUL(DATA_TYPE, DATA_TYPE, DATA_TYPE, M0, N0, K0, NT, NT, a, bt, acc);
Gunes Bayir8918b232023-03-17 13:52:21 +0000249#else // GPU_ARCH == GPU_ARCH_MIDGARD
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +0000250 T_MMUL(DATA_TYPE, DATA_TYPE, DATA_TYPE, M0, N0, K0, NT, T, a, b, acc);
251#endif // GPU_ARCH == GPU_ARCH_MIDGARD
252
253 lhs_offset_first_element_in_bytes += K0 * sizeof(DATA_TYPE);
254 rhs_offset_first_element_in_bytes += K0 * sizeof(DATA_TYPE);
255 }
256
257#if K % K0 != 0
258 /* Leftover Loop */
259 for(; k < K; ++k)
260 {
261 TILE(DATA_TYPE, M0, 1, a);
262 TILE(DATA_TYPE, N0, 1, b);
263
264 LOOP_UNROLLING(int, i, 0, 1, M0,
265 {
266 a[i].v = 0.f;
267 })
268
269 LOOP_UNROLLING(int, i, 0, 1, N0,
270 {
271 b[i].v = 0.f;
272 })
273
274 // Load tile from the lhs/rhs tensors
275 T_LOAD(DATA_TYPE, M0, 1, BUFFER, lhs, 0, 0, 1, lhs_stride_y, a);
276 T_LOAD(DATA_TYPE, N0, 1, BUFFER, rhs, 0, 0, 1, rhs_stride_y, b);
277
278#if GPU_ARCH == GPU_ARCH_MIDGARD
279 // See the main loop for the explanation of this part
280 TILE(DATA_TYPE, 1, N0, bt);
281 LOOP_UNROLLING(int, i, 0, 1, N0,
282 {
283 bt[0].s[i] = b[i].s[0];
284 })
285 T_MMUL(DATA_TYPE, DATA_TYPE, DATA_TYPE, M0, N0, 1, NT, NT, a, bt, acc);
Gunes Bayir8918b232023-03-17 13:52:21 +0000286#else // GPU_ARCH == GPU_ARCH_MIDGARD
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +0000287 T_MMUL(DATA_TYPE, DATA_TYPE, DATA_TYPE, M0, N0, 1, NT, T, a, b, acc);
288#endif // GPU_ARCH == GPU_ARCH_MIDGARD
289
290 lhs_offset_first_element_in_bytes += 1 * sizeof(DATA_TYPE);
291 rhs_offset_first_element_in_bytes += 1 * sizeof(DATA_TYPE);
292 }
293#endif // K % K0 != 0
294
295 const bool x_cond = PARTIAL_STORE_N0 != 0 && get_global_id(0) == 0;
296 const bool y_cond = PARTIAL_STORE_M0 != 0 && get_global_id(1) == 0;
297
298 TILE(int, M0, 1, indirect_buffer);
299 LOOP_UNROLLING(int, _i, 0, 1, M0,
300 {
301 indirect_buffer[_i].v = min(_i, select(M0 - 1, PARTIAL_STORE_M0 - 1, y_cond));
302 });
303
304 T_STORE_INDIRECT_WIDTH_SELECT(DATA_TYPE, M0, N0, PARTIAL_STORE_N0, BUFFER, dst, 0, dst_stride_y, x_cond, acc, indirect_buffer);
305}
Gunes Bayir8918b232023-03-17 13:52:21 +0000306#endif // defined(MAT_MUL_NATIVE_NT_T)
307
308#if defined(MAT_MUL_NATIVE_T_NT)
309/** This OpenCL kernel performs the batch matrix multiplication (BatchMatMul): LHS transposed, RHS non-transposed - buffer only
310 *
311 * @note the "batch" here expresses the number of matrix multiplications to run in parallel. However, it
312 * should NOT be confused with the batch size of the model. For NHWC the "batch" is the "H" dimension
313 * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
314 * @note The block's dimensions used for the LHS and RHS matrices (M0, N0 and K0) must be passed at compile time using -DN0, -DM0 and -DK0 (e.g. -DN0=8, -DM0=4, -DK0=4).
315 * @note The number of leftover outputs rows/columns must be passed using -DPARTIAL_STORE_N0 and -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_N0=2, -DPARTIAL_STORE_M0=3)
316 * @note The dimension K must be passed at compile time using -DK (e.g. -DK=6)
Gunes Bayirbbeef722023-03-20 10:19:10 +0000317 * @note The tensor type ("BUFFER" or "IMAGE") of the rhs tensor must be passed at compile time using -DRHS_TENSOR_TYPE (e.g. -DRHS_TENSOR_TYPE=BUFFER)
Gunes Bayir8918b232023-03-17 13:52:21 +0000318 * @note The kernel name in uppercase must be passed at compile time (e.g. -DMAT_MUL_NATIVE_T_NT)
319 * @note Only the following configurations of M0, N0 and K0 are currently supported:
320 * - M0 = 1, 2, 3, 4, 8, 16
Gunes Bayirbbeef722023-03-20 10:19:10 +0000321 * - N0 = 1, 2, 3, 4, 8, 16 (only 4, 8, 16 if RHS_TENSOR_TYPE=IMAGE)
Gunes Bayir8918b232023-03-17 13:52:21 +0000322 * - K0 > 0
323 * * @note Values > 8 for M0, and K0 are not expected to be efficient
324 *
325 * @param[in] lhs_ptr Pointer to the lhs matrix. Supported data types: F32/F16
326 * @param[in] lhs_stride_y Stride of the lhs matrix in Y (2nd) dimension (in bytes)
327 * @param[in] lhs_stride_z Stride of the lhs tensor in Z (3rd) dimension (in bytes)
328 * @param[in] lhs_w The width of the lhs tensor
329 * @param[in] lhs_h The height of the lhs tensor
330 * @param[in] lhs_n Number of the matrices (buffers) in the batch
331 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the lhs matrix
Gunes Bayirbbeef722023-03-20 10:19:10 +0000332 * @param[in] rhs_img (Optional) Read only cl_image object for the rhs tensor. Included when RHS_TENSOR_TYPE=IMAGE
Gunes Bayir8918b232023-03-17 13:52:21 +0000333 * @param[in] rhs_ptr Pointer to the rhs matrix. Supported data types: same as @p lhs_ptr
334 * @param[in] rhs_stride_y Stride of the rhs matrix in Y (2nd) dimension (in bytes)
335 * @param[in] rhs_stride_z Stride of the rhs tensor in Z (3rd) dimension (in bytes)
336 * @param[in] rhs_w The width of the rhs tensor
337 * @param[in] rhs_h The height of the rhs tensor
338 * @param[in] rhs_n Number of the matrices (buffers) in the batch
339 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the rhs matrix
340 * @param[out] dst_ptr Pointer to the dst matrix. Supported data types: same as @p lhs_ptr
341 * @param[in] dst_stride_y Stride of the dst matrix in Y (2nd) dimension (in bytes)
342 * @param[in] dst_stride_z Stride of the dst tensor in Z (3rd) dimension (in bytes)
343 * @param[in] dst_w The width of the dst tensor
344 * @param[in] dst_h The height of the dst tensor
345 * @param[in] dst_n Number of the matrices (buffers) in the batch
346 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the dst matrix
347 */
348__kernel void mat_mul_native_t_nt(
349 TENSOR3D_T(lhs, BUFFER),
Gunes Bayirbbeef722023-03-20 10:19:10 +0000350 TENSOR3D_T(rhs, RHS_TENSOR_TYPE),
Gunes Bayir8918b232023-03-17 13:52:21 +0000351 TENSOR3D_T(dst, BUFFER))
352{
353 const uint x = GET_SPATIAL_IDX(0, N0, PARTIAL_STORE_N0);
354 const uint y = GET_SPATIAL_IDX(1, M0, PARTIAL_STORE_M0);
355 const uint z = GET_SPATIAL_IDX(2, 1, 0);
356
357 // Compute LHS/RHS/DST matrix address
358 lhs_offset_first_element_in_bytes += y * sizeof(DATA_TYPE) + z * lhs_stride_z;
Gunes Bayir8918b232023-03-17 13:52:21 +0000359 dst_offset_first_element_in_bytes += x * sizeof(DATA_TYPE) + y * dst_stride_y + z * dst_stride_z;
360
361 // Initialize the accumulators
362 TILE(DATA_TYPE, M0, N0, acc);
363
364 LOOP_UNROLLING(int, i, 0, 1, M0,
365 {
366 acc[i].v = 0.f;
367 })
368
Gunes Bayirbbeef722023-03-20 10:19:10 +0000369 const int rhs_z = z * rhs_h;
Gunes Bayir8918b232023-03-17 13:52:21 +0000370 int k;
371 for(k = 0; k <= K - K0; k += K0)
372 {
373 TILE(DATA_TYPE, K0, M0, a);
374 TILE(DATA_TYPE, K0, N0, b);
375
376 LOOP_UNROLLING(int, i, 0, 1, K0,
377 {
378 a[i].v = 0.f;
379 })
380
381 LOOP_UNROLLING(int, i, 0, 1, K0,
382 {
383 b[i].v = 0.f;
384 })
385
386 // Load tile from the lhs/rhs tensors
387 T_LOAD(DATA_TYPE, K0, M0, BUFFER, lhs, 0, 0, 1, lhs_stride_y, a);
Gunes Bayirbbeef722023-03-20 10:19:10 +0000388 T_LOAD(DATA_TYPE, K0, N0, RHS_TENSOR_TYPE, rhs, x, k + rhs_z, 1, rhs_stride_y, b);
Gunes Bayir8918b232023-03-17 13:52:21 +0000389
390#if GPU_ARCH == GPU_ARCH_MIDGARD
391 // For explanation, see mat_mul_native_nt_t
392 TILE(DATA_TYPE, M0, K0, at);
393 LOOP_UNROLLING(int, i, 0, 1, K0,
394 {
395 LOOP_UNROLLING(int, j, 0, 1, M0,
396 {
397 at[j].s[i] = a[i].s[j];
398 })
399 })
400 T_MMUL(DATA_TYPE, DATA_TYPE, DATA_TYPE, M0, N0, K0, NT, NT, at, b, acc);
401#else // GPU_ARCH == GPU_ARCH_MIDGARD
402 T_MMUL(DATA_TYPE, DATA_TYPE, DATA_TYPE, M0, N0, K0, T, NT, a, b, acc);
403#endif // GPU_ARCH == GPU_ARCH_MIDGARD
404
405 lhs_offset_first_element_in_bytes += K0 * lhs_stride_y;
Gunes Bayir8918b232023-03-17 13:52:21 +0000406 }
407
408#ifdef K % K0 != 0
409 /* Leftover Loop */
410 for(; k < K; ++k)
411 {
412 TILE(DATA_TYPE, 1, M0, a);
413 TILE(DATA_TYPE, 1, N0, b);
414
415 LOOP_UNROLLING(int, i, 0, 1, 1,
416 {
417 a[i].v = 0.f;
418 })
419
420 LOOP_UNROLLING(int, i, 0, 1, 1,
421 {
422 b[i].v = 0.f;
423 })
424
425 // Load tile from the lhs/rhs tensors
426 T_LOAD(DATA_TYPE, 1, M0, BUFFER, lhs, 0, 0, 1, lhs_stride_y, a);
Gunes Bayirbbeef722023-03-20 10:19:10 +0000427 T_LOAD(DATA_TYPE, 1, N0, BUFFER, rhs, x, k + rhs_z, 1, rhs_stride_y, b);
Gunes Bayir8918b232023-03-17 13:52:21 +0000428
429#if GPU_ARCH == GPU_ARCH_MIDGARD
430 // For explanation, see mat_mul_native_nt_t
431 TILE(DATA_TYPE, M0, 1, at);
432 LOOP_UNROLLING(int, j, 0, 1, M0,
433 {
434 at[j].s[0] = a[0].s[j];
435 })
436 T_MMUL(DATA_TYPE, DATA_TYPE, DATA_TYPE, M0, N0, 1, NT, NT, at, b, acc);
437#else // GPU_ARCH == GPU_ARCH_MIDGARD
438 T_MMUL(DATA_TYPE, DATA_TYPE, DATA_TYPE, M0, N0, 1, T, NT, a, b, acc);
439#endif // GPU_ARCH == GPU_ARCH_MIDGARD
440
441 lhs_offset_first_element_in_bytes += 1 * lhs_stride_y;
Gunes Bayir8918b232023-03-17 13:52:21 +0000442 }
443#endif // K % K0 != 0
444
445 const bool x_cond = PARTIAL_STORE_N0 != 0 && get_global_id(0) == 0;
446 const bool y_cond = PARTIAL_STORE_M0 != 0 && get_global_id(1) == 0;
447
448 TILE(int, M0, 1, indirect_buffer);
449 LOOP_UNROLLING(int, _i, 0, 1, M0,
450 {
451 indirect_buffer[_i].v = min(_i, select(M0 - 1, PARTIAL_STORE_M0 - 1, y_cond));
452 });
453
454 T_STORE_INDIRECT_WIDTH_SELECT(DATA_TYPE, M0, N0, PARTIAL_STORE_N0, BUFFER, dst, 0, dst_stride_y, x_cond, acc, indirect_buffer);
455}
456#endif // defined(MAT_MUL_NATIVE_T_NT)
457
458#if defined(MAT_MUL_NATIVE_T_T)
459/** This OpenCL kernel performs the batch matrix multiplication (BatchMatMul): LHS transposed, RHS transposed - buffer only
460 *
461 * @note the "batch" here expresses the number of matrix multiplications to run in parallel. However, it
462 * should NOT be confused with the batch size of the model. For NHWC the "batch" is the "H" dimension
463 * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
464 * @note The block's dimensions used for the LHS and RHS matrices (M0, N0 and K0) must be passed at compile time using -DN0, -DM0 and -DK0 (e.g. -DN0=8, -DM0=4, -DK0=4).
465 * @note The number of leftover outputs rows/columns must be passed using -DPARTIAL_STORE_N0 and -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_N0=2, -DPARTIAL_STORE_M0=3)
466 * @note The dimension K must be passed at compile time using -DK (e.g. -DK=6)
467 * @note The kernel name in uppercase must be passed at compile time (e.g. -DMAT_MUL_NATIVE_T_NT)
468 * @note Only the following configurations of M0, N0 and K0 are currently supported:
469 * - M0 = 1, 2, 3, 4, 8, 16
470 * - N0 = 1, 2, 3, 4, 8, 16
471 * - K0 = 1, 2, 3, 4, 8, 16
472 * @note Values > 8 for M0, N0 and K0 are not expected to be efficient
473 *
474 * @param[in] lhs_ptr Pointer to the lhs matrix. Supported data types: F32/F16
475 * @param[in] lhs_stride_y Stride of the lhs matrix in Y (2nd) dimension (in bytes)
476 * @param[in] lhs_stride_z Stride of the lhs tensor in Z (3rd) dimension (in bytes)
477 * @param[in] lhs_w The width of the lhs tensor
478 * @param[in] lhs_h The height of the lhs tensor
479 * @param[in] lhs_n Number of the matrices (buffers) in the batch
480 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the lhs matrix
481 * @param[in] rhs_ptr Pointer to the rhs matrix. Supported data types: same as @p lhs_ptr
482 * @param[in] rhs_stride_y Stride of the rhs matrix in Y (2nd) dimension (in bytes)
483 * @param[in] rhs_stride_z Stride of the rhs tensor in Z (3rd) dimension (in bytes)
484 * @param[in] rhs_w The width of the rhs tensor
485 * @param[in] rhs_h The height of the rhs tensor
486 * @param[in] rhs_n Number of the matrices (buffers) in the batch
487 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the rhs matrix
488 * @param[out] dst_ptr Pointer to the dst matrix. Supported data types: same as @p lhs_ptr
489 * @param[in] dst_stride_y Stride of the dst matrix in Y (2nd) dimension (in bytes)
490 * @param[in] dst_stride_z Stride of the dst tensor in Z (3rd) dimension (in bytes)
491 * @param[in] dst_w The width of the dst tensor
492 * @param[in] dst_h The height of the dst tensor
493 * @param[in] dst_n Number of the matrices (buffers) in the batch
494 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the dst matrix
495 */
496__kernel void mat_mul_native_t_t(
497 TENSOR3D_T(lhs, BUFFER),
498 TENSOR3D_T(rhs, BUFFER),
499 TENSOR3D_T(dst, BUFFER))
500{
501 const uint x = GET_SPATIAL_IDX(0, N0, PARTIAL_STORE_N0);
502 const uint y = GET_SPATIAL_IDX(1, M0, PARTIAL_STORE_M0);
503 const uint z = GET_SPATIAL_IDX(2, 1, 0);
504
505 // Compute LHS/RHS/DST matrix address
506 lhs_offset_first_element_in_bytes += y * sizeof(DATA_TYPE) + z * lhs_stride_z;
507 rhs_offset_first_element_in_bytes += x * rhs_stride_y + z * rhs_stride_z;
508 dst_offset_first_element_in_bytes += x * sizeof(DATA_TYPE) + y * dst_stride_y + z * dst_stride_z;
509
510 // Initialize the accumulators
511 TILE(DATA_TYPE, M0, N0, acc);
512
513 LOOP_UNROLLING(int, i, 0, 1, M0,
514 {
515 acc[i].v = 0.f;
516 })
517
518 int k;
519 for(k = 0; k <= K - K0; k += K0)
520 {
521 TILE(DATA_TYPE, K0, M0, a);
522 TILE(DATA_TYPE, N0, K0, b);
523
524 LOOP_UNROLLING(int, i, 0, 1, K0,
525 {
526 a[i].v = 0.f;
527 })
528
529 LOOP_UNROLLING(int, i, 0, 1, N0,
530 {
531 b[i].v = 0.f;
532 })
533
534 // Load tile from the lhs/rhs tensors
535 T_LOAD(DATA_TYPE, K0, M0, BUFFER, lhs, 0, 0, 1, lhs_stride_y, a);
536 T_LOAD(DATA_TYPE, N0, K0, BUFFER, rhs, 0, 0, 1, rhs_stride_y, b);
537
538#if GPU_ARCH == GPU_ARCH_MIDGARD
539 // For explanation, see mat_mul_native_nt_t
540 TILE(DATA_TYPE, M0, K0, at);
541 TILE(DATA_TYPE, K0, N0, bt);
542
543 LOOP_UNROLLING(int, i, 0, 1, K0,
544 {
545 LOOP_UNROLLING(int, j, 0, 1, M0,
546 {
547 at[j].s[i] = a[i].s[j];
548 })
549 })
550
551 LOOP_UNROLLING(int, i, 0, 1, N0,
552 {
553 LOOP_UNROLLING(int, j, 0, 1, K0,
554 {
555 bt[j].s[i] = b[i].s[j];
556 })
557 })
558
559 T_MMUL(DATA_TYPE, DATA_TYPE, DATA_TYPE, M0, N0, K0, NT, NT, at, bt, acc);
560#else // GPU_ARCH == GPU_ARCH_MIDGARD
561 T_MMUL(DATA_TYPE, DATA_TYPE, DATA_TYPE, M0, N0, K0, T, T, a, b, acc);
562#endif // GPU_ARCH == GPU_ARCH_MIDGARD
563
564 lhs_offset_first_element_in_bytes += K0 * lhs_stride_y;
565 rhs_offset_first_element_in_bytes += K0 * sizeof(DATA_TYPE);
566 }
567
568#ifdef K % K0 != 0
569 /* Leftover Loop */
570 for(; k < K; ++k)
571 {
572 TILE(DATA_TYPE, 1, M0, a);
573 TILE(DATA_TYPE, N0, 1, b);
574
575 LOOP_UNROLLING(int, i, 0, 1, 1,
576 {
577 a[i].v = 0.f;
578 })
579
580 LOOP_UNROLLING(int, i, 0, 1, N0,
581 {
582 b[i].v = 0.f;
583 })
584
585 // Load tile from the lhs/rhs tensors
586 T_LOAD(DATA_TYPE, 1, M0, BUFFER, lhs, 0, 0, 1, lhs_stride_y, a);
587 T_LOAD(DATA_TYPE, N0, 1, BUFFER, rhs, 0, 0, 1, rhs_stride_y, b);
588
589#if GPU_ARCH == GPU_ARCH_MIDGARD
590 // For explanation, see mat_mul_native_nt_t
591 TILE(DATA_TYPE, M0, 1, at);
592 TILE(DATA_TYPE, 1, N0, bt);
593
594 LOOP_UNROLLING(int, j, 0, 1, M0,
595 {
596 at[j].s[0] = a[0].s[j];
597 })
598
599 LOOP_UNROLLING(int, i, 0, 1, N0,
600 {
601 bt[0].s[i] = b[i].s[0];
602 })
603
604 T_MMUL(DATA_TYPE, DATA_TYPE, DATA_TYPE, M0, N0, 1, NT, NT, at, bt, acc);
605#else // GPU_ARCH == GPU_ARCH_MIDGARD
606 T_MMUL(DATA_TYPE, DATA_TYPE, DATA_TYPE, M0, N0, 1, T, T, a, b, acc);
607#endif // GPU_ARCH == GPU_ARCH_MIDGARD
608
609 lhs_offset_first_element_in_bytes += 1 * lhs_stride_y;
610 rhs_offset_first_element_in_bytes += 1 * sizeof(DATA_TYPE);
611 }
612#endif // K % K0 != 0
613
614 const bool x_cond = PARTIAL_STORE_N0 != 0 && get_global_id(0) == 0;
615 const bool y_cond = PARTIAL_STORE_M0 != 0 && get_global_id(1) == 0;
616
617 TILE(int, M0, 1, indirect_buffer);
618 LOOP_UNROLLING(int, _i, 0, 1, M0,
619 {
620 indirect_buffer[_i].v = min(_i, select(M0 - 1, PARTIAL_STORE_M0 - 1, y_cond));
621 });
622
623 T_STORE_INDIRECT_WIDTH_SELECT(DATA_TYPE, M0, N0, PARTIAL_STORE_N0, BUFFER, dst, 0, dst_stride_y, x_cond, acc, indirect_buffer);
624}
625#endif // defined(MAT_MUL_NATIVE_T_T)