blob: fdfb75d39ccae744d41f2929be38f8af6a667094 [file] [log] [blame]
Gunes Bayire87fa662023-09-07 12:20:33 +01001/*
2 * Copyright (c) 2023 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "activation_float_helpers.h"
25#include "helpers.h"
26#include "tile_helpers.h"
27
28#ifdef BIAS
29// This function performs in-place bias addition for integer datatype when bias is enabled.
30// Note The tile's dimensions used for the LHS and RHS matrices (M0, N0) must be passed at compile time using -DN0, -DM0 (e.g. -DN0=8, -DM0=4).
31inline void perform_bias_addition(uchar *bias_ptr, uint bias_offset_first_element_in_bytes, TILE(int, M0, N0, acc), uint x)
32{
33 TILE(int, 1, N0, bias_tile);
34
35 // below expands to use bias_ptr and bias_offset_first_element_in_bytes
36 T_LOAD(int, 1, N0, BUFFER, bias, x, 0, 1, 0, bias_tile);
37
38 // c = c + bias[broadcasted]
39 T_ELTWISE_BROADCAST_ADD_X(int, M0, N0, acc, bias_tile, acc);
40}
41#endif // defined(BIAS)
42
Gunes Bayira116cd32023-09-13 11:59:34 +010043#define MMUL_BLOCK_SIZE (MMUL_M0 * MMUL_N0) // MMUL block size for the output matrix
44
Gunes Bayire87fa662023-09-07 12:20:33 +010045#if defined(MAT_MUL_NATIVE_QUANTIZED_MMUL_NT_NT)
46/** This OpenCL kernel performs the batch matrix multiplication (BatchMatMul): LHS non-transposed, RHS non-transposed - buffer only
47 *
Gunes Bayira116cd32023-09-13 11:59:34 +010048 * @note the "batch" here expresses the number of matrix multiplications to run in parallel. However, it
49 * should NOT be confused with the batch size of the model. For NHWC the "batch" is the "H" dimension
50 * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=uchar)
51 * @note The block's dimensions used for the LHS and RHS matrices (M0, N0 and K0) must be passed at
52 * compile time using -DN0, -DM0 and -DK0 (e.g. -DN0=8, -DM0=4, -DK0=4).
53 * @note The number of leftover outputs rows/columns must be passed using -DN0_LEFTOVER and -DM0_LEFTOVER
54 * (e.g. -DN0_LEFTOVER=2, -DM0_LEFTOVER=3)
55 * @note The dimensions M, N, K must be passed at compile time using -DK (e.g. -DM=5, -DN=8, -DK=6).
56 * K must be a multiple of 16.
57 * @note MMUL block sizes must be passed at compile time using -DMMUL_K0, -DMMUL_M0, -DMMUL_N0
58 * (e.g. -DMMUL_K0=16, -DMMUL_M0=4, -DMMUL_N0=4)
59 * @note If there is bias -DBIAS option must be passed at compile time
60 * @note Quantization offsets of lhs, rhs and dst tensors must be passed at compile time using -DLHS_OFFSET,
61 * -DRHS_OFFSET, -DDST_OFFSET (e.g. -DLHS_OFFSET=10, -DRHS_OFFSET=0, -DDST_OFFSET=-6)
62 * @note Effective quantization multiplier and shift for the destination tensor must be passed at compile time using
63 * -DDST_MULTIPLIER and -DDST_SHIFT (e.g. -DDST_MULTIPLIER=2091, -DST_SHIFT=8)
64 * @note The kernel name in uppercase must be passed at compile time (e.g. -DMAT_MUL_NATIVE_QUANTIZED_MMUL_NT_NT)
65 * @note Only the following configurations of M0, N0 and K0 are currently supported:
66 * - M0 > 0
67 * - N0 = 1, 2, 3, 4, 8, 16
68 * - K0 = 4
69 * @note For a generic view on how the MMUL works, see mat_mul_mmul.cl
Gunes Bayire87fa662023-09-07 12:20:33 +010070 *
71 * @param[in] lhs_ptr Pointer to the lhs matrix. Supported data types: QASYMM8_SIGNED/QASYMM8
72 * @param[in] lhs_stride_y Stride of the lhs matrix in Y (2nd) dimension (in bytes)
73 * @param[in] lhs_stride_z Stride of the lhs tensor in Z (3rd) dimension (in bytes)
74 * @param[in] lhs_w The width of the lhs tensor
75 * @param[in] lhs_h The height of the lhs tensor
76 * @param[in] lhs_n Number of the matrices (buffers) in the batch
77 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the lhs matrix
78 * @param[in] rhs_ptr Pointer to the rhs matrix. Supported data types: same as @p lhs_ptr
79 * @param[in] rhs_stride_y Stride of the rhs matrix in Y (2nd) dimension (in bytes)
80 * @param[in] rhs_stride_z Stride of the rhs tensor in Z (3rd) dimension (in bytes)
81 * @param[in] rhs_w The width of the rhs tensor
82 * @param[in] rhs_h The height of the rhs tensor
83 * @param[in] rhs_n Number of the matrices (buffers) in the batch
84 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the rhs matrix
Gunes Bayira116cd32023-09-13 11:59:34 +010085 * @param[in] bias_ptr (Optional) Pointer to the bias tensor. Supported data type: S32
Gunes Bayire87fa662023-09-07 12:20:33 +010086 * @param[in] bias_stride_y (Optional) Stride of the bias tensor in Y dimension (in bytes)
87 * @param[in] bias_stride_z (Optional) Stride of the bias tensor in Z dimension (in bytes)
88 * @param[in] bias_w (Optional) The size of the width dimension of the bias tensor
89 * @param[in] bias_h (Optional) The size of the height dimension of the bias tensor
90 * @param[in] bias_n (Optional) The size of the depth dimension of the bias tensor
91 * @param[in] bias_offset_first_element_in_bytes (Optional) The offset of the first element in the bias tensor
92 * @param[out] dst_ptr Pointer to the dst matrix. Supported data types: same as @p lhs_ptr
93 * @param[in] dst_stride_y Stride of the dst matrix in Y (2nd) dimension (in bytes)
94 * @param[in] dst_stride_z Stride of the dst tensor in Z (3rd) dimension (in bytes)
95 * @param[in] dst_w The width of the dst tensor
96 * @param[in] dst_h The height of the dst tensor
97 * @param[in] dst_n Number of the matrices (buffers) in the batch
98 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the dst matrix
99 */
100__kernel void mat_mul_native_quantized_mmul_nt_nt(
101 TENSOR3D_T(lhs, BUFFER),
102 TENSOR3D_T(rhs, BUFFER),
103#ifdef BIAS
104 TENSOR3D_T(bias, BUFFER),
105#endif // defined(BIAS)
106 TENSOR3D_T(dst, BUFFER))
107{
Gunes Bayira116cd32023-09-13 11:59:34 +0100108 // The explanation of how this kernel works is very similar to the explanation given in
109 // mat_mul_mmul.cl. The MMUL logic, and terminology is the same. The only difference is
110 // in quantization multiplication, the MMUL block sizes are (4 x 16) for Lhs matrix and
111 // (16 x 4) for Rhs matrix, resulting in (4 x 4) MMUL block size for the destination.
112 //
113 // Figures 1, 2 and 3 in the previous explanation works the same. Since the Lhs and Rhs
114 // MMUL block sizes are different in quantized extension, the thread access pattern is
115 // slightly different. We can redraw Figure 4 (Thread access pattern) as follows:
116 //
117 // (Modified Figure 4 from mat_mul_mmul.cl)
118 // Thread Access Layouts in LHS & RHS matrices
119 //
120 // LHS matrix
121 // 4 times 4 times 4 times 4 times
122 // _______________________________________________________________
123 // |T0_|T0_|T0_|T0_|T1_|T1_|T1_|T1_|T2_|T2_|T2_|T2_|T3_|T3_|T3_|T3_|
124 // |T0_| ... |
125 // M0 | . . |
126 // Times | . . |
127 // | . . |
128 // |T0_|T0_|T0_|T0_|T1_|T1_|T1_|T1_|T2_|T2_|T2_|T2_|T3_|T3_|T3_|T3_|
129 // |T4_|T4_|T4_|T4_|T5_|T5_|T5_|T5_|T6_|T6_|T6_|T6_|T7_|T7_|T7_|T7_|
130 // |T4_|T4_|T4_|T4_|T5_|T5_|T5_|T5_|T6_|T6_|T6_|T6_|T7_|T7_|T7_|T7_|
131 // M0 | . . |
132 // Times | . . |
133 // | . . |
134 // |T4_|T4_|T4_|T4_|T5_|T5_|T5_|T5_|T6_|T6_|T6_|T6_|T7_|T7_|T7_|T7_|
135 // |T8_|T8_|T8_|T8_|T9_|T9_|T9_|T9_|T10|T10|T10|T10|T11|T11|T11|T11|
136 // M0 | . |
137 // Times | . |
138 // | . |
139 // |T8_|T8_|T8_|T8_|T9_|T9_|T9_|T9_|T10|T10|T10|T10|T11|T11|T11|T11|
140 // M0 | . |
141 // Times | . |
142 // | . |
143 // |T12|T12|T12|T12|T13|T13|T13|T13|T14|T14|T14|T14|T15|T15|T15|T15|
144 //
145 //
146 // RHS Matrix
147 //
148 // __________N0 times______N0 times____________________N0 times_______
149 // |__T0__| ... |__T0__|__T1__| ... |__T1__| ... |__T3__| ... |__T3__|
150 // 4 times |__T0__| ... |__T0__|__T1__| ... |__T1__| ... |__T3__| ... |__T3__|
151 // |__T0__| ... |__T0__|__T1__| ... |__T1__| ... |__T3__| ... |__T3__|
152 // |__T0__| ... |__T0__|__T1__| ... |__T1__| ... |__T3__| ... |__T3__|
153 // |__T4__| ... |__T4__|__T5__| ... |__T5__| ... |__T7__| ... |__T7__|
154 // 4 times |__T4__| ... |__T4__|__T5__| ... |__T5__| ... |__T7__| ... |__T7__|
155 // |__T4__| ... |__T4__|__T5__| ... |__T5__| ... |__T7__| ... |__T7__|
156 // X |__T4__| ... |__T4__|__T5__| ... |__T5__| ... |__T7__| ... |__T7__|
157 // |__T8__| ... |__T8__|__T9__| ... |__T9__| ... |__T11_| ... |__T11_|
158 // |__T8__| ... |__T8__|__T9__| ... |__T9__| ... |__T11_| ... |__T11_|
159 // 4 times |__T8__| ... |__T8__|__T9__| ... |__T9__| ... |__T11_| ... |__T11_|
160 // |__T8__| ... |__T8__|__T9__| ... |__T9__| ... |__T11_| ... |__T11_|
161 // |__T12_| ... |__T12_|__T13_| ... |__T13_| ... |__T15_| ... |__T15_|
162 // 4 times |__T12_| ... |__T12_|__T13_| ... |__T13_| ... |__T15_| ... |__T15_|
163 // |__T12_| ... |__T12_|__T13_| ... |__T13_| ... |__T15_| ... |__T15_|
164 // |__T12_|_____|__T12_|__T13_|______|__T13_|_____|__T15_|_____|__T15_|
165 //
166 //
167 // The logic behind this thread access pattern is already descried in the explanation
168 // in mat_mul_mmul.cl. The only change is threads accesses are extended to 4 elements
169 // from 1, in rightward direction in Lhs, and in downward direction in Rhs, because they
170 // are now operating on 4 char/uchar's (again 32-bit data), instead of one 32-bit floating point.
171 //
172 // The mathematical view of the matrix multiplication explained in Figure 5 also holds for this,
173 // except the dimension 4 is 16 instead, but the vector notations do not change, i.e. it's as follows:
174 //
175 // Settings:
176 // - a 8 x 16 LHS section
177 // - 16 x 8 RHS section
178 // - Each vector variable ai, bj represent a 16x1 vector
179 // - ^T (superscript T) denotes transpose
180 // - M0 = N0 = 2
181 // - MMUL_N0 = MMUL_M0 = 4, MMUL_K0 = 16
182 //
183 //
184 // (Modified Figure 5)
185 // Mathematical view of the Matrix Multiplication
186 //
187 // LHS RHS DST
188 // [ a1^T ] [ b1 b2 b3 b4 b5 b6 b7 ] [ a1^Tb1 a1^Tb2 a1^Tb3 ... a1^Tb7 ]
189 // [ a2^T ] 16 x 8 [ a2^Tb1 a2^Tb2 a2^Tb3 ... a2^Tb7 ]
190 // [ a3^T ] [ ]
191 // [ a4^T ] = [ . . ]
192 // [ a5^T ] X [ . . ]
193 // [ a6^T ] [ . . ]
194 // [ a7^T ] [ ]
195 // [ a8^T ] [ a7^Tb1 a7^Tb2 a7^Tb3 ... a7^Tb7 ]
196 // 8 x 16 8 x 8
197 //
198 //
199 // For the first iteration, i.e. (m0, n0) = (0, 0), the arm_matrix_multiply would multiply the following matrices:
200 //
201 // [ a1^T ] [ b1 b3 b5 b7 ] [ a1^Tb1 a1^Tb3 a1^Tb5 a1^Tb7 ]
202 // [ a3^T ] x 4 x 4 = [ a3^Tb1 a1^Tb3 a1^Tb5 a1^Tb7 ]
203 // [ a5^T ] [ a5^Tb1 a1^Tb3 a1^Tb5 a1^Tb7 ]
204 // [ a7^T ] [ a7^Tb1 a7^Tb3 a7^Tb5 a7^Tb7 ]
205 // 4 x 4 4 x 4
206 // The elements calculated in the 4x4 output block are the "interleaved" elements in the DST above.
207 // When we follow for each combination of (m0, n0), every element of the DST matrix "section" is filled.
208 //
209 // Please refer to mat_mul_mmul.cl for more details.
210
211 const uint x0 = get_global_id(0); // [0, (N / N0) * MMUL_M0)
212 // The upper limit is a simplified version of (N / N0) / MMUL_N0) * MMUL_BLOCK_SIZE)
213 const uint y0 = get_global_id(1); // [0, (M / M0) / MMUL_M0)
214 const uint z = get_global_id(2); // Batch
215
216 // Get section coordinates
217 const uint section_x = (x0 / MMUL_BLOCK_SIZE);
218 const uint section_y = y0;
219
220 // Get thread coordinates within an mmul block
221 const uint thread_id = (x0 % MMUL_BLOCK_SIZE);
222 const uint thread_x = thread_id % MMUL_N0;
223 const uint thread_y = (thread_id / MMUL_N0);
224
225 // Calculate dst coordinates
226 const uint dst_x_unclamped = thread_x * N0 + section_x * N0 * MMUL_N0;
227 const uint dst_y_unclamped = thread_y * M0 + section_y * M0 * MMUL_M0;
228 const uint dst_x = min(dst_x_unclamped, (uint)(N - N0));
229 const uint dst_y = min(dst_y_unclamped, (uint)(M - M0));
230
231 // Starting LHS coordinates
232 const uint lhs_x = K0 * thread_x;
233 const uint lhs_y = dst_y;
234
235 // Starting RHS coordinates
236 const uint rhs_x = dst_x;
237 const uint rhs_y = K0 * thread_y;
238
239 // Compute LHS/RHS/DST matrix address
240 lhs_offset_first_element_in_bytes += lhs_x * sizeof(DATA_TYPE) + lhs_y * lhs_stride_y + z * lhs_stride_z;
241 rhs_offset_first_element_in_bytes += rhs_x * sizeof(DATA_TYPE) + rhs_y * rhs_stride_y + z * rhs_stride_z;
242 dst_offset_first_element_in_bytes += dst_x * sizeof(DATA_TYPE) + dst_y * dst_stride_y + z * dst_stride_z;
243
244 // Initialize the accumulators
245 TILE(int, M0, N0, c);
246 LOOP_UNROLLING(int, i, 0, 1, M0,
247 {
248 c[i].v = K * ((int)LHS_OFFSET) * ((int)RHS_OFFSET);
249 })
250
251 // Calculate row and column sums
252 TILE(int, 1, N0, b_sum);
253 b_sum[0].v = 0;
254
255 TILE(int, 1, M0, a_sum);
256 a_sum[0].v = 0;
257
258 VEC_DATA_TYPE(DATA_TYPE, K0)
259 vec_1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(1, 1, 1, 1);
260
261 for(int k = 0; k < lhs_w; k += MMUL_K0)
262 {
263 // A tile of M0xK0 but K0 must be set to K0
264 TILE(DATA_TYPE, M0, K0, a);
265 // A tile of K0xN0 but K0 must be set to K0
266 TILE(DATA_TYPE, K0, N0, b);
267
268 // Load tile from the lhs/rhs tensors
269 T_LOAD(DATA_TYPE, M0, K0, BUFFER, lhs, 0, 0, 1, lhs_stride_y, a);
270 T_LOAD(DATA_TYPE, K0, N0, BUFFER, rhs, 0, 0, 1, rhs_stride_y, b);
271
Gunes Bayir1f841a52023-09-19 17:57:29 +0100272 LOOP_UNROLLING(int, n0, 0, 1, N0,
Gunes Bayira116cd32023-09-13 11:59:34 +0100273 {
Gunes Bayir1f841a52023-09-19 17:57:29 +0100274 VEC_DATA_TYPE(DATA_TYPE, K0)
275 vec_b = (VEC_DATA_TYPE(DATA_TYPE, K0))(b[0].s[n0], b[1].s[n0], b[2].s[n0], b[3].s[n0]);
276
277 LOOP_UNROLLING(int, m0, 0, 1, M0,
Gunes Bayira116cd32023-09-13 11:59:34 +0100278 {
Gunes Bayira116cd32023-09-13 11:59:34 +0100279 c[m0].s[n0] = arm_matrix_multiply(a[m0].v, vec_b, c[m0].s[n0]);
280 })
Gunes Bayir1f841a52023-09-19 17:57:29 +0100281
282#if LHS_OFFSET != 0
283 // Column Sum of B: Calculate the sum of columns by multiplying B
284 // with a matrix of 1's from Left
285 b_sum[0].s[n0] = arm_matrix_multiply(vec_1, vec_b, b_sum[0].s[n0]);
286#endif // LHS_OFFSET != 0s
Gunes Bayira116cd32023-09-13 11:59:34 +0100287 })
288
289#if RHS_OFFSET != 0
290 // Row Sum of A: Calculate the sum of rows by multiplying A with
291 // a matrix of 1's from Right
292 LOOP_UNROLLING(int, m0, 0, 1, M0,
293 {
294 a_sum[0].s[m0] = arm_matrix_multiply(a[m0].v, vec_1, a_sum[0].s[m0]);
295 })
296#endif // RHS_OFFSET != 0
297
Gunes Bayira116cd32023-09-13 11:59:34 +0100298 lhs_offset_first_element_in_bytes += MMUL_K0 * sizeof(DATA_TYPE);
299 rhs_offset_first_element_in_bytes += MMUL_K0 * rhs_stride_y;
300 }
301
302 // Do not write if the coordinates are out of bound
303 // But, read has to happen as arm_matrix_multiply() expects certain number of calls
304 if(dst_x_unclamped >= N || dst_y_unclamped >= M)
305 {
306 return;
307 }
308
309#if RHS_OFFSET != 0 || LHS_OFFSET != 0
310 LOOP_UNROLLING(int, i, 0, 1, M0,
311 {
312 const int A = ((int)RHS_OFFSET) * a_sum[0].s[i];
313 LOOP_UNROLLING(int, j, 0, 1, N0,
314 {
315 c[i].s[j] -= A + ((int)(LHS_OFFSET)) * b_sum[0].s[j];
316 })
317 })
318#endif // RHS_OFFSET != 0 || LHS_OFFSET != 0
319
320#ifdef BIAS
321 perform_bias_addition(bias_ptr, bias_offset_first_element_in_bytes, c, dst_x);
322#endif // defined(BIAS)
323
324 // Quantize the tile
325 TILE(DATA_TYPE, M0, N0, cq);
326 T_QUANTIZE8_ASYMMETRIC(int, DATA_TYPE, M0, N0, DST_OFFSET, DST_SHIFT, DST_MULTIPLIER, c, cq);
327
328 if(dst_x + N0 <= N || N0_LEFTOVER == 0)
329 {
330 LOOP_UNROLLING(int, m0, 0, 1, M0,
331 {
332 if(dst_y + m0 < M || M0_LEFTOVER == 0)
333 {
334 VSTORE(N0)
335 (cq[m0].v, 0, (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + m0 * dst_stride_y));
336 }
337 })
338 }
339 else
340 {
341 LOOP_UNROLLING(int, m0, 0, 1, M0,
342 {
343 if(dst_y + m0 < M || M0_LEFTOVER == 0)
344 {
345 VSTORE_PARTIAL(N0, N0_LEFTOVER)
346 (cq[m0].v, 0, (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + m0 * dst_stride_y));
347 }
348 })
349 }
Gunes Bayire87fa662023-09-07 12:20:33 +0100350}
351#endif // defined(MAT_MUL_NATIVE_QUANTIZED_MMUL_NT_NT)
352
353#if defined(MAT_MUL_NATIVE_QUANTIZED_MMUL_NT_T)
354/** This OpenCL kernel performs the batch matrix multiplication (BatchMatMul): LHS non-transposed, RHS transposed - buffer only
355 *
356 * Supported block configurations:
Gunes Bayir2ad0a6b2023-09-19 15:37:38 +0100357 * - M0 > 0
358 * - N0 = 1, 2, 3, 4, 8, 16
359 * - K0 = 4
Gunes Bayire87fa662023-09-07 12:20:33 +0100360 *
361 * Similar to mat_mul_native_quantized_mmul_nt_nt()
362 */
363__kernel void mat_mul_native_quantized_mmul_nt_t(
364 TENSOR3D_T(lhs, BUFFER),
365 TENSOR3D_T(rhs, BUFFER),
366#ifdef BIAS
367 TENSOR3D_T(bias, BUFFER),
368#endif // defined(BIAS)
369 TENSOR3D_T(dst, BUFFER))
370{
Gunes Bayir2ad0a6b2023-09-19 15:37:38 +0100371 const uint x0 = get_global_id(0); // [0, (N / N0) * MMUL_M0)
372 // The upper limit is a simplified version of (N / N0) / MMUL_N0) * MMUL_BLOCK_SIZE)
373 const uint y0 = get_global_id(1); // [0, (M / M0) / MMUL_M0)
374 const uint z = get_global_id(2); // Batch
375
376 // Get section coordinates
377 const uint section_x = (x0 / MMUL_BLOCK_SIZE);
378 const uint section_y = y0;
379
380 // Get thread coordinates within an mmul block
381 const uint thread_id = (x0 % MMUL_BLOCK_SIZE);
382 const uint thread_x = thread_id % MMUL_N0;
383 const uint thread_y = (thread_id / MMUL_N0);
384
385 // Calculate dst coordinates
386 const uint dst_x_unclamped = thread_x * N0 + section_x * N0 * MMUL_N0;
387 const uint dst_y_unclamped = thread_y * M0 + section_y * M0 * MMUL_M0;
388 const uint dst_x = min(dst_x_unclamped, (uint)(N - N0));
389 const uint dst_y = min(dst_y_unclamped, (uint)(M - M0));
390
391 // Starting LHS coordinates
392 const uint lhs_x = K0 * thread_x;
393 const uint lhs_y = dst_y;
394
395 // Starting RHS coordinates
396 const uint rhs_x = K0 * thread_y;
397 const uint rhs_y = dst_x;
398
399 // Compute LHS/RHS/DST matrix address
400 lhs_offset_first_element_in_bytes += lhs_x * sizeof(DATA_TYPE) + lhs_y * lhs_stride_y + z * lhs_stride_z;
401 rhs_offset_first_element_in_bytes += rhs_x * sizeof(DATA_TYPE) + rhs_y * rhs_stride_y + z * rhs_stride_z;
402 dst_offset_first_element_in_bytes += dst_x * sizeof(DATA_TYPE) + dst_y * dst_stride_y + z * dst_stride_z;
403
404 // Initialize the accumulators
405 TILE(int, M0, N0, c);
406 LOOP_UNROLLING(int, i, 0, 1, M0,
407 {
408 c[i].v = K * ((int)LHS_OFFSET) * ((int)RHS_OFFSET);
409 })
410
411 // Calculate row and column sums
412 TILE(int, 1, N0, b_sum);
413 b_sum[0].v = 0;
414
415 TILE(int, 1, M0, a_sum);
416 a_sum[0].v = 0;
417
418 VEC_DATA_TYPE(DATA_TYPE, K0)
419 vec_1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(1, 1, 1, 1);
420
421 for(int k = 0; k < lhs_w; k += MMUL_K0)
422 {
423 // A tile of M0xK0 but K0 must be set to K0
424 TILE(DATA_TYPE, M0, K0, a);
425 // A tile of K0xN0 but K0 must be set to K0
426 TILE(DATA_TYPE, N0, K0, b);
427
428 // Load tile from the lhs/rhs tensors
429 T_LOAD(DATA_TYPE, M0, K0, BUFFER, lhs, 0, 0, 1, lhs_stride_y, a);
430 T_LOAD(DATA_TYPE, N0, K0, BUFFER, rhs, 0, 0, 1, rhs_stride_y, b);
431
432 LOOP_UNROLLING(int, m0, 0, 1, M0,
433 {
434 LOOP_UNROLLING(int, n0, 0, 1, N0,
435 {
436 c[m0].s[n0] = arm_matrix_multiply(a[m0].v, b[n0].v, c[m0].s[n0]);
437 })
438 })
439
440#if RHS_OFFSET != 0
441 // Row Sum of A: Calculate the sum of rows by multiplying A with
442 // a matrix of 1's from Right
443 LOOP_UNROLLING(int, m0, 0, 1, M0,
444 {
445 a_sum[0].s[m0] = arm_matrix_multiply(a[m0].v, vec_1, a_sum[0].s[m0]);
446 })
447#endif // RHS_OFFSET != 0
448
449#if LHS_OFFSET != 0
450 // Column Sum of B: Calculate the sum of columns by multiplying B
451 // with a matrix of 1's from Left
452 LOOP_UNROLLING(int, n0, 0, 1, N0,
453 {
454 b_sum[0].s[n0] = arm_matrix_multiply(vec_1, b[n0].v, b_sum[0].s[n0]);
455 })
456#endif // LHS_OFFSET != 0
457
458 lhs_offset_first_element_in_bytes += MMUL_K0 * sizeof(DATA_TYPE);
459 rhs_offset_first_element_in_bytes += MMUL_K0 * sizeof(DATA_TYPE);
460 }
461
462 // Do not write if the coordinates are out of bound
463 // But, read has to happen as arm_matrix_multiply() expects certain number of calls
464 if(dst_x_unclamped >= N || dst_y_unclamped >= M)
465 {
466 return;
467 }
468
469#if RHS_OFFSET != 0 || LHS_OFFSET != 0
470 LOOP_UNROLLING(int, i, 0, 1, M0,
471 {
472 const int A = ((int)RHS_OFFSET) * a_sum[0].s[i];
473 LOOP_UNROLLING(int, j, 0, 1, N0,
474 {
475 c[i].s[j] -= A + ((int)(LHS_OFFSET)) * b_sum[0].s[j];
476 })
477 })
478#endif // RHS_OFFSET != 0 || LHS_OFFSET != 0
479
480#ifdef BIAS
481 perform_bias_addition(bias_ptr, bias_offset_first_element_in_bytes, c, dst_x);
482#endif // defined(BIAS)
483
484 // Quantize the tile
485 TILE(DATA_TYPE, M0, N0, cq);
486 T_QUANTIZE8_ASYMMETRIC(int, DATA_TYPE, M0, N0, DST_OFFSET, DST_SHIFT, DST_MULTIPLIER, c, cq);
487
488 if(dst_x + N0 <= N || N0_LEFTOVER == 0)
489 {
490 LOOP_UNROLLING(int, m0, 0, 1, M0,
491 {
492 if(dst_y + m0 < M || M0_LEFTOVER == 0)
493 {
494 VSTORE(N0)
495 (cq[m0].v, 0, (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + m0 * dst_stride_y));
496 }
497 })
498 }
499 else
500 {
501 LOOP_UNROLLING(int, m0, 0, 1, M0,
502 {
503 if(dst_y + m0 < M || M0_LEFTOVER == 0)
504 {
505 VSTORE_PARTIAL(N0, N0_LEFTOVER)
506 (cq[m0].v, 0, (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + m0 * dst_stride_y));
507 }
508 })
509 }
Gunes Bayire87fa662023-09-07 12:20:33 +0100510}
511#endif // defined(MAT_MUL_NATIVE_QUANTIZED_MMUL_NT_T)
512
513#if defined(MAT_MUL_NATIVE_QUANTIZED_MMUL_T_NT)
514/** This OpenCL kernel performs the batch matrix multiplication (BatchMatMul): LHS transposed, RHS non-transposed
515 *
516 * Supported block configurations:
Gunes Bayira396da12023-09-20 10:09:43 +0100517 * - M0 = 1, 2, 3, 4, 8, 16
518 * - N0 = 1, 2, 3, 4, 8, 16
519 * - K0 = 4
Gunes Bayire87fa662023-09-07 12:20:33 +0100520 *
521 * Similar to mat_mul_native_quantized_mmul_nt_nt()
522 */
523__kernel void mat_mul_native_quantized_mmul_t_nt(
524 TENSOR3D_T(lhs, BUFFER),
525 TENSOR3D_T(rhs, BUFFER),
526#ifdef BIAS
527 TENSOR3D_T(bias, BUFFER),
528#endif // defined(BIAS)
529 TENSOR3D_T(dst, BUFFER))
530{
Gunes Bayira396da12023-09-20 10:09:43 +0100531 const uint x0 = get_global_id(0); // [0, (N / N0) * MMUL_M0)
532 // The upper limit is a simplified version of (N / N0) / MMUL_N0) * MMUL_BLOCK_SIZE)
533 const uint y0 = get_global_id(1); // [0, (M / M0) / MMUL_M0)
534 const uint z = get_global_id(2); // Batch
535
536 // Get section coordinates
537 const uint section_x = (x0 / MMUL_BLOCK_SIZE);
538 const uint section_y = y0;
539
540 // Get thread coordinates within an mmul block
541 const uint thread_id = (x0 % MMUL_BLOCK_SIZE);
542 const uint thread_x = thread_id % MMUL_N0;
543 const uint thread_y = (thread_id / MMUL_N0);
544
545 // Calculate dst coordinates
546 const uint dst_x_unclamped = thread_x * N0 + section_x * N0 * MMUL_N0;
547 const uint dst_y_unclamped = thread_y * M0 + section_y * M0 * MMUL_M0;
548 const uint dst_x = min(dst_x_unclamped, (uint)(N - N0));
549 const uint dst_y = min(dst_y_unclamped, (uint)(M - M0));
550
551 // Starting LHS coordinates
552 const uint lhs_x = dst_y;
553 const uint lhs_y = K0 * thread_x;
554
555 // Starting RHS coordinates
556 const uint rhs_x = dst_x;
557 const uint rhs_y = K0 * thread_y;
558
559 // Compute LHS/RHS/DST matrix address
560 lhs_offset_first_element_in_bytes += lhs_x * sizeof(DATA_TYPE) + lhs_y * lhs_stride_y + z * lhs_stride_z;
561 rhs_offset_first_element_in_bytes += rhs_x * sizeof(DATA_TYPE) + rhs_y * rhs_stride_y + z * rhs_stride_z;
562 dst_offset_first_element_in_bytes += dst_x * sizeof(DATA_TYPE) + dst_y * dst_stride_y + z * dst_stride_z;
563
564 // Initialize the accumulators
565 TILE(int, M0, N0, c);
566 LOOP_UNROLLING(int, i, 0, 1, M0,
567 {
568 c[i].v = K * ((int)LHS_OFFSET) * ((int)RHS_OFFSET);
569 })
570
571 // Calculate row and column sums
572 TILE(int, 1, N0, b_sum);
573 b_sum[0].v = 0;
574
575 TILE(int, 1, M0, a_sum);
576 a_sum[0].v = 0;
577
578 VEC_DATA_TYPE(DATA_TYPE, K0)
579 vec_1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(1, 1, 1, 1);
580
581 for(int k = 0; k < lhs_h; k += MMUL_K0)
582 {
583 TILE(DATA_TYPE, K0, M0, a);
584 TILE(DATA_TYPE, K0, N0, b);
585
586 // Load tile from the lhs/rhs tensors
587 T_LOAD(DATA_TYPE, K0, M0, BUFFER, lhs, 0, 0, 1, lhs_stride_y, a);
588 T_LOAD(DATA_TYPE, K0, N0, BUFFER, rhs, 0, 0, 1, rhs_stride_y, b);
589
590 LOOP_UNROLLING(int, m0, 0, 1, M0,
591 {
592 VEC_DATA_TYPE(DATA_TYPE, K0)
593 vec_a = (VEC_DATA_TYPE(DATA_TYPE, K0))(a[0].s[m0], a[1].s[m0], a[2].s[m0], a[3].s[m0]);
594
595 LOOP_UNROLLING(int, n0, 0, 1, N0,
596 {
597 VEC_DATA_TYPE(DATA_TYPE, K0)
598 vec_b = (VEC_DATA_TYPE(DATA_TYPE, K0))(b[0].s[n0], b[1].s[n0], b[2].s[n0], b[3].s[n0]);
599
600 c[m0].s[n0] = arm_matrix_multiply(vec_a, vec_b, c[m0].s[n0]);
601 })
602
603#if RHS_OFFSET != 0
604 // Row Sum of A: Calculate the sum of rows by multiplying A with
605 // a matrix of 1's from Right
606 a_sum[0].s[m0] = arm_matrix_multiply(vec_a, vec_1, a_sum[0].s[m0]);
607#endif // RHS_OFFSET != 0
608 })
609
610#if LHS_OFFSET != 0
611 // Column Sum of B: Calculate the sum of columns by multiplying B
612 // with a matrix of 1's from Left
613 LOOP_UNROLLING(int, n0, 0, 1, N0,
614 {
615 VEC_DATA_TYPE(DATA_TYPE, K0)
616 vec_b = (VEC_DATA_TYPE(DATA_TYPE, K0))(b[0].s[n0], b[1].s[n0], b[2].s[n0], b[3].s[n0]);
617
618 b_sum[0].s[n0] = arm_matrix_multiply(vec_1, vec_b, b_sum[0].s[n0]);
619 })
620#endif // LHS_OFFSET != 0
621
622 lhs_offset_first_element_in_bytes += MMUL_K0 * lhs_stride_y;
623 rhs_offset_first_element_in_bytes += MMUL_K0 * rhs_stride_y;
624 }
625
626 // Do not write if the coordinates are out of bound
627 // But, read has to happen as arm_matrix_multiply() expects certain number of calls
628 if(dst_x_unclamped >= N || dst_y_unclamped >= M)
629 {
630 return;
631 }
632
633#if RHS_OFFSET != 0 || LHS_OFFSET != 0
634 LOOP_UNROLLING(int, i, 0, 1, M0,
635 {
636 const int A = ((int)RHS_OFFSET) * a_sum[0].s[i];
637 LOOP_UNROLLING(int, j, 0, 1, N0,
638 {
639 c[i].s[j] -= A + ((int)(LHS_OFFSET)) * b_sum[0].s[j];
640 })
641 })
642#endif // RHS_OFFSET != 0 || LHS_OFFSET != 0
643
644#ifdef BIAS
645 perform_bias_addition(bias_ptr, bias_offset_first_element_in_bytes, c, dst_x);
646#endif // defined(BIAS)
647
648 // Quantize the tile
649 TILE(DATA_TYPE, M0, N0, cq);
650 T_QUANTIZE8_ASYMMETRIC(int, DATA_TYPE, M0, N0, DST_OFFSET, DST_SHIFT, DST_MULTIPLIER, c, cq);
651
652 if(dst_x + N0 <= N || N0_LEFTOVER == 0)
653 {
654 LOOP_UNROLLING(int, m0, 0, 1, M0,
655 {
656 if(dst_y + m0 < M || M0_LEFTOVER == 0)
657 {
658 VSTORE(N0)
659 (cq[m0].v, 0, (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + m0 * dst_stride_y));
660 }
661 })
662 }
663 else
664 {
665 LOOP_UNROLLING(int, m0, 0, 1, M0,
666 {
667 if(dst_y + m0 < M || M0_LEFTOVER == 0)
668 {
669 VSTORE_PARTIAL(N0, N0_LEFTOVER)
670 (cq[m0].v, 0, (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + m0 * dst_stride_y));
671 }
672 })
673 }
Gunes Bayire87fa662023-09-07 12:20:33 +0100674}
675#endif // defined(MAT_MUL_NATIVE_QUANTIZED_MMUL_T_NT)
676
677#if defined(MAT_MUL_NATIVE_QUANTIZED_MMUL_T_T)
678/** This OpenCL kernel performs the batch matrix multiplication (BatchMatMul): LHS transposed, RHS transposed
679 *
680 * Supported block configurations:
Gunes Bayira396da12023-09-20 10:09:43 +0100681 * - M0 = 1, 2, 3, 4, 8, 16
682 * - N0 = 1, 2, 3, 4, 8, 16
683 * - K0 = 4
Gunes Bayire87fa662023-09-07 12:20:33 +0100684 *
685 * Similar to mat_mul_native_quantized_mmul_nt_nt()
686 */
687__kernel void mat_mul_native_quantized_mmul_t_t(
688 TENSOR3D_T(lhs, BUFFER),
689 TENSOR3D_T(rhs, BUFFER),
690#ifdef BIAS
691 TENSOR3D_T(bias, BUFFER),
692#endif // defined(BIAS)
693 TENSOR3D_T(dst, BUFFER))
694{
Gunes Bayira396da12023-09-20 10:09:43 +0100695 const uint x0 = get_global_id(0); // [0, (N / N0) * MMUL_M0)
696 // The upper limit is a simplified version of (N / N0) / MMUL_N0) * MMUL_BLOCK_SIZE)
697 const uint y0 = get_global_id(1); // [0, (M / M0) / MMUL_M0)
698 const uint z = get_global_id(2); // Batch
699
700 // Get section coordinates
701 const uint section_x = (x0 / MMUL_BLOCK_SIZE);
702 const uint section_y = y0;
703
704 // Get thread coordinates within an mmul block
705 const uint thread_id = (x0 % MMUL_BLOCK_SIZE);
706 const uint thread_x = thread_id % MMUL_N0;
707 const uint thread_y = (thread_id / MMUL_N0);
708
709 // Calculate dst coordinates
710 const uint dst_x_unclamped = thread_x * N0 + section_x * N0 * MMUL_N0;
711 const uint dst_y_unclamped = thread_y * M0 + section_y * M0 * MMUL_M0;
712 const uint dst_x = min(dst_x_unclamped, (uint)(N - N0));
713 const uint dst_y = min(dst_y_unclamped, (uint)(M - M0));
714
715 // Starting LHS coordinates
716 const uint lhs_x = dst_y;
717 const uint lhs_y = K0 * thread_x;
718
719 // Starting RHS coordinates
720 const uint rhs_x = K0 * thread_y;
721 const uint rhs_y = dst_x;
722
723 // Compute LHS/RHS/DST matrix address
724 lhs_offset_first_element_in_bytes += lhs_x * sizeof(DATA_TYPE) + lhs_y * lhs_stride_y + z * lhs_stride_z;
725 rhs_offset_first_element_in_bytes += rhs_x * sizeof(DATA_TYPE) + rhs_y * rhs_stride_y + z * rhs_stride_z;
726 dst_offset_first_element_in_bytes += dst_x * sizeof(DATA_TYPE) + dst_y * dst_stride_y + z * dst_stride_z;
727
728 // Initialize the accumulators
729 TILE(int, M0, N0, c);
730 LOOP_UNROLLING(int, i, 0, 1, M0,
731 {
732 c[i].v = K * ((int)LHS_OFFSET) * ((int)RHS_OFFSET);
733 })
734
735 // Calculate row and column sums
736 TILE(int, 1, N0, b_sum);
737 b_sum[0].v = 0;
738
739 TILE(int, 1, M0, a_sum);
740 a_sum[0].v = 0;
741
742 VEC_DATA_TYPE(DATA_TYPE, K0)
743 vec_1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(1, 1, 1, 1);
744
745 for(int k = 0; k < lhs_h; k += MMUL_K0)
746 {
747 TILE(DATA_TYPE, K0, M0, a);
748 TILE(DATA_TYPE, N0, K0, b);
749
750 // Load tile from the lhs/rhs tensors
751 T_LOAD(DATA_TYPE, K0, M0, BUFFER, lhs, 0, 0, 1, lhs_stride_y, a);
752 T_LOAD(DATA_TYPE, N0, K0, BUFFER, rhs, 0, 0, 1, rhs_stride_y, b);
753
754 LOOP_UNROLLING(int, m0, 0, 1, M0,
755 {
756 VEC_DATA_TYPE(DATA_TYPE, K0)
757 vec_a = (VEC_DATA_TYPE(DATA_TYPE, K0))(a[0].s[m0], a[1].s[m0], a[2].s[m0], a[3].s[m0]);
758
759 LOOP_UNROLLING(int, n0, 0, 1, N0,
760 {
761 c[m0].s[n0] = arm_matrix_multiply(vec_a, b[n0].v, c[m0].s[n0]);
762 })
763#if RHS_OFFSET != 0
764 // Row Sum of A: Calculate the sum of rows by multiplying A with
765 // a matrix of 1's from Right
766 a_sum[0].s[m0] = arm_matrix_multiply(vec_a, vec_1, a_sum[0].s[m0]);
767#endif // RHS_OFFSET != 0
768 })
769
770#if LHS_OFFSET != 0
771 // Column Sum of B: Calculate the sum of columns by multiplying B
772 // with a matrix of 1's from Left
773 LOOP_UNROLLING(int, n0, 0, 1, N0,
774 {
775 b_sum[0].s[n0] = arm_matrix_multiply(vec_1, b[n0].v, b_sum[0].s[n0]);
776 })
777#endif // LHS_OFFSET != 0
778
779 lhs_offset_first_element_in_bytes += MMUL_K0 * lhs_stride_y;
780 rhs_offset_first_element_in_bytes += MMUL_K0 * sizeof(DATA_TYPE);
781 }
782
783 // Do not write if the coordinates are out of bound
784 // But, read has to happen as arm_matrix_multiply() expects certain number of calls
785 if(dst_x_unclamped >= N || dst_y_unclamped >= M)
786 {
787 return;
788 }
789
790#if RHS_OFFSET != 0 || LHS_OFFSET != 0
791 LOOP_UNROLLING(int, i, 0, 1, M0,
792 {
793 const int A = ((int)RHS_OFFSET) * a_sum[0].s[i];
794 LOOP_UNROLLING(int, j, 0, 1, N0,
795 {
796 c[i].s[j] -= A + ((int)(LHS_OFFSET)) * b_sum[0].s[j];
797 })
798 })
799#endif // RHS_OFFSET != 0 || LHS_OFFSET != 0
800
801#ifdef BIAS
802 perform_bias_addition(bias_ptr, bias_offset_first_element_in_bytes, c, dst_x);
803#endif // defined(BIAS)
804
805 // Quantize the tile
806 TILE(DATA_TYPE, M0, N0, cq);
807 T_QUANTIZE8_ASYMMETRIC(int, DATA_TYPE, M0, N0, DST_OFFSET, DST_SHIFT, DST_MULTIPLIER, c, cq);
808
809 if(dst_x + N0 <= N || N0_LEFTOVER == 0)
810 {
811 LOOP_UNROLLING(int, m0, 0, 1, M0,
812 {
813 if(dst_y + m0 < M || M0_LEFTOVER == 0)
814 {
815 VSTORE(N0)
816 (cq[m0].v, 0, (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + m0 * dst_stride_y));
817 }
818 })
819 }
820 else
821 {
822 LOOP_UNROLLING(int, m0, 0, 1, M0,
823 {
824 if(dst_y + m0 < M || M0_LEFTOVER == 0)
825 {
826 VSTORE_PARTIAL(N0, N0_LEFTOVER)
827 (cq[m0].v, 0, (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + m0 * dst_stride_y));
828 }
829 })
830 }
Gunes Bayire87fa662023-09-07 12:20:33 +0100831}
832#endif // defined(MAT_MUL_NATIVE_QUANTIZED_MMUL_T_T)