blob: 5c931d2fc1b88b558a155b760918e0a079c95b1d [file] [log] [blame]
Gunes Bayir9d0c4de2023-04-13 18:22:58 +01001/*
2 * Copyright (c) 2023 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "helpers.h"
25#include "tile_helpers.h"
26
27#if defined(MAT_MUL_NATIVE_QUANTIZED_NT_NT)
28/** This OpenCL kernel performs the batch matrix multiplication (BatchMatMul): LHS non-transposed, RHS non-transposed - buffer only
29 *
30 * @note the "batch" here expresses the number of matrix multiplications to run in parallel. However, it
31 * should NOT be confused with the batch size of the model. For NHWC the "batch" is the "H" dimension
32 * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=uchar)
33 * @note The block's dimensions used for the LHS and RHS matrices (M0, N0 and K0) must be passed at compile time using -DN0, -DM0 and -DK0 (e.g. -DN0=8, -DM0=4, -DK0=4).
34 * @note The number of leftover outputs rows/columns must be passed using -DPARTIAL_STORE_N0 and -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_N0=2, -DPARTIAL_STORE_M0=3)
35 * @note The dimension K must be passed at compile time using -DK (e.g. -DK=6)
36 * @note The kernel name in uppercase must be passed at compile time (e.g. -DMAT_MUL_NATIVE_QUANTIZED_NT_NT)
37 * @note Only the following configurations of M0, N0 and K0 are currently supported:
38 * - M0 > 0
39 * - N0 = 1, 2, 3, 4, 8, 16
40 * - K0 = 1, 2, 3, 4, 8, 16
41 * @note Values > 8 for M0 are not expected to be efficient
42 *
43 * @param[in] lhs_ptr Pointer to the lhs matrix. Supported data types: QASYMM8_SIGNED/QASYMM8
44 * @param[in] lhs_stride_y Stride of the lhs matrix in Y (2nd) dimension (in bytes)
45 * @param[in] lhs_stride_z Stride of the lhs tensor in Z (3rd) dimension (in bytes)
46 * @param[in] lhs_w The width of the lhs tensor
47 * @param[in] lhs_h The height of the lhs tensor
48 * @param[in] lhs_n Number of the matrices (buffers) in the batch
49 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the lhs matrix
50 * @param[in] rhs_ptr Pointer to the rhs matrix. Supported data types: same as @p lhs_ptr
51 * @param[in] rhs_stride_y Stride of the rhs matrix in Y (2nd) dimension (in bytes)
52 * @param[in] rhs_stride_z Stride of the rhs tensor in Z (3rd) dimension (in bytes)
53 * @param[in] rhs_w The width of the rhs tensor
54 * @param[in] rhs_h The height of the rhs tensor
55 * @param[in] rhs_n Number of the matrices (buffers) in the batch
56 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the rhs matrix
57 * @param[out] dst_ptr Pointer to the dst matrix. Supported data types: same as @p lhs_ptr
58 * @param[in] dst_stride_y Stride of the dst matrix in Y (2nd) dimension (in bytes)
59 * @param[in] dst_stride_z Stride of the dst tensor in Z (3rd) dimension (in bytes)
60 * @param[in] dst_w The width of the dst tensor
61 * @param[in] dst_h The height of the dst tensor
62 * @param[in] dst_n Number of the matrices (buffers) in the batch
63 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the dst matrix
64 */
65__kernel void mat_mul_native_quantized_nt_nt(
66 TENSOR3D_T(lhs, BUFFER),
67 TENSOR3D_T(rhs, BUFFER),
68 TENSOR3D_T(dst, BUFFER))
69{
70 const uint x = GET_SPATIAL_IDX(0, N0, PARTIAL_STORE_N0);
71 const uint y = GET_SPATIAL_IDX(1, M0, PARTIAL_STORE_M0);
72 const uint z = GET_SPATIAL_IDX(2, 1, 0);
73
74 // Compute LHS/RHS/DST matrix address
75 lhs_offset_first_element_in_bytes += y * lhs_stride_y + z * lhs_stride_z;
76 rhs_offset_first_element_in_bytes += x * sizeof(DATA_TYPE) + z * rhs_stride_z;
77 dst_offset_first_element_in_bytes += x * sizeof(DATA_TYPE) + y * dst_stride_y + z * dst_stride_z;
78
79 // Initialize the accumulators
80 TILE(int, M0, N0, acc);
81 LOOP_UNROLLING(int, i, 0, 1, M0,
82 {
83 acc[i].v = K * ((int)LHS_OFFSET) * ((int)RHS_OFFSET);
84 })
85
86 TILE(int, 1, N0, b_sum);
87 b_sum[0].v = 0;
88
89 TILE(int, 1, M0, a_sum);
90 a_sum[0].v = 0;
91
92 int k;
93 for(k = 0; k <= K - K0; k += K0)
94 {
95 TILE(DATA_TYPE, M0, K0, a);
96 TILE(DATA_TYPE, N0, K0, b);
97
98 LOOP_UNROLLING(int, i, 0, 1, M0,
99 {
100 a[i].v = 0;
101 })
102
103 LOOP_UNROLLING(int, i, 0, 1, N0,
104 {
105 b[i].v = 0;
106 })
107
108 // Load tile from the lhs tensor
109 T_LOAD(DATA_TYPE, M0, K0, BUFFER, lhs, 0, 0, 1, lhs_stride_y, a);
110
111 // Load tile from the rhs tensor in a transposed fashion
112 // in order to use T_MMUL_NT_T macro because only this macro
113 // can utilize dot product instruction for Int8/UInt8 by
114 // directly multiplying the rows of Lhs and Rhs tensors.
115 T_LOAD_TRANSPOSED(DATA_TYPE, K0, N0, BUFFER, rhs, 0, 0, 1, rhs_stride_y, b);
116
117 T_MMUL(DATA_TYPE, DATA_TYPE, int, M0, N0, K0, NT, T, a, b, acc);
118
119 LOOP_UNROLLING(int, i, 0, 1, M0,
120 {
121 LOOP_UNROLLING(int, j, 0, 1, K0,
122 {
123 a_sum[0].s[i] += (int)a[i].s[j];
124 })
125 })
126
127 LOOP_UNROLLING(int, i, 0, 1, K0,
128 {
129 LOOP_UNROLLING(int, j, 0, 1, N0,
130 {
131 b_sum[0].s[j] += (int)b[j].s[i];
132 })
133 })
134
135 lhs_offset_first_element_in_bytes += K0 * sizeof(DATA_TYPE);
136 rhs_offset_first_element_in_bytes += K0 * rhs_stride_y;
137 }
138
139#if((K % K0) != 0)
140 /* Leftover Loop */
141 for(; k < K; ++k)
142 {
143 TILE(DATA_TYPE, M0, 1, a);
144 TILE(DATA_TYPE, N0, 1, b);
145
146 LOOP_UNROLLING(int, i, 0, 1, M0,
147 {
148 a[i].v = 0;
149 })
150
151 LOOP_UNROLLING(int, i, 0, 1, N0,
152 {
153 b[i].v = 0;
154 })
155
156 // Load tile from the lhs tensor
157 T_LOAD(DATA_TYPE, M0, 1, BUFFER, lhs, 0, 0, 1, lhs_stride_y, a);
158
159 // Load tile from the rhs tensor in a transposed fashion.
160 // See the main loop for more explanation
161 T_LOAD_TRANSPOSED(DATA_TYPE, 1, N0, BUFFER, rhs, 0, 0, 1, rhs_stride_y, b);
162
163 T_MMUL(DATA_TYPE, DATA_TYPE, int, M0, N0, 1, NT, T, a, b, acc);
164
165 LOOP_UNROLLING(int, i, 0, 1, M0,
166 {
167 LOOP_UNROLLING(int, j, 0, 1, 1,
168 {
169 a_sum[0].s[i] += (int)a[i].s[j];
170 })
171 })
172
173 LOOP_UNROLLING(int, i, 0, 1, 1,
174 {
175 LOOP_UNROLLING(int, j, 0, 1, N0,
176 {
177 b_sum[0].s[j] += (int)b[j].s[i];
178 })
179 })
180
181 lhs_offset_first_element_in_bytes += 1 * sizeof(DATA_TYPE);
182 rhs_offset_first_element_in_bytes += 1 * rhs_stride_y;
183 }
184#endif // ((K % K0) != 0)
185
186 LOOP_UNROLLING(int, i, 0, 1, M0,
187 {
188 LOOP_UNROLLING(int, j, 0, 1, N0,
189 {
190 acc[i].s[j] += ((int)RHS_OFFSET) * a_sum[0].s[i] + ((int)(LHS_OFFSET)) * b_sum[0].s[j];
191 })
192 })
193
194 const bool x_cond = PARTIAL_STORE_N0 != 0 && get_global_id(0) == 0;
195 const bool y_cond = PARTIAL_STORE_M0 != 0 && get_global_id(1) == 0;
196
197 // Quantize the tile
198 TILE(DATA_TYPE, M0, N0, accq);
199 T_QUANTIZE8_ASYMMETRIC(int, DATA_TYPE, M0, N0, DST_OFFSET, DST_SHIFT, DST_MULTIPLIER, acc, accq);
200
201 TILE(int, M0, 1, indirect_buffer);
202 LOOP_UNROLLING(int, _i, 0, 1, M0,
203 {
204 indirect_buffer[_i].v = min(_i, select(M0 - 1, PARTIAL_STORE_M0 - 1, y_cond));
205 });
206
207 T_STORE_INDIRECT_WIDTH_SELECT(DATA_TYPE, M0, N0, PARTIAL_STORE_N0, BUFFER, dst, 0, dst_stride_y, x_cond, accq, indirect_buffer);
208}
209#endif // defined(MAT_MUL_NATIVE_QUANTIZED_NT_NT)
210
211#if defined(MAT_MUL_NATIVE_QUANTIZED_T_NT)
212/** This OpenCL kernel performs the batch matrix multiplication (BatchMatMul): LHS transposed, RHS non-transposed
213 *
214 * @note the "batch" here expresses the number of matrix multiplications to run in parallel. However, it
215 * should NOT be confused with the batch size of the model. For NHWC the "batch" is the "H" dimension
216 * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=uchar)
217 * @note The block's dimensions used for the LHS and RHS matrices (M0, N0 and K0) must be passed at compile time using -DN0, -DM0 and -DK0 (e.g. -DN0=8, -DM0=4, -DK0=4).
218 * @note The number of leftover outputs rows/columns must be passed using -DPARTIAL_STORE_N0 and -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_N0=2, -DPARTIAL_STORE_M0=3)
219 * @note The dimension K must be passed at compile time using -DK (e.g. -DK=6)
220 * @note The kernel name in uppercase must be passed at compile time (e.g. -DMAT_MUL_NATIVE_QUANTIZED_T_NT)
221 * @note Only the following configurations of M0, N0 and K0 are currently supported:
222 * - M0 > 0
223 * - N0 = 1, 2, 3, 4, 8, 16
224 * - K0 = 1, 2, 3, 4, 8, 16
225 * @note Values > 8 for M0, N0 and K0 are not expected to be efficient
226 *
227 * @param[in] lhs_ptr Pointer to the lhs matrix. Supported data types: QASYMM8/QASYMM8_SIGNED
228 * @param[in] lhs_stride_y Stride of the lhs matrix in Y (2nd) dimension (in bytes)
229 * @param[in] lhs_stride_z Stride of the lhs tensor in Z (3rd) dimension (in bytes)
230 * @param[in] lhs_w The width of the lhs tensor
231 * @param[in] lhs_h The height of the lhs tensor
232 * @param[in] lhs_n Number of the matrices (buffers) in the batch
233 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the lhs matrix
234 * @param[in] rhs_ptr Pointer to the rhs matrix. Supported data types: same as @p lhs_ptr
235 * @param[in] rhs_stride_y Stride of the rhs matrix in Y (2nd) dimension (in bytes)
236 * @param[in] rhs_stride_z Stride of the rhs tensor in Z (3rd) dimension (in bytes)
237 * @param[in] rhs_w The width of the rhs tensor
238 * @param[in] rhs_h The height of the rhs tensor
239 * @param[in] rhs_n Number of the matrices (buffers) in the batch
240 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the rhs matrix
241 * @param[out] dst_ptr Pointer to the dst matrix. Supported data types: same as @p lhs_ptr
242 * @param[in] dst_stride_y Stride of the dst matrix in Y (2nd) dimension (in bytes)
243 * @param[in] dst_stride_z Stride of the dst tensor in Z (3rd) dimension (in bytes)
244 * @param[in] dst_w The width of the dst tensor
245 * @param[in] dst_h The height of the dst tensor
246 * @param[in] dst_n Number of the matrices (buffers) in the batch
247 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the dst matrix
248 */
249__kernel void mat_mul_native_quantized_t_nt(
250 TENSOR3D_T(lhs, BUFFER),
251 TENSOR3D_T(rhs, BUFFER),
252 TENSOR3D_T(dst, BUFFER))
253{
254 const uint x = GET_SPATIAL_IDX(0, N0, PARTIAL_STORE_N0);
255 const uint y = GET_SPATIAL_IDX(1, M0, PARTIAL_STORE_M0);
256 const uint z = GET_SPATIAL_IDX(2, 1, 0);
257
258 // Compute LHS/RHS/DST matrix address
259 lhs_offset_first_element_in_bytes += y * sizeof(DATA_TYPE) + z * lhs_stride_z;
260 rhs_offset_first_element_in_bytes += x * sizeof(DATA_TYPE) + z * rhs_stride_z;
261 dst_offset_first_element_in_bytes += x * sizeof(DATA_TYPE) + y * dst_stride_y + z * dst_stride_z;
262
263 // Initialize the accumulators
264 TILE(int, M0, N0, acc);
265 LOOP_UNROLLING(int, i, 0, 1, M0,
266 {
267 acc[i].v = K * ((int)LHS_OFFSET) * ((int)RHS_OFFSET);
268 })
269
270 TILE(int, 1, N0, b_sum);
271 b_sum[0].v = 0;
272
273 TILE(int, 1, M0, a_sum);
274 a_sum[0].v = 0;
275
276 int k;
277 for(k = 0; k <= K - K0; k += K0)
278 {
279 TILE(DATA_TYPE, M0, K0, a);
280 TILE(DATA_TYPE, N0, K0, b);
281
282 LOOP_UNROLLING(int, i, 0, 1, M0,
283 {
284 a[i].v = 0;
285 })
286
287 LOOP_UNROLLING(int, i, 0, 1, N0,
288 {
289 b[i].v = 0;
290 })
291
292 // Load tile from the lhs/rhs tensors in a transposed fashion
293 // see mat_mul_native_quantized_nt_nt main loop for more explanation
294 T_LOAD_TRANSPOSED(DATA_TYPE, K0, M0, BUFFER, lhs, 0, 0, 1, lhs_stride_y, a);
295 T_LOAD_TRANSPOSED(DATA_TYPE, K0, N0, BUFFER, rhs, 0, 0, 1, rhs_stride_y, b);
296
297 T_MMUL(DATA_TYPE, DATA_TYPE, int, M0, N0, K0, NT, T, a, b, acc);
298
299 LOOP_UNROLLING(int, i, 0, 1, K0,
300 {
301 LOOP_UNROLLING(int, j, 0, 1, M0,
302 {
303 a_sum[0].s[j] += (int)a[j].s[i];
304 })
305 })
306
307 LOOP_UNROLLING(int, i, 0, 1, K0,
308 {
309 LOOP_UNROLLING(int, j, 0, 1, N0,
310 {
311 b_sum[0].s[j] += (int)b[j].s[i];
312 })
313 })
314
315 lhs_offset_first_element_in_bytes += K0 * lhs_stride_y;
316 rhs_offset_first_element_in_bytes += K0 * rhs_stride_y;
317 }
318
319#if((K % K0) != 0)
320 /* Leftover Loop */
321 for(; k < K; ++k)
322 {
323 TILE(DATA_TYPE, M0, 1, a);
324 TILE(DATA_TYPE, N0, 1, b);
325
326 LOOP_UNROLLING(int, i, 0, 1, M0,
327 {
328 a[i].v = 0;
329 })
330
331 LOOP_UNROLLING(int, i, 0, 1, N0,
332 {
333 b[i].v = 0;
334 })
335
336 // Load tile from the lhs/rhs tensors in a transposed fashion
337 // see mat_mul_native_quantized_nt_nt main loop for more explanation
338 T_LOAD_TRANSPOSED(DATA_TYPE, 1, M0, BUFFER, lhs, 0, 0, 1, lhs_stride_y, a);
339 T_LOAD_TRANSPOSED(DATA_TYPE, 1, N0, BUFFER, rhs, 0, 0, 1, rhs_stride_y, b);
340
341 T_MMUL(DATA_TYPE, DATA_TYPE, int, M0, N0, 1, NT, T, a, b, acc);
342
343 LOOP_UNROLLING(int, i, 0, 1, 1,
344 {
345 LOOP_UNROLLING(int, j, 0, 1, M0,
346 {
347 a_sum[0].s[j] += (int)a[j].s[i];
348 })
349 })
350
351 LOOP_UNROLLING(int, i, 0, 1, 1,
352 {
353 LOOP_UNROLLING(int, j, 0, 1, N0,
354 {
355 b_sum[0].s[j] += (int)b[j].s[i];
356 })
357 })
358
359 lhs_offset_first_element_in_bytes += 1 * lhs_stride_y;
360 rhs_offset_first_element_in_bytes += 1 * rhs_stride_y;
361 }
362#endif // ((K % K0) != 0)
363
364 LOOP_UNROLLING(int, i, 0, 1, M0,
365 {
366 LOOP_UNROLLING(int, j, 0, 1, N0,
367 {
368 acc[i].s[j] += ((int)(RHS_OFFSET)) * a_sum[0].s[i] + ((int)(LHS_OFFSET)) * b_sum[0].s[j];
369 })
370 })
371
372 const bool x_cond = PARTIAL_STORE_N0 != 0 && get_global_id(0) == 0;
373 const bool y_cond = PARTIAL_STORE_M0 != 0 && get_global_id(1) == 0;
374
375 // Quantize the tile
376 TILE(DATA_TYPE, M0, N0, accq);
377 T_QUANTIZE8_ASYMMETRIC(int, DATA_TYPE, M0, N0, DST_OFFSET, DST_SHIFT, DST_MULTIPLIER, acc, accq);
378
379 TILE(int, M0, 1, indirect_buffer);
380 LOOP_UNROLLING(int, _i, 0, 1, M0,
381 {
382 indirect_buffer[_i].v = min(_i, select(M0 - 1, PARTIAL_STORE_M0 - 1, y_cond));
383 });
384
385 T_STORE_INDIRECT_WIDTH_SELECT(DATA_TYPE, M0, N0, PARTIAL_STORE_N0, BUFFER, dst, 0, dst_stride_y, x_cond, accq, indirect_buffer);
386}
387#endif // defined(MAT_MUL_NATIVE_QUANTIZED_T_NT)
Omar Al Khatib467daef2023-04-13 14:56:23 +0100388
389#if defined(MAT_MUL_NATIVE_QUANTIZED_T_T)
390/** This OpenCL kernel performs the batch matrix multiplication (BatchMatMul): LHS transposed, RHS transposed
391 *
392 * @note the "batch" here expresses the number of matrix multiplications to run in parallel. However, it
393 * should NOT be confused with the batch size of the model. For NHWC the "batch" is the "H" dimension
394 * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=uchar)
395 * @note The block's dimensions used for the LHS and RHS matrices (M0, N0 and K0) must be passed at compile time using -DN0, -DM0 and -DK0 (e.g. -DN0=8, -DM0=4, -DK0=4).
396 * @note The number of leftover outputs rows/columns must be passed using -DPARTIAL_STORE_N0 and -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_N0=2, -DPARTIAL_STORE_M0=3)
397 * @note The dimension K must be passed at compile time using -DK (e.g. -DK=6)
398 * @note The kernel name in uppercase must be passed at compile time (e.g. -DMAT_MUL_NATIVE_QUANTIZED_T_T)
399 * @note Only the following configurations of M0, N0 and K0 are currently supported:
400 * - M0 = 1, 2, 3, 4, 8, 16
401 * - N0 = 1, 2, 3, 4, 8, 16
402 * - K0 = 1, 2, 3, 4, 8, 16
403 * @note Values > 8 for M0, N0 and K0 are not expected to be efficient
404 *
405 * @param[in] lhs_ptr Pointer to the lhs matrix. Supported data types: QASYMM8/QASYMM8_SIGNED
406 * @param[in] lhs_stride_y Stride of the lhs matrix in Y (2nd) dimension (in bytes)
407 * @param[in] lhs_stride_z Stride of the lhs tensor in Z (3rd) dimension (in bytes)
408 * @param[in] lhs_w The width of the lhs tensor
409 * @param[in] lhs_h The height of the lhs tensor
410 * @param[in] lhs_n Number of the matrices (buffers) in the batch
411 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the lhs matrix
412 * @param[in] rhs_ptr Pointer to the rhs matrix. Supported data types: same as @p lhs_ptr
413 * @param[in] rhs_stride_y Stride of the rhs matrix in Y (2nd) dimension (in bytes)
414 * @param[in] rhs_stride_z Stride of the rhs tensor in Z (3rd) dimension (in bytes)
415 * @param[in] rhs_w The width of the rhs tensor
416 * @param[in] rhs_h The height of the rhs tensor
417 * @param[in] rhs_n Number of the matrices (buffers) in the batch
418 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the rhs matrix
419 * @param[out] dst_ptr Pointer to the dst matrix. Supported data types: same as @p lhs_ptr
420 * @param[in] dst_stride_y Stride of the dst matrix in Y (2nd) dimension (in bytes)
421 * @param[in] dst_stride_z Stride of the dst tensor in Z (3rd) dimension (in bytes)
422 * @param[in] dst_w The width of the dst tensor
423 * @param[in] dst_h The height of the dst tensor
424 * @param[in] dst_n Number of the matrices (buffers) in the batch
425 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the dst matrix
426 */
427__kernel void mat_mul_native_quantized_t_t(
428 TENSOR3D_T(lhs, BUFFER),
429 TENSOR3D_T(rhs, BUFFER),
430 TENSOR3D_T(dst, BUFFER))
431{
432 const uint x = GET_SPATIAL_IDX(0, N0, PARTIAL_STORE_N0);
433 const uint y = GET_SPATIAL_IDX(1, M0, PARTIAL_STORE_M0);
434 const uint z = GET_SPATIAL_IDX(2, 1, 0);
435
436 // Compute LHS/RHS/DST matrix address
437 lhs_offset_first_element_in_bytes += y * sizeof(DATA_TYPE) + z * lhs_stride_z;
438 rhs_offset_first_element_in_bytes += x * rhs_stride_y + z * rhs_stride_z;
439 dst_offset_first_element_in_bytes += x * sizeof(DATA_TYPE) + y * dst_stride_y + z * dst_stride_z;
440
441 // Initialize the accumulators
442 TILE(int, M0, N0, acc);
443 LOOP_UNROLLING(int, i, 0, 1, M0,
444 {
445 acc[i].v = K * ((int)LHS_OFFSET) * ((int)RHS_OFFSET);
446 })
447
448 TILE(int, 1, M0, a_sum);
449 a_sum[0].v = 0;
450
451 TILE(int, 1, N0, b_sum);
452 b_sum[0].v = 0;
453
454 int k;
455 for(k = 0; k <= K - K0; k += K0)
456 {
457 TILE(DATA_TYPE, M0, K0, a);
458 TILE(DATA_TYPE, N0, K0, b);
459
460 LOOP_UNROLLING(int, i, 0, 1, M0,
461 {
462 a[i].v = 0;
463 })
464
465 LOOP_UNROLLING(int, i, 0, 1, N0,
466 {
467 b[i].v = 0;
468 })
469
470 // Load tile from the lhs tensor in a transposed fashion
471 // see mat_mul_native_quantized_nt_nt main loop for more explanation
472 T_LOAD_TRANSPOSED(DATA_TYPE, K0, M0, BUFFER, lhs, 0, 0, 1, lhs_stride_y, a);
473
474 // Load tile from the rhs tensor
475 T_LOAD(DATA_TYPE, N0, K0, BUFFER, rhs, 0, 0, 1, rhs_stride_y, b);
476
477 T_MMUL(DATA_TYPE, DATA_TYPE, int, M0, N0, K0, NT, T, a, b, acc);
478
479 LOOP_UNROLLING(int, i, 0, 1, K0,
480 {
481 LOOP_UNROLLING(int, j, 0, 1, M0,
482 {
483 a_sum[0].s[j] += (int)a[j].s[i];
484 })
485 })
486
487 LOOP_UNROLLING(int, i, 0, 1, N0,
488 {
489 LOOP_UNROLLING(int, j, 0, 1, K0,
490 {
491 b_sum[0].s[i] += (int)b[i].s[j];
492 })
493 })
494
495 lhs_offset_first_element_in_bytes += K0 * lhs_stride_y;
496 rhs_offset_first_element_in_bytes += K0 * sizeof(DATA_TYPE);
497 }
498
499#if((K % K0) != 0)
500 /* Leftover Loop */
501 for(; k < K; ++k)
502 {
503 TILE(DATA_TYPE, M0, 1, a);
504 TILE(DATA_TYPE, N0, 1, b);
505
506 LOOP_UNROLLING(int, i, 0, 1, M0,
507 {
508 a[i].v = 0;
509 })
510
511 LOOP_UNROLLING(int, i, 0, 1, N0,
512 {
513 b[i].v = 0;
514 })
515
516 // Load tile from the lhs tensor in a transposed fashion
517 // see mat_mul_native_quantized_nt_nt main loop for more explanation
518 T_LOAD_TRANSPOSED(DATA_TYPE, 1, M0, BUFFER, lhs, 0, 0, 1, lhs_stride_y, a);
519
520 // Load tile from the rhs tensor
521 T_LOAD(DATA_TYPE, N0, 1, BUFFER, rhs, 0, 0, 1, rhs_stride_y, b);
522
523 T_MMUL(DATA_TYPE, DATA_TYPE, int, M0, N0, 1, NT, T, a, b, acc);
524
525 LOOP_UNROLLING(int, i, 0, 1, 1,
526 {
527 LOOP_UNROLLING(int, j, 0, 1, M0,
528 {
529 a_sum[0].s[j] += (int)a[j].s[i];
530 })
531 })
532
533 LOOP_UNROLLING(int, i, 0, 1, N0,
534 {
535 LOOP_UNROLLING(int, j, 0, 1, 1,
536 {
537 b_sum[0].s[i] += (int)b[i].s[j];
538 })
539 })
540
541 lhs_offset_first_element_in_bytes += 1 * lhs_stride_y;
542 rhs_offset_first_element_in_bytes += 1 * sizeof(DATA_TYPE);
543 }
544#endif // ((K % K0) != 0)
545
546 LOOP_UNROLLING(int, i, 0, 1, M0,
547 {
548 LOOP_UNROLLING(int, j, 0, 1, N0,
549 {
550 acc[i].s[j] += ((int)RHS_OFFSET) * a_sum[0].s[i] + ((int)(LHS_OFFSET)) * b_sum[0].s[j];
551 })
552 })
553
554 const bool x_cond = PARTIAL_STORE_N0 != 0 && get_global_id(0) == 0;
555 const bool y_cond = PARTIAL_STORE_M0 != 0 && get_global_id(1) == 0;
556
557 // Quantize the tile
558 TILE(DATA_TYPE, M0, N0, accq);
559 T_QUANTIZE8_ASYMMETRIC(int, DATA_TYPE, M0, N0, DST_OFFSET, DST_SHIFT, DST_MULTIPLIER, acc, accq);
560
561 TILE(int, M0, 1, indirect_buffer);
562 LOOP_UNROLLING(int, _i, 0, 1, M0,
563 {
564 indirect_buffer[_i].v = min(_i, select(M0 - 1, PARTIAL_STORE_M0 - 1, y_cond));
565 });
566
567 T_STORE_INDIRECT_WIDTH_SELECT(DATA_TYPE, M0, N0, PARTIAL_STORE_N0, BUFFER, dst, 0, dst_stride_y, x_cond, accq, indirect_buffer);
568}
569#endif // defined(MAT_MUL_NATIVE_QUANTIZED_T_T)