blob: fe2d103de5b7cace77af47d2130ff31ab4db92a6 [file] [log] [blame]
SiCongLiafa19722021-10-24 19:12:33 +01001/*
2 * Copyright (c) 2021 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "fp_post_ops_act_eltwise_op_act.h"
25#include "gemm_helpers.h"
26#include "repeat.h"
27
28/** (EXPERIMENTAL_POST_OPS) gemm_mm_reshaped_only_rhs kernel */
29#if defined(M0) && defined(N0) && defined(K0) && defined(H0) && defined(DATA_TYPE) && defined(M) && defined(N) && defined(K)
30#if defined(P2_ELTWISE_OP) && defined(P2_ELTWISE_ARG1_HEIGHT) && defined(P2_ELTWISE_ARG1_WIDTH)
31
32#define CONCAT(a, b) a##b
33
34#define ARM_DOT1(a, b, c) \
35 ({ \
36 c = fma(a, b, c); \
37 })
38#define ARM_DOT2(a, b, c) \
39 ({ \
40 c = fma(a.s0, b.s0, c); \
41 c = fma(a.s1, b.s1, c); \
42 })
43#define ARM_DOT3(a, b, c) \
44 ({ \
45 ARM_DOT2(a, b, c); \
46 c = fma((a.s2), (b.s2), c); \
47 })
48#define ARM_DOT4(a, b, c) \
49 ({ \
50 ARM_DOT3(a, b, c); \
51 c = fma((a.s3), (b.s3), c); \
52 })
53#define ARM_DOT8(a, b, c) \
54 ({ \
55 ARM_DOT4((a.lo), (b.lo), c); \
56 ARM_DOT4((a.hi), (b.hi), c); \
57 })
58#define ARM_DOT16(a, b, c) \
59 ({ \
60 ARM_DOT8((a.lo), (b.lo), c); \
61 ARM_DOT8((a.hi), (b.hi), c); \
62 })
63
64#if N0 == 2
65#define ARM_DOT_K0XN0(k0, a, b, c) \
66 ({ \
67 CONCAT(ARM_DOT, k0) \
68 ((a), (b##0), (c.s0)); \
69 CONCAT(ARM_DOT, k0) \
70 ((a), (b##1), (c.s1)); \
71 })
72#elif N0 == 3 // N0 == 3
73#define ARM_DOT_K0XN0(k0, a, b, c) \
74 ({ \
75 CONCAT(ARM_DOT, k0) \
76 ((a), (b##0), (c.s0)); \
77 CONCAT(ARM_DOT, k0) \
78 ((a), (b##1), (c.s1)); \
79 CONCAT(ARM_DOT, k0) \
80 ((a), (b##2), (c.s2)); \
81 })
82#elif N0 == 4 // N0 == 4
83#define ARM_DOT_K0XN0(k0, a, b, c) \
84 ({ \
85 CONCAT(ARM_DOT, k0) \
86 ((a), (b##0), (c.s0)); \
87 CONCAT(ARM_DOT, k0) \
88 ((a), (b##1), (c.s1)); \
89 CONCAT(ARM_DOT, k0) \
90 ((a), (b##2), (c.s2)); \
91 CONCAT(ARM_DOT, k0) \
92 ((a), (b##3), (c.s3)); \
93 })
94#elif N0 == 8 // N0 == 8
95#define ARM_DOT_K0XN0(k0, a, b, c) \
96 ({ \
97 CONCAT(ARM_DOT, k0) \
98 ((a), (b##0), (c.s0)); \
99 CONCAT(ARM_DOT, k0) \
100 ((a), (b##1), (c.s1)); \
101 CONCAT(ARM_DOT, k0) \
102 ((a), (b##2), (c.s2)); \
103 CONCAT(ARM_DOT, k0) \
104 ((a), (b##3), (c.s3)); \
105 CONCAT(ARM_DOT, k0) \
106 ((a), (b##4), (c.s4)); \
107 CONCAT(ARM_DOT, k0) \
108 ((a), (b##5), (c.s5)); \
109 CONCAT(ARM_DOT, k0) \
110 ((a), (b##6), (c.s6)); \
111 CONCAT(ARM_DOT, k0) \
112 ((a), (b##7), (c.s7)); \
113 })
114#elif N0 == 16 // N0 == 16
115#define ARM_DOT_K0XN0(k0, a, b, c) \
116 ({ \
117 CONCAT(ARM_DOT, k0) \
118 ((a), (b##0), (c.s0)); \
119 CONCAT(ARM_DOT, k0) \
120 ((a), (b##1), (c.s1)); \
121 CONCAT(ARM_DOT, k0) \
122 ((a), (b##2), (c.s2)); \
123 CONCAT(ARM_DOT, k0) \
124 ((a), (b##3), (c.s3)); \
125 CONCAT(ARM_DOT, k0) \
126 ((a), (b##4), (c.s4)); \
127 CONCAT(ARM_DOT, k0) \
128 ((a), (b##5), (c.s5)); \
129 CONCAT(ARM_DOT, k0) \
130 ((a), (b##6), (c.s6)); \
131 CONCAT(ARM_DOT, k0) \
132 ((a), (b##7), (c.s7)); \
133 CONCAT(ARM_DOT, k0) \
134 ((a), (b##8), (c.s8)); \
135 CONCAT(ARM_DOT, k0) \
136 ((a), (b##9), (c.s9)); \
137 CONCAT(ARM_DOT, k0) \
138 ((a), (b##A), (c.sA)); \
139 CONCAT(ARM_DOT, k0) \
140 ((a), (b##B), (c.sB)); \
141 CONCAT(ARM_DOT, k0) \
142 ((a), (b##C), (c.sC)); \
143 CONCAT(ARM_DOT, k0) \
144 ((a), (b##D), (c.sD)); \
145 CONCAT(ARM_DOT, k0) \
146 ((a), (b##E), (c.sE)); \
147 CONCAT(ARM_DOT, k0) \
148 ((a), (b##F), (c.sF)); \
149 })
150#else // N0 not supported
151#error "N0 value not supported"
152#endif // N0 conditions
153
154/** This OpenCL kernel computes the matrix multiplication between 2 matrices plus 3 post ops:
155 * Post op 1: activation (optional)
156 * Post op 2: elementwise op
157 * Post op 3: activation (optional)
158 *
159 * @note (Optional) -DP1_ACTIVATION_TYPE, -DP1_ACTIVATION_A_VAL, -DP1_ACTIVATION_B_VAL: The activation type, alpha and beta values of the activation post op at slot 3
160 * @note (Required) -DP2_ELTWISE_OP: The (binary) elementwise post op to perform
161 * @note (Required) -DP2_ELTWISE_ARG1_HEIGHT: The height (Y dimension) of the eltwise operand matrix of the eltwise post op at slot 2
162 * @note (Required) -DP2_ELTWISE_ARG1_WIDTH: The width (X dimension) of the eltwise operand matrix of the eltwise post op at slot 2
163 * @note (Optional) -DP3_ACTIVATION_TYPE, -DP3_ACTIVATION_A_VAL, -DP3_ACTIVATION_B_VAL: The activation type, alpha and beta values of the activation post op at slot 3
164 *
165 * All parameters are similarly defined in kernel gemm_mm_reshaped_only_rhs_t, with these additions:
166 *
167 * @param[in] eltwise_operand_ptr Pointer to the eltwise operand matrix. Supported data type: F16/F32
168 * @param[in] eltwise_operand_stride_x Stride of the eltwise operand matrix in X dimension (in bytes)
169 * @param[in] eltwise_operand_step_x eltwise_operand_stride_x * number of elements along X processed per workitem(in bytes)
170 * @param[in] eltwise_operand_stride_y Stride of the eltwise operand matrix in Y dimension (in bytes)
171 * @param[in] eltwise_operand_step_y eltwise_operand_stride_y * number of elements along Y processed per workitem(in bytes)
172 * @param[in] eltwise_operand_stride_z Stride of the eltwise operand tensor in Z dimension (in bytes)
173 */
174__kernel void gemm_mm_reshaped_only_rhs_t_post_act_eltwise_op_act(IMAGE_DECLARATION(lhs),
175 IMAGE_DECLARATION(rhs),
176#if defined(BETA)
177 IMAGE_DECLARATION(bias),
178#endif // defined(BETA)
179 IMAGE_DECLARATION(dst),
180 // Post-Op arguments
181 IMAGE_DECLARATION(eltwise_operand),
182 uint lhs_stride_z,
183 uint rhs_stride_z,
184#if defined(BETA)
185 uint bias_stride_z,
186#endif //defined(BETA)
187 uint dst_stride_z,
188 uint eltwise_operand_stride_z
189#if defined(REINTERPRET_INPUT_AS_3D)
190 ,
191 uint lhs_cross_plane_pad
192#endif // REINTERPRET_INPUT_AS_3D
193#if defined(REINTERPRET_OUTPUT_AS_3D)
194 ,
195 uint dst_cross_plane_pad
196#endif // REINTERPRET_OUTPUT_AS_3D
197 )
198{
199 // Block size
200#define RHS_BLOCK_SIZE ((K0) * (N0))
201
202 // RHS offset and step X
203#if defined(RHS_INTERLEAVE)
204#define RHS_OFFSET_X (K0)
205#define RHS_STEP_X ((K0) * (H0))
206#define RHS_STEP_LOOP (1)
207#else // defined(RHS_INTERLEAVE)
208#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
209#define RHS_STEP_X (K0)
210#define RHS_STEP_LOOP (H0)
211#endif // defined(RHS_INTERLEAVE)
212
213 uint x = get_global_id(0);
214 uint y = get_global_id(1);
215 uint z = get_global_id(2);
216
SiCongLi71cbd282021-11-03 12:17:06 +0000217 const bool cond_y = y == 0;
218 const bool cond_x = ((x + 1) * N0 >= N);
219
SiCongLiafa19722021-10-24 19:12:33 +0100220#if defined(DUMMY_WORK_ITEMS)
221 if((x * N0 >= N) || (y * M0 >= M))
222 {
223 return;
224 }
225#endif // defined(DUMMY_WORK_ITEMS)
226
227 // Compute LHS matrix address
SiCongLi71cbd282021-11-03 12:17:06 +0000228 uint lhs_offset = lhs_offset_first_element_in_bytes + COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) * (uint)lhs_stride_y;
SiCongLiafa19722021-10-24 19:12:33 +0100229
230 // Compute RHS reshaped matrix address
231 uint rhs_offset = rhs_offset_first_element_in_bytes + (x % H0) * (uint)RHS_OFFSET_X * sizeof(DATA_TYPE) + (x / (uint)H0) * rhs_stride_y;
232
233#if defined(MATRIX_B_DEPTH)
234 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
235 rhs_offset += (z % MATRIX_B_DEPTH) * rhs_stride_z;
236#else // defined(MATRIX_B_DEPTH)
237 rhs_offset += z * rhs_stride_z;
238#endif // defined(MATRIX_B_DEPTH)
239
240 REPEAT_VAR_INIT_TO_CONST(8, uint, zlhs, 0); //uint zlhs0=0,zlhs1=0,zlhs2=0,... zlhs7=0;
241 REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0);
242
243#if defined(REINTERPRET_INPUT_AS_3D)
244 // The plane (zlhs) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
SiCongLi71cbd282021-11-03 12:17:06 +0000245 CALCULATE_Z_OFFSET(M0, uint, zlhs, COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0), HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
SiCongLiafa19722021-10-24 19:12:33 +0100246
247 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
248 // multiply lhs_stride_z by DEPTH_GEMM3D
249 lhs_offset += z * lhs_stride_z * DEPTH_GEMM3D;
250
251#else // defined(REINTERPRET_INPUT_AS_3D)
252
253 // Add offset for batched GEMM
254 lhs_offset += z * lhs_stride_z;
255
256#endif // defined(REINTERPRET_INPUT_AS_3D)
257
258 // Initialize the accumulators
259 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(M0-1)=0;
260
261 int i = 0;
262 for(; i <= (K - K0); i += K0)
263 {
264 // Supported cases (M0, K0):
265 // 1,2 - 1,3 - 1,4 - 1,8 - 1,16
266 // 2,2 - 2,3 - 2,4 - 2,8 - 2,16
267 // 3,2 - 3,3 - 3,4 - 3,8 - 3,16
268 // 4,2 - 4,3 - 4,4 - 4,8 - 4,16
269 // 5,2 - 5,3 - 5,4 - 5,8 - 5,16
270 // 6,2 - 6,3 - 6,4 - 6,8 - 6,16
271 // 7,2 - 7,3 - 7,4 - 7,8 - 7,16
272 // 8,2 - 8,3 - 8,4 - 8,8 - 8,16
273 // Load values from LHS matrix
274 LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs);
275
276 // Load values from RHS reshaped matrix
277 LOAD_BLOCK(N0, K0, DATA_TYPE, b, rhs_ptr, rhs_offset, RHS_STEP_X * sizeof(DATA_TYPE), zero);
278
279 // Accumulate
280 ARM_DOT_K0XN0(K0, a0, b, c0);
281#if M0 > 1
282 ARM_DOT_K0XN0(K0, a1, b, c1);
283#endif // M0 > 1
284#if M0 > 2
285 ARM_DOT_K0XN0(K0, a2, b, c2);
286#endif // M0 > 2
287#if M0 > 3
288 ARM_DOT_K0XN0(K0, a3, b, c3);
289#endif // M0 > 3
290#if M0 > 4
291 ARM_DOT_K0XN0(K0, a4, b, c4);
292#endif // M0 > 4
293#if M0 > 5
294 ARM_DOT_K0XN0(K0, a5, b, c5);
295#endif // M0 > 5
296#if M0 > 6
297 ARM_DOT_K0XN0(K0, a6, b, c6);
298#endif // M0 > 6
299#if M0 > 7
300 ARM_DOT_K0XN0(K0, a7, b, c7);
301#endif // M0 > 7
302
303 lhs_offset += K0 * sizeof(DATA_TYPE);
304 rhs_offset += (N0 * RHS_STEP_X * RHS_STEP_LOOP) * sizeof(DATA_TYPE);
305 }
306
307 // Left-over accumulations
308 for(; i < K; ++i)
309 {
310 // Load values from LHS matrix
311 LOAD_BLOCK(M0, 1, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs);
312
313 // Load values from RHS reshaped matrix
314 LOAD_BLOCK(N0, 1, DATA_TYPE, b, rhs_ptr, rhs_offset, RHS_STEP_X * sizeof(DATA_TYPE), zero);
315
316 // Accumulate
317 ARM_DOT_K0XN0(1, a0, b, c0);
318#if M0 > 1
319 ARM_DOT_K0XN0(1, a1, b, c1);
320#endif // M0 > 1
321#if M0 > 2
322 ARM_DOT_K0XN0(1, a2, b, c2);
323#endif // M0 > 2
324#if M0 > 3
325 ARM_DOT_K0XN0(1, a3, b, c3);
326#endif // M0 > 3
327#if M0 > 4
328 ARM_DOT_K0XN0(1, a4, b, c4);
329#endif // M0 > 4
330#if M0 > 5
331 ARM_DOT_K0XN0(1, a5, b, c5);
332#endif // M0 > 5
333#if M0 > 6
334 ARM_DOT_K0XN0(1, a6, b, c6);
335#endif // M0 > 6
336#if M0 > 7
337 ARM_DOT_K0XN0(1, a7, b, c7);
338#endif // M0 > 7
339
340 lhs_offset += sizeof(DATA_TYPE);
341 rhs_offset += sizeof(DATA_TYPE);
342 }
343
SiCongLi71cbd282021-11-03 12:17:06 +0000344 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) * dst_stride_y);
SiCongLiafa19722021-10-24 19:12:33 +0100345
346 REPEAT_VAR_INIT_TO_CONST(8, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
347
SiCongLiafa19722021-10-24 19:12:33 +0100348#if defined(REINTERPRET_OUTPUT_AS_3D)
349
350 // The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
SiCongLi71cbd282021-11-03 12:17:06 +0000351 CALCULATE_Z_OFFSET(M0, uint, zout, COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0), HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
SiCongLiafa19722021-10-24 19:12:33 +0100352
353 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
354 // multiply dst_stride_z by DEPTH_GEMM3D
355 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
356
357#else // defined(REINTERPRET_OUTPUT_AS_3D)
358
359 // Add offset for batched GEMM
360 dst_addr += z * dst_stride_z;
361
362#endif // defined(REINTERPRET_OUTPUT_AS_3D)
363
364 // Multiply by the weight of matrix-matrix product and store the result
365#if defined(ALPHA)
366 SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
367#endif // defined(ALPHA)
368
369 // Add beta*bias
370#if defined(BETA)
371#if defined(BROADCAST_BIAS)
372 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
373
374 LOAD_BLOCK_BOUNDARY_AWARE(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero, 1, PARTIAL_STORE_N0, false, cond_x);
375
376#ifndef UNIT_BETA
377 SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
378#endif // UNIT_BIAS
379
380 // c = c + bias[broadcasted]
381 ADD_BLOCK_BROADCAST(M0, c, bias0);
382
383#else // defined(BROADCAST_BIAS)
SiCongLi71cbd282021-11-03 12:17:06 +0000384 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) * bias_stride_y) + z * bias_stride_z;
SiCongLiafa19722021-10-24 19:12:33 +0100385
386 LOAD_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
387
388#ifndef UNIT_BETA
389 SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
390#endif // UNIT_BIAS
391
392 // c = c + bias
393 ADD_BLOCK(M0, c, bias);
394
395#endif // defined(BROADCAST_BIAS)
396#endif // defined(BETA)
397
398 // c = act(c)
399 POST_OP1_ACTIVATION_OPTIONAL(M0, DATA_TYPE, DATA_TYPE_ACCUMULATOR, N0, c);
400 // c = c + eltwise_operand (mix-precision, broadcast, boundary aware)
SiCongLi71cbd282021-11-03 12:17:06 +0000401 POST_OP2_ELTWISE_OP(P2_ELTWISE_OP, M0, N0, c, eltwise_operand, COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0), DATA_TYPE, DATA_TYPE_ACCUMULATOR, zero, 1, PARTIAL_STORE_N0, false, cond_x);
SiCongLiafa19722021-10-24 19:12:33 +0100402 // c = act(c)
403 POST_OP3_ACTIVATION_OPTIONAL(M0, DATA_TYPE, DATA_TYPE_ACCUMULATOR, N0, c);
404
405 // Store output block
406 STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
407
408#undef RHS_BLOCK_SIZE
409#undef RHS_OFFSET_X
410#undef RHS_STEP_X
411}
412
413#if defined(OPENCL_IMAGE_SUPPORT)
414/** This OpenCL kernel computes the matrix multiplication between 2 matrices plus 3 post ops. The RHS matrix is stored in OpenCL image object.
415 * Post op 1: activation (optional)
416 * Post op 2: elementwise op
417 * Post op 3: activation (optional)
418 *
419 * @note (Optional) -DP1_ACTIVATION_TYPE, -DP1_ACTIVATION_A_VAL, -DP1_ACTIVATION_B_VAL: The activation type, alpha and beta values of the activation post op at slot 3
420 * @note (Required) -DP2_ELTWISE_OP: The (binary) elementwise post op to perform
421 * @note (Required) -DP2_ELTWISE_ARG1_HEIGHT: The height (Y dimension) of the eltwise operand matrix of the eltwise post op at slot 2
422 * @note (Required) -DP2_ELTWISE_ARG1_WIDTH: The width (X dimension) of the eltwise operand matrix of the eltwise post op at slot 2
423 * @note (Optional) -DP3_ACTIVATION_TYPE, -DP3_ACTIVATION_A_VAL, -DP3_ACTIVATION_B_VAL: The activation type, alpha and beta values of the activation post op at slot 3
424 *
425 * All parameters are similarly defined in kernel gemm_mm_reshaped_only_rhs_t_texture, with these additions:
426 *
427 * @param[in] eltwise_operand_ptr Pointer to the eltwise operand matrix. Supported data type: F16/F32
428 * @param[in] eltwise_operand_stride_x Stride of the eltwise operand matrix in X dimension (in bytes)
429 * @param[in] eltwise_operand_step_x eltwise_operand_stride_x * number of elements along X processed per workitem(in bytes)
430 * @param[in] eltwise_operand_stride_y Stride of the eltwise operand matrix in Y dimension (in bytes)
431 * @param[in] eltwise_operand_step_y eltwise_operand_stride_y * number of elements along Y processed per workitem(in bytes)
432 * @param[in] eltwise_operand_stride_z Stride of the eltwise operand tensor in Z dimension (in bytes)
433 */
434__kernel void gemm_mm_reshaped_only_rhs_t_texture_post_act_eltwise_op_act(IMAGE_DECLARATION(lhs),
435 __read_only image2d_t rhs_img,
436#if defined(BETA)
437 IMAGE_DECLARATION(bias),
438#endif // defined(BETA)
439 IMAGE_DECLARATION(dst),
440 // Post-Op arguments
441 IMAGE_DECLARATION(eltwise_operand),
442 uint lhs_stride_z,
443 uint rhs_stride_z,
444#if defined(BETA)
445 uint bias_stride_z,
446#endif //defined(BETA)
447 uint dst_stride_z,
448 uint eltwise_operand_stride_z
449#if defined(REINTERPRET_INPUT_AS_3D)
450 ,
451 uint lhs_cross_plane_pad
452#endif // REINTERPRET_INPUT_AS_3D
453#if defined(REINTERPRET_OUTPUT_AS_3D)
454 ,
455 uint dst_cross_plane_pad
456#endif // REINTERPRET_OUTPUT_AS_3D
457 )
458{
459 // Pixel unit
460#define PIXEL_UNIT CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT(K0)
461
462#define LEFTOVER_K (K % K0)
463
464 // Block size
465#define RHS_BLOCK_SIZE (PIXEL_UNIT * (N0))
466
467 // RHS offset and step X
468#if defined(RHS_INTERLEAVE)
469#define RHS_OFFSET_X (PIXEL_UNIT)
470#define RHS_STEP_X (PIXEL_UNIT * (H0))
471#define RHS_STEP_LOOP (1)
472#else // defined(RHS_INTERLEAVE)
473#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
474#define RHS_STEP_X PIXEL_UNIT
475#define RHS_STEP_LOOP (H0)
476#endif // defined(RHS_INTERLEAVE)
477
478 uint x = get_global_id(0);
479 uint y = get_global_id(1);
480 uint z = get_global_id(2);
481
SiCongLi71cbd282021-11-03 12:17:06 +0000482 const bool cond_y = y == 0;
483 const bool cond_x = ((x + 1) * N0 >= N);
484
SiCongLiafa19722021-10-24 19:12:33 +0100485#if defined(DUMMY_WORK_ITEMS)
486 if((x * N0 >= N) || (y * M0 >= M))
487 {
488 return;
489 }
490#endif // defined(DUMMY_WORK_ITEMS)
491
492 // Compute LHS matrix address
SiCongLi71cbd282021-11-03 12:17:06 +0000493 uint lhs_offset = lhs_offset_first_element_in_bytes + COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) * (uint)lhs_stride_y;
SiCongLiafa19722021-10-24 19:12:33 +0100494
495#if defined(MATRIX_B_DEPTH)
496 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
497 const uint z_rhs = (get_global_id(2) % MATRIX_B_DEPTH);
498#else // defined(MATRIX_B_DEPTH)
499 const uint z_rhs = get_global_id(2);
500#endif // defined(MATRIX_B_DEPTH)
501
502 // Compute RHS matrix coordinates
503 uint x_rhs = (get_global_id(0) % H0) * (uint)RHS_OFFSET_X;
504 const uint y_rhs = (get_global_id(0) / (uint)H0) + z_rhs * RHS_HEIGHT;
505
506 REPEAT_VAR_INIT_TO_CONST(M0, uint, zlhs, 0);
507 REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0);
508
509#if defined(REINTERPRET_INPUT_AS_3D)
510 // The plane (zlhs) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
SiCongLi71cbd282021-11-03 12:17:06 +0000511 CALCULATE_Z_OFFSET(M0, uint, zlhs, COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0), HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
SiCongLiafa19722021-10-24 19:12:33 +0100512
513 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
514 // multiply lhs_stride_z by DEPTH_GEMM3D
515 lhs_offset += z * lhs_stride_z * DEPTH_GEMM3D;
516
517#else // defined(REINTERPRET_INPUT_AS_3D)
518
519 // Add offset for batched GEMM
520 lhs_offset += z * lhs_stride_z;
521
522#endif // defined(REINTERPRET_INPUT_AS_3D)
523
524 // Initialize the accumulators
525 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0);
526
527 int i = 0;
528 for(; i <= (K - K0); i += K0)
529 {
530 // Load values from LHS matrix
531 LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs);
532
533 // Load values from RHS matrix stored in a cl_image
534 REPEAT_VAR_INIT_TO_CONST(N0, VEC_DATA_TYPE(DATA_TYPE, K0), b, 0);
535 LOAD_TEXTURE2D(N0, PIXEL_UNIT, DATA_TYPE, b, rhs_img, x_rhs, y_rhs, RHS_STEP_X, 0);
536
537 // Accumulate
538 ARM_DOT_K0XN0(K0, a0, b, c0);
539#if M0 > 1
540 ARM_DOT_K0XN0(K0, a1, b, c1);
541#endif // M0 > 1
542#if M0 > 2
543 ARM_DOT_K0XN0(K0, a2, b, c2);
544#endif // M0 > 2
545#if M0 > 3
546 ARM_DOT_K0XN0(K0, a3, b, c3);
547#endif // M0 > 3
548#if M0 > 4
549 ARM_DOT_K0XN0(K0, a4, b, c4);
550#endif // M0 > 4
551#if M0 > 5
552 ARM_DOT_K0XN0(K0, a5, b, c5);
553#endif // M0 > 5
554#if M0 > 6
555 ARM_DOT_K0XN0(K0, a6, b, c6);
556#endif // M0 > 6
557#if M0 > 7
558 ARM_DOT_K0XN0(K0, a7, b, c7);
559#endif // M0 > 7
560
561 lhs_offset += K0 * sizeof(DATA_TYPE);
562 x_rhs += N0 * RHS_STEP_X * RHS_STEP_LOOP;
563 }
564
565#if LEFTOVER_K != 0
566 // Note: We cannot read out-of-bound elements from the RHS matrix because
567 // the RHS width is always multiple of K0. This is not be true for the LHS matrix
568
569 union UNION_VEC_TYPE
570 {
571 DATA_TYPE s[K0];
572 VEC_DATA_TYPE(DATA_TYPE, K0)
573 v;
574 };
575
576 union UNION_VEC_TYPE a0 = {.v = 0 };
577#if M0 > 1
578 union UNION_VEC_TYPE a1 = {.v = 0 };
579#endif // M0 > 1
580#if M0 > 2
581 union UNION_VEC_TYPE a2 = {.v = 0 };
582#endif // M0 > 2
583#if M0 > 3
584 union UNION_VEC_TYPE a3 = {.v = 0 };
585#endif // M0 > 3
586#if M0 > 4
587 union UNION_VEC_TYPE a4 = {.v = 0 };
588#endif // M0 > 4
589#if M0 > 5
590 union UNION_VEC_TYPE a5 = {.v = 0 };
591#endif // M0 > 5
592#if M0 > 6
593 union UNION_VEC_TYPE a6 = {.v = 0 };
594#endif // M0 > 6
595#if M0 > 7
596 union UNION_VEC_TYPE a7 = {.v = 0 };
597#endif // M0 > 7
598
599 REPEAT_VAR_INIT_TO_CONST(N0, VEC_DATA_TYPE(DATA_TYPE, K0), b, 0);
600
601 // Load from RHS matrix
602 LOAD_TEXTURE2D(N0, PIXEL_UNIT, DATA_TYPE, b, rhs_img, x_rhs, y_rhs, RHS_STEP_X, 0);
603
604 // Load from LHS matrix
605 for(int k = 0; k < LEFTOVER_K; ++k)
606 {
607 a0.s[k] = *(__global DATA_TYPE *)(lhs_ptr + lhs_offset + 0 * lhs_stride_y + zlhs0);
608#if M0 > 1
609 a1.s[k] = *(__global DATA_TYPE *)(lhs_ptr + lhs_offset + 1 * lhs_stride_y + zlhs1);
610#endif // M0 > 1
611#if M0 > 2
612 a2.s[k] = *(__global DATA_TYPE *)(lhs_ptr + lhs_offset + 2 * lhs_stride_y + zlhs2);
613#endif // M0 > 2
614#if M0 > 3
615 a3.s[k] = *(__global DATA_TYPE *)(lhs_ptr + lhs_offset + 3 * lhs_stride_y + zlhs3);
616#endif // M0 > 3
617#if M0 > 4
618 a4.s[k] = *(__global DATA_TYPE *)(lhs_ptr + lhs_offset + 4 * lhs_stride_y + zlhs4);
619#endif // M0 > 4
620#if M0 > 5
621 a5.s[k] = *(__global DATA_TYPE *)(lhs_ptr + lhs_offset + 5 * lhs_stride_y + zlhs5);
622#endif // M0 > 5
623#if M0 > 6
624 a6.s[k] = *(__global DATA_TYPE *)(lhs_ptr + lhs_offset + 6 * lhs_stride_y + zlhs6);
625#endif // M0 > 6
626#if M0 > 7
627 a7.s[k] = *(__global DATA_TYPE *)(lhs_ptr + lhs_offset + 7 * lhs_stride_y + zlhs7);
628#endif // M0 > 7
629
630 lhs_offset += sizeof(DATA_TYPE);
631 }
632
633 // Accumulate
634 ARM_DOT_K0XN0(K0, a0.v, b, c0);
635#if M0 > 1
636 ARM_DOT_K0XN0(K0, a1.v, b, c1);
637#endif // M0 > 1
638#if M0 > 2
639 ARM_DOT_K0XN0(K0, a2.v, b, c2);
640#endif // M0 > 2
641#if M0 > 3
642 ARM_DOT_K0XN0(K0, a3.v, b, c3);
643#endif // M0 > 3
644#if M0 > 4
645 ARM_DOT_K0XN0(K0, a4.v, b, c4);
646#endif // M0 > 4
647#if M0 > 5
648 ARM_DOT_K0XN0(K0, a5.v, b, c5);
649#endif // M0 > 5
650#if M0 > 6
651 ARM_DOT_K0XN0(K0, a6.v, b, c6);
652#endif // M0 > 6
653#if M0 > 7
654 ARM_DOT_K0XN0(K0, a7.v, b, c7);
655#endif // M0 > 7
656
657#endif // LEFTOVER_K != 0
658
SiCongLi71cbd282021-11-03 12:17:06 +0000659 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) * dst_stride_y);
SiCongLiafa19722021-10-24 19:12:33 +0100660
661 REPEAT_VAR_INIT_TO_CONST(M0, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
662
SiCongLiafa19722021-10-24 19:12:33 +0100663#if defined(REINTERPRET_OUTPUT_AS_3D)
664
665 // The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
SiCongLi71cbd282021-11-03 12:17:06 +0000666 CALCULATE_Z_OFFSET(M0, uint, zout, COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0), HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
SiCongLiafa19722021-10-24 19:12:33 +0100667
668 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
669 // multiply dst_stride_z by DEPTH_GEMM3D
670 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
671
672#else // defined(REINTERPRET_OUTPUT_AS_3D)
673
674 // Add offset for batched GEMM
675 dst_addr += z * dst_stride_z;
676
677#endif // defined(REINTERPRET_OUTPUT_AS_3D)
678
679 // Multiply by the weight of matrix-matrix product and store the result
680#if defined(ALPHA)
681 SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
682#endif // defined(ALPHA)
683
684 // Add beta*bias
685#if defined(BETA)
686#if defined(BROADCAST_BIAS)
687 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
688
689 LOAD_BLOCK_BOUNDARY_AWARE(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero, 1, PARTIAL_STORE_N0, false, cond_x);
690
691#ifndef UNIT_BETA
692 SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
693#endif // UNIT_BIAS
694
695 // c = c + bias[broadcasted]
696 ADD_BLOCK_BROADCAST(M0, c, bias0);
697
698#else // defined(BROADCAST_BIAS)
SiCongLi71cbd282021-11-03 12:17:06 +0000699 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) * bias_stride_y) + z * bias_stride_z;
SiCongLiafa19722021-10-24 19:12:33 +0100700
701 LOAD_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
702
703#ifndef UNIT_BETA
704 SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
705#endif // UNIT_BIAS
706
707 // c = c + bias
708 ADD_BLOCK(M0, c, bias);
709
710#endif // defined(BROADCAST_BIAS)
711#endif // defined(BETA)
712
713 // c = act(c)
714 POST_OP1_ACTIVATION_OPTIONAL(M0, DATA_TYPE, DATA_TYPE_ACCUMULATOR, N0, c);
715 // c = c + eltwise_operand (mix-precision, broadcast, boundary aware)
SiCongLi71cbd282021-11-03 12:17:06 +0000716 POST_OP2_ELTWISE_OP(P2_ELTWISE_OP, M0, N0, c, eltwise_operand, COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0), DATA_TYPE, DATA_TYPE_ACCUMULATOR, zero, 1, PARTIAL_STORE_N0, false, cond_x);
SiCongLiafa19722021-10-24 19:12:33 +0100717 // c = act(c)
718 POST_OP3_ACTIVATION_OPTIONAL(M0, DATA_TYPE, DATA_TYPE_ACCUMULATOR, N0, c);
719
720 // Store output block
721 STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
722
723#undef RHS_BLOCK_SIZE
724#undef RHS_OFFSET_X
725#undef RHS_STEP_X
726#undef LEFTOVER_K
727#undef PIXEL_UNIT
728}
729#endif // defined(OPENCL_IMAGE_SUPPORT)
730
731#define VFMA(a, b, c) \
732 ({ \
733 c = fma(a, b, c); \
734 })
735
736#if M0 == 1
737#define VFMA_M0xN0(i, a, b, c) \
738 ({ \
739 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
740 })
741#elif M0 == 2 // M0 == 2
742#define VFMA_M0xN0(i, a, b, c) \
743 ({ \
744 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
745 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
746 })
747#elif M0 == 3 // M0 == 3
748#define VFMA_M0xN0(i, a, b, c) \
749 ({ \
750 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
751 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
752 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
753 })
754#elif M0 == 4 // M0 == 4
755#define VFMA_M0xN0(i, a, b, c) \
756 ({ \
757 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
758 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
759 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
760 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
761 })
762#elif M0 == 5 // M0 == 5
763#define VFMA_M0xN0(i, a, b, c) \
764 ({ \
765 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
766 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
767 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
768 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
769 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
770 })
771#elif M0 == 6 // M0 == 6
772#define VFMA_M0xN0(i, a, b, c) \
773 ({ \
774 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
775 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
776 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
777 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
778 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
779 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
780 })
781#elif M0 == 7 // M0 == 7
782#define VFMA_M0xN0(i, a, b, c) \
783 ({ \
784 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
785 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
786 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
787 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
788 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
789 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
790 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##6).s##i), b, (c##6)); \
791 })
792#elif M0 == 8 // M0 == 8
793#define VFMA_M0xN0(i, a, b, c) \
794 ({ \
795 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
796 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
797 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
798 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
799 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
800 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
801 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##6).s##i), b, (c##6)); \
802 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##7).s##i), b, (c##7)); \
803 })
804#else // M0 not supported
805#error "M0 not supported"
806#endif // M0 not supported
807
808/** This OpenCL kernel computes the matrix multiplication between 2 matrices plus 3 post ops:
809 * Post op 1: activation (optional)
810 * Post op 2: elementwise op
811 * Post op 3: activation (optional)
812 *
813 * @note (Optional) -DP1_ACTIVATION_TYPE, -DP1_ACTIVATION_A_VAL, -DP1_ACTIVATION_B_VAL: The activation type, alpha and beta values of the activation post op at slot 3
814 * @note (Required) -DP2_ELTWISE_OP: The (binary) elementwise post op to perform
815 * @note (Required) -DP2_ELTWISE_ARG1_HEIGHT: The height (Y dimension) of the eltwise operand matrix of the eltwise post op at slot 2
816 * @note (Required) -DP2_ELTWISE_ARG1_WIDTH: The width (X dimension) of the eltwise operand matrix of the eltwise post op at slot 2
817 * @note (Optional) -DP3_ACTIVATION_TYPE, -DP3_ACTIVATION_A_VAL, -DP3_ACTIVATION_B_VAL: The activation type, alpha and beta values of the activation post op at slot 3
818 *
819 * All parameters are similarly defined in kernel gemm_mm_reshaped_only_rhs_nt, with these additions:
820 *
821 * @param[in] eltwise_operand_ptr Pointer to the eltwise operand matrix. Supported data type: F16/F32
822 * @param[in] eltwise_operand_stride_x Stride of the eltwise operand matrix in X dimension (in bytes)
823 * @param[in] eltwise_operand_step_x eltwise_operand_stride_x * number of elements along X processed per workitem(in bytes)
824 * @param[in] eltwise_operand_stride_y Stride of the eltwise operand matrix in Y dimension (in bytes)
825 * @param[in] eltwise_operand_step_y eltwise_operand_stride_y * number of elements along Y processed per workitem(in bytes)
826 * @param[in] eltwise_operand_stride_z Stride of the eltwise operand tensor in Z dimension (in bytes)
827 */
828__kernel void gemm_mm_reshaped_only_rhs_nt_post_act_eltwise_op_act(IMAGE_DECLARATION(lhs),
829 IMAGE_DECLARATION(rhs),
830#if defined(BETA)
831 IMAGE_DECLARATION(bias),
832#endif // defined(BETA)
833 IMAGE_DECLARATION(dst),
834 // Post-Op arguments
835 IMAGE_DECLARATION(eltwise_operand),
836 uint lhs_stride_z,
837 uint rhs_stride_z,
838#if defined(BETA)
839 uint bias_stride_z,
840#endif //defined(BETA)
841 uint dst_stride_z,
842 uint eltwise_operand_stride_z
843#if defined(REINTERPRET_INPUT_AS_3D)
844 ,
845 uint lhs_cross_plane_pad
846#endif // REINTERPRET_INPUT_AS_3D
847#if defined(REINTERPRET_OUTPUT_AS_3D)
848 ,
849 uint dst_cross_plane_pad
850#endif // REINTERPRET_OUTPUT_AS_3D
851 )
852{
853 // Block size
854#define RHS_BLOCK_SIZE ((K0) * (N0))
855
856 // RHS offset and step X
857#if defined(RHS_INTERLEAVE)
858#define RHS_OFFSET_X (N0)
859#define RHS_STEP_X ((N0) * (H0))
860#define RHS_STEP_LOOP (1)
861#else // defined(RHS_INTERLEAVE)
862#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
863#define RHS_STEP_X (N0)
864#define RHS_STEP_LOOP (H0)
865#endif // defined(RHS_INTERLEAVE)
866
867 uint x = get_global_id(0);
868 uint y = get_global_id(1);
869 uint z = get_global_id(2);
870
SiCongLi71cbd282021-11-03 12:17:06 +0000871 const bool cond_y = y == 0;
872 const bool cond_x = ((x + 1) * N0 >= N);
873
SiCongLiafa19722021-10-24 19:12:33 +0100874#if defined(DUMMY_WORK_ITEMS)
875 if((x * N0 >= N) || (y * M0 >= M))
876 {
877 return;
878 }
879#endif // defined(DUMMY_WORK_ITEMS)
880
881 // Compute LHS matrix address
SiCongLi71cbd282021-11-03 12:17:06 +0000882 uint lhs_offset = lhs_offset_first_element_in_bytes + COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) * (uint)lhs_stride_y;
SiCongLiafa19722021-10-24 19:12:33 +0100883
884 // Compute RHS reshaped matrix address
885 uint rhs_offset = rhs_offset_first_element_in_bytes + (x % H0) * (uint)RHS_OFFSET_X * sizeof(DATA_TYPE) + (x / (uint)H0) * rhs_stride_y;
886
887#if defined(MATRIX_B_DEPTH)
888 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
889 rhs_offset += (z % MATRIX_B_DEPTH) * rhs_stride_z;
890#else // defined(MATRIX_B_DEPTH)
891 rhs_offset += z * rhs_stride_z;
892#endif // defined(MATRIX_B_DEPTH)
893
894 REPEAT_VAR_INIT_TO_CONST(8, uint, zin, 0); //uint zin0=0,zin1=0,zin2=0,... zin7=0;
895 REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0); //uint zero0=0,zero1=0,zero2=0,... zero7=0;
896
897#if defined(REINTERPRET_INPUT_AS_3D)
898
899 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
SiCongLi71cbd282021-11-03 12:17:06 +0000900 CALCULATE_Z_OFFSET(M0, uint, zin, COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0), HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
SiCongLiafa19722021-10-24 19:12:33 +0100901
902 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
903 // multiply lhs_stride_z by DEPTH_GEMM3D
904 lhs_offset += z * lhs_stride_z * DEPTH_GEMM3D;
905
906#else // defined(REINTERPRET_INPUT_AS_3D)
907
908 // Add offset for batched GEMM
909 lhs_offset += z * lhs_stride_z;
910
911#endif // defined(REINTERPRET_INPUT_AS_3D)
912
913 // Initialize the accumulators
914 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(N0-1)=0;
915
916 int i = 0;
917 for(; i <= (K - K0); i += K0)
918 {
919 // Supported cases (M0, K0):
920 // 1,2 - 1,3 - 1,4 - 1,8 - 1,16
921 // 2,2 - 2,3 - 2,4 - 2,8 - 2,16
922 // 3,2 - 3,3 - 3,4 - 3,8 - 3,16
923 // 4,2 - 4,3 - 4,4 - 4,8 - 4,16
924 // 5,2 - 5,3 - 5,4 - 5,8 - 5,16
925 // 6,2 - 6,3 - 6,4 - 6,8 - 6,16
926 // 7,2 - 7,3 - 7,4 - 7,8 - 7,16
927 // 8,2 - 8,3 - 8,4 - 8,8 - 8,16
928 // Load values from LHS matrix
929 LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zin);
930
931 VEC_DATA_TYPE(DATA_TYPE, N0)
932 b0;
933
934 b0 = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0 * RHS_STEP_X * sizeof(DATA_TYPE)));
935 VFMA_M0xN0(0, a, b0, c);
936 b0 = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 1 * RHS_STEP_X * sizeof(DATA_TYPE)));
937 VFMA_M0xN0(1, a, b0, c);
938#if K0 > 2
939 b0 = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 2 * RHS_STEP_X * sizeof(DATA_TYPE)));
940 VFMA_M0xN0(2, a, b0, c);
941#endif // K0 > 2
942#if K0 > 3
943 b0 = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 3 * RHS_STEP_X * sizeof(DATA_TYPE)));
944 VFMA_M0xN0(3, a, b0, c);
945#endif // K0 > 3
946#if K0 > 4
947 b0 = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 4 * RHS_STEP_X * sizeof(DATA_TYPE)));
948 VFMA_M0xN0(4, a, b0, c);
949 b0 = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 5 * RHS_STEP_X * sizeof(DATA_TYPE)));
950 VFMA_M0xN0(5, a, b0, c);
951 b0 = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 6 * RHS_STEP_X * sizeof(DATA_TYPE)));
952 VFMA_M0xN0(6, a, b0, c);
953 b0 = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 7 * RHS_STEP_X * sizeof(DATA_TYPE)));
954 VFMA_M0xN0(7, a, b0, c);
955#endif // K0 > 4
956#if K0 > 8
957 b0 = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 8 * RHS_STEP_X * sizeof(DATA_TYPE)));
958 VFMA_M0xN0(8, a, b0, c);
959 b0 = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 9 * RHS_STEP_X * sizeof(DATA_TYPE)));
960 VFMA_M0xN0(9, a, b0, c);
961 b0 = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 10 * RHS_STEP_X * sizeof(DATA_TYPE)));
962 VFMA_M0xN0(A, a, b0, c);
963 b0 = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 11 * RHS_STEP_X * sizeof(DATA_TYPE)));
964 VFMA_M0xN0(B, a, b0, c);
965 b0 = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 12 * RHS_STEP_X * sizeof(DATA_TYPE)));
966 VFMA_M0xN0(C, a, b0, c);
967 b0 = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 13 * RHS_STEP_X * sizeof(DATA_TYPE)));
968 VFMA_M0xN0(D, a, b0, c);
969 b0 = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 14 * RHS_STEP_X * sizeof(DATA_TYPE)));
970 VFMA_M0xN0(E, a, b0, c);
971 b0 = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 15 * RHS_STEP_X * sizeof(DATA_TYPE)));
972 VFMA_M0xN0(F, a, b0, c);
973#endif // K0 > 8
974
975 lhs_offset += K0 * sizeof(DATA_TYPE);
976 rhs_offset += K0 * RHS_STEP_X * RHS_STEP_LOOP * sizeof(DATA_TYPE);
977 }
978
979 // Left-over accumulations
980 for(; i < K; ++i)
981 {
982 // Load values from LHS matrix
983 VEC_DATA_TYPE(DATA_TYPE, 2)
984 a0 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 0 * lhs_stride_y + zin0));
985#if M0 > 1
986 VEC_DATA_TYPE(DATA_TYPE, 2)
987 a1 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 1 * lhs_stride_y + zin1));
988#endif // M0 > 1
989#if M0 > 2
990 VEC_DATA_TYPE(DATA_TYPE, 2)
991 a2 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 2 * lhs_stride_y + zin2));
992#endif // M0 > 2
993#if M0 > 3
994 VEC_DATA_TYPE(DATA_TYPE, 2)
995 a3 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 3 * lhs_stride_y + zin3));
996#endif // M0 > 3
997#if M0 > 4
998 VEC_DATA_TYPE(DATA_TYPE, 2)
999 a4 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 4 * lhs_stride_y + zin4));
1000#endif // M0 > 4
1001#if M0 > 5
1002 VEC_DATA_TYPE(DATA_TYPE, 2)
1003 a5 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 5 * lhs_stride_y + zin5));
1004#endif // M0 > 5
1005#if M0 > 6
1006 VEC_DATA_TYPE(DATA_TYPE, 2)
1007 a6 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 6 * lhs_stride_y + zin6));
1008#endif // M0 > 6
1009#if M0 > 7
1010 VEC_DATA_TYPE(DATA_TYPE, 2)
1011 a7 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 7 * lhs_stride_y + zin7));
1012#endif // M0 > 7
1013
1014 VEC_DATA_TYPE(DATA_TYPE, N0)
1015 b0;
1016
1017 b0 = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0 * RHS_STEP_X * sizeof(DATA_TYPE)));
1018 VFMA_M0xN0(0, a, b0, c);
1019
1020 lhs_offset += sizeof(DATA_TYPE);
1021 rhs_offset += RHS_STEP_X * sizeof(DATA_TYPE);
1022 }
1023
SiCongLi71cbd282021-11-03 12:17:06 +00001024 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) * dst_stride_y);
SiCongLiafa19722021-10-24 19:12:33 +01001025
1026 REPEAT_VAR_INIT_TO_CONST(8, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
1027
SiCongLiafa19722021-10-24 19:12:33 +01001028#if defined(REINTERPRET_OUTPUT_AS_3D)
1029 // The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
SiCongLi71cbd282021-11-03 12:17:06 +00001030 CALCULATE_Z_OFFSET(M0, uint, zout, COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0), HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
SiCongLiafa19722021-10-24 19:12:33 +01001031
1032 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1033 // multiply dst_stride_z by DEPTH_GEMM3D
1034 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1035
1036#else // defined(REINTERPRET_OUTPUT_AS_3D)
1037
1038 // Add offset for batched GEMM
1039 dst_addr += z * dst_stride_z;
1040
1041#endif // defined(REINTERPRET_OUTPUT_AS_3D)
1042
1043 // Multiply by the weight of matrix-matrix product and store the result
1044#if defined(ALPHA)
1045 SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
1046#endif // defined(ALPHA)
1047
1048 // Add beta*bias
1049#if defined(BETA)
1050#if defined(BROADCAST_BIAS)
1051 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
1052
1053 LOAD_BLOCK_BOUNDARY_AWARE(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero, 1, PARTIAL_STORE_N0, false, cond_x);
1054
1055#ifndef UNIT_BETA
1056 SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
1057#endif // UNIT_BIAS
1058
1059 // c = c + bias[broadcasted]
1060 ADD_BLOCK_BROADCAST(M0, c, bias0);
1061
1062#else // defined(BROADCAST_BIAS)
SiCongLi71cbd282021-11-03 12:17:06 +00001063 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) * bias_stride_y) + z * bias_stride_z;
SiCongLiafa19722021-10-24 19:12:33 +01001064
1065 LOAD_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
1066
1067#ifndef UNIT_BETA
1068 SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
1069#endif // UNIT_BIAS
1070
1071 // c = c + bias
1072 ADD_BLOCK(M0, c, bias);
1073
1074#endif // defined(BROADCAST_BIAS)
1075#endif // defined(BETA)
1076
1077 // c = act(c)
1078 POST_OP1_ACTIVATION_OPTIONAL(M0, DATA_TYPE, DATA_TYPE_ACCUMULATOR, N0, c);
1079 // c = c + eltwise_operand (mix-precision, broadcast, boundary aware)
SiCongLi71cbd282021-11-03 12:17:06 +00001080 POST_OP2_ELTWISE_OP(P2_ELTWISE_OP, M0, N0, c, eltwise_operand, COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0), DATA_TYPE, DATA_TYPE_ACCUMULATOR, zero, 1, PARTIAL_STORE_N0, false, cond_x);
SiCongLiafa19722021-10-24 19:12:33 +01001081 // c = act(c)
1082 POST_OP3_ACTIVATION_OPTIONAL(M0, DATA_TYPE, DATA_TYPE_ACCUMULATOR, N0, c);
1083
1084 // Store output block
1085 STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
1086
1087#undef RHS_BLOCK_SIZE
1088#undef RHS_OFFSET_X
1089#undef RHS_STEP_X
1090}
1091
1092#if defined(OPENCL_IMAGE_SUPPORT)
1093/** This OpenCL kernel computes the matrix multiplication between 2 matrices plus 3 post ops. The RHS matrix is stored in OpenCL image object.
1094 * Post op 1: activation (optional)
1095 * Post op 2: elementwise op
1096 * Post op 3: activation (optional)
1097 *
1098 * @note (Optional) -DP1_ACTIVATION_TYPE, -DP1_ACTIVATION_A_VAL, -DP1_ACTIVATION_B_VAL: The activation type, alpha and beta values of the activation post op at slot 3
1099 * @note (Required) -DP2_ELTWISE_OP: The (binary) elementwise post op to perform
1100 * @note (Required) -DP2_ELTWISE_ARG1_HEIGHT: The height (Y dimension) of the eltwise operand matrix of the eltwise post op at slot 2
1101 * @note (Required) -DP2_ELTWISE_ARG1_WIDTH: The width (X dimension) of the eltwise operand matrix of the eltwise post op at slot 2
1102 * @note (Optional) -DP3_ACTIVATION_TYPE, -DP3_ACTIVATION_A_VAL, -DP3_ACTIVATION_B_VAL: The activation type, alpha and beta values of the activation post op at slot 3
1103 *
1104 * All parameters are similarly defined in kernel gemm_mm_reshaped_only_rhs_nt_texture, with these additions:
1105 *
1106 * @param[in] eltwise_operand_ptr Pointer to the eltwise operand matrix. Supported data type: F16/F32
1107 * @param[in] eltwise_operand_stride_x Stride of the eltwise operand matrix in X dimension (in bytes)
1108 * @param[in] eltwise_operand_step_x eltwise_operand_stride_x * number of elements along X processed per workitem(in bytes)
1109 * @param[in] eltwise_operand_stride_y Stride of the eltwise operand matrix in Y dimension (in bytes)
1110 * @param[in] eltwise_operand_step_y eltwise_operand_stride_y * number of elements along Y processed per workitem(in bytes)
1111 * @param[in] eltwise_operand_stride_z Stride of the eltwise operand tensor in Z dimension (in bytes)
1112 */
1113__kernel void gemm_mm_reshaped_only_rhs_nt_texture_post_act_eltwise_op_act(IMAGE_DECLARATION(lhs),
1114 __read_only image2d_t rhs_img,
1115#if defined(BETA)
1116 IMAGE_DECLARATION(bias),
1117#endif // defined(BETA)
1118 IMAGE_DECLARATION(dst),
1119 // Post-Op arguments
1120 IMAGE_DECLARATION(eltwise_operand),
1121 uint lhs_stride_z,
1122 uint rhs_stride_z,
1123#if defined(BETA)
1124 uint bias_stride_z,
1125#endif //defined(BETA)
1126 uint dst_stride_z,
1127 uint eltwise_operand_stride_z
1128#if defined(REINTERPRET_INPUT_AS_3D)
1129 ,
1130 uint lhs_cross_plane_pad
1131#endif // REINTERPRET_INPUT_AS_3D
1132#if defined(REINTERPRET_OUTPUT_AS_3D)
1133 ,
1134 uint dst_cross_plane_pad
1135#endif // REINTERPRET_OUTPUT_AS_3D
1136 )
1137{
1138 // Pixel unit
1139#define PIXEL_UNIT CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT(N0)
1140
1141 // Block size
1142#define RHS_BLOCK_SIZE ((K0) * (PIXEL_UNIT))
1143
1144 // RHS offset and step X
1145#if defined(RHS_INTERLEAVE)
1146#define RHS_OFFSET_X (PIXEL_UNIT)
1147#define RHS_STEP_X ((PIXEL_UNIT) * (H0))
1148#else // defined(RHS_INTERLEAVE)
1149#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
1150#define RHS_STEP_X (PIXEL_UNIT)
1151#endif // defined(RHS_INTERLEAVE)
1152
1153 uint x = get_global_id(0);
1154 uint y = get_global_id(1);
1155 uint z = get_global_id(2);
1156
SiCongLi71cbd282021-11-03 12:17:06 +00001157 const bool cond_y = y == 0;
1158 const bool cond_x = ((x + 1) * N0 >= N);
1159
SiCongLiafa19722021-10-24 19:12:33 +01001160#if defined(DUMMY_WORK_ITEMS)
1161 if((x * N0 >= N) || (y * M0 >= M))
1162 {
1163 return;
1164 }
1165#endif // defined(DUMMY_WORK_ITEMS)
1166
1167 // Compute LHS matrix address
SiCongLi71cbd282021-11-03 12:17:06 +00001168 uint lhs_offset = lhs_offset_first_element_in_bytes + COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) * (uint)lhs_stride_y;
SiCongLiafa19722021-10-24 19:12:33 +01001169
1170#if defined(MATRIX_B_DEPTH)
1171 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1172 const uint z_rhs = (z % MATRIX_B_DEPTH);
1173#else // defined(MATRIX_B_DEPTH)
1174 const uint z_rhs = z;
1175#endif // defined(MATRIX_B_DEPTH)
1176
1177 // Compute RHS matrix coordinates
1178 uint x_rhs = (x % H0) * (uint)RHS_OFFSET_X;
1179 const uint y_rhs = (x / (uint)H0) + z_rhs * RHS_HEIGHT;
1180
1181 REPEAT_VAR_INIT_TO_CONST(8, uint, zin, 0);
1182 REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0);
1183
1184#if defined(REINTERPRET_INPUT_AS_3D)
1185
1186 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
SiCongLi71cbd282021-11-03 12:17:06 +00001187 CALCULATE_Z_OFFSET(M0, uint, zin, COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0), HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
SiCongLiafa19722021-10-24 19:12:33 +01001188
1189 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1190 // multiply lhs_stride_z by DEPTH_GEMM3D
1191 lhs_offset += z * lhs_stride_z * DEPTH_GEMM3D;
1192
1193#else // defined(REINTERPRET_INPUT_AS_3D)
1194
1195 // Add offset for batched GEMM
1196 lhs_offset += z * lhs_stride_z;
1197
1198#endif // defined(REINTERPRET_INPUT_AS_3D)
1199
1200 // Initialize the accumulators
1201 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0);
1202
1203 int i = 0;
1204 for(; i <= (K - K0); i += K0)
1205 {
1206 // Load values from LHS matrix
1207 LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zin);
1208
1209 VEC_DATA_TYPE(DATA_TYPE, N0)
1210 b0;
1211
1212 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 0 * RHS_STEP_X), (y_rhs));
1213 VFMA_M0xN0(0, a, b0, c);
1214 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 1 * RHS_STEP_X), (y_rhs));
1215 VFMA_M0xN0(1, a, b0, c);
1216#if K0 > 2
1217 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 2 * RHS_STEP_X), (y_rhs));
1218 VFMA_M0xN0(2, a, b0, c);
1219#endif // K0 > 2
1220#if K0 > 3
1221 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 3 * RHS_STEP_X), (y_rhs));
1222 VFMA_M0xN0(3, a, b0, c);
1223#endif // K0 > 3
1224#if K0 > 4
1225 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 4 * RHS_STEP_X), (y_rhs));
1226 VFMA_M0xN0(4, a, b0, c);
1227 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 5 * RHS_STEP_X), (y_rhs));
1228 VFMA_M0xN0(5, a, b0, c);
1229 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 6 * RHS_STEP_X), (y_rhs));
1230 VFMA_M0xN0(6, a, b0, c);
1231 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 7 * RHS_STEP_X), (y_rhs));
1232 VFMA_M0xN0(7, a, b0, c);
1233#endif // K0 > 4
1234#if K0 > 8
1235 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 8 * RHS_STEP_X), (y_rhs));
1236 VFMA_M0xN0(8, a, b0, c);
1237 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 9 * RHS_STEP_X), (y_rhs));
1238 VFMA_M0xN0(9, a, b0, c);
1239 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 10 * RHS_STEP_X), (y_rhs));
1240 VFMA_M0xN0(A, a, b0, c);
1241 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 11 * RHS_STEP_X), (y_rhs));
1242 VFMA_M0xN0(B, a, b0, c);
1243 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 12 * RHS_STEP_X), (y_rhs));
1244 VFMA_M0xN0(C, a, b0, c);
1245 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 13 * RHS_STEP_X), (y_rhs));
1246 VFMA_M0xN0(D, a, b0, c);
1247 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 14 * RHS_STEP_X), (y_rhs));
1248 VFMA_M0xN0(E, a, b0, c);
1249 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 15 * RHS_STEP_X), (y_rhs));
1250 VFMA_M0xN0(F, a, b0, c);
1251#endif // K0 > 8
1252
1253 lhs_offset += K0 * sizeof(DATA_TYPE);
1254 x_rhs += K0 * RHS_STEP_X * RHS_STEP_LOOP;
1255 }
1256
1257 // Left-over accumulations
1258 for(; i < K; ++i)
1259 {
1260 // Load values from LHS matrix
1261 VEC_DATA_TYPE(DATA_TYPE, 2)
1262 a0 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 0 * lhs_stride_y + zin0));
1263#if M0 > 1
1264 VEC_DATA_TYPE(DATA_TYPE, 2)
1265 a1 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 1 * lhs_stride_y + zin1));
1266#endif // M0 > 1
1267#if M0 > 2
1268 VEC_DATA_TYPE(DATA_TYPE, 2)
1269 a2 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 2 * lhs_stride_y + zin2));
1270#endif // M0 > 2
1271#if M0 > 3
1272 VEC_DATA_TYPE(DATA_TYPE, 2)
1273 a3 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 3 * lhs_stride_y + zin3));
1274#endif // M0 > 3
1275#if M0 > 4
1276 VEC_DATA_TYPE(DATA_TYPE, 2)
1277 a4 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 4 * lhs_stride_y + zin4));
1278#endif // M0 > 4
1279#if M0 > 5
1280 VEC_DATA_TYPE(DATA_TYPE, 2)
1281 a5 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 5 * lhs_stride_y + zin5));
1282#endif // M0 > 5
1283#if M0 > 6
1284 VEC_DATA_TYPE(DATA_TYPE, 2)
1285 a6 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 6 * lhs_stride_y + zin6));
1286#endif // M0 > 6
1287#if M0 > 7
1288 VEC_DATA_TYPE(DATA_TYPE, 2)
1289 a7 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 7 * lhs_stride_y + zin7));
1290#endif // M0 > 7
1291
1292 VEC_DATA_TYPE(DATA_TYPE, N0)
1293 b0;
1294 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 0 * RHS_STEP_X), (y_rhs));
1295
1296 VFMA_M0xN0(0, a, b0, c);
1297
1298 lhs_offset += sizeof(DATA_TYPE);
1299 x_rhs += RHS_STEP_X;
1300 }
1301
SiCongLi71cbd282021-11-03 12:17:06 +00001302 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) * dst_stride_y);
SiCongLiafa19722021-10-24 19:12:33 +01001303
1304 REPEAT_VAR_INIT_TO_CONST(8, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
1305
SiCongLiafa19722021-10-24 19:12:33 +01001306#if defined(REINTERPRET_OUTPUT_AS_3D)
1307 // The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
SiCongLi71cbd282021-11-03 12:17:06 +00001308 CALCULATE_Z_OFFSET(M0, uint, zout, COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0), HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
SiCongLiafa19722021-10-24 19:12:33 +01001309
1310 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1311 // multiply dst_stride_z by DEPTH_GEMM3D
1312 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1313
1314#else // defined(REINTERPRET_OUTPUT_AS_3D)
1315
1316 // Add offset for batched GEMM
1317 dst_addr += z * dst_stride_z;
1318
1319#endif // defined(REINTERPRET_OUTPUT_AS_3D)
1320
1321 // Multiply by the weight of matrix-matrix product and store the result
1322#if defined(ALPHA)
1323 SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
1324#endif // defined(ALPHA)
1325
1326 // Add beta*bias
1327#if defined(BETA)
1328#if defined(BROADCAST_BIAS)
1329 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
1330
1331 LOAD_BLOCK_BOUNDARY_AWARE(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero, 1, PARTIAL_STORE_N0, false, cond_x);
1332
1333#ifndef UNIT_BETA
1334 SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
1335#endif // UNIT_BIAS
1336
1337 // c = c + bias[broadcasted]
1338 ADD_BLOCK_BROADCAST(M0, c, bias0);
1339
1340#else // defined(BROADCAST_BIAS)
SiCongLi71cbd282021-11-03 12:17:06 +00001341 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) * bias_stride_y) + z * bias_stride_z;
SiCongLiafa19722021-10-24 19:12:33 +01001342
1343 LOAD_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
1344
1345#ifndef UNIT_BETA
1346 SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
1347#endif // UNIT_BIAS
1348
1349 // c = c + bias
1350 ADD_BLOCK(M0, c, bias);
1351
1352#endif // defined(BROADCAST_BIAS)
1353#endif // defined(BETA)
1354
1355 // c = act(c)
1356 POST_OP1_ACTIVATION_OPTIONAL(M0, DATA_TYPE, DATA_TYPE_ACCUMULATOR, N0, c);
1357 // c = c + eltwise_operand (mix-precision, broadcast, boundary aware)
SiCongLi71cbd282021-11-03 12:17:06 +00001358 POST_OP2_ELTWISE_OP(P2_ELTWISE_OP, M0, N0, c, eltwise_operand, COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0), DATA_TYPE, DATA_TYPE_ACCUMULATOR, zero, 1, PARTIAL_STORE_N0, false, cond_x);
SiCongLiafa19722021-10-24 19:12:33 +01001359 // c = act(c)
1360 POST_OP3_ACTIVATION_OPTIONAL(M0, DATA_TYPE, DATA_TYPE_ACCUMULATOR, N0, c);
1361
1362 // Store output block
1363 STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
1364
1365#undef RHS_BLOCK_SIZE
1366#undef RHS_OFFSET_X
1367#undef RHS_STEP_X
1368}
1369#endif // defined(OPENCL_IMAGE_SUPPORT)
1370#endif // defined(P2_ELTWISE_OP) && defined(P2_ELTWISE_ARG1_HEIGHT) && defined(P2_ELTWISE_ARG1_WIDTH)
1371#endif // defined(M0) && defined(N0) && defined(K0) && defined(H0) && defined(DATA_TYPE) && defined(M) && defined(N) && defined(K)