blob: 90d485e8159ba7ffe047200d3b1ec4b714d9fa08 [file] [log] [blame]
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +00001/*
2 * Copyright (c) 2023 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "helpers.h"
25#include "tile_helpers.h"
26
27#if defined(MAT_MUL_NATIVE_NT_NT)
28/** This OpenCL kernel performs the batch matrix multiplication (BatchMatMul): LHS non-transposed, RHS non-transposed - buffer only
29 *
30 * @note the "batch" here expresses the number of matrix multiplications to run in parallel. However, it
31 * should NOT be confused with the batch size of the model. For NHWC the "batch" is the "H" dimension
Gunes Bayir8918b232023-03-17 13:52:21 +000032 * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +000033 * @note The block's dimensions used for the LHS and RHS matrices (M0, N0 and K0) must be passed at compile time using -DN0, -DM0 and -DK0 (e.g. -DN0=8, -DM0=4, -DK0=4).
Gunes Bayir8918b232023-03-17 13:52:21 +000034 * @note The number of leftover outputs rows/columns must be passed using -DPARTIAL_STORE_N0 and -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_N0=2, -DPARTIAL_STORE_M0=3)
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +000035 * @note The dimension K must be passed at compile time using -DK (e.g. -DK=6)
Gunes Bayirbbeef722023-03-20 10:19:10 +000036 * @note The tensor type ("BUFFER" or "IMAGE") of the rhs tensor must be passed at compile time using -DRHS_TENSOR_TYPE (e.g. -DRHS_TENSOR_TYPE=BUFFER)
Gunes Bayir8918b232023-03-17 13:52:21 +000037 * @note The kernel name in uppercase must be passed at compile time (e.g. -DMAT_MUL_NATIVE_NT_NT)
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +000038 * @note Only the following configurations of M0, N0 and K0 are currently supported:
39 * - M0 > 0
Gunes Bayirbbeef722023-03-20 10:19:10 +000040 * - N0 = 1, 2, 3, 4, 8, 16 (only 4, 8, 16 if RHS_TENSOR_TYPE=IMAGE)
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +000041 * - K0 = 1, 2, 3, 4, 8, 16
42 * @note Values > 8 for M0 are not expected to be efficient
43 *
44 * @param[in] lhs_ptr Pointer to the lhs matrix. Supported data types: F32/F16
45 * @param[in] lhs_stride_y Stride of the lhs matrix in Y (2nd) dimension (in bytes)
46 * @param[in] lhs_stride_z Stride of the lhs tensor in Z (3rd) dimension (in bytes)
47 * @param[in] lhs_w The width of the lhs tensor
48 * @param[in] lhs_h The height of the lhs tensor
49 * @param[in] lhs_n Number of the matrices (buffers) in the batch
50 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the lhs matrix
Gunes Bayirbbeef722023-03-20 10:19:10 +000051 * @param[in] rhs_img (Optional) Read only cl_image object for the rhs tensor. Included when RHS_TENSOR_TYPE=IMAGE
Gunes Bayir8918b232023-03-17 13:52:21 +000052 * @param[in] rhs_ptr Pointer to the rhs matrix. Supported data types: same as @p lhs_ptr
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +000053 * @param[in] rhs_stride_y Stride of the rhs matrix in Y (2nd) dimension (in bytes)
54 * @param[in] rhs_stride_z Stride of the rhs tensor in Z (3rd) dimension (in bytes)
55 * @param[in] rhs_w The width of the rhs tensor
56 * @param[in] rhs_h The height of the rhs tensor
57 * @param[in] rhs_n Number of the matrices (buffers) in the batch
58 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the rhs matrix
Gunes Bayir8918b232023-03-17 13:52:21 +000059 * @param[out] dst_ptr Pointer to the dst matrix. Supported data types: same as @p lhs_ptr
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +000060 * @param[in] dst_stride_y Stride of the dst matrix in Y (2nd) dimension (in bytes)
61 * @param[in] dst_stride_z Stride of the dst tensor in Z (3rd) dimension (in bytes)
62 * @param[in] dst_w The width of the dst tensor
63 * @param[in] dst_h The height of the dst tensor
64 * @param[in] dst_n Number of the matrices (buffers) in the batch
65 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the dst matrix
66 */
67__kernel void mat_mul_native_nt_nt(
68 TENSOR3D_T(lhs, BUFFER),
Gunes Bayirbbeef722023-03-20 10:19:10 +000069 TENSOR3D_T(rhs, RHS_TENSOR_TYPE),
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +000070 TENSOR3D_T(dst, BUFFER))
71{
72 const uint x = GET_SPATIAL_IDX(0, N0, PARTIAL_STORE_N0);
73 const uint y = GET_SPATIAL_IDX(1, M0, PARTIAL_STORE_M0);
74 const uint z = GET_SPATIAL_IDX(2, 1, 0);
75
76 // Compute LHS/RHS/DST matrix address
77 lhs_offset_first_element_in_bytes += y * lhs_stride_y + z * lhs_stride_z;
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +000078 dst_offset_first_element_in_bytes += x * sizeof(DATA_TYPE) + y * dst_stride_y + z * dst_stride_z;
79
80 // Initialize the accumulators
81 TILE(DATA_TYPE, M0, N0, acc);
82
83 LOOP_UNROLLING(int, i, 0, 1, M0,
84 {
85 acc[i].v = 0.f;
86 })
87
Gunes Bayirbbeef722023-03-20 10:19:10 +000088 const int rhs_z = z * rhs_h;
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +000089 int k;
90 for(k = 0; k <= K - K0; k += K0)
91 {
92 TILE(DATA_TYPE, M0, K0, a);
93 TILE(DATA_TYPE, K0, N0, b);
94
95 LOOP_UNROLLING(int, i, 0, 1, M0,
96 {
97 a[i].v = 0.f;
98 })
99
100 LOOP_UNROLLING(int, i, 0, 1, K0,
101 {
102 b[i].v = 0.f;
103 })
104
105 // Load tile from the lhs/rhs tensors
106 T_LOAD(DATA_TYPE, M0, K0, BUFFER, lhs, 0, 0, 1, lhs_stride_y, a);
Gunes Bayirbbeef722023-03-20 10:19:10 +0000107 T_LOAD(DATA_TYPE, K0, N0, RHS_TENSOR_TYPE, rhs, x, k + rhs_z, 1, rhs_stride_y, b);
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +0000108
109 T_MMUL(DATA_TYPE, DATA_TYPE, DATA_TYPE, M0, N0, K0, NT, NT, a, b, acc);
110
111 lhs_offset_first_element_in_bytes += K0 * sizeof(DATA_TYPE);
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +0000112 }
113
114#ifdef K % K0 != 0
Gunes Bayir8918b232023-03-17 13:52:21 +0000115 /* Leftover Loop */
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +0000116 for(; k < K; ++k)
117 {
118 TILE(DATA_TYPE, M0, 1, a);
119 TILE(DATA_TYPE, 1, N0, b);
120
121 LOOP_UNROLLING(int, i, 0, 1, M0,
122 {
123 a[i].v = 0.f;
124 })
125
126 LOOP_UNROLLING(int, i, 0, 1, 1,
127 {
128 b[i].v = 0.f;
129 })
130
131 // Load tile from the lhs/rhs tensors
132 T_LOAD(DATA_TYPE, M0, 1, BUFFER, lhs, 0, 0, 1, lhs_stride_y, a);
Gunes Bayirbbeef722023-03-20 10:19:10 +0000133 T_LOAD(DATA_TYPE, 1, N0, BUFFER, rhs, x, k + rhs_z, 1, rhs_stride_y, b);
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +0000134
135 T_MMUL(DATA_TYPE, DATA_TYPE, DATA_TYPE, M0, N0, 1, NT, NT, a, b, acc);
136
137 lhs_offset_first_element_in_bytes += 1 * sizeof(DATA_TYPE);
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +0000138 }
139#endif // K % K0 != 0
140
141 const bool x_cond = PARTIAL_STORE_N0 != 0 && get_global_id(0) == 0;
142 const bool y_cond = PARTIAL_STORE_M0 != 0 && get_global_id(1) == 0;
143
144 TILE(int, M0, 1, indirect_buffer);
145 LOOP_UNROLLING(int, _i, 0, 1, M0,
146 {
147 indirect_buffer[_i].v = min(_i, select(M0 - 1, PARTIAL_STORE_M0 - 1, y_cond));
148 });
149
150 T_STORE_INDIRECT_WIDTH_SELECT(DATA_TYPE, M0, N0, PARTIAL_STORE_N0, BUFFER, dst, 0, dst_stride_y, x_cond, acc, indirect_buffer);
151}
152#endif // defined(MAT_MUL_NATIVE_NT_NT)
153
154#if defined(MAT_MUL_NATIVE_NT_T)
155/** This OpenCL kernel performs the batch matrix multiplication (BatchMatMul): LHS non-transposed, RHS transposed - buffer only
156 *
157 * @note the "batch" here expresses the number of matrix multiplications to run in parallel. However, it
158 * should NOT be confused with the batch size of the model. For NHWC the "batch" is the "H" dimension
Gunes Bayir8918b232023-03-17 13:52:21 +0000159 * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +0000160 * @note The block's dimensions used for the LHS and RHS matrices (M0, N0 and K0) must be passed at compile time using -DN0, -DM0 and -DK0 (e.g. -DN0=8, -DM0=4, -DK0=4).
Gunes Bayir8918b232023-03-17 13:52:21 +0000161 * @note The number of leftover outputs rows/columns must be passed using -DPARTIAL_STORE_N0 and -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_N0=2, -DPARTIAL_STORE_M0=3)
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +0000162 * @note The dimension K must be passed at compile time using -DK (e.g. -DK=6)
Ramy Elgammalb531b752023-03-20 10:19:10 +0000163 * @note The tensor type ("BUFFER" or "IMAGE") of the rhs tensor must be passed at compile time using -DRHS_TENSOR_TYPE (e.g. -DRHS_TENSOR_TYPE=BUFFER)
Gunes Bayir8918b232023-03-17 13:52:21 +0000164 * @note The kernel name in uppercase must be passed at compile time (e.g. -DMAT_MUL_NATIVE_NT_T)
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +0000165 * @note Only the following configurations of M0, N0 and K0 are currently supported:
166 * - M0 > 0
167 * - N0 = 1, 2, 3, 4, 8, 16
Ramy Elgammalb531b752023-03-20 10:19:10 +0000168 * - K0 = 1, 2, 3, 4, 8, 16 (only 4, 8, 16 if RHS_TENSOR_TYPE=IMAGE)
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +0000169 * @note Values > 8 for M0, N0 and K0 are not expected to be efficient
170 *
171 * @param[in] lhs_ptr Pointer to the lhs matrix. Supported data types: F32/F16
172 * @param[in] lhs_stride_y Stride of the lhs matrix in Y (2nd) dimension (in bytes)
173 * @param[in] lhs_stride_z Stride of the lhs tensor in Z (3rd) dimension (in bytes)
174 * @param[in] lhs_w The width of the lhs tensor
175 * @param[in] lhs_h The height of the lhs tensor
176 * @param[in] lhs_n Number of the matrices (buffers) in the batch
177 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the lhs matrix
Ramy Elgammalb531b752023-03-20 10:19:10 +0000178 * @param[in] rhs_img (Optional) Read only cl_image object for the rhs tensor. Included when RHS_TENSOR_TYPE=IMAGE
Gunes Bayir8918b232023-03-17 13:52:21 +0000179 * @param[in] rhs_ptr Pointer to the rhs matrix. Supported data types: same as @p lhs_ptr
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +0000180 * @param[in] rhs_stride_y Stride of the rhs matrix in Y (2nd) dimension (in bytes)
181 * @param[in] rhs_stride_z Stride of the rhs tensor in Z (3rd) dimension (in bytes)
182 * @param[in] rhs_w The width of the rhs tensor
183 * @param[in] rhs_h The height of the rhs tensor
184 * @param[in] rhs_n Number of the matrices (buffers) in the batch
185 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the rhs matrix
Gunes Bayir8918b232023-03-17 13:52:21 +0000186 * @param[out] dst_ptr Pointer to the dst matrix. Supported data types: same as @p lhs_ptr
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +0000187 * @param[in] dst_stride_y Stride of the dst matrix in Y (2nd) dimension (in bytes)
188 * @param[in] dst_stride_z Stride of the dst tensor in Z (3rd) dimension (in bytes)
189 * @param[in] dst_w The width of the dst tensor
190 * @param[in] dst_h The height of the dst tensor
191 * @param[in] dst_n Number of the matrices (buffers) in the batch
192 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the dst matrix
193 */
194__kernel void mat_mul_native_nt_t(TENSOR3D_T(lhs, BUFFER),
Ramy Elgammalb531b752023-03-20 10:19:10 +0000195 TENSOR3D_T(rhs, RHS_TENSOR_TYPE),
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +0000196 TENSOR3D_T(dst, BUFFER))
197
198{
199 const uint x = GET_SPATIAL_IDX(0, N0, PARTIAL_STORE_N0);
200 const uint y = GET_SPATIAL_IDX(1, M0, PARTIAL_STORE_M0);
201 const uint z = GET_SPATIAL_IDX(2, 1, 0);
202
203 // Compute LHS/RHS/DST matrix address
204 lhs_offset_first_element_in_bytes += y * lhs_stride_y + z * lhs_stride_z;
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +0000205 dst_offset_first_element_in_bytes += x * sizeof(DATA_TYPE) + y * dst_stride_y + z * dst_stride_z;
206
207 // Initialize the accumulators
208 TILE(DATA_TYPE, M0, N0, acc);
209
210 LOOP_UNROLLING(int, i, 0, 1, M0,
211 {
212 acc[i].v = 0.f;
213 })
214
Ramy Elgammalb531b752023-03-20 10:19:10 +0000215 const int rhs_z = z * rhs_h;
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +0000216 int k;
217 for(k = 0; k <= K - K0; k += K0)
218 {
219 TILE(DATA_TYPE, M0, K0, a);
220 TILE(DATA_TYPE, N0, K0, b);
221
222 LOOP_UNROLLING(int, i, 0, 1, M0,
223 {
224 a[i].v = 0.f;
225 })
226
227 LOOP_UNROLLING(int, i, 0, 1, N0,
228 {
229 b[i].v = 0.f;
230 })
231
232 // Load tile from the lhs/rhs tensors
233 T_LOAD(DATA_TYPE, M0, K0, BUFFER, lhs, 0, 0, 1, lhs_stride_y, a);
Ramy Elgammalb531b752023-03-20 10:19:10 +0000234 T_LOAD(DATA_TYPE, N0, K0, RHS_TENSOR_TYPE, rhs, k, x + rhs_z, 1, rhs_stride_y, b);
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +0000235
236#if GPU_ARCH == GPU_ARCH_MIDGARD
237 // This part is written to decrease the number of loop unrollings caused
238 // by T_MMUL. The NT/NT version is partly vectorized and uses less number
239 // of loop unrollings, and code behaves as expected. Although this is not
240 // a performant solution for the specified architecture, it is necessary
241 // to overcome some limitations.
242 TILE(DATA_TYPE, K0, N0, bt);
243 LOOP_UNROLLING(int, i, 0, 1, N0,
244 {
245 LOOP_UNROLLING(int, j, 0, 1, K0,
246 {
247 bt[j].s[i] = b[i].s[j];
248 })
249 })
250 T_MMUL(DATA_TYPE, DATA_TYPE, DATA_TYPE, M0, N0, K0, NT, NT, a, bt, acc);
Gunes Bayir8918b232023-03-17 13:52:21 +0000251#else // GPU_ARCH == GPU_ARCH_MIDGARD
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +0000252 T_MMUL(DATA_TYPE, DATA_TYPE, DATA_TYPE, M0, N0, K0, NT, T, a, b, acc);
253#endif // GPU_ARCH == GPU_ARCH_MIDGARD
254
255 lhs_offset_first_element_in_bytes += K0 * sizeof(DATA_TYPE);
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +0000256 }
257
258#if K % K0 != 0
259 /* Leftover Loop */
260 for(; k < K; ++k)
261 {
262 TILE(DATA_TYPE, M0, 1, a);
263 TILE(DATA_TYPE, N0, 1, b);
264
265 LOOP_UNROLLING(int, i, 0, 1, M0,
266 {
267 a[i].v = 0.f;
268 })
269
270 LOOP_UNROLLING(int, i, 0, 1, N0,
271 {
272 b[i].v = 0.f;
273 })
274
275 // Load tile from the lhs/rhs tensors
276 T_LOAD(DATA_TYPE, M0, 1, BUFFER, lhs, 0, 0, 1, lhs_stride_y, a);
Ramy Elgammalb531b752023-03-20 10:19:10 +0000277 T_LOAD(DATA_TYPE, N0, 1, BUFFER, rhs, k, x + rhs_z, 1, rhs_stride_y, b);
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +0000278
279#if GPU_ARCH == GPU_ARCH_MIDGARD
280 // See the main loop for the explanation of this part
281 TILE(DATA_TYPE, 1, N0, bt);
282 LOOP_UNROLLING(int, i, 0, 1, N0,
283 {
284 bt[0].s[i] = b[i].s[0];
285 })
286 T_MMUL(DATA_TYPE, DATA_TYPE, DATA_TYPE, M0, N0, 1, NT, NT, a, bt, acc);
Gunes Bayir8918b232023-03-17 13:52:21 +0000287#else // GPU_ARCH == GPU_ARCH_MIDGARD
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +0000288 T_MMUL(DATA_TYPE, DATA_TYPE, DATA_TYPE, M0, N0, 1, NT, T, a, b, acc);
289#endif // GPU_ARCH == GPU_ARCH_MIDGARD
290
291 lhs_offset_first_element_in_bytes += 1 * sizeof(DATA_TYPE);
Ramy Elgammal2b6ebfe2023-03-09 21:15:37 +0000292 }
293#endif // K % K0 != 0
294
295 const bool x_cond = PARTIAL_STORE_N0 != 0 && get_global_id(0) == 0;
296 const bool y_cond = PARTIAL_STORE_M0 != 0 && get_global_id(1) == 0;
297
298 TILE(int, M0, 1, indirect_buffer);
299 LOOP_UNROLLING(int, _i, 0, 1, M0,
300 {
301 indirect_buffer[_i].v = min(_i, select(M0 - 1, PARTIAL_STORE_M0 - 1, y_cond));
302 });
303
304 T_STORE_INDIRECT_WIDTH_SELECT(DATA_TYPE, M0, N0, PARTIAL_STORE_N0, BUFFER, dst, 0, dst_stride_y, x_cond, acc, indirect_buffer);
305}
Gunes Bayir8918b232023-03-17 13:52:21 +0000306#endif // defined(MAT_MUL_NATIVE_NT_T)
307
308#if defined(MAT_MUL_NATIVE_T_NT)
309/** This OpenCL kernel performs the batch matrix multiplication (BatchMatMul): LHS transposed, RHS non-transposed - buffer only
310 *
311 * @note the "batch" here expresses the number of matrix multiplications to run in parallel. However, it
312 * should NOT be confused with the batch size of the model. For NHWC the "batch" is the "H" dimension
313 * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
314 * @note The block's dimensions used for the LHS and RHS matrices (M0, N0 and K0) must be passed at compile time using -DN0, -DM0 and -DK0 (e.g. -DN0=8, -DM0=4, -DK0=4).
315 * @note The number of leftover outputs rows/columns must be passed using -DPARTIAL_STORE_N0 and -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_N0=2, -DPARTIAL_STORE_M0=3)
316 * @note The dimension K must be passed at compile time using -DK (e.g. -DK=6)
Gunes Bayirbbeef722023-03-20 10:19:10 +0000317 * @note The tensor type ("BUFFER" or "IMAGE") of the rhs tensor must be passed at compile time using -DRHS_TENSOR_TYPE (e.g. -DRHS_TENSOR_TYPE=BUFFER)
Gunes Bayir8918b232023-03-17 13:52:21 +0000318 * @note The kernel name in uppercase must be passed at compile time (e.g. -DMAT_MUL_NATIVE_T_NT)
319 * @note Only the following configurations of M0, N0 and K0 are currently supported:
320 * - M0 = 1, 2, 3, 4, 8, 16
Gunes Bayirbbeef722023-03-20 10:19:10 +0000321 * - N0 = 1, 2, 3, 4, 8, 16 (only 4, 8, 16 if RHS_TENSOR_TYPE=IMAGE)
Gunes Bayir8918b232023-03-17 13:52:21 +0000322 * - K0 > 0
323 * * @note Values > 8 for M0, and K0 are not expected to be efficient
324 *
325 * @param[in] lhs_ptr Pointer to the lhs matrix. Supported data types: F32/F16
326 * @param[in] lhs_stride_y Stride of the lhs matrix in Y (2nd) dimension (in bytes)
327 * @param[in] lhs_stride_z Stride of the lhs tensor in Z (3rd) dimension (in bytes)
328 * @param[in] lhs_w The width of the lhs tensor
329 * @param[in] lhs_h The height of the lhs tensor
330 * @param[in] lhs_n Number of the matrices (buffers) in the batch
331 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the lhs matrix
Gunes Bayirbbeef722023-03-20 10:19:10 +0000332 * @param[in] rhs_img (Optional) Read only cl_image object for the rhs tensor. Included when RHS_TENSOR_TYPE=IMAGE
Gunes Bayir8918b232023-03-17 13:52:21 +0000333 * @param[in] rhs_ptr Pointer to the rhs matrix. Supported data types: same as @p lhs_ptr
334 * @param[in] rhs_stride_y Stride of the rhs matrix in Y (2nd) dimension (in bytes)
335 * @param[in] rhs_stride_z Stride of the rhs tensor in Z (3rd) dimension (in bytes)
336 * @param[in] rhs_w The width of the rhs tensor
337 * @param[in] rhs_h The height of the rhs tensor
338 * @param[in] rhs_n Number of the matrices (buffers) in the batch
339 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the rhs matrix
340 * @param[out] dst_ptr Pointer to the dst matrix. Supported data types: same as @p lhs_ptr
341 * @param[in] dst_stride_y Stride of the dst matrix in Y (2nd) dimension (in bytes)
342 * @param[in] dst_stride_z Stride of the dst tensor in Z (3rd) dimension (in bytes)
343 * @param[in] dst_w The width of the dst tensor
344 * @param[in] dst_h The height of the dst tensor
345 * @param[in] dst_n Number of the matrices (buffers) in the batch
346 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the dst matrix
347 */
348__kernel void mat_mul_native_t_nt(
349 TENSOR3D_T(lhs, BUFFER),
Gunes Bayirbbeef722023-03-20 10:19:10 +0000350 TENSOR3D_T(rhs, RHS_TENSOR_TYPE),
Gunes Bayir8918b232023-03-17 13:52:21 +0000351 TENSOR3D_T(dst, BUFFER))
352{
353 const uint x = GET_SPATIAL_IDX(0, N0, PARTIAL_STORE_N0);
354 const uint y = GET_SPATIAL_IDX(1, M0, PARTIAL_STORE_M0);
355 const uint z = GET_SPATIAL_IDX(2, 1, 0);
356
357 // Compute LHS/RHS/DST matrix address
358 lhs_offset_first_element_in_bytes += y * sizeof(DATA_TYPE) + z * lhs_stride_z;
Gunes Bayir8918b232023-03-17 13:52:21 +0000359 dst_offset_first_element_in_bytes += x * sizeof(DATA_TYPE) + y * dst_stride_y + z * dst_stride_z;
360
361 // Initialize the accumulators
362 TILE(DATA_TYPE, M0, N0, acc);
363
364 LOOP_UNROLLING(int, i, 0, 1, M0,
365 {
366 acc[i].v = 0.f;
367 })
368
Gunes Bayirbbeef722023-03-20 10:19:10 +0000369 const int rhs_z = z * rhs_h;
Gunes Bayir8918b232023-03-17 13:52:21 +0000370 int k;
371 for(k = 0; k <= K - K0; k += K0)
372 {
373 TILE(DATA_TYPE, K0, M0, a);
374 TILE(DATA_TYPE, K0, N0, b);
375
376 LOOP_UNROLLING(int, i, 0, 1, K0,
377 {
378 a[i].v = 0.f;
379 })
380
381 LOOP_UNROLLING(int, i, 0, 1, K0,
382 {
383 b[i].v = 0.f;
384 })
385
386 // Load tile from the lhs/rhs tensors
387 T_LOAD(DATA_TYPE, K0, M0, BUFFER, lhs, 0, 0, 1, lhs_stride_y, a);
Gunes Bayirbbeef722023-03-20 10:19:10 +0000388 T_LOAD(DATA_TYPE, K0, N0, RHS_TENSOR_TYPE, rhs, x, k + rhs_z, 1, rhs_stride_y, b);
Gunes Bayir8918b232023-03-17 13:52:21 +0000389
390#if GPU_ARCH == GPU_ARCH_MIDGARD
391 // For explanation, see mat_mul_native_nt_t
392 TILE(DATA_TYPE, M0, K0, at);
393 LOOP_UNROLLING(int, i, 0, 1, K0,
394 {
395 LOOP_UNROLLING(int, j, 0, 1, M0,
396 {
397 at[j].s[i] = a[i].s[j];
398 })
399 })
400 T_MMUL(DATA_TYPE, DATA_TYPE, DATA_TYPE, M0, N0, K0, NT, NT, at, b, acc);
401#else // GPU_ARCH == GPU_ARCH_MIDGARD
402 T_MMUL(DATA_TYPE, DATA_TYPE, DATA_TYPE, M0, N0, K0, T, NT, a, b, acc);
403#endif // GPU_ARCH == GPU_ARCH_MIDGARD
404
405 lhs_offset_first_element_in_bytes += K0 * lhs_stride_y;
Gunes Bayir8918b232023-03-17 13:52:21 +0000406 }
407
408#ifdef K % K0 != 0
409 /* Leftover Loop */
410 for(; k < K; ++k)
411 {
412 TILE(DATA_TYPE, 1, M0, a);
413 TILE(DATA_TYPE, 1, N0, b);
414
415 LOOP_UNROLLING(int, i, 0, 1, 1,
416 {
417 a[i].v = 0.f;
418 })
419
420 LOOP_UNROLLING(int, i, 0, 1, 1,
421 {
422 b[i].v = 0.f;
423 })
424
425 // Load tile from the lhs/rhs tensors
426 T_LOAD(DATA_TYPE, 1, M0, BUFFER, lhs, 0, 0, 1, lhs_stride_y, a);
Gunes Bayirbbeef722023-03-20 10:19:10 +0000427 T_LOAD(DATA_TYPE, 1, N0, BUFFER, rhs, x, k + rhs_z, 1, rhs_stride_y, b);
Gunes Bayir8918b232023-03-17 13:52:21 +0000428
429#if GPU_ARCH == GPU_ARCH_MIDGARD
430 // For explanation, see mat_mul_native_nt_t
431 TILE(DATA_TYPE, M0, 1, at);
432 LOOP_UNROLLING(int, j, 0, 1, M0,
433 {
434 at[j].s[0] = a[0].s[j];
435 })
436 T_MMUL(DATA_TYPE, DATA_TYPE, DATA_TYPE, M0, N0, 1, NT, NT, at, b, acc);
437#else // GPU_ARCH == GPU_ARCH_MIDGARD
438 T_MMUL(DATA_TYPE, DATA_TYPE, DATA_TYPE, M0, N0, 1, T, NT, a, b, acc);
439#endif // GPU_ARCH == GPU_ARCH_MIDGARD
440
441 lhs_offset_first_element_in_bytes += 1 * lhs_stride_y;
Gunes Bayir8918b232023-03-17 13:52:21 +0000442 }
443#endif // K % K0 != 0
444
445 const bool x_cond = PARTIAL_STORE_N0 != 0 && get_global_id(0) == 0;
446 const bool y_cond = PARTIAL_STORE_M0 != 0 && get_global_id(1) == 0;
447
448 TILE(int, M0, 1, indirect_buffer);
449 LOOP_UNROLLING(int, _i, 0, 1, M0,
450 {
451 indirect_buffer[_i].v = min(_i, select(M0 - 1, PARTIAL_STORE_M0 - 1, y_cond));
452 });
453
454 T_STORE_INDIRECT_WIDTH_SELECT(DATA_TYPE, M0, N0, PARTIAL_STORE_N0, BUFFER, dst, 0, dst_stride_y, x_cond, acc, indirect_buffer);
455}
456#endif // defined(MAT_MUL_NATIVE_T_NT)
457
458#if defined(MAT_MUL_NATIVE_T_T)
459/** This OpenCL kernel performs the batch matrix multiplication (BatchMatMul): LHS transposed, RHS transposed - buffer only
460 *
461 * @note the "batch" here expresses the number of matrix multiplications to run in parallel. However, it
462 * should NOT be confused with the batch size of the model. For NHWC the "batch" is the "H" dimension
463 * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
464 * @note The block's dimensions used for the LHS and RHS matrices (M0, N0 and K0) must be passed at compile time using -DN0, -DM0 and -DK0 (e.g. -DN0=8, -DM0=4, -DK0=4).
465 * @note The number of leftover outputs rows/columns must be passed using -DPARTIAL_STORE_N0 and -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_N0=2, -DPARTIAL_STORE_M0=3)
466 * @note The dimension K must be passed at compile time using -DK (e.g. -DK=6)
Ramy Elgammalb531b752023-03-20 10:19:10 +0000467 * @note The tensor type ("BUFFER" or "IMAGE") of the rhs tensor must be passed at compile time using -DRHS_TENSOR_TYPE (e.g. -DRHS_TENSOR_TYPE=BUFFER)
Gunes Bayir8918b232023-03-17 13:52:21 +0000468 * @note The kernel name in uppercase must be passed at compile time (e.g. -DMAT_MUL_NATIVE_T_NT)
469 * @note Only the following configurations of M0, N0 and K0 are currently supported:
470 * - M0 = 1, 2, 3, 4, 8, 16
471 * - N0 = 1, 2, 3, 4, 8, 16
Ramy Elgammalb531b752023-03-20 10:19:10 +0000472 * - K0 = 1, 2, 3, 4, 8, 16 (only 4, 8, 16 if RHS_TENSOR_TYPE=IMAGE)
Gunes Bayir8918b232023-03-17 13:52:21 +0000473 * @note Values > 8 for M0, N0 and K0 are not expected to be efficient
474 *
475 * @param[in] lhs_ptr Pointer to the lhs matrix. Supported data types: F32/F16
476 * @param[in] lhs_stride_y Stride of the lhs matrix in Y (2nd) dimension (in bytes)
477 * @param[in] lhs_stride_z Stride of the lhs tensor in Z (3rd) dimension (in bytes)
478 * @param[in] lhs_w The width of the lhs tensor
479 * @param[in] lhs_h The height of the lhs tensor
480 * @param[in] lhs_n Number of the matrices (buffers) in the batch
481 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the lhs matrix
Ramy Elgammalb531b752023-03-20 10:19:10 +0000482 * @param[in] rhs_img (Optional) Read only cl_image object for the rhs tensor. Included when RHS_TENSOR_TYPE=IMAGE
Gunes Bayir8918b232023-03-17 13:52:21 +0000483 * @param[in] rhs_ptr Pointer to the rhs matrix. Supported data types: same as @p lhs_ptr
484 * @param[in] rhs_stride_y Stride of the rhs matrix in Y (2nd) dimension (in bytes)
485 * @param[in] rhs_stride_z Stride of the rhs tensor in Z (3rd) dimension (in bytes)
486 * @param[in] rhs_w The width of the rhs tensor
487 * @param[in] rhs_h The height of the rhs tensor
488 * @param[in] rhs_n Number of the matrices (buffers) in the batch
489 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the rhs matrix
490 * @param[out] dst_ptr Pointer to the dst matrix. Supported data types: same as @p lhs_ptr
491 * @param[in] dst_stride_y Stride of the dst matrix in Y (2nd) dimension (in bytes)
492 * @param[in] dst_stride_z Stride of the dst tensor in Z (3rd) dimension (in bytes)
493 * @param[in] dst_w The width of the dst tensor
494 * @param[in] dst_h The height of the dst tensor
495 * @param[in] dst_n Number of the matrices (buffers) in the batch
496 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the dst matrix
497 */
498__kernel void mat_mul_native_t_t(
499 TENSOR3D_T(lhs, BUFFER),
Ramy Elgammalb531b752023-03-20 10:19:10 +0000500 TENSOR3D_T(rhs, RHS_TENSOR_TYPE),
Gunes Bayir8918b232023-03-17 13:52:21 +0000501 TENSOR3D_T(dst, BUFFER))
502{
503 const uint x = GET_SPATIAL_IDX(0, N0, PARTIAL_STORE_N0);
504 const uint y = GET_SPATIAL_IDX(1, M0, PARTIAL_STORE_M0);
505 const uint z = GET_SPATIAL_IDX(2, 1, 0);
506
507 // Compute LHS/RHS/DST matrix address
508 lhs_offset_first_element_in_bytes += y * sizeof(DATA_TYPE) + z * lhs_stride_z;
Gunes Bayir8918b232023-03-17 13:52:21 +0000509 dst_offset_first_element_in_bytes += x * sizeof(DATA_TYPE) + y * dst_stride_y + z * dst_stride_z;
510
511 // Initialize the accumulators
512 TILE(DATA_TYPE, M0, N0, acc);
513
514 LOOP_UNROLLING(int, i, 0, 1, M0,
515 {
516 acc[i].v = 0.f;
517 })
518
Ramy Elgammalb531b752023-03-20 10:19:10 +0000519 const int rhs_z = z * rhs_h;
Gunes Bayir8918b232023-03-17 13:52:21 +0000520 int k;
521 for(k = 0; k <= K - K0; k += K0)
522 {
523 TILE(DATA_TYPE, K0, M0, a);
524 TILE(DATA_TYPE, N0, K0, b);
525
526 LOOP_UNROLLING(int, i, 0, 1, K0,
527 {
528 a[i].v = 0.f;
529 })
530
531 LOOP_UNROLLING(int, i, 0, 1, N0,
532 {
533 b[i].v = 0.f;
534 })
535
536 // Load tile from the lhs/rhs tensors
537 T_LOAD(DATA_TYPE, K0, M0, BUFFER, lhs, 0, 0, 1, lhs_stride_y, a);
Ramy Elgammalb531b752023-03-20 10:19:10 +0000538 T_LOAD(DATA_TYPE, N0, K0, RHS_TENSOR_TYPE, rhs, k, x + rhs_z, 1, rhs_stride_y, b);
Gunes Bayir8918b232023-03-17 13:52:21 +0000539#if GPU_ARCH == GPU_ARCH_MIDGARD
540 // For explanation, see mat_mul_native_nt_t
541 TILE(DATA_TYPE, M0, K0, at);
542 TILE(DATA_TYPE, K0, N0, bt);
543
544 LOOP_UNROLLING(int, i, 0, 1, K0,
545 {
546 LOOP_UNROLLING(int, j, 0, 1, M0,
547 {
548 at[j].s[i] = a[i].s[j];
549 })
550 })
551
552 LOOP_UNROLLING(int, i, 0, 1, N0,
553 {
554 LOOP_UNROLLING(int, j, 0, 1, K0,
555 {
556 bt[j].s[i] = b[i].s[j];
557 })
558 })
559
560 T_MMUL(DATA_TYPE, DATA_TYPE, DATA_TYPE, M0, N0, K0, NT, NT, at, bt, acc);
561#else // GPU_ARCH == GPU_ARCH_MIDGARD
562 T_MMUL(DATA_TYPE, DATA_TYPE, DATA_TYPE, M0, N0, K0, T, T, a, b, acc);
563#endif // GPU_ARCH == GPU_ARCH_MIDGARD
564
565 lhs_offset_first_element_in_bytes += K0 * lhs_stride_y;
Gunes Bayir8918b232023-03-17 13:52:21 +0000566 }
567
568#ifdef K % K0 != 0
569 /* Leftover Loop */
570 for(; k < K; ++k)
571 {
572 TILE(DATA_TYPE, 1, M0, a);
573 TILE(DATA_TYPE, N0, 1, b);
574
575 LOOP_UNROLLING(int, i, 0, 1, 1,
576 {
577 a[i].v = 0.f;
578 })
579
580 LOOP_UNROLLING(int, i, 0, 1, N0,
581 {
582 b[i].v = 0.f;
583 })
584
585 // Load tile from the lhs/rhs tensors
586 T_LOAD(DATA_TYPE, 1, M0, BUFFER, lhs, 0, 0, 1, lhs_stride_y, a);
Ramy Elgammalb531b752023-03-20 10:19:10 +0000587 T_LOAD(DATA_TYPE, N0, 1, BUFFER, rhs, k, x + rhs_z, 1, rhs_stride_y, b);
Gunes Bayir8918b232023-03-17 13:52:21 +0000588
589#if GPU_ARCH == GPU_ARCH_MIDGARD
590 // For explanation, see mat_mul_native_nt_t
591 TILE(DATA_TYPE, M0, 1, at);
592 TILE(DATA_TYPE, 1, N0, bt);
593
594 LOOP_UNROLLING(int, j, 0, 1, M0,
595 {
596 at[j].s[0] = a[0].s[j];
597 })
598
599 LOOP_UNROLLING(int, i, 0, 1, N0,
600 {
601 bt[0].s[i] = b[i].s[0];
602 })
603
604 T_MMUL(DATA_TYPE, DATA_TYPE, DATA_TYPE, M0, N0, 1, NT, NT, at, bt, acc);
605#else // GPU_ARCH == GPU_ARCH_MIDGARD
606 T_MMUL(DATA_TYPE, DATA_TYPE, DATA_TYPE, M0, N0, 1, T, T, a, b, acc);
607#endif // GPU_ARCH == GPU_ARCH_MIDGARD
608
609 lhs_offset_first_element_in_bytes += 1 * lhs_stride_y;
Gunes Bayir8918b232023-03-17 13:52:21 +0000610 }
611#endif // K % K0 != 0
612
613 const bool x_cond = PARTIAL_STORE_N0 != 0 && get_global_id(0) == 0;
614 const bool y_cond = PARTIAL_STORE_M0 != 0 && get_global_id(1) == 0;
615
616 TILE(int, M0, 1, indirect_buffer);
617 LOOP_UNROLLING(int, _i, 0, 1, M0,
618 {
619 indirect_buffer[_i].v = min(_i, select(M0 - 1, PARTIAL_STORE_M0 - 1, y_cond));
620 });
621
622 T_STORE_INDIRECT_WIDTH_SELECT(DATA_TYPE, M0, N0, PARTIAL_STORE_N0, BUFFER, dst, 0, dst_stride_y, x_cond, acc, indirect_buffer);
623}
624#endif // defined(MAT_MUL_NATIVE_T_T)