blob: 508ee96d2d080d670e93313d3b025d680a024728 [file] [log] [blame]
SiCongLiafa19722021-10-24 19:12:33 +01001/*
2 * Copyright (c) 2021 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "fp_post_ops_act_eltwise_op_act.h"
25#include "gemm_helpers.h"
26#include "repeat.h"
27
28/** (EXPERIMENTAL_POST_OPS) gemm_mm_reshaped_only_rhs kernel */
29#if defined(M0) && defined(N0) && defined(K0) && defined(H0) && defined(DATA_TYPE) && defined(M) && defined(N) && defined(K)
30#if defined(P2_ELTWISE_OP) && defined(P2_ELTWISE_ARG1_HEIGHT) && defined(P2_ELTWISE_ARG1_WIDTH)
31
32#define CONCAT(a, b) a##b
33
34#define ARM_DOT1(a, b, c) \
35 ({ \
36 c = fma(a, b, c); \
37 })
38#define ARM_DOT2(a, b, c) \
39 ({ \
40 c = fma(a.s0, b.s0, c); \
41 c = fma(a.s1, b.s1, c); \
42 })
43#define ARM_DOT3(a, b, c) \
44 ({ \
45 ARM_DOT2(a, b, c); \
46 c = fma((a.s2), (b.s2), c); \
47 })
48#define ARM_DOT4(a, b, c) \
49 ({ \
50 ARM_DOT3(a, b, c); \
51 c = fma((a.s3), (b.s3), c); \
52 })
53#define ARM_DOT8(a, b, c) \
54 ({ \
55 ARM_DOT4((a.lo), (b.lo), c); \
56 ARM_DOT4((a.hi), (b.hi), c); \
57 })
58#define ARM_DOT16(a, b, c) \
59 ({ \
60 ARM_DOT8((a.lo), (b.lo), c); \
61 ARM_DOT8((a.hi), (b.hi), c); \
62 })
63
64#if N0 == 2
65#define ARM_DOT_K0XN0(k0, a, b, c) \
66 ({ \
67 CONCAT(ARM_DOT, k0) \
68 ((a), (b##0), (c.s0)); \
69 CONCAT(ARM_DOT, k0) \
70 ((a), (b##1), (c.s1)); \
71 })
72#elif N0 == 3 // N0 == 3
73#define ARM_DOT_K0XN0(k0, a, b, c) \
74 ({ \
75 CONCAT(ARM_DOT, k0) \
76 ((a), (b##0), (c.s0)); \
77 CONCAT(ARM_DOT, k0) \
78 ((a), (b##1), (c.s1)); \
79 CONCAT(ARM_DOT, k0) \
80 ((a), (b##2), (c.s2)); \
81 })
82#elif N0 == 4 // N0 == 4
83#define ARM_DOT_K0XN0(k0, a, b, c) \
84 ({ \
85 CONCAT(ARM_DOT, k0) \
86 ((a), (b##0), (c.s0)); \
87 CONCAT(ARM_DOT, k0) \
88 ((a), (b##1), (c.s1)); \
89 CONCAT(ARM_DOT, k0) \
90 ((a), (b##2), (c.s2)); \
91 CONCAT(ARM_DOT, k0) \
92 ((a), (b##3), (c.s3)); \
93 })
94#elif N0 == 8 // N0 == 8
95#define ARM_DOT_K0XN0(k0, a, b, c) \
96 ({ \
97 CONCAT(ARM_DOT, k0) \
98 ((a), (b##0), (c.s0)); \
99 CONCAT(ARM_DOT, k0) \
100 ((a), (b##1), (c.s1)); \
101 CONCAT(ARM_DOT, k0) \
102 ((a), (b##2), (c.s2)); \
103 CONCAT(ARM_DOT, k0) \
104 ((a), (b##3), (c.s3)); \
105 CONCAT(ARM_DOT, k0) \
106 ((a), (b##4), (c.s4)); \
107 CONCAT(ARM_DOT, k0) \
108 ((a), (b##5), (c.s5)); \
109 CONCAT(ARM_DOT, k0) \
110 ((a), (b##6), (c.s6)); \
111 CONCAT(ARM_DOT, k0) \
112 ((a), (b##7), (c.s7)); \
113 })
114#elif N0 == 16 // N0 == 16
115#define ARM_DOT_K0XN0(k0, a, b, c) \
116 ({ \
117 CONCAT(ARM_DOT, k0) \
118 ((a), (b##0), (c.s0)); \
119 CONCAT(ARM_DOT, k0) \
120 ((a), (b##1), (c.s1)); \
121 CONCAT(ARM_DOT, k0) \
122 ((a), (b##2), (c.s2)); \
123 CONCAT(ARM_DOT, k0) \
124 ((a), (b##3), (c.s3)); \
125 CONCAT(ARM_DOT, k0) \
126 ((a), (b##4), (c.s4)); \
127 CONCAT(ARM_DOT, k0) \
128 ((a), (b##5), (c.s5)); \
129 CONCAT(ARM_DOT, k0) \
130 ((a), (b##6), (c.s6)); \
131 CONCAT(ARM_DOT, k0) \
132 ((a), (b##7), (c.s7)); \
133 CONCAT(ARM_DOT, k0) \
134 ((a), (b##8), (c.s8)); \
135 CONCAT(ARM_DOT, k0) \
136 ((a), (b##9), (c.s9)); \
137 CONCAT(ARM_DOT, k0) \
138 ((a), (b##A), (c.sA)); \
139 CONCAT(ARM_DOT, k0) \
140 ((a), (b##B), (c.sB)); \
141 CONCAT(ARM_DOT, k0) \
142 ((a), (b##C), (c.sC)); \
143 CONCAT(ARM_DOT, k0) \
144 ((a), (b##D), (c.sD)); \
145 CONCAT(ARM_DOT, k0) \
146 ((a), (b##E), (c.sE)); \
147 CONCAT(ARM_DOT, k0) \
148 ((a), (b##F), (c.sF)); \
149 })
150#else // N0 not supported
151#error "N0 value not supported"
152#endif // N0 conditions
153
154/** This OpenCL kernel computes the matrix multiplication between 2 matrices plus 3 post ops:
155 * Post op 1: activation (optional)
156 * Post op 2: elementwise op
157 * Post op 3: activation (optional)
158 *
159 * @note (Optional) -DP1_ACTIVATION_TYPE, -DP1_ACTIVATION_A_VAL, -DP1_ACTIVATION_B_VAL: The activation type, alpha and beta values of the activation post op at slot 3
160 * @note (Required) -DP2_ELTWISE_OP: The (binary) elementwise post op to perform
161 * @note (Required) -DP2_ELTWISE_ARG1_HEIGHT: The height (Y dimension) of the eltwise operand matrix of the eltwise post op at slot 2
162 * @note (Required) -DP2_ELTWISE_ARG1_WIDTH: The width (X dimension) of the eltwise operand matrix of the eltwise post op at slot 2
163 * @note (Optional) -DP3_ACTIVATION_TYPE, -DP3_ACTIVATION_A_VAL, -DP3_ACTIVATION_B_VAL: The activation type, alpha and beta values of the activation post op at slot 3
164 *
165 * All parameters are similarly defined in kernel gemm_mm_reshaped_only_rhs_t, with these additions:
166 *
167 * @param[in] eltwise_operand_ptr Pointer to the eltwise operand matrix. Supported data type: F16/F32
168 * @param[in] eltwise_operand_stride_x Stride of the eltwise operand matrix in X dimension (in bytes)
169 * @param[in] eltwise_operand_step_x eltwise_operand_stride_x * number of elements along X processed per workitem(in bytes)
170 * @param[in] eltwise_operand_stride_y Stride of the eltwise operand matrix in Y dimension (in bytes)
171 * @param[in] eltwise_operand_step_y eltwise_operand_stride_y * number of elements along Y processed per workitem(in bytes)
172 * @param[in] eltwise_operand_stride_z Stride of the eltwise operand tensor in Z dimension (in bytes)
173 */
174__kernel void gemm_mm_reshaped_only_rhs_t_post_act_eltwise_op_act(IMAGE_DECLARATION(lhs),
175 IMAGE_DECLARATION(rhs),
176#if defined(BETA)
177 IMAGE_DECLARATION(bias),
178#endif // defined(BETA)
179 IMAGE_DECLARATION(dst),
180 // Post-Op arguments
181 IMAGE_DECLARATION(eltwise_operand),
182 uint lhs_stride_z,
183 uint rhs_stride_z,
184#if defined(BETA)
185 uint bias_stride_z,
186#endif //defined(BETA)
187 uint dst_stride_z,
188 uint eltwise_operand_stride_z
189#if defined(REINTERPRET_INPUT_AS_3D)
190 ,
191 uint lhs_cross_plane_pad
192#endif // REINTERPRET_INPUT_AS_3D
193#if defined(REINTERPRET_OUTPUT_AS_3D)
194 ,
195 uint dst_cross_plane_pad
196#endif // REINTERPRET_OUTPUT_AS_3D
197 )
198{
199 // Block size
200#define RHS_BLOCK_SIZE ((K0) * (N0))
201
202 // RHS offset and step X
203#if defined(RHS_INTERLEAVE)
204#define RHS_OFFSET_X (K0)
205#define RHS_STEP_X ((K0) * (H0))
206#define RHS_STEP_LOOP (1)
207#else // defined(RHS_INTERLEAVE)
208#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
209#define RHS_STEP_X (K0)
210#define RHS_STEP_LOOP (H0)
211#endif // defined(RHS_INTERLEAVE)
212
213 uint x = get_global_id(0);
214 uint y = get_global_id(1);
215 uint z = get_global_id(2);
216
217#if defined(DUMMY_WORK_ITEMS)
218 if((x * N0 >= N) || (y * M0 >= M))
219 {
220 return;
221 }
222#endif // defined(DUMMY_WORK_ITEMS)
223
224 // Compute LHS matrix address
225 uint lhs_offset = lhs_offset_first_element_in_bytes + y * M0 * (uint)lhs_stride_y;
226
227 // Compute RHS reshaped matrix address
228 uint rhs_offset = rhs_offset_first_element_in_bytes + (x % H0) * (uint)RHS_OFFSET_X * sizeof(DATA_TYPE) + (x / (uint)H0) * rhs_stride_y;
229
230#if defined(MATRIX_B_DEPTH)
231 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
232 rhs_offset += (z % MATRIX_B_DEPTH) * rhs_stride_z;
233#else // defined(MATRIX_B_DEPTH)
234 rhs_offset += z * rhs_stride_z;
235#endif // defined(MATRIX_B_DEPTH)
236
237 REPEAT_VAR_INIT_TO_CONST(8, uint, zlhs, 0); //uint zlhs0=0,zlhs1=0,zlhs2=0,... zlhs7=0;
238 REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0);
239
240#if defined(REINTERPRET_INPUT_AS_3D)
241 // The plane (zlhs) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
242 CALCULATE_Z_OFFSET(M0, uint, zlhs, y * M0, HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
243
244 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
245 // multiply lhs_stride_z by DEPTH_GEMM3D
246 lhs_offset += z * lhs_stride_z * DEPTH_GEMM3D;
247
248#else // defined(REINTERPRET_INPUT_AS_3D)
249
250 // Add offset for batched GEMM
251 lhs_offset += z * lhs_stride_z;
252
253#endif // defined(REINTERPRET_INPUT_AS_3D)
254
255 // Initialize the accumulators
256 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(M0-1)=0;
257
258 int i = 0;
259 for(; i <= (K - K0); i += K0)
260 {
261 // Supported cases (M0, K0):
262 // 1,2 - 1,3 - 1,4 - 1,8 - 1,16
263 // 2,2 - 2,3 - 2,4 - 2,8 - 2,16
264 // 3,2 - 3,3 - 3,4 - 3,8 - 3,16
265 // 4,2 - 4,3 - 4,4 - 4,8 - 4,16
266 // 5,2 - 5,3 - 5,4 - 5,8 - 5,16
267 // 6,2 - 6,3 - 6,4 - 6,8 - 6,16
268 // 7,2 - 7,3 - 7,4 - 7,8 - 7,16
269 // 8,2 - 8,3 - 8,4 - 8,8 - 8,16
270 // Load values from LHS matrix
271 LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs);
272
273 // Load values from RHS reshaped matrix
274 LOAD_BLOCK(N0, K0, DATA_TYPE, b, rhs_ptr, rhs_offset, RHS_STEP_X * sizeof(DATA_TYPE), zero);
275
276 // Accumulate
277 ARM_DOT_K0XN0(K0, a0, b, c0);
278#if M0 > 1
279 ARM_DOT_K0XN0(K0, a1, b, c1);
280#endif // M0 > 1
281#if M0 > 2
282 ARM_DOT_K0XN0(K0, a2, b, c2);
283#endif // M0 > 2
284#if M0 > 3
285 ARM_DOT_K0XN0(K0, a3, b, c3);
286#endif // M0 > 3
287#if M0 > 4
288 ARM_DOT_K0XN0(K0, a4, b, c4);
289#endif // M0 > 4
290#if M0 > 5
291 ARM_DOT_K0XN0(K0, a5, b, c5);
292#endif // M0 > 5
293#if M0 > 6
294 ARM_DOT_K0XN0(K0, a6, b, c6);
295#endif // M0 > 6
296#if M0 > 7
297 ARM_DOT_K0XN0(K0, a7, b, c7);
298#endif // M0 > 7
299
300 lhs_offset += K0 * sizeof(DATA_TYPE);
301 rhs_offset += (N0 * RHS_STEP_X * RHS_STEP_LOOP) * sizeof(DATA_TYPE);
302 }
303
304 // Left-over accumulations
305 for(; i < K; ++i)
306 {
307 // Load values from LHS matrix
308 LOAD_BLOCK(M0, 1, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs);
309
310 // Load values from RHS reshaped matrix
311 LOAD_BLOCK(N0, 1, DATA_TYPE, b, rhs_ptr, rhs_offset, RHS_STEP_X * sizeof(DATA_TYPE), zero);
312
313 // Accumulate
314 ARM_DOT_K0XN0(1, a0, b, c0);
315#if M0 > 1
316 ARM_DOT_K0XN0(1, a1, b, c1);
317#endif // M0 > 1
318#if M0 > 2
319 ARM_DOT_K0XN0(1, a2, b, c2);
320#endif // M0 > 2
321#if M0 > 3
322 ARM_DOT_K0XN0(1, a3, b, c3);
323#endif // M0 > 3
324#if M0 > 4
325 ARM_DOT_K0XN0(1, a4, b, c4);
326#endif // M0 > 4
327#if M0 > 5
328 ARM_DOT_K0XN0(1, a5, b, c5);
329#endif // M0 > 5
330#if M0 > 6
331 ARM_DOT_K0XN0(1, a6, b, c6);
332#endif // M0 > 6
333#if M0 > 7
334 ARM_DOT_K0XN0(1, a7, b, c7);
335#endif // M0 > 7
336
337 lhs_offset += sizeof(DATA_TYPE);
338 rhs_offset += sizeof(DATA_TYPE);
339 }
340
341 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * M0 * dst_stride_y);
342
343 REPEAT_VAR_INIT_TO_CONST(8, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
344
345 // Boundary conditions: detect if current block is at the "bottom" or "right" boundary
346 const bool cond_y = ((y + 1) * M0 >= M);
347 const bool cond_x = ((x + 1) * N0 >= N);
348
349#if defined(REINTERPRET_OUTPUT_AS_3D)
350
351 // The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
352 CALCULATE_Z_OFFSET(M0, uint, zout, y * M0, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
353
354 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
355 // multiply dst_stride_z by DEPTH_GEMM3D
356 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
357
358#else // defined(REINTERPRET_OUTPUT_AS_3D)
359
360 // Add offset for batched GEMM
361 dst_addr += z * dst_stride_z;
362
363#endif // defined(REINTERPRET_OUTPUT_AS_3D)
364
365 // Multiply by the weight of matrix-matrix product and store the result
366#if defined(ALPHA)
367 SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
368#endif // defined(ALPHA)
369
370 // Add beta*bias
371#if defined(BETA)
372#if defined(BROADCAST_BIAS)
373 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
374
375 LOAD_BLOCK_BOUNDARY_AWARE(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero, 1, PARTIAL_STORE_N0, false, cond_x);
376
377#ifndef UNIT_BETA
378 SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
379#endif // UNIT_BIAS
380
381 // c = c + bias[broadcasted]
382 ADD_BLOCK_BROADCAST(M0, c, bias0);
383
384#else // defined(BROADCAST_BIAS)
385 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * M0 * bias_stride_y) + z * bias_stride_z;
386
387 LOAD_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
388
389#ifndef UNIT_BETA
390 SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
391#endif // UNIT_BIAS
392
393 // c = c + bias
394 ADD_BLOCK(M0, c, bias);
395
396#endif // defined(BROADCAST_BIAS)
397#endif // defined(BETA)
398
399 // c = act(c)
400 POST_OP1_ACTIVATION_OPTIONAL(M0, DATA_TYPE, DATA_TYPE_ACCUMULATOR, N0, c);
401 // c = c + eltwise_operand (mix-precision, broadcast, boundary aware)
402 POST_OP2_ELTWISE_OP(P2_ELTWISE_OP, M0, N0, c, eltwise_operand, DATA_TYPE, DATA_TYPE_ACCUMULATOR, zero, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
403 // c = act(c)
404 POST_OP3_ACTIVATION_OPTIONAL(M0, DATA_TYPE, DATA_TYPE_ACCUMULATOR, N0, c);
405
406 // Store output block
407 STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
408
409#undef RHS_BLOCK_SIZE
410#undef RHS_OFFSET_X
411#undef RHS_STEP_X
412}
413
414#if defined(OPENCL_IMAGE_SUPPORT)
415/** This OpenCL kernel computes the matrix multiplication between 2 matrices plus 3 post ops. The RHS matrix is stored in OpenCL image object.
416 * Post op 1: activation (optional)
417 * Post op 2: elementwise op
418 * Post op 3: activation (optional)
419 *
420 * @note (Optional) -DP1_ACTIVATION_TYPE, -DP1_ACTIVATION_A_VAL, -DP1_ACTIVATION_B_VAL: The activation type, alpha and beta values of the activation post op at slot 3
421 * @note (Required) -DP2_ELTWISE_OP: The (binary) elementwise post op to perform
422 * @note (Required) -DP2_ELTWISE_ARG1_HEIGHT: The height (Y dimension) of the eltwise operand matrix of the eltwise post op at slot 2
423 * @note (Required) -DP2_ELTWISE_ARG1_WIDTH: The width (X dimension) of the eltwise operand matrix of the eltwise post op at slot 2
424 * @note (Optional) -DP3_ACTIVATION_TYPE, -DP3_ACTIVATION_A_VAL, -DP3_ACTIVATION_B_VAL: The activation type, alpha and beta values of the activation post op at slot 3
425 *
426 * All parameters are similarly defined in kernel gemm_mm_reshaped_only_rhs_t_texture, with these additions:
427 *
428 * @param[in] eltwise_operand_ptr Pointer to the eltwise operand matrix. Supported data type: F16/F32
429 * @param[in] eltwise_operand_stride_x Stride of the eltwise operand matrix in X dimension (in bytes)
430 * @param[in] eltwise_operand_step_x eltwise_operand_stride_x * number of elements along X processed per workitem(in bytes)
431 * @param[in] eltwise_operand_stride_y Stride of the eltwise operand matrix in Y dimension (in bytes)
432 * @param[in] eltwise_operand_step_y eltwise_operand_stride_y * number of elements along Y processed per workitem(in bytes)
433 * @param[in] eltwise_operand_stride_z Stride of the eltwise operand tensor in Z dimension (in bytes)
434 */
435__kernel void gemm_mm_reshaped_only_rhs_t_texture_post_act_eltwise_op_act(IMAGE_DECLARATION(lhs),
436 __read_only image2d_t rhs_img,
437#if defined(BETA)
438 IMAGE_DECLARATION(bias),
439#endif // defined(BETA)
440 IMAGE_DECLARATION(dst),
441 // Post-Op arguments
442 IMAGE_DECLARATION(eltwise_operand),
443 uint lhs_stride_z,
444 uint rhs_stride_z,
445#if defined(BETA)
446 uint bias_stride_z,
447#endif //defined(BETA)
448 uint dst_stride_z,
449 uint eltwise_operand_stride_z
450#if defined(REINTERPRET_INPUT_AS_3D)
451 ,
452 uint lhs_cross_plane_pad
453#endif // REINTERPRET_INPUT_AS_3D
454#if defined(REINTERPRET_OUTPUT_AS_3D)
455 ,
456 uint dst_cross_plane_pad
457#endif // REINTERPRET_OUTPUT_AS_3D
458 )
459{
460 // Pixel unit
461#define PIXEL_UNIT CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT(K0)
462
463#define LEFTOVER_K (K % K0)
464
465 // Block size
466#define RHS_BLOCK_SIZE (PIXEL_UNIT * (N0))
467
468 // RHS offset and step X
469#if defined(RHS_INTERLEAVE)
470#define RHS_OFFSET_X (PIXEL_UNIT)
471#define RHS_STEP_X (PIXEL_UNIT * (H0))
472#define RHS_STEP_LOOP (1)
473#else // defined(RHS_INTERLEAVE)
474#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
475#define RHS_STEP_X PIXEL_UNIT
476#define RHS_STEP_LOOP (H0)
477#endif // defined(RHS_INTERLEAVE)
478
479 uint x = get_global_id(0);
480 uint y = get_global_id(1);
481 uint z = get_global_id(2);
482
483#if defined(DUMMY_WORK_ITEMS)
484 if((x * N0 >= N) || (y * M0 >= M))
485 {
486 return;
487 }
488#endif // defined(DUMMY_WORK_ITEMS)
489
490 // Compute LHS matrix address
491 uint lhs_offset = lhs_offset_first_element_in_bytes + y * M0 * (uint)lhs_stride_y;
492
493#if defined(MATRIX_B_DEPTH)
494 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
495 const uint z_rhs = (get_global_id(2) % MATRIX_B_DEPTH);
496#else // defined(MATRIX_B_DEPTH)
497 const uint z_rhs = get_global_id(2);
498#endif // defined(MATRIX_B_DEPTH)
499
500 // Compute RHS matrix coordinates
501 uint x_rhs = (get_global_id(0) % H0) * (uint)RHS_OFFSET_X;
502 const uint y_rhs = (get_global_id(0) / (uint)H0) + z_rhs * RHS_HEIGHT;
503
504 REPEAT_VAR_INIT_TO_CONST(M0, uint, zlhs, 0);
505 REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0);
506
507#if defined(REINTERPRET_INPUT_AS_3D)
508 // The plane (zlhs) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
509 CALCULATE_Z_OFFSET(M0, uint, zlhs, y * M0, HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
510
511 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
512 // multiply lhs_stride_z by DEPTH_GEMM3D
513 lhs_offset += z * lhs_stride_z * DEPTH_GEMM3D;
514
515#else // defined(REINTERPRET_INPUT_AS_3D)
516
517 // Add offset for batched GEMM
518 lhs_offset += z * lhs_stride_z;
519
520#endif // defined(REINTERPRET_INPUT_AS_3D)
521
522 // Initialize the accumulators
523 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0);
524
525 int i = 0;
526 for(; i <= (K - K0); i += K0)
527 {
528 // Load values from LHS matrix
529 LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs);
530
531 // Load values from RHS matrix stored in a cl_image
532 REPEAT_VAR_INIT_TO_CONST(N0, VEC_DATA_TYPE(DATA_TYPE, K0), b, 0);
533 LOAD_TEXTURE2D(N0, PIXEL_UNIT, DATA_TYPE, b, rhs_img, x_rhs, y_rhs, RHS_STEP_X, 0);
534
535 // Accumulate
536 ARM_DOT_K0XN0(K0, a0, b, c0);
537#if M0 > 1
538 ARM_DOT_K0XN0(K0, a1, b, c1);
539#endif // M0 > 1
540#if M0 > 2
541 ARM_DOT_K0XN0(K0, a2, b, c2);
542#endif // M0 > 2
543#if M0 > 3
544 ARM_DOT_K0XN0(K0, a3, b, c3);
545#endif // M0 > 3
546#if M0 > 4
547 ARM_DOT_K0XN0(K0, a4, b, c4);
548#endif // M0 > 4
549#if M0 > 5
550 ARM_DOT_K0XN0(K0, a5, b, c5);
551#endif // M0 > 5
552#if M0 > 6
553 ARM_DOT_K0XN0(K0, a6, b, c6);
554#endif // M0 > 6
555#if M0 > 7
556 ARM_DOT_K0XN0(K0, a7, b, c7);
557#endif // M0 > 7
558
559 lhs_offset += K0 * sizeof(DATA_TYPE);
560 x_rhs += N0 * RHS_STEP_X * RHS_STEP_LOOP;
561 }
562
563#if LEFTOVER_K != 0
564 // Note: We cannot read out-of-bound elements from the RHS matrix because
565 // the RHS width is always multiple of K0. This is not be true for the LHS matrix
566
567 union UNION_VEC_TYPE
568 {
569 DATA_TYPE s[K0];
570 VEC_DATA_TYPE(DATA_TYPE, K0)
571 v;
572 };
573
574 union UNION_VEC_TYPE a0 = {.v = 0 };
575#if M0 > 1
576 union UNION_VEC_TYPE a1 = {.v = 0 };
577#endif // M0 > 1
578#if M0 > 2
579 union UNION_VEC_TYPE a2 = {.v = 0 };
580#endif // M0 > 2
581#if M0 > 3
582 union UNION_VEC_TYPE a3 = {.v = 0 };
583#endif // M0 > 3
584#if M0 > 4
585 union UNION_VEC_TYPE a4 = {.v = 0 };
586#endif // M0 > 4
587#if M0 > 5
588 union UNION_VEC_TYPE a5 = {.v = 0 };
589#endif // M0 > 5
590#if M0 > 6
591 union UNION_VEC_TYPE a6 = {.v = 0 };
592#endif // M0 > 6
593#if M0 > 7
594 union UNION_VEC_TYPE a7 = {.v = 0 };
595#endif // M0 > 7
596
597 REPEAT_VAR_INIT_TO_CONST(N0, VEC_DATA_TYPE(DATA_TYPE, K0), b, 0);
598
599 // Load from RHS matrix
600 LOAD_TEXTURE2D(N0, PIXEL_UNIT, DATA_TYPE, b, rhs_img, x_rhs, y_rhs, RHS_STEP_X, 0);
601
602 // Load from LHS matrix
603 for(int k = 0; k < LEFTOVER_K; ++k)
604 {
605 a0.s[k] = *(__global DATA_TYPE *)(lhs_ptr + lhs_offset + 0 * lhs_stride_y + zlhs0);
606#if M0 > 1
607 a1.s[k] = *(__global DATA_TYPE *)(lhs_ptr + lhs_offset + 1 * lhs_stride_y + zlhs1);
608#endif // M0 > 1
609#if M0 > 2
610 a2.s[k] = *(__global DATA_TYPE *)(lhs_ptr + lhs_offset + 2 * lhs_stride_y + zlhs2);
611#endif // M0 > 2
612#if M0 > 3
613 a3.s[k] = *(__global DATA_TYPE *)(lhs_ptr + lhs_offset + 3 * lhs_stride_y + zlhs3);
614#endif // M0 > 3
615#if M0 > 4
616 a4.s[k] = *(__global DATA_TYPE *)(lhs_ptr + lhs_offset + 4 * lhs_stride_y + zlhs4);
617#endif // M0 > 4
618#if M0 > 5
619 a5.s[k] = *(__global DATA_TYPE *)(lhs_ptr + lhs_offset + 5 * lhs_stride_y + zlhs5);
620#endif // M0 > 5
621#if M0 > 6
622 a6.s[k] = *(__global DATA_TYPE *)(lhs_ptr + lhs_offset + 6 * lhs_stride_y + zlhs6);
623#endif // M0 > 6
624#if M0 > 7
625 a7.s[k] = *(__global DATA_TYPE *)(lhs_ptr + lhs_offset + 7 * lhs_stride_y + zlhs7);
626#endif // M0 > 7
627
628 lhs_offset += sizeof(DATA_TYPE);
629 }
630
631 // Accumulate
632 ARM_DOT_K0XN0(K0, a0.v, b, c0);
633#if M0 > 1
634 ARM_DOT_K0XN0(K0, a1.v, b, c1);
635#endif // M0 > 1
636#if M0 > 2
637 ARM_DOT_K0XN0(K0, a2.v, b, c2);
638#endif // M0 > 2
639#if M0 > 3
640 ARM_DOT_K0XN0(K0, a3.v, b, c3);
641#endif // M0 > 3
642#if M0 > 4
643 ARM_DOT_K0XN0(K0, a4.v, b, c4);
644#endif // M0 > 4
645#if M0 > 5
646 ARM_DOT_K0XN0(K0, a5.v, b, c5);
647#endif // M0 > 5
648#if M0 > 6
649 ARM_DOT_K0XN0(K0, a6.v, b, c6);
650#endif // M0 > 6
651#if M0 > 7
652 ARM_DOT_K0XN0(K0, a7.v, b, c7);
653#endif // M0 > 7
654
655#endif // LEFTOVER_K != 0
656
657 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * M0 * dst_stride_y);
658
659 REPEAT_VAR_INIT_TO_CONST(M0, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
660
661 // Boundary conditions: detect if current block is at the "bottom" or "right" boundary
662 const bool cond_y = ((y + 1) * M0 >= M);
663 const bool cond_x = ((x + 1) * N0 >= N);
664
665#if defined(REINTERPRET_OUTPUT_AS_3D)
666
667 // The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
668 CALCULATE_Z_OFFSET(M0, uint, zout, y * M0, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
669
670 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
671 // multiply dst_stride_z by DEPTH_GEMM3D
672 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
673
674#else // defined(REINTERPRET_OUTPUT_AS_3D)
675
676 // Add offset for batched GEMM
677 dst_addr += z * dst_stride_z;
678
679#endif // defined(REINTERPRET_OUTPUT_AS_3D)
680
681 // Multiply by the weight of matrix-matrix product and store the result
682#if defined(ALPHA)
683 SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
684#endif // defined(ALPHA)
685
686 // Add beta*bias
687#if defined(BETA)
688#if defined(BROADCAST_BIAS)
689 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
690
691 LOAD_BLOCK_BOUNDARY_AWARE(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero, 1, PARTIAL_STORE_N0, false, cond_x);
692
693#ifndef UNIT_BETA
694 SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
695#endif // UNIT_BIAS
696
697 // c = c + bias[broadcasted]
698 ADD_BLOCK_BROADCAST(M0, c, bias0);
699
700#else // defined(BROADCAST_BIAS)
701 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * M0 * bias_stride_y) + z * bias_stride_z;
702
703 LOAD_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
704
705#ifndef UNIT_BETA
706 SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
707#endif // UNIT_BIAS
708
709 // c = c + bias
710 ADD_BLOCK(M0, c, bias);
711
712#endif // defined(BROADCAST_BIAS)
713#endif // defined(BETA)
714
715 // c = act(c)
716 POST_OP1_ACTIVATION_OPTIONAL(M0, DATA_TYPE, DATA_TYPE_ACCUMULATOR, N0, c);
717 // c = c + eltwise_operand (mix-precision, broadcast, boundary aware)
718 POST_OP2_ELTWISE_OP(P2_ELTWISE_OP, M0, N0, c, eltwise_operand, DATA_TYPE, DATA_TYPE_ACCUMULATOR, zero, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
719 // c = act(c)
720 POST_OP3_ACTIVATION_OPTIONAL(M0, DATA_TYPE, DATA_TYPE_ACCUMULATOR, N0, c);
721
722 // Store output block
723 STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
724
725#undef RHS_BLOCK_SIZE
726#undef RHS_OFFSET_X
727#undef RHS_STEP_X
728#undef LEFTOVER_K
729#undef PIXEL_UNIT
730}
731#endif // defined(OPENCL_IMAGE_SUPPORT)
732
733#define VFMA(a, b, c) \
734 ({ \
735 c = fma(a, b, c); \
736 })
737
738#if M0 == 1
739#define VFMA_M0xN0(i, a, b, c) \
740 ({ \
741 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
742 })
743#elif M0 == 2 // M0 == 2
744#define VFMA_M0xN0(i, a, b, c) \
745 ({ \
746 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
747 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
748 })
749#elif M0 == 3 // M0 == 3
750#define VFMA_M0xN0(i, a, b, c) \
751 ({ \
752 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
753 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
754 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
755 })
756#elif M0 == 4 // M0 == 4
757#define VFMA_M0xN0(i, a, b, c) \
758 ({ \
759 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
760 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
761 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
762 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
763 })
764#elif M0 == 5 // M0 == 5
765#define VFMA_M0xN0(i, a, b, c) \
766 ({ \
767 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
768 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
769 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
770 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
771 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
772 })
773#elif M0 == 6 // M0 == 6
774#define VFMA_M0xN0(i, a, b, c) \
775 ({ \
776 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
777 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
778 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
779 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
780 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
781 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
782 })
783#elif M0 == 7 // M0 == 7
784#define VFMA_M0xN0(i, a, b, c) \
785 ({ \
786 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
787 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
788 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
789 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
790 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
791 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
792 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##6).s##i), b, (c##6)); \
793 })
794#elif M0 == 8 // M0 == 8
795#define VFMA_M0xN0(i, a, b, c) \
796 ({ \
797 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
798 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
799 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
800 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
801 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
802 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
803 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##6).s##i), b, (c##6)); \
804 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##7).s##i), b, (c##7)); \
805 })
806#else // M0 not supported
807#error "M0 not supported"
808#endif // M0 not supported
809
810/** This OpenCL kernel computes the matrix multiplication between 2 matrices plus 3 post ops:
811 * Post op 1: activation (optional)
812 * Post op 2: elementwise op
813 * Post op 3: activation (optional)
814 *
815 * @note (Optional) -DP1_ACTIVATION_TYPE, -DP1_ACTIVATION_A_VAL, -DP1_ACTIVATION_B_VAL: The activation type, alpha and beta values of the activation post op at slot 3
816 * @note (Required) -DP2_ELTWISE_OP: The (binary) elementwise post op to perform
817 * @note (Required) -DP2_ELTWISE_ARG1_HEIGHT: The height (Y dimension) of the eltwise operand matrix of the eltwise post op at slot 2
818 * @note (Required) -DP2_ELTWISE_ARG1_WIDTH: The width (X dimension) of the eltwise operand matrix of the eltwise post op at slot 2
819 * @note (Optional) -DP3_ACTIVATION_TYPE, -DP3_ACTIVATION_A_VAL, -DP3_ACTIVATION_B_VAL: The activation type, alpha and beta values of the activation post op at slot 3
820 *
821 * All parameters are similarly defined in kernel gemm_mm_reshaped_only_rhs_nt, with these additions:
822 *
823 * @param[in] eltwise_operand_ptr Pointer to the eltwise operand matrix. Supported data type: F16/F32
824 * @param[in] eltwise_operand_stride_x Stride of the eltwise operand matrix in X dimension (in bytes)
825 * @param[in] eltwise_operand_step_x eltwise_operand_stride_x * number of elements along X processed per workitem(in bytes)
826 * @param[in] eltwise_operand_stride_y Stride of the eltwise operand matrix in Y dimension (in bytes)
827 * @param[in] eltwise_operand_step_y eltwise_operand_stride_y * number of elements along Y processed per workitem(in bytes)
828 * @param[in] eltwise_operand_stride_z Stride of the eltwise operand tensor in Z dimension (in bytes)
829 */
830__kernel void gemm_mm_reshaped_only_rhs_nt_post_act_eltwise_op_act(IMAGE_DECLARATION(lhs),
831 IMAGE_DECLARATION(rhs),
832#if defined(BETA)
833 IMAGE_DECLARATION(bias),
834#endif // defined(BETA)
835 IMAGE_DECLARATION(dst),
836 // Post-Op arguments
837 IMAGE_DECLARATION(eltwise_operand),
838 uint lhs_stride_z,
839 uint rhs_stride_z,
840#if defined(BETA)
841 uint bias_stride_z,
842#endif //defined(BETA)
843 uint dst_stride_z,
844 uint eltwise_operand_stride_z
845#if defined(REINTERPRET_INPUT_AS_3D)
846 ,
847 uint lhs_cross_plane_pad
848#endif // REINTERPRET_INPUT_AS_3D
849#if defined(REINTERPRET_OUTPUT_AS_3D)
850 ,
851 uint dst_cross_plane_pad
852#endif // REINTERPRET_OUTPUT_AS_3D
853 )
854{
855 // Block size
856#define RHS_BLOCK_SIZE ((K0) * (N0))
857
858 // RHS offset and step X
859#if defined(RHS_INTERLEAVE)
860#define RHS_OFFSET_X (N0)
861#define RHS_STEP_X ((N0) * (H0))
862#define RHS_STEP_LOOP (1)
863#else // defined(RHS_INTERLEAVE)
864#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
865#define RHS_STEP_X (N0)
866#define RHS_STEP_LOOP (H0)
867#endif // defined(RHS_INTERLEAVE)
868
869 uint x = get_global_id(0);
870 uint y = get_global_id(1);
871 uint z = get_global_id(2);
872
873#if defined(DUMMY_WORK_ITEMS)
874 if((x * N0 >= N) || (y * M0 >= M))
875 {
876 return;
877 }
878#endif // defined(DUMMY_WORK_ITEMS)
879
880 // Compute LHS matrix address
881 uint lhs_offset = lhs_offset_first_element_in_bytes + y * M0 * (uint)lhs_stride_y;
882
883 // Compute RHS reshaped matrix address
884 uint rhs_offset = rhs_offset_first_element_in_bytes + (x % H0) * (uint)RHS_OFFSET_X * sizeof(DATA_TYPE) + (x / (uint)H0) * rhs_stride_y;
885
886#if defined(MATRIX_B_DEPTH)
887 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
888 rhs_offset += (z % MATRIX_B_DEPTH) * rhs_stride_z;
889#else // defined(MATRIX_B_DEPTH)
890 rhs_offset += z * rhs_stride_z;
891#endif // defined(MATRIX_B_DEPTH)
892
893 REPEAT_VAR_INIT_TO_CONST(8, uint, zin, 0); //uint zin0=0,zin1=0,zin2=0,... zin7=0;
894 REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0); //uint zero0=0,zero1=0,zero2=0,... zero7=0;
895
896#if defined(REINTERPRET_INPUT_AS_3D)
897
898 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
899 CALCULATE_Z_OFFSET(M0, uint, zin, y * M0, HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
900
901 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
902 // multiply lhs_stride_z by DEPTH_GEMM3D
903 lhs_offset += z * lhs_stride_z * DEPTH_GEMM3D;
904
905#else // defined(REINTERPRET_INPUT_AS_3D)
906
907 // Add offset for batched GEMM
908 lhs_offset += z * lhs_stride_z;
909
910#endif // defined(REINTERPRET_INPUT_AS_3D)
911
912 // Initialize the accumulators
913 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(N0-1)=0;
914
915 int i = 0;
916 for(; i <= (K - K0); i += K0)
917 {
918 // Supported cases (M0, K0):
919 // 1,2 - 1,3 - 1,4 - 1,8 - 1,16
920 // 2,2 - 2,3 - 2,4 - 2,8 - 2,16
921 // 3,2 - 3,3 - 3,4 - 3,8 - 3,16
922 // 4,2 - 4,3 - 4,4 - 4,8 - 4,16
923 // 5,2 - 5,3 - 5,4 - 5,8 - 5,16
924 // 6,2 - 6,3 - 6,4 - 6,8 - 6,16
925 // 7,2 - 7,3 - 7,4 - 7,8 - 7,16
926 // 8,2 - 8,3 - 8,4 - 8,8 - 8,16
927 // Load values from LHS matrix
928 LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zin);
929
930 VEC_DATA_TYPE(DATA_TYPE, N0)
931 b0;
932
933 b0 = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0 * RHS_STEP_X * sizeof(DATA_TYPE)));
934 VFMA_M0xN0(0, a, b0, c);
935 b0 = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 1 * RHS_STEP_X * sizeof(DATA_TYPE)));
936 VFMA_M0xN0(1, a, b0, c);
937#if K0 > 2
938 b0 = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 2 * RHS_STEP_X * sizeof(DATA_TYPE)));
939 VFMA_M0xN0(2, a, b0, c);
940#endif // K0 > 2
941#if K0 > 3
942 b0 = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 3 * RHS_STEP_X * sizeof(DATA_TYPE)));
943 VFMA_M0xN0(3, a, b0, c);
944#endif // K0 > 3
945#if K0 > 4
946 b0 = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 4 * RHS_STEP_X * sizeof(DATA_TYPE)));
947 VFMA_M0xN0(4, a, b0, c);
948 b0 = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 5 * RHS_STEP_X * sizeof(DATA_TYPE)));
949 VFMA_M0xN0(5, a, b0, c);
950 b0 = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 6 * RHS_STEP_X * sizeof(DATA_TYPE)));
951 VFMA_M0xN0(6, a, b0, c);
952 b0 = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 7 * RHS_STEP_X * sizeof(DATA_TYPE)));
953 VFMA_M0xN0(7, a, b0, c);
954#endif // K0 > 4
955#if K0 > 8
956 b0 = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 8 * RHS_STEP_X * sizeof(DATA_TYPE)));
957 VFMA_M0xN0(8, a, b0, c);
958 b0 = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 9 * RHS_STEP_X * sizeof(DATA_TYPE)));
959 VFMA_M0xN0(9, a, b0, c);
960 b0 = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 10 * RHS_STEP_X * sizeof(DATA_TYPE)));
961 VFMA_M0xN0(A, a, b0, c);
962 b0 = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 11 * RHS_STEP_X * sizeof(DATA_TYPE)));
963 VFMA_M0xN0(B, a, b0, c);
964 b0 = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 12 * RHS_STEP_X * sizeof(DATA_TYPE)));
965 VFMA_M0xN0(C, a, b0, c);
966 b0 = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 13 * RHS_STEP_X * sizeof(DATA_TYPE)));
967 VFMA_M0xN0(D, a, b0, c);
968 b0 = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 14 * RHS_STEP_X * sizeof(DATA_TYPE)));
969 VFMA_M0xN0(E, a, b0, c);
970 b0 = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 15 * RHS_STEP_X * sizeof(DATA_TYPE)));
971 VFMA_M0xN0(F, a, b0, c);
972#endif // K0 > 8
973
974 lhs_offset += K0 * sizeof(DATA_TYPE);
975 rhs_offset += K0 * RHS_STEP_X * RHS_STEP_LOOP * sizeof(DATA_TYPE);
976 }
977
978 // Left-over accumulations
979 for(; i < K; ++i)
980 {
981 // Load values from LHS matrix
982 VEC_DATA_TYPE(DATA_TYPE, 2)
983 a0 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 0 * lhs_stride_y + zin0));
984#if M0 > 1
985 VEC_DATA_TYPE(DATA_TYPE, 2)
986 a1 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 1 * lhs_stride_y + zin1));
987#endif // M0 > 1
988#if M0 > 2
989 VEC_DATA_TYPE(DATA_TYPE, 2)
990 a2 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 2 * lhs_stride_y + zin2));
991#endif // M0 > 2
992#if M0 > 3
993 VEC_DATA_TYPE(DATA_TYPE, 2)
994 a3 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 3 * lhs_stride_y + zin3));
995#endif // M0 > 3
996#if M0 > 4
997 VEC_DATA_TYPE(DATA_TYPE, 2)
998 a4 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 4 * lhs_stride_y + zin4));
999#endif // M0 > 4
1000#if M0 > 5
1001 VEC_DATA_TYPE(DATA_TYPE, 2)
1002 a5 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 5 * lhs_stride_y + zin5));
1003#endif // M0 > 5
1004#if M0 > 6
1005 VEC_DATA_TYPE(DATA_TYPE, 2)
1006 a6 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 6 * lhs_stride_y + zin6));
1007#endif // M0 > 6
1008#if M0 > 7
1009 VEC_DATA_TYPE(DATA_TYPE, 2)
1010 a7 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 7 * lhs_stride_y + zin7));
1011#endif // M0 > 7
1012
1013 VEC_DATA_TYPE(DATA_TYPE, N0)
1014 b0;
1015
1016 b0 = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0 * RHS_STEP_X * sizeof(DATA_TYPE)));
1017 VFMA_M0xN0(0, a, b0, c);
1018
1019 lhs_offset += sizeof(DATA_TYPE);
1020 rhs_offset += RHS_STEP_X * sizeof(DATA_TYPE);
1021 }
1022
1023 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * M0 * dst_stride_y);
1024
1025 REPEAT_VAR_INIT_TO_CONST(8, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
1026
1027 // Boundary conditions: detect if current block is at the "bottom" or "right" boundary
1028 const bool cond_y = ((y + 1) * M0 >= M);
1029 const bool cond_x = ((x + 1) * N0 >= N);
1030
1031#if defined(REINTERPRET_OUTPUT_AS_3D)
1032 // The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
1033 CALCULATE_Z_OFFSET(M0, uint, zout, y * M0, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
1034
1035 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1036 // multiply dst_stride_z by DEPTH_GEMM3D
1037 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1038
1039#else // defined(REINTERPRET_OUTPUT_AS_3D)
1040
1041 // Add offset for batched GEMM
1042 dst_addr += z * dst_stride_z;
1043
1044#endif // defined(REINTERPRET_OUTPUT_AS_3D)
1045
1046 // Multiply by the weight of matrix-matrix product and store the result
1047#if defined(ALPHA)
1048 SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
1049#endif // defined(ALPHA)
1050
1051 // Add beta*bias
1052#if defined(BETA)
1053#if defined(BROADCAST_BIAS)
1054 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
1055
1056 LOAD_BLOCK_BOUNDARY_AWARE(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero, 1, PARTIAL_STORE_N0, false, cond_x);
1057
1058#ifndef UNIT_BETA
1059 SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
1060#endif // UNIT_BIAS
1061
1062 // c = c + bias[broadcasted]
1063 ADD_BLOCK_BROADCAST(M0, c, bias0);
1064
1065#else // defined(BROADCAST_BIAS)
1066 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * M0 * bias_stride_y) + z * bias_stride_z;
1067
1068 LOAD_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
1069
1070#ifndef UNIT_BETA
1071 SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
1072#endif // UNIT_BIAS
1073
1074 // c = c + bias
1075 ADD_BLOCK(M0, c, bias);
1076
1077#endif // defined(BROADCAST_BIAS)
1078#endif // defined(BETA)
1079
1080 // c = act(c)
1081 POST_OP1_ACTIVATION_OPTIONAL(M0, DATA_TYPE, DATA_TYPE_ACCUMULATOR, N0, c);
1082 // c = c + eltwise_operand (mix-precision, broadcast, boundary aware)
1083 POST_OP2_ELTWISE_OP(P2_ELTWISE_OP, M0, N0, c, eltwise_operand, DATA_TYPE, DATA_TYPE_ACCUMULATOR, zero, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
1084 // c = act(c)
1085 POST_OP3_ACTIVATION_OPTIONAL(M0, DATA_TYPE, DATA_TYPE_ACCUMULATOR, N0, c);
1086
1087 // Store output block
1088 STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
1089
1090#undef RHS_BLOCK_SIZE
1091#undef RHS_OFFSET_X
1092#undef RHS_STEP_X
1093}
1094
1095#if defined(OPENCL_IMAGE_SUPPORT)
1096/** This OpenCL kernel computes the matrix multiplication between 2 matrices plus 3 post ops. The RHS matrix is stored in OpenCL image object.
1097 * Post op 1: activation (optional)
1098 * Post op 2: elementwise op
1099 * Post op 3: activation (optional)
1100 *
1101 * @note (Optional) -DP1_ACTIVATION_TYPE, -DP1_ACTIVATION_A_VAL, -DP1_ACTIVATION_B_VAL: The activation type, alpha and beta values of the activation post op at slot 3
1102 * @note (Required) -DP2_ELTWISE_OP: The (binary) elementwise post op to perform
1103 * @note (Required) -DP2_ELTWISE_ARG1_HEIGHT: The height (Y dimension) of the eltwise operand matrix of the eltwise post op at slot 2
1104 * @note (Required) -DP2_ELTWISE_ARG1_WIDTH: The width (X dimension) of the eltwise operand matrix of the eltwise post op at slot 2
1105 * @note (Optional) -DP3_ACTIVATION_TYPE, -DP3_ACTIVATION_A_VAL, -DP3_ACTIVATION_B_VAL: The activation type, alpha and beta values of the activation post op at slot 3
1106 *
1107 * All parameters are similarly defined in kernel gemm_mm_reshaped_only_rhs_nt_texture, with these additions:
1108 *
1109 * @param[in] eltwise_operand_ptr Pointer to the eltwise operand matrix. Supported data type: F16/F32
1110 * @param[in] eltwise_operand_stride_x Stride of the eltwise operand matrix in X dimension (in bytes)
1111 * @param[in] eltwise_operand_step_x eltwise_operand_stride_x * number of elements along X processed per workitem(in bytes)
1112 * @param[in] eltwise_operand_stride_y Stride of the eltwise operand matrix in Y dimension (in bytes)
1113 * @param[in] eltwise_operand_step_y eltwise_operand_stride_y * number of elements along Y processed per workitem(in bytes)
1114 * @param[in] eltwise_operand_stride_z Stride of the eltwise operand tensor in Z dimension (in bytes)
1115 */
1116__kernel void gemm_mm_reshaped_only_rhs_nt_texture_post_act_eltwise_op_act(IMAGE_DECLARATION(lhs),
1117 __read_only image2d_t rhs_img,
1118#if defined(BETA)
1119 IMAGE_DECLARATION(bias),
1120#endif // defined(BETA)
1121 IMAGE_DECLARATION(dst),
1122 // Post-Op arguments
1123 IMAGE_DECLARATION(eltwise_operand),
1124 uint lhs_stride_z,
1125 uint rhs_stride_z,
1126#if defined(BETA)
1127 uint bias_stride_z,
1128#endif //defined(BETA)
1129 uint dst_stride_z,
1130 uint eltwise_operand_stride_z
1131#if defined(REINTERPRET_INPUT_AS_3D)
1132 ,
1133 uint lhs_cross_plane_pad
1134#endif // REINTERPRET_INPUT_AS_3D
1135#if defined(REINTERPRET_OUTPUT_AS_3D)
1136 ,
1137 uint dst_cross_plane_pad
1138#endif // REINTERPRET_OUTPUT_AS_3D
1139 )
1140{
1141 // Pixel unit
1142#define PIXEL_UNIT CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT(N0)
1143
1144 // Block size
1145#define RHS_BLOCK_SIZE ((K0) * (PIXEL_UNIT))
1146
1147 // RHS offset and step X
1148#if defined(RHS_INTERLEAVE)
1149#define RHS_OFFSET_X (PIXEL_UNIT)
1150#define RHS_STEP_X ((PIXEL_UNIT) * (H0))
1151#else // defined(RHS_INTERLEAVE)
1152#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
1153#define RHS_STEP_X (PIXEL_UNIT)
1154#endif // defined(RHS_INTERLEAVE)
1155
1156 uint x = get_global_id(0);
1157 uint y = get_global_id(1);
1158 uint z = get_global_id(2);
1159
1160#if defined(DUMMY_WORK_ITEMS)
1161 if((x * N0 >= N) || (y * M0 >= M))
1162 {
1163 return;
1164 }
1165#endif // defined(DUMMY_WORK_ITEMS)
1166
1167 // Compute LHS matrix address
1168 uint lhs_offset = lhs_offset_first_element_in_bytes + y * M0 * (uint)lhs_stride_y;
1169
1170#if defined(MATRIX_B_DEPTH)
1171 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1172 const uint z_rhs = (z % MATRIX_B_DEPTH);
1173#else // defined(MATRIX_B_DEPTH)
1174 const uint z_rhs = z;
1175#endif // defined(MATRIX_B_DEPTH)
1176
1177 // Compute RHS matrix coordinates
1178 uint x_rhs = (x % H0) * (uint)RHS_OFFSET_X;
1179 const uint y_rhs = (x / (uint)H0) + z_rhs * RHS_HEIGHT;
1180
1181 REPEAT_VAR_INIT_TO_CONST(8, uint, zin, 0);
1182 REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0);
1183
1184#if defined(REINTERPRET_INPUT_AS_3D)
1185
1186 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
1187 CALCULATE_Z_OFFSET(M0, uint, zin, y * M0, HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
1188
1189 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1190 // multiply lhs_stride_z by DEPTH_GEMM3D
1191 lhs_offset += z * lhs_stride_z * DEPTH_GEMM3D;
1192
1193#else // defined(REINTERPRET_INPUT_AS_3D)
1194
1195 // Add offset for batched GEMM
1196 lhs_offset += z * lhs_stride_z;
1197
1198#endif // defined(REINTERPRET_INPUT_AS_3D)
1199
1200 // Initialize the accumulators
1201 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0);
1202
1203 int i = 0;
1204 for(; i <= (K - K0); i += K0)
1205 {
1206 // Load values from LHS matrix
1207 LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zin);
1208
1209 VEC_DATA_TYPE(DATA_TYPE, N0)
1210 b0;
1211
1212 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 0 * RHS_STEP_X), (y_rhs));
1213 VFMA_M0xN0(0, a, b0, c);
1214 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 1 * RHS_STEP_X), (y_rhs));
1215 VFMA_M0xN0(1, a, b0, c);
1216#if K0 > 2
1217 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 2 * RHS_STEP_X), (y_rhs));
1218 VFMA_M0xN0(2, a, b0, c);
1219#endif // K0 > 2
1220#if K0 > 3
1221 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 3 * RHS_STEP_X), (y_rhs));
1222 VFMA_M0xN0(3, a, b0, c);
1223#endif // K0 > 3
1224#if K0 > 4
1225 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 4 * RHS_STEP_X), (y_rhs));
1226 VFMA_M0xN0(4, a, b0, c);
1227 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 5 * RHS_STEP_X), (y_rhs));
1228 VFMA_M0xN0(5, a, b0, c);
1229 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 6 * RHS_STEP_X), (y_rhs));
1230 VFMA_M0xN0(6, a, b0, c);
1231 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 7 * RHS_STEP_X), (y_rhs));
1232 VFMA_M0xN0(7, a, b0, c);
1233#endif // K0 > 4
1234#if K0 > 8
1235 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 8 * RHS_STEP_X), (y_rhs));
1236 VFMA_M0xN0(8, a, b0, c);
1237 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 9 * RHS_STEP_X), (y_rhs));
1238 VFMA_M0xN0(9, a, b0, c);
1239 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 10 * RHS_STEP_X), (y_rhs));
1240 VFMA_M0xN0(A, a, b0, c);
1241 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 11 * RHS_STEP_X), (y_rhs));
1242 VFMA_M0xN0(B, a, b0, c);
1243 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 12 * RHS_STEP_X), (y_rhs));
1244 VFMA_M0xN0(C, a, b0, c);
1245 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 13 * RHS_STEP_X), (y_rhs));
1246 VFMA_M0xN0(D, a, b0, c);
1247 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 14 * RHS_STEP_X), (y_rhs));
1248 VFMA_M0xN0(E, a, b0, c);
1249 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 15 * RHS_STEP_X), (y_rhs));
1250 VFMA_M0xN0(F, a, b0, c);
1251#endif // K0 > 8
1252
1253 lhs_offset += K0 * sizeof(DATA_TYPE);
1254 x_rhs += K0 * RHS_STEP_X * RHS_STEP_LOOP;
1255 }
1256
1257 // Left-over accumulations
1258 for(; i < K; ++i)
1259 {
1260 // Load values from LHS matrix
1261 VEC_DATA_TYPE(DATA_TYPE, 2)
1262 a0 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 0 * lhs_stride_y + zin0));
1263#if M0 > 1
1264 VEC_DATA_TYPE(DATA_TYPE, 2)
1265 a1 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 1 * lhs_stride_y + zin1));
1266#endif // M0 > 1
1267#if M0 > 2
1268 VEC_DATA_TYPE(DATA_TYPE, 2)
1269 a2 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 2 * lhs_stride_y + zin2));
1270#endif // M0 > 2
1271#if M0 > 3
1272 VEC_DATA_TYPE(DATA_TYPE, 2)
1273 a3 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 3 * lhs_stride_y + zin3));
1274#endif // M0 > 3
1275#if M0 > 4
1276 VEC_DATA_TYPE(DATA_TYPE, 2)
1277 a4 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 4 * lhs_stride_y + zin4));
1278#endif // M0 > 4
1279#if M0 > 5
1280 VEC_DATA_TYPE(DATA_TYPE, 2)
1281 a5 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 5 * lhs_stride_y + zin5));
1282#endif // M0 > 5
1283#if M0 > 6
1284 VEC_DATA_TYPE(DATA_TYPE, 2)
1285 a6 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 6 * lhs_stride_y + zin6));
1286#endif // M0 > 6
1287#if M0 > 7
1288 VEC_DATA_TYPE(DATA_TYPE, 2)
1289 a7 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 7 * lhs_stride_y + zin7));
1290#endif // M0 > 7
1291
1292 VEC_DATA_TYPE(DATA_TYPE, N0)
1293 b0;
1294 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 0 * RHS_STEP_X), (y_rhs));
1295
1296 VFMA_M0xN0(0, a, b0, c);
1297
1298 lhs_offset += sizeof(DATA_TYPE);
1299 x_rhs += RHS_STEP_X;
1300 }
1301
1302 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * M0 * dst_stride_y);
1303
1304 REPEAT_VAR_INIT_TO_CONST(8, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
1305
1306 // Boundary conditions: detect if current block is at the "bottom" or "right" boundary
1307 const bool cond_y = ((y + 1) * M0 >= M);
1308 const bool cond_x = ((x + 1) * N0 >= N);
1309
1310#if defined(REINTERPRET_OUTPUT_AS_3D)
1311 // The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
1312 CALCULATE_Z_OFFSET(M0, uint, zout, y * M0, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
1313
1314 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1315 // multiply dst_stride_z by DEPTH_GEMM3D
1316 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1317
1318#else // defined(REINTERPRET_OUTPUT_AS_3D)
1319
1320 // Add offset for batched GEMM
1321 dst_addr += z * dst_stride_z;
1322
1323#endif // defined(REINTERPRET_OUTPUT_AS_3D)
1324
1325 // Multiply by the weight of matrix-matrix product and store the result
1326#if defined(ALPHA)
1327 SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
1328#endif // defined(ALPHA)
1329
1330 // Add beta*bias
1331#if defined(BETA)
1332#if defined(BROADCAST_BIAS)
1333 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
1334
1335 LOAD_BLOCK_BOUNDARY_AWARE(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero, 1, PARTIAL_STORE_N0, false, cond_x);
1336
1337#ifndef UNIT_BETA
1338 SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
1339#endif // UNIT_BIAS
1340
1341 // c = c + bias[broadcasted]
1342 ADD_BLOCK_BROADCAST(M0, c, bias0);
1343
1344#else // defined(BROADCAST_BIAS)
1345 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * M0 * bias_stride_y) + z * bias_stride_z;
1346
1347 LOAD_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
1348
1349#ifndef UNIT_BETA
1350 SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
1351#endif // UNIT_BIAS
1352
1353 // c = c + bias
1354 ADD_BLOCK(M0, c, bias);
1355
1356#endif // defined(BROADCAST_BIAS)
1357#endif // defined(BETA)
1358
1359 // c = act(c)
1360 POST_OP1_ACTIVATION_OPTIONAL(M0, DATA_TYPE, DATA_TYPE_ACCUMULATOR, N0, c);
1361 // c = c + eltwise_operand (mix-precision, broadcast, boundary aware)
1362 POST_OP2_ELTWISE_OP(P2_ELTWISE_OP, M0, N0, c, eltwise_operand, DATA_TYPE, DATA_TYPE_ACCUMULATOR, zero, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
1363 // c = act(c)
1364 POST_OP3_ACTIVATION_OPTIONAL(M0, DATA_TYPE, DATA_TYPE_ACCUMULATOR, N0, c);
1365
1366 // Store output block
1367 STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
1368
1369#undef RHS_BLOCK_SIZE
1370#undef RHS_OFFSET_X
1371#undef RHS_STEP_X
1372}
1373#endif // defined(OPENCL_IMAGE_SUPPORT)
1374#endif // defined(P2_ELTWISE_OP) && defined(P2_ELTWISE_ARG1_HEIGHT) && defined(P2_ELTWISE_ARG1_WIDTH)
1375#endif // defined(M0) && defined(N0) && defined(K0) && defined(H0) && defined(DATA_TYPE) && defined(M) && defined(N) && defined(K)