blob: 8d638bc6bbc184d3f0eb3c381107f45e554466aa [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +00002 * Copyright (c) 2017-2019 ARM Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Usama Arif0681e3b2019-04-25 14:28:07 +010024#include "gemm_helpers.h"
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +000025#include "repeat.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010026
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +000027#if defined(M0) && defined(K0) && defined(V0) && defined(DATA_TYPE) && defined(SRC_WIDTH)
28#define INC2 (VEC_DATA_TYPE(uint, 2))(0, 1)
29#define INC3 (VEC_DATA_TYPE(uint, 3))(0, 1, 2)
30#define INC4 (VEC_DATA_TYPE(uint, 4))(0, 1, 2, 3)
31#define INC8 (VEC_DATA_TYPE(uint, 8))(0, 1, 2, 3, 4, 5, 6, 7)
32#define INC16 (VEC_DATA_TYPE(uint, 16))(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)
33#define CONCAT_INC(K0) INC##K0
34#define INC(K0) CONCAT_INC(K0)
35
36#if(SRC_WIDTH % K0)
37#define BOUNDARY_CONDITION_X(x, a) \
38 ({ \
39 a = select(0, a, CONVERT(((x * (VEC_DATA_TYPE(uint, K0))K0 + INC(K0)) < (VEC_DATA_TYPE(uint, K0))SRC_WIDTH), VEC_DATA_TYPE(DATA_TYPE, K0))); \
40 })
41#else // (SRC_WIDTH % K0)
42#define BOUNDARY_CONDITION_X(x, a) \
43 ({})
44#endif // (SRC_WIDTH % K0)
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000045
46/** This OpenCL kernel reshapes the lhs input matrix. The kernel splits the input matrix in blocks of size M0xK0 and stores each one (not transposed) in
47 * the output matrix unrolling the values.
48 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +010049 * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
50 * @note The width of the input tensor must be passed at compile time using -DSRC_WIDTH (e.g. -DSRC_WIDTH=16)
51 * @note The block's dimensions (M0 and K0) must be passed at compile time using -DM0 and -DK0 (e.g. -DM0=2, -DK0=2).
52 * @note The number of M0xK0 vertical blocks to store on the same output row must be passed at compile time using -DV0 (e.g. -DV0=2)
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000053 * @note Only the following values for M0, K0 and V0 are supported:
54 * M0: 2,3,4,5,6,7,8
Gian Marco Iodicebacfec52019-01-11 11:30:55 +000055 * K0: 2,3,4,8,16
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000056 * V0: greater than 0
Gian Marco Iodiced1f54762019-07-19 09:54:47 +010057 * @note In case the input has to be reinterpreted as a 3D tensor (e.g. input of convolution layer 1x1), the following information must be passed at compile time:
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000058 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
59 * -# HEIGHT_GEMM3D: The height of the input in case it has to be reinterpreted as a 3D tensor.
60 * -# DEPTH_GEMM3D: The depth of the input in case it has to be reinterpreted as a 3D tensor
61 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
62 * @note If the M0xK0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
63 *
64 * @param[in] src_ptr Pointer to the source LHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
65 * @param[in] src_stride_x Stride of the source LHS tensor in X dimension (in bytes)
66 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
67 * @param[in] src_stride_y Stride of the source LHS tensor in Y dimension (in bytes)
68 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
69 * @param[in] src_stride_z Stride of the source LHS tensor in Z dimension (in bytes)
70 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
71 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source LHS tensor
72 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
73 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
74 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
75 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
76 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
77 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
78 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
79 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
80 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
81 */
82__kernel void gemm_reshape_lhs_matrix_nt(TENSOR3D_DECLARATION(src),
83 TENSOR3D_DECLARATION(dst)
84#if defined(REINTERPRET_INPUT_AS_3D)
85 ,
86 uint cross_plane_pad
87#endif // REINTERPRET_INPUT_AS_3D
88 )
89{
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +000090 // Block size
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000091#define BLOCK_SIZE ((M0) * (K0))
92
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +000093 // Output offset X
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000094#if defined(INTERLEAVE)
95#define OUTPUT_OFFSET_X (K0)
96#else // defined(INTERLEAVE)
97#define OUTPUT_OFFSET_X (BLOCK_SIZE)
98#endif // defined(INTERLEAVE)
99
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000100 // Output step X
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000101#if defined(INTERLEAVE)
102#define OUTPUT_STEP_X (K0) * (V0)
103#else // Do not interleave
104#define OUTPUT_STEP_X (K0)
105#endif // defined(INTERLEAVE)
106
107 // Compute source and destination addresses
108 uint x = get_global_id(0);
109 uint y = get_global_id(1);
110 uint z = get_global_id(2);
111
112 // ------------------ Compute input/output addresses ---------------------------
113
114 // Compute the input address
115 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)K0 * sizeof(DATA_TYPE) + y * (uint)M0 * src_stride_y;
116
117 // Compute the output address
118 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)BLOCK_SIZE * (uint)V0 * sizeof(DATA_TYPE)) + ((y / (uint)V0) * (uint)dst_stride_y) + ((y % V0) *
119 (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE));
120
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000121 // Create variables: uint zin0=0, zin1=0, zin2=0...zin(M0-1)=0;
122 REPEAT_VAR_INIT_TO_CONST(M0, uint, zin, 0);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000123
124#if defined(REINTERPRET_INPUT_AS_3D)
125 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
126 // multiply src_stride_z by DEPTH_GEMM3D
127
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000128 input_ptr += z * (uint)src_stride_z * DEPTH_GEMM3D;
129
130 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
Usama Arif0681e3b2019-04-25 14:28:07 +0100131 CALCULATE_Z_OFFSET(M0, uint, zin, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, cross_plane_pad, src_stride_y);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000132
133#else // defined(REINTERPRET_INPUT_AS_3D)
134
135 input_ptr += z * (uint)src_stride_z;
136
137#endif // defined(REINTERPRET_INPUT_AS_3D)
138
139 // Add offset for batched GEMM
140 output_ptr += z * (uint)dst_stride_z;
141
142 // ---------------------------Load input values --------------------------------
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000143 // Load values from the LHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +0100144 LOAD_BLOCK(M0, K0, DATA_TYPE, a, input_ptr, 0, src_stride_y, zin);
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000145 BOUNDARY_CONDITION_X(x, a0);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000146#if M0 > 1
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000147 BOUNDARY_CONDITION_X(x, a1);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000148#endif // M0 > 1
149#if M0 > 2
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000150 BOUNDARY_CONDITION_X(x, a2);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000151#endif // M0 > 2
152#if M0 > 3
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000153 BOUNDARY_CONDITION_X(x, a3);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000154#endif // M0 > 3
155#if M0 > 4
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000156 BOUNDARY_CONDITION_X(x, a4);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000157#endif // M0 > 4
158#if M0 > 5
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000159 BOUNDARY_CONDITION_X(x, a5);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000160#endif // M0 > 5
161#if M0 > 6
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000162 BOUNDARY_CONDITION_X(x, a6);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000163#endif // M0 > 6
164#if M0 > 7
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000165 BOUNDARY_CONDITION_X(x, a7);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000166#endif // M0 > 7
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000167 // ---------------------------Store output values ------------------------------
Usama Arif0681e3b2019-04-25 14:28:07 +0100168 REPEAT_VAR_INIT_TO_CONST(16, uint, zout, 0);
169 STORE_BLOCK(M0, K0, DATA_TYPE, a, output_ptr, OUTPUT_STEP_X * sizeof(DATA_TYPE), zout);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000170
171#undef BLOCK_SIZE
172#undef OUTPUT_OFFSET_X
173#undef OUTPUT_STEP_X
174}
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000175
176#if M0 == 2
177#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
178 ({ \
179 VEC_DATA_TYPE(DATA_TYPE, M0) \
180 res = (VEC_DATA_TYPE(DATA_TYPE, M0))(a0.s##i, a1.s##i); \
181 VSTORE(M0) \
182 (res, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
183 })
184#elif M0 == 3 // M0 == 3
185#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
186 ({ \
187 VEC_DATA_TYPE(DATA_TYPE, M0) \
188 res = (VEC_DATA_TYPE(DATA_TYPE, M0))(a0.s##i, a1.s##i, a2.s##i); \
189 VSTORE(M0) \
190 (res, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
191 })
192#elif M0 == 4 // M0 == 4
193#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
194 ({ \
195 VEC_DATA_TYPE(DATA_TYPE, M0) \
196 res = (VEC_DATA_TYPE(DATA_TYPE, M0))(a0.s##i, a1.s##i, a2.s##i, a3.s##i); \
197 VSTORE(M0) \
198 (res, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
199 })
200#elif M0 == 5 // M0 == 5
201#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
202 ({ \
203 VEC_DATA_TYPE(DATA_TYPE, 4) \
204 res0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s##i, a1.s##i, a2.s##i, a3.s##i); \
205 DATA_TYPE res1 = a4.s##i; \
206 VSTORE(4) \
207 (res0, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
208 *((__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE)) + 4) = res1; \
209 })
210#elif M0 == 6 // M0 == 6
211#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
212 ({ \
213 VEC_DATA_TYPE(DATA_TYPE, 4) \
214 res0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s##i, a1.s##i, a2.s##i, a3.s##i); \
215 VEC_DATA_TYPE(DATA_TYPE, 2) \
216 res1 = (VEC_DATA_TYPE(DATA_TYPE, 2))(a4.s##i, a5.s##i); \
217 VSTORE(4) \
218 (res0, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
219 VSTORE(2) \
220 (res1, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE)) + 4); \
221 })
222#elif M0 == 7 // M0 == 7
223#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
224 ({ \
225 VEC_DATA_TYPE(DATA_TYPE, 4) \
226 res0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s##i, a1.s##i, a2.s##i, a3.s##i); \
227 VEC_DATA_TYPE(DATA_TYPE, 3) \
228 res1 = (VEC_DATA_TYPE(DATA_TYPE, 3))(a4.s##i, a5.s##i, a6.s##i); \
229 VSTORE(4) \
230 (res0, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
231 VSTORE(3) \
232 (res1, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE)) + 4); \
233 })
234#elif M0 == 8 // M0 == 8
235#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
236 ({ \
237 VEC_DATA_TYPE(DATA_TYPE, M0) \
238 res = (VEC_DATA_TYPE(DATA_TYPE, M0))(a0.s##i, a1.s##i, a2.s##i, a3.s##i, a4.s##i, a5.s##i, a6.s##i, a7.s##i); \
239 VSTORE(M0) \
240 (res, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
241 })
242#else // M0 not supported
243#error "M0 value not supported"
244#endif // N0 conditions
245
246/** This OpenCL kernel reshapes the lhs input matrix. The kernel splits the input matrix in blocks of size M0xK0 and stores each one (transposed) in
247 * the output matrix unrolling the values.
248 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +0100249 * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
250 * @note The width of the input tensor must be passed at compile time using -DSRC_WIDTH (e.g. -DSRC_WIDTH=16)
251 * @note The block's dimensions (M0 and K0) must be passed at compile time using -DM0 and -DK0 (e.g. -DM0=2, -DK0=2).
252 * @note The number of M0xK0 vertical blocks to store on the same output row must be passed at compile time using -DV0 (e.g. -DV0=2)
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000253 * @note Only the following values for M0, K0 and V0 are supported:
254 * M0: 2,3,4,5,6,7,8
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000255 * K0: 2,3,4,8,16
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000256 * V0: greater than 0
Gian Marco Iodiced1f54762019-07-19 09:54:47 +0100257 * @note In case the input has to be reinterpreted as a 3D tensor (e.g. input of convolution layer 1x1), the following information must be passed at compile time:
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000258 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
259 * -# HEIGHT_GEMM3D: The height of the input in case it has to be reinterpreted as a 3D tensor.
260 * -# DEPTH_GEMM3D: The depth of the input in case it has to be reinterpreted as a 3D tensor
261 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
262 * @note If the M0xK0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
263 *
264 * @param[in] src_ptr Pointer to the source LHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
265 * @param[in] src_stride_x Stride of the source LHS tensor in X dimension (in bytes)
266 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
267 * @param[in] src_stride_y Stride of the source LHS tensor in Y dimension (in bytes)
268 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
269 * @param[in] src_stride_z Stride of the source LHS tensor in Z dimension (in bytes)
270 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
271 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source LHS tensor
272 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
273 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
274 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
275 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
276 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
277 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
278 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
279 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
280 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
281 */
282__kernel void gemm_reshape_lhs_matrix_t(TENSOR3D_DECLARATION(src),
283 TENSOR3D_DECLARATION(dst)
284#if defined(REINTERPRET_INPUT_AS_3D)
285 ,
286 uint cross_plane_pad
287#endif // REINTERPRET_INPUT_AS_3D
288 )
289{
290 // Block size
291#define BLOCK_SIZE ((M0) * (K0))
292
293 // Output offset X
294#if defined(INTERLEAVE)
295#define OUTPUT_OFFSET_X (M0)
296#else // defined(INTERLEAVE)
297#define OUTPUT_OFFSET_X (BLOCK_SIZE)
298#endif // defined(INTERLEAVE)
299
300 // Output step X
301#if defined(INTERLEAVE)
302#define OUTPUT_STEP_X (M0) * (V0)
303#else // Do not interleave
304#define OUTPUT_STEP_X (M0)
305#endif // defined(INTERLEAVE)
306
307 // Compute source and destination addresses
308 uint x = get_global_id(0);
309 uint y = get_global_id(1);
310 uint z = get_global_id(2);
311
312 // ------------------ Compute input/output addresses ---------------------------
313
314 // Compute the input address
315 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)K0 * sizeof(DATA_TYPE) + y * (uint)M0 * src_stride_y;
316
317 // Compute the output address
318 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)BLOCK_SIZE * (uint)V0 * sizeof(DATA_TYPE)) + ((y / (uint)V0) * (uint)dst_stride_y) + ((y % V0) *
319 (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE));
320
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000321 // Create variables: uint zin0=0, zin1=0, zin2=0...zin(M0-1)=0;
322 REPEAT_VAR_INIT_TO_CONST(M0, uint, zin, 0);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000323
324#if defined(REINTERPRET_INPUT_AS_3D)
325 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
326 // multiply src_stride_z by DEPTH_GEMM3D
327
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000328 input_ptr += z * (uint)src_stride_z * DEPTH_GEMM3D;
329
330 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
Usama Arif0681e3b2019-04-25 14:28:07 +0100331 CALCULATE_Z_OFFSET(M0, uint, zin, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, cross_plane_pad, src_stride_y);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000332
333#else // defined(REINTERPRET_INPUT_AS_3D)
334
335 input_ptr += z * (uint)src_stride_z;
336
337#endif // defined(REINTERPRET_INPUT_AS_3D)
338
339 // Add offset for batched GEMM
340 output_ptr += z * (uint)dst_stride_z;
341
342 // ---------------------------Load input values --------------------------------
343
344 // Load values from the LHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +0100345 LOAD_BLOCK(M0, K0, DATA_TYPE, a, input_ptr, 0, src_stride_y, zin);
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000346 BOUNDARY_CONDITION_X(x, a0);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000347#if M0 > 1
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000348 BOUNDARY_CONDITION_X(x, a1);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000349#endif // M0 > 1
350#if M0 > 2
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000351 BOUNDARY_CONDITION_X(x, a2);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000352#endif // M0 > 2
353#if M0 > 3
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000354 BOUNDARY_CONDITION_X(x, a3);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000355#endif // M0 > 3
356#if M0 > 4
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000357 BOUNDARY_CONDITION_X(x, a4);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000358#endif // M0 > 4
359#if M0 > 5
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000360 BOUNDARY_CONDITION_X(x, a5);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000361#endif // M0 > 5
362#if M0 > 6
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000363 BOUNDARY_CONDITION_X(x, a6);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000364#endif // M0 > 6
365#if M0 > 7
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000366 BOUNDARY_CONDITION_X(x, a7);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000367#endif // M0 > 7
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000368 // ---------------------------Transpose and store block -----------------------
369
370 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 0);
371 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 1);
372#if K0 > 2
373 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 2);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000374#endif // K0 > 2
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000375#if K0 > 3
376 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 3);
377#endif // K0 > 3
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000378#if K0 > 4
379 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 4);
380 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 5);
381 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 6);
382 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 7);
383#endif // K0 > 4
384#if K0 > 8
385 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 8);
386 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 9);
387 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, A);
388 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, B);
389 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, C);
390 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, D);
391 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, E);
392 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, F);
393#endif // K0 > 8
394
395#undef BLOCK_SIZE
396#undef OUTPUT_OFFSET_X
397#undef OUTPUT_STEP_X
398}
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000399#endif // defined(M0) && defined(K0) && defined(V0) && defined(DATA_TYPE) && defined(SRC_WIDTH)
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000400
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000401#if defined(K0) && defined(N0) && defined(H0) && defined(DATA_TYPE) && defined(SRC_HEIGHT)
402/** This OpenCL kernel reshapes the rhs input matrix. The kernel splits the input matrix in blocks of size K0xN0 and stores each one (not transposed) in
403 * the output matrix unrolling the values.
404 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +0100405 * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
406 * @note The height of the input tensor must be passed at compile time using -DSRC_HEIGHT (e.g. -DSRC_HEIGHT=16)
407 * @note The block's dimensions (K0 and N0) must be passed at compile time using -DK0 and -DN0 (e.g. -DK0=2, -DN0=2).
408 * @note The number of K0xN0 vertical blocks to store on the same output row must be passed at compile time using -DH0 (e.g. -DH0=2)
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000409 * @note If the K0xN0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
410 * @note Only the following values for K0, N0 and H0 are supported:
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000411 * N0: 2,3,4,8,16
412 * K0: 1,2,3,4,8,16
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000413 * H0: greater than 0
414 *
415 * @param[in] src_ptr Pointer to the source RHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
416 * @param[in] src_stride_x Stride of the source RHS tensor in X dimension (in bytes)
417 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
418 * @param[in] src_stride_y Stride of the source RHS tensor in Y dimension (in bytes)
419 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
420 * @param[in] src_stride_z Stride of the source RHS tensor in Z dimension (in bytes)
421 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
422 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source RHS tensor
423 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
424 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
425 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
426 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
427 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
428 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
429 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
430 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
431 */
432__kernel void gemm_reshape_rhs_matrix_nt(TENSOR3D_DECLARATION(src),
433 TENSOR3D_DECLARATION(dst))
434{
435 // Block size
436#define BLOCK_SIZE ((K0) * (N0))
437
438 // Output offset X
439#if defined(INTERLEAVE)
440#define OUTPUT_OFFSET_X (N0)
441#else // defined(INTERLEAVE)
442#define OUTPUT_OFFSET_X (BLOCK_SIZE)
443#endif // defined(INTERLEAVE)
444
445 // Output step X
446#if defined(INTERLEAVE)
447#define OUTPUT_STEP_X (N0) * (H0)
448#else // Do not interleave
449#define OUTPUT_STEP_X (N0)
450#endif // defined(INTERLEAVE)
451
452 // Compute source and destination addresses
453 uint x = get_global_id(0);
454 uint y = get_global_id(1);
455 uint z = get_global_id(2);
456
457 // ------------------ Compute input/output addresses ---------------------------
458
459 // Compute the input address
460 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)N0 * sizeof(DATA_TYPE) + y * (uint)K0 * src_stride_y + z * (uint)src_stride_z;
461
462 // Compute the output address
463 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (y * (uint)BLOCK_SIZE * (uint)H0 * sizeof(DATA_TYPE)) + ((x % (uint)H0) * (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE)) + ((
464 x / (uint)H0)
465 * (uint)dst_stride_y)
466 + z * (uint)dst_stride_z;
467
468 // ---------------------------Load input values --------------------------------
469
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +0000470 REPEAT_VAR_INIT_TO_CONST(K0, VEC_DATA_TYPE(DATA_TYPE, N0), a, 0); ////uint a0=0, a1=0, a2=0...a(M0-1)=0;
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000471
472 // Load values from the RHS matrix
473 a0 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y));
474#if K0 > 1
475 if(y * (uint)K0 + 1 < SRC_HEIGHT)
476 {
477 a1 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y));
478 }
479#endif // K0 > 1
480#if K0 > 2
481 if(y * (uint)K0 + 2 < SRC_HEIGHT)
482 {
483 a2 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y));
484 }
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000485#endif // K0 > 2
486#if K0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000487 if(y * (uint)K0 + 3 < SRC_HEIGHT)
488 {
489 a3 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y));
490 }
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000491#endif // K0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000492#if K0 > 4
493 if(y * (uint)K0 + 4 < SRC_HEIGHT)
494 {
495 a4 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 4 * src_stride_y));
496 }
497 if(y * (uint)K0 + 5 < SRC_HEIGHT)
498 {
499 a5 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 5 * src_stride_y));
500 }
501 if(y * (uint)K0 + 6 < SRC_HEIGHT)
502 {
503 a6 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 6 * src_stride_y));
504 }
505 if(y * (uint)K0 + 7 < SRC_HEIGHT)
506 {
507 a7 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 7 * src_stride_y));
508 }
509#endif // K0 > 4
510#if K0 > 8
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000511 if(y * (uint)K0 + 8 < SRC_HEIGHT)
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000512 {
513 a8 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 8 * src_stride_y));
514 }
515 if(y * (uint)K0 + 9 < SRC_HEIGHT)
516 {
517 a9 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 9 * src_stride_y));
518 }
519 if(y * (uint)K0 + 10 < SRC_HEIGHT)
520 {
521 aA = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 10 * src_stride_y));
522 }
523 if(y * (uint)K0 + 11 < SRC_HEIGHT)
524 {
525 aB = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 11 * src_stride_y));
526 }
527 if(y * (uint)K0 + 12 < SRC_HEIGHT)
528 {
529 aC = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 12 * src_stride_y));
530 }
531 if(y * (uint)K0 + 13 < SRC_HEIGHT)
532 {
533 aD = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 13 * src_stride_y));
534 }
535 if(y * (uint)K0 + 14 < SRC_HEIGHT)
536 {
537 aE = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 14 * src_stride_y));
538 }
539 if(y * (uint)K0 + 15 < SRC_HEIGHT)
540 {
541 aF = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 15 * src_stride_y));
542 }
543#endif // K0 > 8
544
545 // ---------------------------Store output values ------------------------------
Usama Arif0681e3b2019-04-25 14:28:07 +0100546 REPEAT_VAR_INIT_TO_CONST(16, uint, zout, 0);
547 STORE_BLOCK(K0, N0, DATA_TYPE, a, output_ptr, OUTPUT_STEP_X * sizeof(DATA_TYPE), zout);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000548
549#undef BLOCK_SIZE
550#undef OUTPUT_OFFSET_X
551#undef OUTPUT_STEP_X
552}
553
554#if defined(TRANSPOSE)
555/** This OpenCL kernel reshapes the rhs input matrix. The kernel splits the input matrix in blocks of size K0xN0 and stores each one (transposed) in
556 * the output matrix unrolling the values.
557 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +0100558 * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
559 * @note The height of the input tensor must be passed at compile time using -DSRC_HEIGHT (e.g. -DSRC_HEIGHT=16)
560 * @note The block's dimensions (K0 and N0) must be passed at compile time using -DK0 and -DN0 (e.g. -DK0=2, -DN0=2).
561 * @note The number of K0xN0 vertical blocks to store on the same output row must be passed at compile time using -DH0 (e.g. -DH0=2)
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000562 * @note If the K0xN0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
563 * @note The option -DTRANSPOSE must passed at compile time.
564 * @note Only the following values for K0, N0 and H0 are supported:
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000565 * N0: 2,3,4,8,16
566 * K0: 2,3,4,8,16
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000567 * H0: greater than 0
568 *
569 * @param[in] src_ptr Pointer to the source RHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
570 * @param[in] src_stride_x Stride of the source RHS tensor in X dimension (in bytes)
571 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
572 * @param[in] src_stride_y Stride of the source RHS tensor in Y dimension (in bytes)
573 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
574 * @param[in] src_stride_z Stride of the source RHS tensor in Z dimension (in bytes)
575 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
576 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source RHS tensor
577 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
578 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
579 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
580 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
581 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
582 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
583 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
584 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
585 */
586__kernel void gemm_reshape_rhs_matrix_t(TENSOR3D_DECLARATION(src),
587 TENSOR3D_DECLARATION(dst))
588{
589 // Block size
590#define BLOCK_SIZE ((K0) * (N0))
591
592 // Output offset X
593#if defined(INTERLEAVE)
594#define OUTPUT_OFFSET_X (K0)
595#else // defined(INTERLEAVE)
596#define OUTPUT_OFFSET_X (BLOCK_SIZE)
597#endif // defined(INTERLEAVE)
598
599 // Output step X
600#if defined(INTERLEAVE)
601#define OUTPUT_STEP_X (K0) * (H0)
602#else // Do not interleave
603#define OUTPUT_STEP_X (K0)
604#endif // defined(INTERLEAVE)
605
606 // Compute source and destination addresses
607 uint x = get_global_id(0);
608 uint y = get_global_id(1);
609 uint z = get_global_id(2);
610
611 // ------------------ Compute input/output addresses ---------------------------
612
613 // Compute the input address
614 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)N0 * sizeof(DATA_TYPE) + y * (uint)K0 * src_stride_y + z * (uint)src_stride_z;
615
616 // Compute the output address
617 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (y * (uint)BLOCK_SIZE * (uint)H0 * sizeof(DATA_TYPE)) + ((x % H0) * (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE)) + ((x /
618 (uint)H0) * (uint)dst_stride_y) + z * (uint)dst_stride_z;
619
620 // ---------------------------Load input values --------------------------------
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +0000621 REPEAT_VAR_INIT_TO_CONST(K0, VEC_DATA_TYPE(DATA_TYPE, N0), a, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) a0=0, a1=0, ... a(K0-1)=0;
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000622
623 // Load values from the RHS matrix
624 a0 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y));
625 if(y * (uint)K0 + 1 < SRC_HEIGHT)
626 {
627 a1 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y));
628 }
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000629#if K0 > 2
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000630 if(y * (uint)K0 + 2 < SRC_HEIGHT)
631 {
632 a2 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y));
633 }
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000634#endif // K0 > 2
635#if K0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000636 if(y * (uint)K0 + 3 < SRC_HEIGHT)
637 {
638 a3 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y));
639 }
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000640#endif // K0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000641#if K0 > 4
642 if(y * (uint)K0 + 4 < SRC_HEIGHT)
643 {
644 a4 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 4 * src_stride_y));
645 }
646 if(y * (uint)K0 + 5 < SRC_HEIGHT)
647 {
648 a5 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 5 * src_stride_y));
649 }
650 if(y * (uint)K0 + 6 < SRC_HEIGHT)
651 {
652 a6 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 6 * src_stride_y));
653 }
654 if(y * (uint)K0 + 7 < SRC_HEIGHT)
655 {
656 a7 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 7 * src_stride_y));
657 }
658#endif // K0 > 4
659#if K0 > 8
Gian Marco Iodice89124342018-12-19 14:17:22 +0000660 if(y * (uint)K0 + 8 < SRC_HEIGHT)
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000661 {
662 a8 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 8 * src_stride_y));
663 }
664 if(y * (uint)K0 + 9 < SRC_HEIGHT)
665 {
666 a9 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 9 * src_stride_y));
667 }
668 if(y * (uint)K0 + 10 < SRC_HEIGHT)
669 {
670 aA = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 10 * src_stride_y));
671 }
672 if(y * (uint)K0 + 11 < SRC_HEIGHT)
673 {
674 aB = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 11 * src_stride_y));
675 }
676 if(y * (uint)K0 + 12 < SRC_HEIGHT)
677 {
678 aC = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 12 * src_stride_y));
679 }
680 if(y * (uint)K0 + 13 < SRC_HEIGHT)
681 {
682 aD = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 13 * src_stride_y));
683 }
684 if(y * (uint)K0 + 14 < SRC_HEIGHT)
685 {
686 aE = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 14 * src_stride_y));
687 }
688 if(y * (uint)K0 + 15 < SRC_HEIGHT)
689 {
690 aF = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 15 * src_stride_y));
691 }
692#endif // K0 > 8
693
694 // ---------------------------Transpose the block ------------------------------
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +0000695 REPEAT_VAR_INIT_TO_CONST(N0, VEC_DATA_TYPE(DATA_TYPE, K0), res, 0); //VEC_DATA_TYPE(DATA_TYPE, K0) res0=0, res1=0, res2=0,... res(N0-1)=0;
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000696
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000697#if K0 == 2
698 // This part computes the following transpositions:
699 // 2x2 -> 2x2
700 // 2x4 -> 4x2
701 // 2x8 -> 8x2
702 // 2x16 -> 16x2
703 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0);
704 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1);
705#if N0 > 2
706 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2);
707#endif // N0 > 2
708#if N0 > 3
709 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3);
710#endif // N0 > 3
711#if N0 > 4
712 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4);
713 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5);
714 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6);
715 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7);
716#endif // N0 > 4
717#if N0 > 8
718 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8);
719 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9);
720 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA);
721 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB);
722 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC);
723 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD);
724 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE);
725 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF);
726#endif // N0 > 8
727
728#elif K0 == 3 // K0 == 2
729 // This part computes the following transpositions:
730 // 3x2 -> 2x3
731 // 3x4 -> 4x3
732 // 3x8 -> 8x3
733 // 3x16 -> 16x3
Georgios Pinitasb0f342e2019-05-21 13:32:43 +0100734 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0);
735 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000736#if N0 > 2
Georgios Pinitasb0f342e2019-05-21 13:32:43 +0100737 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000738#endif // N0 > 2
739#if N0 > 3
Georgios Pinitasb0f342e2019-05-21 13:32:43 +0100740 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000741#endif // N0 > 3
742#if N0 > 4
Georgios Pinitasb0f342e2019-05-21 13:32:43 +0100743 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4);
744 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5);
745 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6);
746 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000747#endif // N0 > 4
748#if N0 > 8
Georgios Pinitasb0f342e2019-05-21 13:32:43 +0100749 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8);
750 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9);
751 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA);
752 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB);
753 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC);
754 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD);
755 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE);
756 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000757#endif // N0 > 8
758
759#elif K0 == 4 // K0 == 4
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000760 // This part computes the following transpositions:
761 // 4x2 -> 2x4
762 // 4x4 -> 4x4
763 // 4x8 -> 8x4
764 // 4x16 -> 16x4
765 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0, a3.s0);
766 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1, a3.s1);
767#if N0 > 2
768 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2, a3.s2);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000769#endif // N0 > 2
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000770#if N0 > 3
771 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3, a3.s3);
772#endif // N0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000773#if N0 > 4
774 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4, a3.s4);
775 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5, a3.s5);
776 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6, a3.s6);
777 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7, a3.s7);
778#endif // N0 > 4
779#if N0 > 8
780 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8, a3.s8);
781 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9, a3.s9);
782 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA, a3.sA);
783 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB, a3.sB);
784 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC, a3.sC);
785 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD, a3.sD);
786 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE, a3.sE);
787 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF, a3.sF);
788#endif // N0 > 8
789
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000790#elif K0 == 8 // K0 == 8
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000791 // This part computes the following transpositions:
792 // 8x2 -> 2x8
793 // 8x4 -> 4x8
794 // 8x8 -> 8x8
795 // 8x16 -> 16x8
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000796 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0, a3.s0, a4.s0, a5.s0, a6.s0, a7.s0);
797 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1, a3.s1, a4.s1, a5.s1, a6.s1, a7.s1);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000798#if N0 > 2
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000799 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2, a3.s2, a4.s2, a5.s2, a6.s2, a7.s2);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000800#endif // N0 > 2
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000801#if N0 > 3
802 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3, a3.s3, a4.s3, a5.s3, a6.s3, a7.s3);
803#endif // N0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000804#if N0 > 4
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000805 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4, a3.s4, a4.s4, a5.s4, a6.s4, a7.s4);
806 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5, a3.s5, a4.s5, a5.s5, a6.s5, a7.s5);
807 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6, a3.s6, a4.s6, a5.s6, a6.s6, a7.s6);
808 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7, a3.s7, a4.s7, a5.s7, a6.s7, a7.s7);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000809#endif // N0 > 4
810#if N0 > 8
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000811 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8, a3.s8, a4.s8, a5.s8, a6.s8, a7.s8);
812 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9, a3.s9, a4.s9, a5.s9, a6.s9, a7.s9);
813 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA, a3.sA, a4.sA, a5.sA, a6.sA, a7.sA);
814 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB, a3.sB, a4.sB, a5.sB, a6.sB, a7.sB);
815 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC, a3.sC, a4.sC, a5.sC, a6.sC, a7.sC);
816 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD, a3.sD, a4.sD, a5.sD, a6.sD, a7.sD);
817 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE, a3.sE, a4.sE, a5.sE, a6.sE, a7.sE);
818 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF, a3.sF, a4.sF, a5.sF, a6.sF, a7.sF);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000819#endif // N0 > 8
820
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000821#elif K0 == 16 // K0 == 16
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000822
823 // This part computes the following transpositions:
824 // 16x2 -> 2x16
825 // 16x4 -> 4x16
826 // 16x8 -> 8x16
827 // 16x16 -> 16x16
828 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0, a3.s0, a4.s0, a5.s0, a6.s0, a7.s0,
829 a8.s0, a9.s0, aA.s0, aB.s0, aC.s0, aD.s0, aE.s0, aF.s0);
830 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1, a3.s1, a4.s1, a5.s1, a6.s1, a7.s1,
831 a8.s1, a9.s1, aA.s1, aB.s1, aC.s1, aD.s1, aE.s1, aF.s1);
832#if N0 > 2
833 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2, a3.s2, a4.s2, a5.s2, a6.s2, a7.s2,
834 a8.s2, a9.s2, aA.s2, aB.s2, aC.s2, aD.s2, aE.s2, aF.s2);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000835#endif // N0 > 2
836#if N0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000837 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3, a3.s3, a4.s3, a5.s3, a6.s3, a7.s3,
838 a8.s3, a9.s3, aA.s3, aB.s3, aC.s3, aD.s3, aE.s3, aF.s3);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000839#endif // N0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000840#if N0 > 4
841 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4, a3.s4, a4.s4, a5.s4, a6.s4, a7.s4,
842 a8.s4, a9.s4, aA.s4, aB.s4, aC.s4, aD.s4, aE.s4, aF.s4);
843 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5, a3.s5, a4.s5, a5.s5, a6.s5, a7.s5,
844 a8.s5, a9.s5, aA.s5, aB.s5, aC.s5, aD.s5, aE.s5, aF.s5);
845 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6, a3.s6, a4.s6, a5.s6, a6.s6, a7.s6,
846 a8.s6, a9.s6, aA.s6, aB.s6, aC.s6, aD.s6, aE.s6, aF.s6);
847 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7, a3.s7, a4.s7, a5.s7, a6.s7, a7.s7,
848 a8.s7, a9.s7, aA.s7, aB.s7, aC.s7, aD.s7, aE.s7, aF.s7);
849#endif // N0 > 4
850#if N0 > 8
851 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8, a3.s8, a4.s8, a5.s8, a6.s8, a7.s8,
852 a8.s8, a9.s8, aA.s8, aB.s8, aC.s8, aD.s8, aE.s8, aF.s8);
853 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9, a3.s9, a4.s9, a5.s9, a6.s9, a7.s9,
854 a8.s9, a9.s9, aA.s9, aB.s9, aC.s9, aD.s9, aE.s9, aF.s9);
855 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA, a3.sA, a4.sA, a5.sA, a6.sA, a7.sA,
856 a8.sA, a9.sA, aA.sA, aB.sA, aC.sA, aD.sA, aE.sA, aF.sA);
857 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB, a3.sB, a4.sB, a5.sB, a6.sB, a7.sB,
858 a8.sB, a9.sB, aA.sB, aB.sB, aC.sB, aD.sB, aE.sB, aF.sB);
859 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC, a3.sC, a4.sC, a5.sC, a6.sC, a7.sC,
860 a8.sC, a9.sC, aA.sC, aB.sC, aC.sC, aD.sC, aE.sC, aF.sC);
861 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD, a3.sD, a4.sD, a5.sD, a6.sD, a7.sD,
862 a8.sD, a9.sD, aA.sD, aB.sD, aC.sD, aD.sD, aE.sD, aF.sD);
863 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE, a3.sE, a4.sE, a5.sE, a6.sE, a7.sE,
864 a8.sE, a9.sE, aA.sE, aB.sE, aC.sE, aD.sE, aE.sE, aF.sE);
865 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF, a3.sF, a4.sF, a5.sF, a6.sF, a7.sF,
866 a8.sF, a9.sF, aA.sF, aB.sF, aC.sF, aD.sF, aE.sF, aF.sF);
867#endif // N0 > 8
868
869#else // N0 == 16
870#error "Not supported N0 value"
871#endif // N0 > 2
872
873 // ---------------------------Store the output values ------------------------------
Usama Arif0681e3b2019-04-25 14:28:07 +0100874 REPEAT_VAR_INIT_TO_CONST(16, uint, zout, 0);
875 STORE_BLOCK(N0, K0, DATA_TYPE, res, output_ptr, OUTPUT_STEP_X * sizeof(DATA_TYPE), zout);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000876
877#undef BLOCK_SIZE
878#undef OUTPUT_OFFSET_X
879#undef OUTPUT_STEP_X
880}
881#endif // defined(TRANSPOSE)
882#endif // defined(K0) && defined(N0) && defined(H0) && defined(DATA_TYPE) && defined(SRC_HEIGHT)
883
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +0000884#if defined(M0) && defined(N0) && defined(K0) && defined(H0) && defined(DATA_TYPE) && defined(M) && defined(N) && defined(K)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +0000885
886#define CONCAT(a, b) a##b
887
888#define ARM_DOT1(a, b, c) \
889 ({ \
890 c = fma(a, b, c); \
891 })
892#define ARM_DOT2(a, b, c) \
893 ({ \
894 c = fma(a.s0, b.s0, c); \
895 c = fma(a.s1, b.s1, c); \
896 })
897#define ARM_DOT3(a, b, c) \
898 ({ \
899 ARM_DOT2(a, b, c); \
900 c = fma((a.s2), (b.s2), c); \
901 })
902#define ARM_DOT4(a, b, c) \
903 ({ \
904 ARM_DOT3(a, b, c); \
905 c = fma((a.s3), (b.s3), c); \
906 })
907#define ARM_DOT8(a, b, c) \
908 ({ \
909 ARM_DOT4((a.lo), (b.lo), c); \
910 ARM_DOT4((a.hi), (b.hi), c); \
911 })
912#define ARM_DOT16(a, b, c) \
913 ({ \
914 ARM_DOT8((a.lo), (b.lo), c); \
915 ARM_DOT8((a.hi), (b.hi), c); \
916 })
917
918#if N0 == 2
919#define ARM_DOT_K0XN0(k0, a, b, c) \
920 ({ \
921 CONCAT(ARM_DOT, k0) \
922 ((a), (b##0), (c.s0)); \
923 CONCAT(ARM_DOT, k0) \
924 ((a), (b##1), (c.s1)); \
925 })
926#elif N0 == 3 // N0 == 3
927#define ARM_DOT_K0XN0(k0, a, b, c) \
928 ({ \
929 CONCAT(ARM_DOT, k0) \
930 ((a), (b##0), (c.s0)); \
931 CONCAT(ARM_DOT, k0) \
932 ((a), (b##1), (c.s1)); \
933 CONCAT(ARM_DOT, k0) \
934 ((a), (b##2), (c.s2)); \
935 })
936#elif N0 == 4 // N0 == 4
937#define ARM_DOT_K0XN0(k0, a, b, c) \
938 ({ \
939 CONCAT(ARM_DOT, k0) \
940 ((a), (b##0), (c.s0)); \
941 CONCAT(ARM_DOT, k0) \
942 ((a), (b##1), (c.s1)); \
943 CONCAT(ARM_DOT, k0) \
944 ((a), (b##2), (c.s2)); \
945 CONCAT(ARM_DOT, k0) \
946 ((a), (b##3), (c.s3)); \
947 })
948#elif N0 == 8 // N0 == 8
949#define ARM_DOT_K0XN0(k0, a, b, c) \
950 ({ \
951 CONCAT(ARM_DOT, k0) \
952 ((a), (b##0), (c.s0)); \
953 CONCAT(ARM_DOT, k0) \
954 ((a), (b##1), (c.s1)); \
955 CONCAT(ARM_DOT, k0) \
956 ((a), (b##2), (c.s2)); \
957 CONCAT(ARM_DOT, k0) \
958 ((a), (b##3), (c.s3)); \
959 CONCAT(ARM_DOT, k0) \
960 ((a), (b##4), (c.s4)); \
961 CONCAT(ARM_DOT, k0) \
962 ((a), (b##5), (c.s5)); \
963 CONCAT(ARM_DOT, k0) \
964 ((a), (b##6), (c.s6)); \
965 CONCAT(ARM_DOT, k0) \
966 ((a), (b##7), (c.s7)); \
967 })
968#elif N0 == 16 // N0 == 16
969#define ARM_DOT_K0XN0(k0, a, b, c) \
970 ({ \
971 CONCAT(ARM_DOT, k0) \
972 ((a), (b##0), (c.s0)); \
973 CONCAT(ARM_DOT, k0) \
974 ((a), (b##1), (c.s1)); \
975 CONCAT(ARM_DOT, k0) \
976 ((a), (b##2), (c.s2)); \
977 CONCAT(ARM_DOT, k0) \
978 ((a), (b##3), (c.s3)); \
979 CONCAT(ARM_DOT, k0) \
980 ((a), (b##4), (c.s4)); \
981 CONCAT(ARM_DOT, k0) \
982 ((a), (b##5), (c.s5)); \
983 CONCAT(ARM_DOT, k0) \
984 ((a), (b##6), (c.s6)); \
985 CONCAT(ARM_DOT, k0) \
986 ((a), (b##7), (c.s7)); \
987 CONCAT(ARM_DOT, k0) \
988 ((a), (b##8), (c.s8)); \
989 CONCAT(ARM_DOT, k0) \
990 ((a), (b##9), (c.s9)); \
991 CONCAT(ARM_DOT, k0) \
992 ((a), (b##A), (c.sA)); \
993 CONCAT(ARM_DOT, k0) \
994 ((a), (b##B), (c.sB)); \
995 CONCAT(ARM_DOT, k0) \
996 ((a), (b##C), (c.sC)); \
997 CONCAT(ARM_DOT, k0) \
998 ((a), (b##D), (c.sD)); \
999 CONCAT(ARM_DOT, k0) \
1000 ((a), (b##E), (c.sE)); \
1001 CONCAT(ARM_DOT, k0) \
1002 ((a), (b##F), (c.sF)); \
1003 })
1004#else // N0 not supported
1005#error "N0 value not supported"
1006#endif // N0 conditions
1007
1008/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
1009 * The LHS matrix is NOT reshaped
1010 * The RHS is reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the block K0xN0 is transposed
1011 *
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001012 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01001013 * @note The GEMM's dimensions (M,N and K) must be passed at compile time using -DM, -DN and and -DK (e.g. -DM=52, -DN=30 and -DK=90)
1014 * @note The number of columns of LHS matrix must be passed at compile time using -DK (e.g. -DK=64)
1015 * @note The block's dimensions used for reshaping the RHS matrix (N0 and K0) must be passed at compile time using -DN0 and -DK0 (e.g. -DN0=8, -DK0=4).
1016 * @note The number of M0 rows to process must be passed at compile time using -DM0 (e.g. -DM0=2)
1017 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (e.g. -DH0=2)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001018 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
1019 * @note Only the following configurations of M0, N0 and K0 are currently supported:
1020 * - M0 = 1, 2, 3, 4, 5, 6, 7, 8
1021 * - N0 = 2, 3, 4, 8, 16
1022 * - K0 = 2, 3, 4, 8, 16
Gian Marco Iodice62251f72019-03-11 16:07:12 +00001023 * - H0 >= 1
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001024 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01001025 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +01001026 * The activation function is performed after the bias addition
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001027 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
1028 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
1029 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1030 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1031 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1032 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix
1033 *
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001034 * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: F16/F32
1035 * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes)
1036 * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1037 * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes)
1038 * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1039 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
1040 * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
1041 * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes)
1042 * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1043 * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
1044 * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1045 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001046 * @param[in] bias_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
1047 * @param[in] bias_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
1048 * @param[in] bias_step_x (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes)
1049 * @param[in] bias_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
1050 * @param[in] bias_step_y (Optional) bias_stride_y * number of elements along Y processed per workitem(in bytes)
1051 * @param[in] bias_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001052 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
1053 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1054 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1055 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1056 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1057 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1058 * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes)
1059 * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001060 * @param[in] bias_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001061 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1062 * @param[in] lhs_cross_plane_pad (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
1063 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001064 */
1065__kernel void gemm_mm_reshaped_only_rhs_t(IMAGE_DECLARATION(lhs),
1066 IMAGE_DECLARATION(rhs),
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001067#if defined(BETA)
1068 IMAGE_DECLARATION(bias),
1069#endif // defined(BETA)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001070 IMAGE_DECLARATION(dst),
1071 uint lhs_stride_z,
1072 uint rhs_stride_z,
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001073#if defined(BETA)
1074 uint bias_stride_z,
1075#endif //defined(BETA)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001076 uint dst_stride_z
1077#if defined(REINTERPRET_INPUT_AS_3D)
1078 ,
1079 uint lhs_cross_plane_pad
1080#endif // REINTERPRET_INPUT_AS_3D
1081#if defined(REINTERPRET_OUTPUT_AS_3D)
1082 ,
1083 uint dst_cross_plane_pad
1084#endif // REINTERPRET_OUTPUT_AS_3D
1085 )
1086{
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001087 // Block size
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001088#define RHS_BLOCK_SIZE ((K0) * (N0))
1089
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001090 // RHS offset and step X
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001091#if defined(RHS_INTERLEAVE)
1092#define RHS_OFFSET_X (K0)
1093#define RHS_STEP_X ((K0) * (H0))
1094#define RHS_STEP_LOOP (1)
1095#else // defined(RHS_INTERLEAVE)
1096#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
1097#define RHS_STEP_X (K0)
1098#define RHS_STEP_LOOP (H0)
1099#endif // defined(RHS_INTERLEAVE)
1100
1101 uint x = get_global_id(0);
1102 uint y = get_global_id(1);
1103 uint z = get_global_id(2);
1104
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001105#if defined(DUMMY_WORK_ITEMS)
1106 if((x * N0 >= N) || (y * M0 >= M))
1107 {
1108 return;
1109 }
1110#endif // defined(DUMMY_WORK_ITEMS)
1111
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001112 // Compute LHS matrix address
1113 uint lhs_offset = lhs_offset_first_element_in_bytes + y * M0 * (uint)lhs_stride_y;
1114
1115 // Compute RHS matrix address
1116 uint rhs_offset = rhs_offset_first_element_in_bytes + (x % H0) * (uint)RHS_OFFSET_X * sizeof(DATA_TYPE) + (x / (uint)H0) * rhs_stride_y;
1117
1118#if defined(MATRIX_B_DEPTH)
1119 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1120 rhs_offset += (z % MATRIX_B_DEPTH) * rhs_stride_z;
1121#else // defined(MATRIX_B_DEPTH)
1122 rhs_offset += z * rhs_stride_z;
1123#endif // defined(MATRIX_B_DEPTH)
1124
Usama Arif0681e3b2019-04-25 14:28:07 +01001125 REPEAT_VAR_INIT_TO_CONST(8, uint, zlhs, 0); //uint zlhs0=0,zlhs1=0,zlhs2=0,... zlhs7=0;
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001126 REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001127
1128#if defined(REINTERPRET_INPUT_AS_3D)
Usama Arif0681e3b2019-04-25 14:28:07 +01001129 // The plane (zlhs) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
1130 CALCULATE_Z_OFFSET(M0, uint, zlhs, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001131
1132 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1133 // multiply lhs_stride_z by DEPTH_GEMM3D
1134 lhs_offset += z * lhs_stride_z * DEPTH_GEMM3D;
1135
1136#else // defined(REINTERPRET_INPUT_AS_3D)
1137
1138 // Add offset for batched GEMM
1139 lhs_offset += z * lhs_stride_z;
1140
1141#endif // defined(REINTERPRET_INPUT_AS_3D)
1142
1143 // Initialize the accumulators
1144 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(M0-1)=0;
1145
1146 int i = 0;
1147 for(; i <= (K - K0); i += K0)
1148 {
1149 // Supported cases (M0, K0):
1150 // 1,2 - 1,3 - 1,4 - 1,8 - 1,16
1151 // 2,2 - 2,3 - 2,4 - 2,8 - 2,16
1152 // 3,2 - 3,3 - 3,4 - 3,8 - 3,16
1153 // 4,2 - 4,3 - 4,4 - 4,8 - 4,16
1154 // 5,2 - 5,3 - 5,4 - 5,8 - 5,16
1155 // 6,2 - 6,3 - 6,4 - 6,8 - 6,16
1156 // 7,2 - 7,3 - 7,4 - 7,8 - 7,16
1157 // 8,2 - 8,3 - 8,4 - 8,8 - 8,16
1158 // Load values from LHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +01001159 LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001160
1161 // Load values from RHS matrix
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001162 LOAD_BLOCK(N0, K0, DATA_TYPE, b, rhs_ptr, rhs_offset, RHS_STEP_X * sizeof(DATA_TYPE), zero);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001163
1164 // Accumulate
1165 ARM_DOT_K0XN0(K0, a0, b, c0);
1166#if M0 > 1
1167 ARM_DOT_K0XN0(K0, a1, b, c1);
1168#endif // M0 > 1
1169#if M0 > 2
1170 ARM_DOT_K0XN0(K0, a2, b, c2);
1171#endif // M0 > 2
1172#if M0 > 3
1173 ARM_DOT_K0XN0(K0, a3, b, c3);
1174#endif // M0 > 3
1175#if M0 > 4
1176 ARM_DOT_K0XN0(K0, a4, b, c4);
1177#endif // M0 > 4
1178#if M0 > 5
1179 ARM_DOT_K0XN0(K0, a5, b, c5);
1180#endif // M0 > 5
1181#if M0 > 6
1182 ARM_DOT_K0XN0(K0, a6, b, c6);
1183#endif // M0 > 6
1184#if M0 > 7
1185 ARM_DOT_K0XN0(K0, a7, b, c7);
1186#endif // M0 > 7
1187
1188 lhs_offset += K0 * sizeof(DATA_TYPE);
1189 rhs_offset += (N0 * RHS_STEP_X * RHS_STEP_LOOP) * sizeof(DATA_TYPE);
1190 }
1191
1192 // Left-over accumulations
1193 for(; i < K; ++i)
1194 {
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001195 // Load values from LHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +01001196 LOAD_BLOCK(M0, 1, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001197
1198 // Load values from RHS matrix
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001199 LOAD_BLOCK(N0, 1, DATA_TYPE, b, rhs_ptr, rhs_offset, RHS_STEP_X * sizeof(DATA_TYPE), zero);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001200
1201 // Accumulate
1202 ARM_DOT_K0XN0(1, a0, b, c0);
1203#if M0 > 1
1204 ARM_DOT_K0XN0(1, a1, b, c1);
1205#endif // M0 > 1
1206#if M0 > 2
1207 ARM_DOT_K0XN0(1, a2, b, c2);
1208#endif // M0 > 2
1209#if M0 > 3
1210 ARM_DOT_K0XN0(1, a3, b, c3);
1211#endif // M0 > 3
1212#if M0 > 4
1213 ARM_DOT_K0XN0(1, a4, b, c4);
1214#endif // M0 > 4
1215#if M0 > 5
1216 ARM_DOT_K0XN0(1, a5, b, c5);
1217#endif // M0 > 5
1218#if M0 > 6
1219 ARM_DOT_K0XN0(1, a6, b, c6);
1220#endif // M0 > 6
1221#if M0 > 7
1222 ARM_DOT_K0XN0(1, a7, b, c7);
1223#endif // M0 > 7
1224
1225 lhs_offset += sizeof(DATA_TYPE);
1226 rhs_offset += sizeof(DATA_TYPE);
1227 }
1228
1229 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * dst_stride_y);
1230
1231 REPEAT_VAR_INIT_TO_CONST(8, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
1232
1233#if defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001234
1235 // The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
Usama Arif0681e3b2019-04-25 14:28:07 +01001236 CALCULATE_Z_OFFSET(M0, uint, zout, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001237
1238 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1239 // multiply dst_stride_z by DEPTH_GEMM3D
1240 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1241
1242#else // defined(REINTERPRET_OUTPUT_AS_3D)
1243
1244 // Add offset for batched GEMM
1245 dst_addr += z * dst_stride_z;
1246
1247#endif // defined(REINTERPRET_OUTPUT_AS_3D)
1248
1249 // Multiply by the weight of matrix-matrix product and store the result
1250#if defined(ALPHA)
Usama Arif0681e3b2019-04-25 14:28:07 +01001251 SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001252#endif // defined(ALPHA)
1253
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001254 // Add beta*bias
1255#if defined(BETA)
1256#if defined(BROADCAST_BIAS)
1257 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
1258
1259 LOAD_BLOCK(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
1260
1261#ifndef UNIT_BETA
1262 SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
1263#endif // UNIT_BIAS
1264
1265 // c = c + bias[broadcasted]
1266 ADD_BLOCK_BROADCAST(M0, c, bias0);
1267
1268#else // defined(BROADCAST_BIAS)
1269 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id(
1270 2) * bias_stride_z;
1271
1272 LOAD_BLOCK(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
1273
1274#ifndef UNIT_BETA
1275 SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
1276#endif // UNIT_BIAS
1277
1278 // c = c + bias
1279 ADD_BLOCK(M0, c, bias);
1280
1281#endif // defined(BROADCAST_BIAS)
1282#endif // defined(BETA)
1283
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +01001284#if defined(ACTIVATION_TYPE)
1285 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE, c, A_VAL, B_VAL);
1286#endif // defined(ACTIVATION_TYPE)
1287
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001288 // Store output block
Usama Arif0681e3b2019-04-25 14:28:07 +01001289 STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001290
1291#undef RHS_BLOCK_SIZE
1292#undef RHS_OFFSET_X
1293#undef RHS_STEP_X
1294}
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001295
1296#define VFMA(a, b, c) \
1297 ({ \
1298 c = fma(a, b, c); \
1299 })
1300
1301#if M0 == 1
1302#define LD_RHS_VFMA_M0xN0(i, a, c) \
1303 ({ \
1304 VEC_DATA_TYPE(DATA_TYPE, N0) \
1305 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1306 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1307 })
1308#elif M0 == 2 // M0 == 2
1309#define LD_RHS_VFMA_M0xN0(i, a, c) \
1310 ({ \
1311 VEC_DATA_TYPE(DATA_TYPE, N0) \
1312 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1313 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1314 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1315 })
1316#elif M0 == 3 // M0 == 3
1317#define LD_RHS_VFMA_M0xN0(i, a, c) \
1318 ({ \
1319 VEC_DATA_TYPE(DATA_TYPE, N0) \
1320 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1321 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1322 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1323 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1324 })
1325#elif M0 == 4 // M0 == 4
1326#define LD_RHS_VFMA_M0xN0(i, a, c) \
1327 ({ \
1328 VEC_DATA_TYPE(DATA_TYPE, N0) \
1329 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1330 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1331 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1332 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1333 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1334 })
1335#elif M0 == 5 // M0 == 5
1336#define LD_RHS_VFMA_M0xN0(i, a, c) \
1337 ({ \
1338 VEC_DATA_TYPE(DATA_TYPE, N0) \
1339 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1340 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1341 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1342 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1343 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1344 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
1345 })
1346#elif M0 == 6 // M0 == 6
1347#define LD_RHS_VFMA_M0xN0(i, a, c) \
1348 ({ \
1349 VEC_DATA_TYPE(DATA_TYPE, N0) \
1350 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1351 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1352 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1353 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1354 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1355 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
1356 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
1357 })
1358#elif M0 == 7 // M0 == 7
1359#define LD_RHS_VFMA_M0xN0(i, a, c) \
1360 ({ \
1361 VEC_DATA_TYPE(DATA_TYPE, N0) \
1362 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1363 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1364 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1365 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1366 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1367 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
1368 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
1369 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##6).s##i), b, (c##6)); \
1370 })
1371#elif M0 == 8 // M0 == 8
1372#define LD_RHS_VFMA_M0xN0(i, a, c) \
1373 ({ \
1374 VEC_DATA_TYPE(DATA_TYPE, N0) \
1375 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1376 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1377 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1378 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1379 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1380 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
1381 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
1382 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##6).s##i), b, (c##6)); \
1383 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##7).s##i), b, (c##7)); \
1384 })
1385#else // M0 not supported
1386#error "M0 not supported"
1387#endif // M0 not supported
1388
1389/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
1390 * The LHS matrix is NOT reshaped
1391 * The RHS is reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the block K0xN0 is NOT transposed
1392 *
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001393 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01001394 * @note The GEMM's dimensions (M,N and K) must be passed at compile time using -DM, -DN and and -DK (e.g. -DM=52, -DN=30 and -DK=90).
1395 * @note The block's dimensions used for reshaping the RHS matrix (N0 and K0) must be passed at compile time using -DN0 and -DK0 (e.g. -DN0=8, -DK0=4).
1396 * @note The number of M0 rows to process must be passed at compile time using -DM0 (e.g. -DM0=2)
1397 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (e.g. -DH0=2)
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001398 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
1399 * @note Only the following configurations of M0, N0 and K0 are currently supported:
1400 * - M0 = 1, 2, 3, 4, 5, 6, 7, 8
1401 * - N0 = 2, 3, 4, 8, 16
1402 * - K0 = 2, 3, 4, 8, 16
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001403 * - H0 >= 1
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001404 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01001405 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +01001406 * The activation function is performed after the bias addition
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001407 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
1408 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
1409 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1410 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1411 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1412 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix
1413 *
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001414 * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: F16/F32
1415 * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes)
1416 * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1417 * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes)
1418 * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1419 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
1420 * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
1421 * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes)
1422 * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1423 * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
1424 * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1425 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001426 * @param[in] bias_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
1427 * @param[in] bias_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001428 * @param[in] bias_step_x (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes)
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001429 * @param[in] bias_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001430 * @param[in] bias_step_y (Optional) bias_stride_y * number of elements along Y processed per workitem(in bytes)
1431 * @param[in] bias_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
1432 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
1433 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1434 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1435 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1436 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1437 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1438 * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes)
1439 * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001440 * @param[in] bias_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001441 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1442 * @param[in] lhs_cross_plane_pad (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
1443 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001444 */
1445__kernel void gemm_mm_reshaped_only_rhs_nt(IMAGE_DECLARATION(lhs),
1446 IMAGE_DECLARATION(rhs),
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001447#if defined(BETA)
1448 IMAGE_DECLARATION(bias),
1449#endif // defined(BETA)
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001450 IMAGE_DECLARATION(dst),
1451 uint lhs_stride_z,
1452 uint rhs_stride_z,
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001453#if defined(BETA)
1454 uint bias_stride_z,
1455#endif //defined(BETA)
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001456 uint dst_stride_z
1457#if defined(REINTERPRET_INPUT_AS_3D)
1458 ,
1459 uint lhs_cross_plane_pad
1460#endif // REINTERPRET_INPUT_AS_3D
1461#if defined(REINTERPRET_OUTPUT_AS_3D)
1462 ,
1463 uint dst_cross_plane_pad
1464#endif // REINTERPRET_OUTPUT_AS_3D
1465 )
1466{
1467 // Block size
1468#define RHS_BLOCK_SIZE ((K0) * (N0))
1469
1470 // RHS offset and step X
1471#if defined(RHS_INTERLEAVE)
1472#define RHS_OFFSET_X (N0)
1473#define RHS_STEP_X ((N0) * (H0))
1474#define RHS_STEP_LOOP (1)
1475#else // defined(RHS_INTERLEAVE)
1476#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
1477#define RHS_STEP_X (N0)
1478#define RHS_STEP_LOOP (H0)
1479#endif // defined(RHS_INTERLEAVE)
1480
1481 uint x = get_global_id(0);
1482 uint y = get_global_id(1);
1483 uint z = get_global_id(2);
1484
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001485#if defined(DUMMY_WORK_ITEMS)
1486 if((x * N0 >= N) || (y * M0 >= M))
1487 {
1488 return;
1489 }
1490#endif // defined(DUMMY_WORK_ITEMS)
1491
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001492 // Compute LHS matrix address
1493 uint lhs_offset = lhs_offset_first_element_in_bytes + y * M0 * (uint)lhs_stride_y;
1494
1495 // Compute RHS matrix address
1496 uint rhs_offset = rhs_offset_first_element_in_bytes + (x % H0) * (uint)RHS_OFFSET_X * sizeof(DATA_TYPE) + (x / (uint)H0) * rhs_stride_y;
1497
1498#if defined(MATRIX_B_DEPTH)
1499 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1500 rhs_offset += (z % MATRIX_B_DEPTH) * rhs_stride_z;
1501#else // defined(MATRIX_B_DEPTH)
1502 rhs_offset += z * rhs_stride_z;
1503#endif // defined(MATRIX_B_DEPTH)
1504
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001505 REPEAT_VAR_INIT_TO_CONST(8, uint, zin, 0); //uint zin0=0,zin1=0,zin2=0,... zin7=0;
1506 REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0); //uint zero0=0,zero1=0,zero2=0,... zero7=0;
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001507
1508#if defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001509
1510 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
Usama Arif0681e3b2019-04-25 14:28:07 +01001511 CALCULATE_Z_OFFSET(M0, uint, zin, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001512
1513 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1514 // multiply lhs_stride_z by DEPTH_GEMM3D
1515 lhs_offset += z * lhs_stride_z * DEPTH_GEMM3D;
1516
1517#else // defined(REINTERPRET_INPUT_AS_3D)
1518
1519 // Add offset for batched GEMM
1520 lhs_offset += z * lhs_stride_z;
1521
1522#endif // defined(REINTERPRET_INPUT_AS_3D)
1523
1524 // Initialize the accumulators
1525 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(N0-1)=0;
1526
1527 int i = 0;
1528 for(; i <= (K - K0); i += K0)
1529 {
1530 // Supported cases (M0, K0):
1531 // 1,2 - 1,3 - 1,4 - 1,8 - 1,16
1532 // 2,2 - 2,3 - 2,4 - 2,8 - 2,16
1533 // 3,2 - 3,3 - 3,4 - 3,8 - 3,16
1534 // 4,2 - 4,3 - 4,4 - 4,8 - 4,16
1535 // 5,2 - 5,3 - 5,4 - 5,8 - 5,16
1536 // 6,2 - 6,3 - 6,4 - 6,8 - 6,16
1537 // 7,2 - 7,3 - 7,4 - 7,8 - 7,16
1538 // 8,2 - 8,3 - 8,4 - 8,8 - 8,16
1539 // Load values from LHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +01001540 LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zin);
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001541
1542 LD_RHS_VFMA_M0xN0(0, a, c);
1543 LD_RHS_VFMA_M0xN0(1, a, c);
1544#if K0 > 2
1545 LD_RHS_VFMA_M0xN0(2, a, c);
1546#endif // K0 > 2
1547#if K0 > 3
1548 LD_RHS_VFMA_M0xN0(3, a, c);
1549#endif // K0 > 3
1550#if K0 > 4
1551 LD_RHS_VFMA_M0xN0(4, a, c);
1552 LD_RHS_VFMA_M0xN0(5, a, c);
1553 LD_RHS_VFMA_M0xN0(6, a, c);
1554 LD_RHS_VFMA_M0xN0(7, a, c);
1555#endif // K0 > 4
1556#if K0 > 8
1557 LD_RHS_VFMA_M0xN0(8, a, c);
1558 LD_RHS_VFMA_M0xN0(9, a, c);
1559 LD_RHS_VFMA_M0xN0(A, a, c);
1560 LD_RHS_VFMA_M0xN0(B, a, c);
1561 LD_RHS_VFMA_M0xN0(C, a, c);
1562 LD_RHS_VFMA_M0xN0(D, a, c);
1563 LD_RHS_VFMA_M0xN0(E, a, c);
1564 LD_RHS_VFMA_M0xN0(F, a, c);
1565#endif // K0 > 8
1566
1567 lhs_offset += K0 * sizeof(DATA_TYPE);
1568 rhs_offset += K0 * RHS_STEP_X * RHS_STEP_LOOP * sizeof(DATA_TYPE);
1569 }
1570
1571 // Left-over accumulations
1572 for(; i < K; ++i)
1573 {
1574 // Load values from LHS matrix
1575 VEC_DATA_TYPE(DATA_TYPE, 2)
1576 a0 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 0 * lhs_stride_y + zin0));
1577#if M0 > 1
1578 VEC_DATA_TYPE(DATA_TYPE, 2)
1579 a1 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 1 * lhs_stride_y + zin1));
1580#endif // M0 > 1
1581#if M0 > 2
1582 VEC_DATA_TYPE(DATA_TYPE, 2)
1583 a2 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 2 * lhs_stride_y + zin2));
1584#endif // M0 > 2
1585#if M0 > 3
1586 VEC_DATA_TYPE(DATA_TYPE, 2)
1587 a3 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 3 * lhs_stride_y + zin3));
1588#endif // M0 > 3
1589#if M0 > 4
1590 VEC_DATA_TYPE(DATA_TYPE, 2)
1591 a4 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 4 * lhs_stride_y + zin4));
1592#endif // M0 > 4
1593#if M0 > 5
1594 VEC_DATA_TYPE(DATA_TYPE, 2)
1595 a5 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 5 * lhs_stride_y + zin5));
1596#endif // M0 > 5
1597#if M0 > 6
1598 VEC_DATA_TYPE(DATA_TYPE, 2)
1599 a6 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 6 * lhs_stride_y + zin6));
1600#endif // M0 > 6
1601#if M0 > 7
1602 VEC_DATA_TYPE(DATA_TYPE, 2)
giuros01b3204e72019-04-01 13:50:22 +01001603 a7 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 7 * lhs_stride_y + zin7));
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001604#endif // M0 > 7
1605
1606 LD_RHS_VFMA_M0xN0(0, a, c);
1607
1608 lhs_offset += sizeof(DATA_TYPE);
1609 rhs_offset += RHS_STEP_X * sizeof(DATA_TYPE);
1610 }
1611
1612 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * dst_stride_y);
1613
1614 REPEAT_VAR_INIT_TO_CONST(8, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
1615
1616#if defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001617 // The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
Usama Arif0681e3b2019-04-25 14:28:07 +01001618 CALCULATE_Z_OFFSET(M0, uint, zout, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001619
1620 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1621 // multiply dst_stride_z by DEPTH_GEMM3D
1622 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1623
1624#else // defined(REINTERPRET_OUTPUT_AS_3D)
1625
1626 // Add offset for batched GEMM
1627 dst_addr += z * dst_stride_z;
1628
1629#endif // defined(REINTERPRET_OUTPUT_AS_3D)
1630
1631 // Multiply by the weight of matrix-matrix product and store the result
1632#if defined(ALPHA)
Usama Arif0681e3b2019-04-25 14:28:07 +01001633 SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001634#endif // defined(ALPHA)
1635
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001636 // Add beta*bias
1637#if defined(BETA)
1638#if defined(BROADCAST_BIAS)
1639 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
1640
1641 LOAD_BLOCK(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
1642
1643#ifndef UNIT_BETA
1644 SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
1645#endif // UNIT_BIAS
1646
1647 // c = c + bias[broadcasted]
1648 ADD_BLOCK_BROADCAST(M0, c, bias0);
1649
1650#else // defined(BROADCAST_BIAS)
1651 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id(
1652 2) * bias_stride_z;
1653
1654 LOAD_BLOCK(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
1655
1656#ifndef UNIT_BETA
1657 SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
1658#endif // UNIT_BIAS
1659
1660 // c = c + bias
1661 ADD_BLOCK(M0, c, bias);
1662
1663#endif // defined(BROADCAST_BIAS)
1664#endif // defined(BETA)
1665
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +01001666#if defined(ACTIVATION_TYPE)
1667 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE, c, A_VAL, B_VAL);
1668#endif // defined(ACTIVATION_TYPE)
1669
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001670 // Store output block
Usama Arif0681e3b2019-04-25 14:28:07 +01001671 STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001672
1673#undef RHS_BLOCK_SIZE
1674#undef RHS_OFFSET_X
1675#undef RHS_STEP_X
1676}
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001677#endif // defined(M0) && defined(N0) && defined(K0) && defined(H0) && defined(DATA_TYPE) && defined(M) && defined(N) && defined(K)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001678
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001679#if defined(M0) && defined(N0) && defined(K0) && defined(V0) && defined(H0) && defined(DATA_TYPE) && defined(M) && defined(N)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001680
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001681#if K0 == 2
1682#define ARM_DOT_K0(a, b, c) \
1683 ({ \
1684 c = fma(a.s0, b.s0, c); \
1685 c = fma(a.s1, b.s1, c); \
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001686 })
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001687#elif K0 == 3 // K0 == 3
1688#define ARM_DOT_K0(a, b, c) \
1689 ({ \
1690 c = fma(a.s0, b.s0, c); \
1691 c = fma(a.s1, b.s1, c); \
1692 c = fma(a.s2, b.s2, c); \
1693 })
1694#elif K0 == 4 // K0 == 4
1695#define ARM_DOT_K0(a, b, c) \
1696 ({ \
1697 c = fma(a.s0, b.s0, c); \
1698 c = fma(a.s1, b.s1, c); \
1699 c = fma(a.s2, b.s2, c); \
1700 c = fma(a.s3, b.s3, c); \
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001701 })
1702#elif K0 == 8 // K0 == 8
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001703#define ARM_DOT_K0(a, b, c) \
1704 ({ \
1705 c = fma(a.s0, b.s0, c); \
1706 c = fma(a.s1, b.s1, c); \
1707 c = fma(a.s2, b.s2, c); \
1708 c = fma(a.s3, b.s3, c); \
1709 c = fma(a.s4, b.s4, c); \
1710 c = fma(a.s5, b.s5, c); \
1711 c = fma(a.s6, b.s6, c); \
1712 c = fma(a.s7, b.s7, c); \
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001713 })
1714#elif K0 == 16 // K0 == 16
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001715#define ARM_DOT_K0(a, b, c) \
1716 ({ \
1717 c = fma(a.s0, b.s0, c); \
1718 c = fma(a.s1, b.s1, c); \
1719 c = fma(a.s2, b.s2, c); \
1720 c = fma(a.s3, b.s3, c); \
1721 c = fma(a.s4, b.s4, c); \
1722 c = fma(a.s5, b.s5, c); \
1723 c = fma(a.s6, b.s6, c); \
1724 c = fma(a.s7, b.s7, c); \
1725 c = fma(a.s8, b.s8, c); \
1726 c = fma(a.s9, b.s9, c); \
1727 c = fma(a.sA, b.sA, c); \
1728 c = fma(a.sB, b.sB, c); \
1729 c = fma(a.sC, b.sC, c); \
1730 c = fma(a.sD, b.sD, c); \
1731 c = fma(a.sE, b.sE, c); \
1732 c = fma(a.sF, b.sF, c); \
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001733 })
1734#else // K0 not supported
1735#error "K0 value not supported"
1736#endif // K0 conditions
1737
1738#if N0 == 2
1739#define ARM_DOT_K0XN0(a, b, c) \
1740 ({ \
1741 ARM_DOT_K0((a), (b##0), (c.s0)); \
1742 ARM_DOT_K0((a), (b##1), (c.s1)); \
1743 })
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001744#elif N0 == 3 // N0 == 3
1745#define ARM_DOT_K0XN0(a, b, c) \
1746 ({ \
1747 ARM_DOT_K0((a), (b##0), (c.s0)); \
1748 ARM_DOT_K0((a), (b##1), (c.s1)); \
1749 ARM_DOT_K0((a), (b##2), (c.s2)); \
1750 })
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001751#elif N0 == 4 // N0 == 4
1752#define ARM_DOT_K0XN0(a, b, c) \
1753 ({ \
1754 ARM_DOT_K0((a), (b##0), (c.s0)); \
1755 ARM_DOT_K0((a), (b##1), (c.s1)); \
1756 ARM_DOT_K0((a), (b##2), (c.s2)); \
1757 ARM_DOT_K0((a), (b##3), (c.s3)); \
1758 })
1759#elif N0 == 8 // N0 == 8
1760#define ARM_DOT_K0XN0(a, b, c) \
1761 ({ \
1762 ARM_DOT_K0((a), (b##0), (c.s0)); \
1763 ARM_DOT_K0((a), (b##1), (c.s1)); \
1764 ARM_DOT_K0((a), (b##2), (c.s2)); \
1765 ARM_DOT_K0((a), (b##3), (c.s3)); \
1766 ARM_DOT_K0((a), (b##4), (c.s4)); \
1767 ARM_DOT_K0((a), (b##5), (c.s5)); \
1768 ARM_DOT_K0((a), (b##6), (c.s6)); \
1769 ARM_DOT_K0((a), (b##7), (c.s7)); \
1770 })
1771#elif N0 == 16 // N0 == 16
1772#define ARM_DOT_K0XN0(a, b, c) \
1773 ({ \
1774 ARM_DOT_K0((a), (b##0), (c.s0)); \
1775 ARM_DOT_K0((a), (b##1), (c.s1)); \
1776 ARM_DOT_K0((a), (b##2), (c.s2)); \
1777 ARM_DOT_K0((a), (b##3), (c.s3)); \
1778 ARM_DOT_K0((a), (b##4), (c.s4)); \
1779 ARM_DOT_K0((a), (b##5), (c.s5)); \
1780 ARM_DOT_K0((a), (b##6), (c.s6)); \
1781 ARM_DOT_K0((a), (b##7), (c.s7)); \
1782 ARM_DOT_K0((a), (b##8), (c.s8)); \
1783 ARM_DOT_K0((a), (b##9), (c.s9)); \
1784 ARM_DOT_K0((a), (b##A), (c.sA)); \
1785 ARM_DOT_K0((a), (b##B), (c.sB)); \
1786 ARM_DOT_K0((a), (b##C), (c.sC)); \
1787 ARM_DOT_K0((a), (b##D), (c.sD)); \
1788 ARM_DOT_K0((a), (b##E), (c.sE)); \
1789 ARM_DOT_K0((a), (b##F), (c.sF)); \
1790 })
1791#else // N0 not supported
1792#error "N0 value not supported"
1793#endif // N0 conditions
1794
1795/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
1796 * The LHS matrix must be reshaped with @ref CLGEMMReshapeLHSMatrixKernel and the M0xK0 must be NOT transposed
1797 * The RHS matrix must be reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the K0xN0 must be transposed
1798 *
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001799 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01001800 * @note The GEMM's dimensions M and N must be passed at compile time using -DM and -DN (e.g. -DM=52 and -DN=90).
1801 * @note The block's dimensions used for reshaping the LHS matrix and the RHS matrix (M0, N0 and K0) must be passed at compile time using -DM0, -DN0 and -DK0 (e.g. -DM0=4, -DN0=8, -DK0=4).
1802 * @note The number of M0xK0 vertical blocks stored on the same output row of the reshaped LHS matrix must be passed at compile time using -DV0 (e.g. -DV0=2)
1803 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (e.g. -DH0=2)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001804 * @note If the M0xK0 blocks in the reshaped LHS matrix have been interleaved, the option -DLHS_INTERLEAVE must passed at compile time.
1805 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
1806 * @note Only the following configurations of M0, N0 and K0 are currently supported:
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001807 * - M0 = 1, 2, 3, 4, 5, 6, 7, 8
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001808 * - N0 = 2, 3, 4, 8, 16
1809 * - K0 = 2, 3, 4, 8, 16
Gian Marco Iodice62251f72019-03-11 16:07:12 +00001810 * - V0 >= 1
1811 * - H0 >= 1
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001812 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01001813 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +01001814 * The activation function is performed after the bias addition
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01001815 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001816 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1817 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1818 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1819 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix NOT reshaped
1820 *
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001821 * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: F16/F32
1822 * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes)
1823 * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1824 * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes)
1825 * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1826 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
1827 * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
1828 * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes)
1829 * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1830 * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
1831 * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1832 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
1833 * @param[in] bias_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
1834 * @param[in] bias_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
1835 * @param[in] bias_step_x (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes)
1836 * @param[in] bias_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
1837 * @param[in] bias_step_y (Optional) bias_stride_y * number of elements along Y processed per workitem(in bytes)
1838 * @param[in] bias_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
1839 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
1840 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1841 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1842 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1843 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1844 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1845 * @param[in] k Number of columns in LHS matrix and rows in RHS matrix not reshaped.
1846 * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes)
1847 * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
1848 * @param[in] bias_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
1849 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1850 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001851 */
1852__kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs),
1853 IMAGE_DECLARATION(rhs),
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001854#if defined(BETA)
1855 IMAGE_DECLARATION(bias),
1856#endif // defined(BETA)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001857 IMAGE_DECLARATION(dst),
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001858 uint k,
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001859 uint lhs_stride_z,
1860 uint rhs_stride_z,
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001861#if defined(BETA)
1862 uint bias_stride_z,
1863#endif //defined(BETA)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001864 uint dst_stride_z
1865#if defined(REINTERPRET_OUTPUT_AS_3D)
1866 ,
1867 uint dst_cross_plane_pad
1868#endif // REINTERPRET_OUTPUT_AS_3D
1869 )
1870{
1871 // Block size
1872#define LHS_BLOCK_SIZE ((K0) * (M0))
1873
1874#if defined(LHS_INTERLEAVE)
1875#define LHS_OFFSET_X (K0)
1876#define LHS_STEP_X ((K0) * (V0))
1877#define LHS_STEP_LOOP (1)
1878#else // defined(INTERLEAVE)
1879#define LHS_OFFSET_X (LHS_BLOCK_SIZE)
1880#define LHS_STEP_X (K0)
1881#define LHS_STEP_LOOP (V0)
1882#endif // defined(INTERLEAVE)
1883
1884 // Block size
1885#define RHS_BLOCK_SIZE ((K0) * (N0))
1886
1887 // RHS offset and step X
1888#if defined(RHS_INTERLEAVE)
1889#define RHS_OFFSET_X (K0)
1890#define RHS_STEP_X ((K0) * (H0))
1891#define RHS_STEP_LOOP (1)
1892#else // defined(RHS_INTERLEAVE)
1893#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
1894#define RHS_STEP_X (K0)
1895#define RHS_STEP_LOOP (H0)
1896#endif // defined(RHS_INTERLEAVE)
1897
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001898#if defined(DUMMY_WORK_ITEMS)
1899 if((get_global_id(0) * N0 >= N) || (get_global_id(1) * M0 >= M))
1900 {
1901 return;
1902 }
1903#endif // defined(DUMMY_WORK_ITEMS)
1904
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001905 // Compute LHS matrix address
1906 __global uchar *lhs_addr = lhs_ptr + lhs_offset_first_element_in_bytes + (get_global_id(1) % V0) * (uint)LHS_OFFSET_X * sizeof(DATA_TYPE) + (get_global_id(1) / V0) * (uint)lhs_stride_y +
1907 (get_global_id(2) * lhs_stride_z);
1908
1909 // Compute RHS matrix address
1910 __global uchar *rhs_addr = rhs_ptr + rhs_offset_first_element_in_bytes + (get_global_id(0) % H0) * (uint)RHS_OFFSET_X * sizeof(DATA_TYPE) + (get_global_id(0) / (uint)H0) * rhs_stride_y;
1911
1912#if defined(MATRIX_B_DEPTH)
1913 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1914 rhs_addr += (get_global_id(2) % MATRIX_B_DEPTH) * rhs_stride_z;
1915#else // defined(MATRIX_B_DEPTH)
1916 rhs_addr += get_global_id(2) * rhs_stride_z;
1917#endif // defined(MATRIX_B_DEPTH)
1918
1919 // Initialize the accumulators
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +00001920 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(M0-1)=0;
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001921
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001922 REPEAT_VAR_INIT_TO_CONST(M0, uint, zlhs, 0); //uint zlhs0=0,zlhs1=0,zlhs2=0,... zlhs7=0;
1923 REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0);
Usama Arif0681e3b2019-04-25 14:28:07 +01001924
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001925 for(int i = 0; i < k; i += K0)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001926 {
1927 // Supported cases (M0, K0):
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001928 // 1,2 - 1,3 - 1,4 - 1,8 - 1,16
1929 // 2,2 - 2,3 - 2,4 - 2,8 - 2,16
1930 // 3,2 - 3,3 - 3,4 - 3,8 - 3,16
1931 // 4,2 - 4,3 - 4,4 - 4,8 - 4,16
1932 // 5,2 - 5,3 - 5,4 - 5,8 - 5,16
1933 // 6,2 - 6,3 - 6,4 - 6,8 - 6,16
1934 // 7,2 - 7,3 - 7,4 - 7,8 - 7,16
1935 // 8,2 - 8,3 - 8,4 - 8,8 - 8,16
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001936 // Load values from LHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +01001937 LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_addr, 0, LHS_STEP_X * sizeof(DATA_TYPE), zlhs);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001938
1939 // Load values from RHS matrix
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001940 LOAD_BLOCK(N0, K0, DATA_TYPE, b, rhs_addr, 0, RHS_STEP_X * sizeof(DATA_TYPE), zero);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001941
1942 // Accumulate
1943 ARM_DOT_K0XN0(a0, b, c0);
1944#if M0 > 1
1945 ARM_DOT_K0XN0(a1, b, c1);
1946#endif // M0 > 1
1947#if M0 > 2
1948 ARM_DOT_K0XN0(a2, b, c2);
1949#endif // M0 > 2
1950#if M0 > 3
1951 ARM_DOT_K0XN0(a3, b, c3);
1952#endif // M0 > 3
1953#if M0 > 4
1954 ARM_DOT_K0XN0(a4, b, c4);
1955#endif // M0 > 4
1956#if M0 > 5
1957 ARM_DOT_K0XN0(a5, b, c5);
1958#endif // M0 > 5
1959#if M0 > 6
1960 ARM_DOT_K0XN0(a6, b, c6);
1961#endif // M0 > 6
1962#if M0 > 7
1963 ARM_DOT_K0XN0(a7, b, c7);
1964#endif // M0 > 7
1965
1966 lhs_addr += (M0 * LHS_STEP_X * LHS_STEP_LOOP) * sizeof(DATA_TYPE);
1967 rhs_addr += (N0 * RHS_STEP_X * RHS_STEP_LOOP) * sizeof(DATA_TYPE);
1968 }
1969
1970 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * dst_stride_y);
1971
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001972 REPEAT_VAR_INIT_TO_CONST(M0, uint, zout, 0);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001973
1974#if defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001975
1976 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
Usama Arif0681e3b2019-04-25 14:28:07 +01001977 CALCULATE_Z_OFFSET(M0, uint, zout, get_global_id(1), HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001978 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1979 // multiply dst_stride_z by DEPTH_GEMM3D
1980 dst_addr += get_global_id(2) * dst_stride_z * DEPTH_GEMM3D;
1981
1982#else // defined(REINTERPRET_OUTPUT_AS_3D)
1983
1984 // Add offset for batched GEMM
1985 dst_addr += get_global_id(2) * dst_stride_z;
1986
1987#endif // defined(REINTERPRET_OUTPUT_AS_3D)
1988
1989 // Multiply by the weight of matrix-matrix product and store the result
1990#if defined(ALPHA)
Usama Arif0681e3b2019-04-25 14:28:07 +01001991 SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001992#endif // defined(ALPHA)
1993
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001994 // Add beta*bias
1995#if defined(BETA)
1996#if defined(BROADCAST_BIAS)
1997 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
1998
1999 LOAD_BLOCK(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
2000
2001#ifndef UNIT_BETA
2002 SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
2003#endif // UNIT_BIAS
2004
2005 // c = c + bias[broadcasted]
2006 ADD_BLOCK_BROADCAST(M0, c, bias0);
2007
2008#else // defined(BROADCAST_BIAS)
2009 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id(
2010 2) * bias_stride_z;
2011
2012 LOAD_BLOCK(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
2013
2014#ifndef UNIT_BETA
2015 SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
2016#endif // UNIT_BIAS
2017
2018 // c = c + bias
2019 ADD_BLOCK(M0, c, bias);
2020
2021#endif // defined(BROADCAST_BIAS)
2022#endif // defined(BETA)
2023
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +01002024#if defined(ACTIVATION_TYPE)
2025 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE, c, A_VAL, B_VAL);
2026#endif // defined(ACTIVATION_TYPE)
2027
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002028 // Store output block
Usama Arif0681e3b2019-04-25 14:28:07 +01002029 STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01002030
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002031#undef LHS_BLOCK_SIZE
2032#undef LHS_OFFSET_X
2033#undef LHS_STEP_X
2034#undef RHS_BLOCK_SIZE
2035#undef RHS_OFFSET_X
2036#undef RHS_STEP_X
2037}
giuros01b3204e72019-04-01 13:50:22 +01002038
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002039#endif // defined(M0) && defined(N0) && defined(K0) && defined(V0) && defined(H0) && defined(K) && defined(DATA_TYPE)
2040
giuros01b3204e72019-04-01 13:50:22 +01002041#if defined(M0) && defined(N0) && defined(K0) && defined(K) && defined(DATA_TYPE)
2042
2043#define VFMA(a, b, c) \
2044 ({ \
2045 c = fma(a, b, c); \
2046 })
2047
2048#if M0 == 1
2049#define RHS_VFMA_M0xN0(i, a, b, c) \
2050 ({ \
2051 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
2052 })
2053#elif M0 == 2 // M0 == 2
2054#define RHS_VFMA_M0xN0(i, a, b, c) \
2055 ({ \
2056 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
2057 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
2058 })
2059#elif M0 == 3 // M0 == 3
2060#define RHS_VFMA_M0xN0(i, a, b, c) \
2061 ({ \
2062 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
2063 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
2064 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
2065 })
2066#elif M0 == 4 // M0 == 4
2067#define RHS_VFMA_M0xN0(i, a, b, c) \
2068 ({ \
2069 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
2070 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
2071 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
2072 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
2073 })
2074#elif M0 == 5 // M0 == 5
2075#define RHS_VFMA_M0xN0(i, a, b, c) \
2076 ({ \
2077 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
2078 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
2079 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
2080 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
2081 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
2082 })
2083#elif M0 == 6 // M0 == 6
2084#define RHS_VFMA_M0xN0(i, a, b, c) \
2085 ({ \
2086 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
2087 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
2088 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
2089 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
2090 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
2091 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
2092 })
2093#elif M0 == 7 // M0 == 7
2094#define RHS_VFMA_M0xN0(i, a, b, c) \
2095 ({ \
2096 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
2097 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
2098 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
2099 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
2100 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
2101 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
2102 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##6).s##i), b, (c##6)); \
2103 })
2104#elif M0 == 8 // M0 == 8
2105#define RHS_VFMA_M0xN0(i, a, b, c) \
2106 ({ \
2107 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
2108 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
2109 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
2110 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
2111 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
2112 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
2113 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##6).s##i), b, (c##6)); \
2114 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##7).s##i), b, (c##7)); \
2115 })
2116#else // M0 not supported
2117#error "M0 not supported"
2118#endif // M0 not supported
2119
2120/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
2121 * The LHS matrix is NOT reshaped
2122 * The RHS matrix is NOT reshaped
2123 *
2124 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002125 * @note The GEMM's dimensions (M,N and K) must be passed at compile time using -DM, -DN and and -DK (e.g. -DM=52, -DN=30 and -DK=90)
2126 * @note The number of columns of LHS matrix must be passed at compile time using -DK (e.g. -DK=64)
2127 * @note The number of M0 rows to process must be passed at compile time using -DM0 (e.g. -DM0=2)
2128 * @note The number of K0 partial accumulations must be passed at compile time using -DK0 (e.g., -DK0=2)
2129 * @note The number of N0 columns to process must be passed at compile time using -DN0 (e.g. -DN0=2)
giuros01b3204e72019-04-01 13:50:22 +01002130 * @note Only the following configurations of M0, N0 and K0 are currently supported:
2131 * - M0 = 1, 2, 3, 4, 5, 6, 7, 8
2132 * - N0 = 2, 3, 4, 8, 16
2133 * - K0 = 2, 3, 4, 8, 16
2134 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002135 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +01002136 * The activation function is performed after the bias addition
giuros01b3204e72019-04-01 13:50:22 +01002137 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
2138 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
2139 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2140 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2141 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2142 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix
2143 *
Gian Marco Iodice944170e2019-06-24 14:40:30 +01002144 * @param[in] lhs_ptr Pointer to the LHS matrix. Supported data type: F16/F32
2145 * @param[in] lhs_stride_x Stride of the LHS matrix in X dimension (in bytes)
2146 * @param[in] lhs_step_x lhs_stride_x * number of elements along X processed per workitem(in bytes)
2147 * @param[in] lhs_stride_y Stride of the LHS matrix in Y dimension (in bytes)
2148 * @param[in] lhs_step_y lhs_stride_y * number of elements along Y processed per workitem(in bytes)
2149 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS matrix
2150 * @param[in] rhs_ptr Pointer to the RHS matrix. Supported data type: same as @p lhs_ptr
2151 * @param[in] rhs_stride_x Stride of the RHS matrix in X dimension (in bytes)
2152 * @param[in] rhs_step_x rhs_stride_x * number of elements along X processed per workitem(in bytes)
2153 * @param[in] rhs_stride_y Stride of the RHS matrix in Y dimension (in bytes)
2154 * @param[in] rhs_step_y rhs_stride_y * number of elements along Y processed per workitem(in bytes)
2155 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS matrix
Gian Marco Iodice944170e2019-06-24 14:40:30 +01002156 * @param[in] bias_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
2157 * @param[in] bias_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
2158 * @param[in] bias_step_x (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes)
2159 * @param[in] bias_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
2160 * @param[in] bias_step_y (Optional) bias_stride_y * number of elements along Y processed per workitem(in bytes)
2161 * @param[in] bias_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
2162 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
2163 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2164 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2165 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2166 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2167 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2168 * @param[in] lhs_stride_z Stride of the LHS matrix in Z dimension (in bytes)
2169 * @param[in] rhs_stride_z Stride of the RHS matrix in Z dimension (in bytes)
2170 * @param[in] bias_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
2171 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2172 * @param[in] lhs_cross_plane_pad (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
2173 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
giuros01b3204e72019-04-01 13:50:22 +01002174 */
2175__kernel void gemm_mm_native(IMAGE_DECLARATION(lhs),
2176 IMAGE_DECLARATION(rhs),
Gian Marco Iodice944170e2019-06-24 14:40:30 +01002177#if defined(BETA)
2178 IMAGE_DECLARATION(bias),
2179#endif // defined(BETA)
giuros01b3204e72019-04-01 13:50:22 +01002180 IMAGE_DECLARATION(dst),
2181 uint lhs_stride_z,
2182 uint rhs_stride_z,
Gian Marco Iodice944170e2019-06-24 14:40:30 +01002183#if defined(BETA)
2184 uint bias_stride_z,
2185#endif //defined(BETA)
giuros01b3204e72019-04-01 13:50:22 +01002186 uint dst_stride_z
2187#if defined(REINTERPRET_INPUT_AS_3D)
2188 ,
2189 uint lhs_cross_plane_pad
2190#endif // REINTERPRET_INPUT_AS_3D
2191#if defined(REINTERPRET_OUTPUT_AS_3D)
2192 ,
2193 uint dst_cross_plane_pad
2194#endif // REINTERPRET_OUTPUT_AS_3D
2195 )
2196{
2197 // Block size
2198#define RHS_BLOCK_SIZE ((K0) * (N0))
2199
2200 // RHS offset and step X
2201#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
2202
2203 uint x = get_global_id(0);
2204 uint y = get_global_id(1);
2205 uint z = get_global_id(2);
2206
2207#if defined(DUMMY_WORK_ITEMS)
2208 if((x * N0 >= N) || (y * M0 >= M))
2209 {
2210 return;
2211 }
2212#endif // defined(DUMMY_WORK_ITEMS)
2213
2214 // Compute LHS matrix address
2215 uint lhs_offset = lhs_offset_first_element_in_bytes + y * M0 * (uint)lhs_stride_y;
2216
2217 // Compute RHS matrix address
2218 uint rhs_offset = rhs_offset_first_element_in_bytes + x * N0 * sizeof(DATA_TYPE);
2219
2220#if defined(MATRIX_B_DEPTH)
2221 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2222 rhs_offset += (z % MATRIX_B_DEPTH) * rhs_stride_z;
2223#else // defined(MATRIX_B_DEPTH)
2224 rhs_offset += z * rhs_stride_z;
2225#endif // defined(MATRIX_B_DEPTH)
2226
Gian Marco Iodice944170e2019-06-24 14:40:30 +01002227 REPEAT_VAR_INIT_TO_CONST(M0, uint, zlhs, 0);
2228 REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0);
giuros01b3204e72019-04-01 13:50:22 +01002229
2230#if defined(REINTERPRET_INPUT_AS_3D)
2231 // The plane (zlhs) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
2232 CALCULATE_Z_OFFSET(M0, uint, zlhs, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
2233
2234 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2235 // multiply lhs_stride_z by DEPTH_GEMM3D
2236 lhs_offset += z * lhs_stride_z * DEPTH_GEMM3D;
2237
2238#else // defined(REINTERPRET_INPUT_AS_3D)
2239
2240 // Add offset for batched GEMM
2241 lhs_offset += z * lhs_stride_z;
2242
2243#endif // defined(REINTERPRET_INPUT_AS_3D)
2244
2245 // Initialize the accumulators
Gian Marco Iodice944170e2019-06-24 14:40:30 +01002246 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(M0-1)=0;
giuros01b3204e72019-04-01 13:50:22 +01002247
2248 int i = 0;
2249 for(; i <= (K - K0); i += K0)
2250 {
2251 // Supported cases (M0, K0):
2252 // 1,2 - 1,3 - 1,4 - 1,8 - 1,16
2253 // 2,2 - 2,3 - 2,4 - 2,8 - 2,16
2254 // 3,2 - 3,3 - 3,4 - 3,8 - 3,16
2255 // 4,2 - 4,3 - 4,4 - 4,8 - 4,16
2256 // 5,2 - 5,3 - 5,4 - 5,8 - 5,16
2257 // 6,2 - 6,3 - 6,4 - 6,8 - 6,16
2258 // 7,2 - 7,3 - 7,4 - 7,8 - 7,16
2259 // 8,2 - 8,3 - 8,4 - 8,8 - 8,16
2260 // Load values from LHS matrix
2261 LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs);
2262
2263 // Load values from RHS matrix
Gian Marco Iodice944170e2019-06-24 14:40:30 +01002264 LOAD_BLOCK(K0, N0, DATA_TYPE, b, rhs_ptr, rhs_offset, rhs_stride_y, zero);
giuros01b3204e72019-04-01 13:50:22 +01002265
2266 RHS_VFMA_M0xN0(0, a, b0, c);
2267 RHS_VFMA_M0xN0(1, a, b1, c);
2268#if K0 > 2
2269 RHS_VFMA_M0xN0(2, a, b2, c);
2270#endif // K0 > 2
2271#if K0 > 3
2272 RHS_VFMA_M0xN0(3, a, b3, c);
2273#endif // K0 > 3
2274#if K0 > 4
2275 RHS_VFMA_M0xN0(4, a, b4, c);
2276 RHS_VFMA_M0xN0(5, a, b5, c);
2277 RHS_VFMA_M0xN0(6, a, b6, c);
2278 RHS_VFMA_M0xN0(7, a, b7, c);
2279#endif // K0 > 4
2280#if K0 > 8
2281 RHS_VFMA_M0xN0(8, a, b8, c);
2282 RHS_VFMA_M0xN0(9, a, b9, c);
2283 RHS_VFMA_M0xN0(A, a, b10, c);
2284 RHS_VFMA_M0xN0(B, a, b11, c);
2285 RHS_VFMA_M0xN0(C, a, b12, c);
2286 RHS_VFMA_M0xN0(D, a, b13, c);
2287 RHS_VFMA_M0xN0(E, a, b14, c);
2288 RHS_VFMA_M0xN0(F, a, b15, c);
2289#endif // K0 > 8
2290
2291 lhs_offset += K0 * sizeof(DATA_TYPE);
2292 rhs_offset += K0 * rhs_stride_y;
2293 }
2294
2295 // Left-over accumulations
2296 for(; i < K; ++i)
2297 {
2298 // Load values from LHS matrix
2299 VEC_DATA_TYPE(DATA_TYPE, 2)
2300 a0 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 0 * lhs_stride_y + zlhs0));
2301#if M0 > 1
2302 VEC_DATA_TYPE(DATA_TYPE, 2)
2303 a1 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 1 * lhs_stride_y + zlhs1));
2304#endif // M0 > 1
2305#if M0 > 2
2306 VEC_DATA_TYPE(DATA_TYPE, 2)
2307 a2 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 2 * lhs_stride_y + zlhs2));
2308#endif // M0 > 2
2309#if M0 > 3
2310 VEC_DATA_TYPE(DATA_TYPE, 2)
2311 a3 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 3 * lhs_stride_y + zlhs3));
2312#endif // M0 > 3
2313#if M0 > 4
2314 VEC_DATA_TYPE(DATA_TYPE, 2)
2315 a4 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 4 * lhs_stride_y + zlhs4));
2316#endif // M0 > 4
2317#if M0 > 5
2318 VEC_DATA_TYPE(DATA_TYPE, 2)
2319 a5 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 5 * lhs_stride_y + zlhs5));
2320#endif // M0 > 5
2321#if M0 > 6
2322 VEC_DATA_TYPE(DATA_TYPE, 2)
2323 a6 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 6 * lhs_stride_y + zlhs6));
2324#endif // M0 > 6
2325#if M0 > 7
2326 VEC_DATA_TYPE(DATA_TYPE, 2)
2327 a7 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 7 * lhs_stride_y + zlhs7));
2328#endif // M0 > 7
2329
2330 VEC_DATA_TYPE(DATA_TYPE, N0)
2331 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0 * rhs_stride_y));
2332 RHS_VFMA_M0xN0(0, a, b, c);
2333
2334 lhs_offset += sizeof(DATA_TYPE);
2335 rhs_offset += rhs_stride_y;
2336 }
2337
2338 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * dst_stride_y);
2339
Gian Marco Iodice944170e2019-06-24 14:40:30 +01002340 REPEAT_VAR_INIT_TO_CONST(M0, uint, zout, 0);
giuros01b3204e72019-04-01 13:50:22 +01002341
2342#if defined(REINTERPRET_OUTPUT_AS_3D)
2343 // The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
2344 CALCULATE_Z_OFFSET(M0, uint, zout, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
2345
2346 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2347 // multiply dst_stride_z by DEPTH_GEMM3D
2348 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
2349
2350#else // defined(REINTERPRET_OUTPUT_AS_3D)
2351
2352 // Add offset for batched GEMM
2353 dst_addr += z * dst_stride_z;
2354
2355#endif // defined(REINTERPRET_OUTPUT_AS_3D)
2356
2357 // Multiply by the weight of matrix-matrix product and store the result
giuros01b3204e72019-04-01 13:50:22 +01002358#if defined(ALPHA)
2359 SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
2360#endif // defined(ALPHA)
2361
Gian Marco Iodice944170e2019-06-24 14:40:30 +01002362 // Add beta*bias
2363#if defined(BETA)
2364#if defined(BROADCAST_BIAS)
2365 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
2366
2367 LOAD_BLOCK(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
2368
2369#ifndef UNIT_BETA
2370 SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
2371#endif // UNIT_BIAS
2372
2373 // c = c + bias[broadcasted]
2374 ADD_BLOCK_BROADCAST(M0, c, bias0);
2375
2376#else // defined(BROADCAST_BIAS)
2377 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id(
2378 2) * bias_stride_z;
2379
2380 LOAD_BLOCK(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
2381
2382#ifndef UNIT_BETA
2383 SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
2384#endif // UNIT_BIAS
2385
2386 // c = c + bias
2387 ADD_BLOCK(M0, c, bias);
2388
2389#endif // defined(BROADCAST_BIAS)
2390#endif // defined(BETA)
2391
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +01002392#if defined(ACTIVATION_TYPE)
2393 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE, c, A_VAL, B_VAL);
2394#endif // defined(ACTIVATION_TYPE)
2395
giuros01b3204e72019-04-01 13:50:22 +01002396 // Store output block
2397 STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
2398
2399#undef RHS_BLOCK_SIZE
2400#undef RHS_OFFSET_X
2401#undef RHS_STEP_X
2402}
2403#endif // defined(M0) && defined(N0) && defined(K0) && defined(K) && defined(DATA_TYPE)
2404
Gian Marco36a0a462018-01-12 10:21:40 +00002405#if defined(COLS_B) && defined(MULT_TRANSPOSE1XW_WIDTH) && defined(MULT_INTERLEAVE4X4_HEIGHT)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002406/** This OpenCL kernel is optimised for Midgard. It computes the matrix multiplication between matrix A reshaped (src0) and matrix B reshaped (src1)
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002407 *
Gian Marco19835e52018-01-30 13:35:54 +00002408 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002409 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (e.g. -DMULT_TRANSPOSE1XW_WIDTH=2)
2410 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (e.g. -DMULT_INTERLEAVE4X4_HEIGHT=2)
2411 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
2412 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002413 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002414 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
2415 * The activation function is performed after the bias addition
2416 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002417 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2418 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2419 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2420 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
2421 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002422 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
2423 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
2424 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2425 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
2426 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2427 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002428 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002429 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
2430 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2431 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
2432 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2433 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002434 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
2435 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
2436 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
2437 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
2438 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
2439 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002440 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002441 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00002442 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002443 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00002444 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002445 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002446 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
2447 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002448 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002449 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002450 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002451 */
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01002452__kernel void gemm_mm_interleaved_transposed_f32(IMAGE_DECLARATION(src0),
2453 IMAGE_DECLARATION(src1),
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002454#if defined(BETA)
2455 IMAGE_DECLARATION(src2),
2456#endif // defined(BETA)
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01002457 IMAGE_DECLARATION(dst),
2458 uint src0_stride_z,
2459 uint src1_stride_z,
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002460#if defined(BETA)
2461 uint src2_stride_z,
2462#endif //defined(BETA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002463 uint dst_stride_z
2464#if defined(REINTERPRET_OUTPUT_AS_3D)
2465 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002466 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002467#endif // REINTERPRET_OUTPUT_AS_3D
2468 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002469{
Gian Marco36a0a462018-01-12 10:21:40 +00002470 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
2471 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +00002472 int z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002473
Gian Marco36a0a462018-01-12 10:21:40 +00002474 // Offset
2475 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
2476 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 4;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002477
Gian Marco36a0a462018-01-12 10:21:40 +00002478 // src_addr_a = address of matrix A
2479 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002480 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
2481 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
2482
2483#if defined(MATRIX_B_DEPTH)
2484 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2485 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
2486#else // defined(MATRIX_B_DEPTH)
2487 src1_addr_in_bytes += z * src1_stride_z;
2488#endif // defined(MATRIX_B_DEPTH)
2489
2490 __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes);
2491 __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002492
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002493 // Compute end row address for matrix B
Gian Marco36a0a462018-01-12 10:21:40 +00002494 __global float *src_end_addr_b = src_addr_b + COLS_B;
2495
2496 src_addr_a += offset_row_a;
2497 src_addr_b += offset_row_b;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002498
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002499 // Reset accumulators
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002500 float4 c0 = 0.0f;
2501 float4 c1 = 0.0f;
2502 float4 c2 = 0.0f;
2503 float4 c3 = 0.0f;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002504
Gian Marco36a0a462018-01-12 10:21:40 +00002505 for(; src_addr_b <= (src_end_addr_b - (int)(8 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002506 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002507 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00002508 float4 a0 = vload4(0, src_addr_a);
2509 float4 b0 = vload4(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002510
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002511 c0 += (float4)a0.s0 * b0;
2512 c1 += (float4)a0.s1 * b0;
2513 c2 += (float4)a0.s2 * b0;
2514 c3 += (float4)a0.s3 * b0;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002515
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002516 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00002517 a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
2518 b0 = vload4(0, src_addr_b + 4 * MULT_TRANSPOSE1XW_WIDTH);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002519
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002520 c0 += (float4)a0.s0 * b0;
2521 c1 += (float4)a0.s1 * b0;
2522 c2 += (float4)a0.s2 * b0;
2523 c3 += (float4)a0.s3 * b0;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002524 }
2525
Gian Marco36a0a462018-01-12 10:21:40 +00002526 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002527 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002528 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00002529 float4 a0 = vload4(0, src_addr_a);
2530 float4 b0 = vload4(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002531
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002532 c0 += (float4)a0.s0 * b0;
2533 c1 += (float4)a0.s1 * b0;
2534 c2 += (float4)a0.s2 * b0;
2535 c3 += (float4)a0.s3 * b0;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002536 }
2537
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002538 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002539 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2540
Gian Marcoae2af742018-02-15 12:35:44 +00002541 // Compute dst address
2542 __global uchar *dst_addr = offset(&dst, 0, 0);
2543
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002544 uint4 zout = 0;
2545
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002546#if defined(REINTERPRET_OUTPUT_AS_3D)
2547 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002548 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002549 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002550 // | |
2551 // | plane0 |
2552 // | |
2553 // |__________________|
2554 // |******************|
2555 // | cross_plane_pad |
2556 // |******************|
2557 // | |
2558 // | plane1 |
2559 // | |
2560 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002561
2562 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002563 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
2564 zout = min(DEPTH_GEMM3D - 1, zout);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002565
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002566 // Add offset due to the cross plane paddings
2567 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002568
2569 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2570 // multiply dst_stride_z by DEPTH_GEMM3D
2571 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002572#else // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +00002573 // Add offset for batched GEMM
2574 dst_addr += z * dst_stride_z;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002575#endif // defined(REINTERPRET_OUTPUT_AS_3D)
2576
2577 // Multiply by the weight of matrix-matrix product and store the result
2578#if defined(ALPHA)
2579 SCALE_BLOCK(4, float, c, ALPHA);
2580#endif // defined(ALPHA)
2581
2582 // Add beta*bias
2583#if defined(BETA)
2584 REPEAT_VAR_INIT_TO_CONST(4, uint, zero, 0);
2585
2586#if defined(BROADCAST_BIAS)
2587 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float));
2588
2589 LOAD_BLOCK(1, 4, float, bias, src2_addr, 0, src2_stride_y, zero);
2590
2591#ifndef UNIT_BETA
2592 SCALE_BLOCK(1, float, bias, BETA);
2593#endif // UNIT_BIAS
2594
2595 // c = c + bias[broadcasted]
2596 ADD_BLOCK_BROADCAST(4, c, bias0);
2597
2598#else // defined(BROADCAST_BIAS)
2599 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id(
2600 2) * src2_stride_z;
2601
2602 LOAD_BLOCK(4, 4, float, bias, src2_addr, 0, src2_stride_y, zero);
2603
2604#ifndef UNIT_BETA
2605 SCALE_BLOCK(4, float, bias, BETA);
2606#endif // UNIT_BIAS
2607
2608 // c = c + bias
2609 ADD_BLOCK(4, c, bias);
2610
2611#endif // defined(BROADCAST_BIAS)
2612#endif // defined(BETA)
2613
2614#if defined(ACTIVATION_TYPE)
2615 ACTIVATION_BLOCK(4, ACTIVATION_TYPE, float, c, A_VAL, B_VAL);
2616#endif // defined(ACTIVATION_TYPE)
Gian Marcoae2af742018-02-15 12:35:44 +00002617
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002618 // Store 4x4 block
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002619 vstore4(c0, 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
2620 vstore4(c1, 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
2621 vstore4(c2, 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
2622 vstore4(c3, 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002623}
2624
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002625/** This OpenCL kernel is optimized for Bifrost and tt computes the matrix multiplication between matrix A reshaped (src0) and matrix B reshaped (src1)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002626 *
Gian Marco19835e52018-01-30 13:35:54 +00002627 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002628 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (e.g. -DMULT_TRANSPOSE1XW_WIDTH=2)
2629 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (e.g. -DMULT_INTERLEAVE4X4_HEIGHT=2)
2630 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (e.g. -DMULT_INTERLEAVE4X4_HEIGHT=2)
2631 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
2632 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002633 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002634 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
2635 * The activation function is performed after the bias addition
2636 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002637 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2638 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2639 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2640 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
2641 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002642 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
2643 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
2644 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2645 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
2646 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2647 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002648 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002649 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
2650 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2651 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
2652 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2653 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002654 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
2655 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
2656 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
2657 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
2658 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
2659 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002660 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002661 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00002662 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002663 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00002664 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002665 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002666 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
2667 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002668 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002669 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002670 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002671 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002672__kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0),
2673 IMAGE_DECLARATION(src1),
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002674#if defined(BETA)
2675 IMAGE_DECLARATION(src2),
2676#endif // defined(BETA)
Gian Marcoae2af742018-02-15 12:35:44 +00002677 IMAGE_DECLARATION(dst),
2678 uint src0_stride_z,
2679 uint src1_stride_z,
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002680#if defined(BETA)
2681 uint src2_stride_z,
2682#endif //defined(BETA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002683 uint dst_stride_z
2684#if defined(REINTERPRET_OUTPUT_AS_3D)
2685 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002686 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002687#endif // REINTERPRET_OUTPUT_AS_3D
2688 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002689{
Gian Marco36a0a462018-01-12 10:21:40 +00002690 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
2691 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +00002692 int z = get_global_id(2);
Gian Marco36a0a462018-01-12 10:21:40 +00002693
2694 // Offset
2695 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
2696 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 4;
2697
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002698 // src_addr_a = address of matrix A
2699 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002700 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
2701 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
2702
2703#if defined(MATRIX_B_DEPTH)
2704 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2705 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
2706#else // defined(MATRIX_B_DEPTH)
2707 src1_addr_in_bytes += z * src1_stride_z;
2708#endif // defined(MATRIX_B_DEPTH)
2709
2710 __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes);
2711 __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002712
Gian Marco36a0a462018-01-12 10:21:40 +00002713 src_addr_a += offset_row_a;
2714 src_addr_b += offset_row_b;
2715
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002716 // Reset accumulators
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002717 float4 c0 = 0.0f;
2718 float4 c1 = 0.0f;
2719 float4 c2 = 0.0f;
2720 float4 c3 = 0.0f;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002721
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002722#define COLS_MTX_B (COLS_B / (4 * MULT_TRANSPOSE1XW_WIDTH))
2723
2724 int i = 0;
2725 for(; i <= (int)(COLS_MTX_B - 4); i += 4)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002726 {
2727 // Load values from matrix A (interleaved) and matrix B (transposed)
2728 float4 a0 = vload4(0, src_addr_a);
2729 float4 b0 = vload4(0, src_addr_b);
2730
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002731 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
2732 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002733
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002734 c0.s0 = fma(a0.s0, b0.s0, c0.s0);
2735 c0.s1 = fma(a0.s0, b0.s1, c0.s1);
2736 c0.s2 = fma(a0.s0, b0.s2, c0.s2);
2737 c0.s3 = fma(a0.s0, b0.s3, c0.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002738
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002739 c1.s0 = fma(a0.s1, b0.s0, c1.s0);
2740 c1.s1 = fma(a0.s1, b0.s1, c1.s1);
2741 c1.s2 = fma(a0.s1, b0.s2, c1.s2);
2742 c1.s3 = fma(a0.s1, b0.s3, c1.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002743
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002744 c2.s0 = fma(a0.s2, b0.s0, c2.s0);
2745 c2.s1 = fma(a0.s2, b0.s1, c2.s1);
2746 c2.s2 = fma(a0.s2, b0.s2, c2.s2);
2747 c2.s3 = fma(a0.s2, b0.s3, c2.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002748
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002749 c3.s0 = fma(a0.s3, b0.s0, c3.s0);
2750 c3.s1 = fma(a0.s3, b0.s1, c3.s1);
2751 c3.s2 = fma(a0.s3, b0.s2, c3.s2);
2752 c3.s3 = fma(a0.s3, b0.s3, c3.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002753
2754 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002755 a0 = vload4(0, src_addr_a);
2756 b0 = vload4(0, src_addr_b);
2757
2758 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
2759 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002760
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002761 c0.s0 = fma(a0.s0, b0.s0, c0.s0);
2762 c0.s1 = fma(a0.s0, b0.s1, c0.s1);
2763 c0.s2 = fma(a0.s0, b0.s2, c0.s2);
2764 c0.s3 = fma(a0.s0, b0.s3, c0.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002765
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002766 c1.s0 = fma(a0.s1, b0.s0, c1.s0);
2767 c1.s1 = fma(a0.s1, b0.s1, c1.s1);
2768 c1.s2 = fma(a0.s1, b0.s2, c1.s2);
2769 c1.s3 = fma(a0.s1, b0.s3, c1.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002770
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002771 c2.s0 = fma(a0.s2, b0.s0, c2.s0);
2772 c2.s1 = fma(a0.s2, b0.s1, c2.s1);
2773 c2.s2 = fma(a0.s2, b0.s2, c2.s2);
2774 c2.s3 = fma(a0.s2, b0.s3, c2.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002775
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002776 c3.s0 = fma(a0.s3, b0.s0, c3.s0);
2777 c3.s1 = fma(a0.s3, b0.s1, c3.s1);
2778 c3.s2 = fma(a0.s3, b0.s2, c3.s2);
2779 c3.s3 = fma(a0.s3, b0.s3, c3.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002780
2781 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002782 a0 = vload4(0, src_addr_a);
2783 b0 = vload4(0, src_addr_b);
2784
2785 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
2786 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
2787
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002788 c0.s0 = fma(a0.s0, b0.s0, c0.s0);
2789 c0.s1 = fma(a0.s0, b0.s1, c0.s1);
2790 c0.s2 = fma(a0.s0, b0.s2, c0.s2);
2791 c0.s3 = fma(a0.s0, b0.s3, c0.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002792
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002793 c1.s0 = fma(a0.s1, b0.s0, c1.s0);
2794 c1.s1 = fma(a0.s1, b0.s1, c1.s1);
2795 c1.s2 = fma(a0.s1, b0.s2, c1.s2);
2796 c1.s3 = fma(a0.s1, b0.s3, c1.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002797
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002798 c2.s0 = fma(a0.s2, b0.s0, c2.s0);
2799 c2.s1 = fma(a0.s2, b0.s1, c2.s1);
2800 c2.s2 = fma(a0.s2, b0.s2, c2.s2);
2801 c2.s3 = fma(a0.s2, b0.s3, c2.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002802
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002803 c3.s0 = fma(a0.s3, b0.s0, c3.s0);
2804 c3.s1 = fma(a0.s3, b0.s1, c3.s1);
2805 c3.s2 = fma(a0.s3, b0.s2, c3.s2);
2806 c3.s3 = fma(a0.s3, b0.s3, c3.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002807
2808 // Load values from matrix A (interleaved) and matrix B (transposed)
2809 a0 = vload4(0, src_addr_a);
2810 b0 = vload4(0, src_addr_b);
2811
2812 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
2813 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002814
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002815 c0.s0 = fma(a0.s0, b0.s0, c0.s0);
2816 c0.s1 = fma(a0.s0, b0.s1, c0.s1);
2817 c0.s2 = fma(a0.s0, b0.s2, c0.s2);
2818 c0.s3 = fma(a0.s0, b0.s3, c0.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002819
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002820 c1.s0 = fma(a0.s1, b0.s0, c1.s0);
2821 c1.s1 = fma(a0.s1, b0.s1, c1.s1);
2822 c1.s2 = fma(a0.s1, b0.s2, c1.s2);
2823 c1.s3 = fma(a0.s1, b0.s3, c1.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002824
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002825 c2.s0 = fma(a0.s2, b0.s0, c2.s0);
2826 c2.s1 = fma(a0.s2, b0.s1, c2.s1);
2827 c2.s2 = fma(a0.s2, b0.s2, c2.s2);
2828 c2.s3 = fma(a0.s2, b0.s3, c2.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002829
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002830 c3.s0 = fma(a0.s3, b0.s0, c3.s0);
2831 c3.s1 = fma(a0.s3, b0.s1, c3.s1);
2832 c3.s2 = fma(a0.s3, b0.s2, c3.s2);
2833 c3.s3 = fma(a0.s3, b0.s3, c3.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002834 }
2835
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002836 for(; i < (int)(COLS_MTX_B); ++i)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002837 {
2838 // Load values from matrix A (interleaved) and matrix B (transposed)
2839 float4 a0 = vload4(0, src_addr_a);
2840 float4 b0 = vload4(0, src_addr_b);
2841
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002842 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
2843 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
2844
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002845 c0.s0 = fma(a0.s0, b0.s0, c0.s0);
2846 c0.s1 = fma(a0.s0, b0.s1, c0.s1);
2847 c0.s2 = fma(a0.s0, b0.s2, c0.s2);
2848 c0.s3 = fma(a0.s0, b0.s3, c0.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002849
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002850 c1.s0 = fma(a0.s1, b0.s0, c1.s0);
2851 c1.s1 = fma(a0.s1, b0.s1, c1.s1);
2852 c1.s2 = fma(a0.s1, b0.s2, c1.s2);
2853 c1.s3 = fma(a0.s1, b0.s3, c1.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002854
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002855 c2.s0 = fma(a0.s2, b0.s0, c2.s0);
2856 c2.s1 = fma(a0.s2, b0.s1, c2.s1);
2857 c2.s2 = fma(a0.s2, b0.s2, c2.s2);
2858 c2.s3 = fma(a0.s2, b0.s3, c2.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002859
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002860 c3.s0 = fma(a0.s3, b0.s0, c3.s0);
2861 c3.s1 = fma(a0.s3, b0.s1, c3.s1);
2862 c3.s2 = fma(a0.s3, b0.s2, c3.s2);
2863 c3.s3 = fma(a0.s3, b0.s3, c3.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002864 }
2865
2866 // Compute destination address
2867 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2868
Gian Marcoae2af742018-02-15 12:35:44 +00002869 // Compute dst address
2870 __global uchar *dst_addr = offset(&dst, 0, 0);
2871
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002872 uint4 zout = 0;
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002873
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002874#if defined(REINTERPRET_OUTPUT_AS_3D)
2875 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002876 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002877 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002878 // | |
2879 // | plane0 |
2880 // | |
2881 // |__________________|
2882 // |******************|
2883 // | cross_plane_pad |
2884 // |******************|
2885 // | |
2886 // | plane1 |
2887 // | |
2888 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002889
2890 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002891 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
2892 zout = min(DEPTH_GEMM3D - 1, zout);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002893
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002894 // Add offset due to the cross plane paddings
2895 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002896
2897 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2898 // multiply dst_stride_z by DEPTH_GEMM3D
2899 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002900#else // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +00002901 // Add offset for batched GEMM
2902 dst_addr += z * dst_stride_z;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002903#endif // defined(REINTERPRET_OUTPUT_AS_3D)
2904
2905 // Multiply by the weight of matrix-matrix product and store the result
2906#if defined(ALPHA)
2907 SCALE_BLOCK(4, float, c, ALPHA);
2908#endif // defined(ALPHA)
2909
2910 // Add beta*bias
2911#if defined(BETA)
2912 REPEAT_VAR_INIT_TO_CONST(4, uint, zero, 0);
2913
2914#if defined(BROADCAST_BIAS)
2915 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float));
2916
2917 LOAD_BLOCK(1, 4, float, bias, src2_addr, 0, src2_stride_y, zero);
2918
2919#ifndef UNIT_BETA
2920 SCALE_BLOCK(1, float, bias, BETA);
2921#endif // UNIT_BIAS
2922
2923 // c = c + bias[broadcasted]
2924 ADD_BLOCK_BROADCAST(4, c, bias0);
2925
2926#else // defined(BROADCAST_BIAS)
2927 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id(
2928 2) * src2_stride_z;
2929
2930 LOAD_BLOCK(4, 4, float, bias, src2_addr, 0, src2_stride_y, zero);
2931
2932#ifndef UNIT_BETA
2933 SCALE_BLOCK(4, float, bias, BETA);
2934#endif // UNIT_BIAS
2935
2936 // c = c + bias
2937 ADD_BLOCK(4, c, bias);
2938
2939#endif // defined(BROADCAST_BIAS)
2940#endif // defined(BETA)
2941
2942#if defined(ACTIVATION_TYPE)
2943 ACTIVATION_BLOCK(4, ACTIVATION_TYPE, float, c, A_VAL, B_VAL);
2944#endif // defined(ACTIVATION_TYPE)
Gian Marcoae2af742018-02-15 12:35:44 +00002945
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002946 // Store 4x4 block
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002947 vstore4(c0, 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
2948 vstore4(c1, 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
2949 vstore4(c2, 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
2950 vstore4(c3, 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002951}
2952
Georgios Pinitas84225582018-05-14 12:00:05 +01002953// Undefine local defines
2954#undef COLS_MTX_B
2955
Matthew Bentham6f31f8c2017-10-27 11:50:06 +01002956#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002957/** This OpenCL kernel computes the matrix multiplication between matrix A reshaped (src0) and matrix B reshaped (src1)
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002958 *
Gian Marco19835e52018-01-30 13:35:54 +00002959 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002960 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (e.g. -DMULT_TRANSPOSE1XW_WIDTH=2)
2961 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (e.g. -DMULT_INTERLEAVE4X4_HEIGHT=2)
2962 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
2963 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002964 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002965 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
2966 * The activation function is performed after the bias addition
2967 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002968 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2969 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2970 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2971 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
2972 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002973 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
2974 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
2975 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2976 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
2977 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2978 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002979 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002980 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
2981 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2982 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
2983 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2984 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002985 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
2986 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
2987 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
2988 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
2989 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
2990 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002991 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002992 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00002993 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002994 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00002995 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002996 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002997 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
2998 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002999 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003000 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003001 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003002 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003003__kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0),
3004 IMAGE_DECLARATION(src1),
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003005#if defined(BETA)
3006 IMAGE_DECLARATION(src2),
3007#endif // defined(BETA)
Gian Marcoae2af742018-02-15 12:35:44 +00003008 IMAGE_DECLARATION(dst),
3009 uint src0_stride_z,
3010 uint src1_stride_z,
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003011#if defined(BETA)
3012 uint src2_stride_z,
3013#endif //defined(BETA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003014 uint dst_stride_z
3015#if defined(REINTERPRET_OUTPUT_AS_3D)
3016 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003017 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003018#endif // REINTERPRET_OUTPUT_AS_3D
3019 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003020{
Gian Marco36a0a462018-01-12 10:21:40 +00003021 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
3022 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +00003023 int z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003024
Gian Marco36a0a462018-01-12 10:21:40 +00003025 // Offset
3026 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
3027 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003028
Gian Marco36a0a462018-01-12 10:21:40 +00003029 // src_addr_a = address of matrix A
3030 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003031 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
3032 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
3033
3034#if defined(MATRIX_B_DEPTH)
3035 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
3036 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
3037#else // defined(MATRIX_B_DEPTH)
3038 src1_addr_in_bytes += z * src1_stride_z;
3039#endif // defined(MATRIX_B_DEPTH)
3040
3041 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
3042 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003043
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003044 // Compute end row address for matrix B
Gian Marco36a0a462018-01-12 10:21:40 +00003045 __global half *src_end_addr_b = src_addr_b + COLS_B;
3046
3047 src_addr_a += offset_row_a;
3048 src_addr_b += offset_row_b;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003049
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003050 // Reset accumulators
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003051 half8 c0 = 0.0f;
3052 half8 c1 = 0.0f;
3053 half8 c2 = 0.0f;
3054 half8 c3 = 0.0f;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003055
Gian Marco36a0a462018-01-12 10:21:40 +00003056 for(; src_addr_b <= (src_end_addr_b - (int)(16 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 16 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003057 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003058 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00003059 half4 a0 = vload4(0, src_addr_a);
3060 half8 b0 = vload8(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003061
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003062 c0 += (half8)a0.s0 * b0;
3063 c1 += (half8)a0.s1 * b0;
3064 c2 += (half8)a0.s2 * b0;
3065 c3 += (half8)a0.s3 * b0;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003066
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003067 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00003068 a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
3069 b0 = vload8(0, src_addr_b + 8 * MULT_TRANSPOSE1XW_WIDTH);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003070
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003071 c0 += (half8)a0.s0 * b0;
3072 c1 += (half8)a0.s1 * b0;
3073 c2 += (half8)a0.s2 * b0;
3074 c3 += (half8)a0.s3 * b0;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003075 }
3076
Gian Marco36a0a462018-01-12 10:21:40 +00003077 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003078 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003079 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00003080 half4 a0 = vload4(0, src_addr_a);
3081 half8 b0 = vload8(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003082
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003083 c0 += (half8)a0.s0 * b0;
3084 c1 += (half8)a0.s1 * b0;
3085 c2 += (half8)a0.s2 * b0;
3086 c3 += (half8)a0.s3 * b0;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003087 }
3088
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003089 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003090 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
3091
Gian Marcoae2af742018-02-15 12:35:44 +00003092 // Compute dst address
3093 __global uchar *dst_addr = offset(&dst, 0, 0);
3094
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003095 uint4 zout = 0;
3096
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003097#if defined(REINTERPRET_OUTPUT_AS_3D)
3098 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003099 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003100 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003101 // | |
3102 // | plane0 |
3103 // | |
3104 // |__________________|
3105 // |******************|
3106 // | cross_plane_pad |
3107 // |******************|
3108 // | |
3109 // | plane1 |
3110 // | |
3111 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003112
3113 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003114 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
3115 zout = min(DEPTH_GEMM3D - 1, zout);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003116
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003117 // Add offset due to the cross plane paddings
3118 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003119
3120 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3121 // multiply dst_stride_z by DEPTH_GEMM3D
3122 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003123#else // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +00003124 // Add offset for batched GEMM
3125 dst_addr += z * dst_stride_z;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003126#endif // defined(REINTERPRET_OUTPUT_AS_3D)
3127
3128 // Multiply by the weight of matrix-matrix product and store the result
3129#if defined(ALPHA)
3130 SCALE_BLOCK(4, half, c, ALPHA);
3131#endif // defined(ALPHA)
3132
3133 // Add beta*bias
3134#if defined(BETA)
3135 REPEAT_VAR_INIT_TO_CONST(4, uint, zero, 0);
3136
3137#if defined(BROADCAST_BIAS)
3138 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half));
3139
3140 LOAD_BLOCK(1, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
3141
3142#ifndef UNIT_BETA
3143 SCALE_BLOCK(1, half, bias, BETA);
3144#endif // UNIT_BIAS
3145
3146 // c = c + bias[broadcasted]
3147 ADD_BLOCK_BROADCAST(4, c, bias0);
3148
3149#else // defined(BROADCAST_BIAS)
3150
3151 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id(
3152 2) * src2_stride_z;
3153
3154 LOAD_BLOCK(4, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
3155
3156#ifndef UNIT_BETA
3157 SCALE_BLOCK(4, half, bias, BETA);
3158#endif // UNIT_BIAS
3159
3160 // c = c + bias
3161 ADD_BLOCK(4, c, bias);
3162
3163#endif // defined(BROADCAST_BIAS)
3164#endif // defined(BETA)
3165
3166#if defined(ACTIVATION_TYPE)
3167 ACTIVATION_BLOCK(4, ACTIVATION_TYPE, half, c, A_VAL, B_VAL);
3168#endif // defined(ACTIVATION_TYPE)
Gian Marcoae2af742018-02-15 12:35:44 +00003169
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003170 // Store 4x8 block
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003171 vstore8(c0, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
3172 vstore8(c1, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
3173 vstore8(c2, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
3174 vstore8(c3, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003175}
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003176
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003177/** This OpenCL kernel computes the matrix multiplication between matrix A reshaped (src0) and matrix B reshaped (src1) while accumulating the result in a 32 floating point variable.
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003178 *
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003179 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003180 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (e.g. -DMULT_TRANSPOSE1XW_WIDTH=2)
3181 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (e.g. -DMULT_INTERLEAVE4X4_HEIGHT=2)
3182 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
3183 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003184 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003185 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
3186 * The activation function is performed after the bias addition
3187 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003188 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
3189 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
3190 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
3191 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
3192 *
3193 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
3194 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
3195 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3196 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
3197 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3198 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
3199 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
3200 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
3201 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3202 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
3203 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3204 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003205 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
3206 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
3207 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
3208 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
3209 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
3210 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003211 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
3212 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
3213 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
3214 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
3215 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
3216 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
3217 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
3218 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003219 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003220 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
3221 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
3222 */
3223__kernel void gemm_mm_interleaved_transposed_f16_acc32(IMAGE_DECLARATION(src0),
3224 IMAGE_DECLARATION(src1),
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003225#if defined(BETA)
3226 IMAGE_DECLARATION(src2),
3227#endif // defined(BETA)
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003228 IMAGE_DECLARATION(dst),
3229 uint src0_stride_z,
3230 uint src1_stride_z,
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003231#if defined(BETA)
3232 uint src2_stride_z,
3233#endif //defined(BETA)
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003234 uint dst_stride_z
3235#if defined(REINTERPRET_OUTPUT_AS_3D)
3236 ,
3237 uint cross_plane_pad
3238#endif // REINTERPRET_OUTPUT_AS_3D
3239 )
3240{
3241 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
3242 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
3243 int z = get_global_id(2);
3244
3245 // Offset
3246 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
3247 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
3248
3249 // src_addr_a = address of matrix A
3250 // src_addr_b = address of matrix B
3251 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
3252 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
3253
3254#if defined(MATRIX_B_DEPTH)
3255 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
3256 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
3257#else // defined(MATRIX_B_DEPTH)
3258 src1_addr_in_bytes += z * src1_stride_z;
3259#endif // defined(MATRIX_B_DEPTH)
3260
3261 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
3262 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
3263
3264 // Compute end row address for matrix B
3265 __global half *src_end_addr_b = src_addr_b + COLS_B;
3266
3267 src_addr_a += offset_row_a;
3268 src_addr_b += offset_row_b;
3269
3270 // Reset accumulators
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003271 float8 c0 = 0.0f;
3272 float8 c1 = 0.0f;
3273 float8 c2 = 0.0f;
3274 float8 c3 = 0.0f;
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003275
3276 for(; src_addr_b <= (src_end_addr_b - (int)(16 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 16 * MULT_TRANSPOSE1XW_WIDTH)
3277 {
3278 // Load values from matrix A (interleaved) and matrix B (transposed)
3279 float4 a0 = convert_float4(vload4(0, src_addr_a));
3280 float8 b0 = convert_float8(vload8(0, src_addr_b));
3281
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003282 c0 += (float8)a0.s0 * b0;
3283 c1 += (float8)a0.s1 * b0;
3284 c2 += (float8)a0.s2 * b0;
3285 c3 += (float8)a0.s3 * b0;
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003286
3287 // Load values from matrix A (interleaved) and matrix B (transposed)
3288 a0 = convert_float4(vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT));
3289 b0 = convert_float8(vload8(0, src_addr_b + 8 * MULT_TRANSPOSE1XW_WIDTH));
3290
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003291 c0 += (float8)a0.s0 * b0;
3292 c1 += (float8)a0.s1 * b0;
3293 c2 += (float8)a0.s2 * b0;
3294 c3 += (float8)a0.s3 * b0;
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003295 }
3296
3297 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
3298 {
3299 // Load values from matrix A (interleaved) and matrix B (transposed)
3300 float4 a0 = convert_float4(vload4(0, src_addr_a));
3301 float8 b0 = convert_float8(vload8(0, src_addr_b));
3302
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003303 c0 += (float8)a0.s0 * b0;
3304 c1 += (float8)a0.s1 * b0;
3305 c2 += (float8)a0.s2 * b0;
3306 c3 += (float8)a0.s3 * b0;
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003307 }
3308
3309 // Compute destination address
3310 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
3311
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003312 // Compute dst address
3313 __global uchar *dst_addr = offset(&dst, 0, 0);
3314
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003315 uint4 zout = 0;
3316
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003317#if defined(REINTERPRET_OUTPUT_AS_3D)
3318 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
3319 // in order to take into account the presence of possible cross plane paddings
3320 //
3321 // | |
3322 // | plane0 |
3323 // | |
3324 // |__________________|
3325 // |******************|
3326 // | cross_plane_pad |
3327 // |******************|
3328 // | |
3329 // | plane1 |
3330 // | |
3331 // |__________________|
3332
3333 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003334 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
3335 zout = min(DEPTH_GEMM3D - 1, zout);
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003336
3337 // Add offset due to the cross plane paddings
3338 zout *= (cross_plane_pad * dst_stride_y);
3339
3340 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3341 // multiply dst_stride_z by DEPTH_GEMM3D
3342 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003343#else // defined(REINTERPRET_OUTPUT_AS_3D)
3344 // Add offset for batched GEMM
3345 dst_addr += z * dst_stride_z;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003346#endif // defined(REINTERPRET_OUTPUT_AS_3D)
3347
3348 // Multiply by the weight of matrix-matrix product and store the result
3349#if defined(ALPHA)
3350 SCALE_BLOCK(4, float, c, ALPHA);
3351#endif // defined(ALPHA)
3352
3353#if defined(BETA)
3354 REPEAT_VAR_INIT_TO_CONST(4, uint, zero, 0);
3355
3356#if defined(BROADCAST_BIAS)
3357 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half));
3358
3359 LOAD_BLOCK(1, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
3360
3361 float8 bias_f0 = convert_float8(bias0);
3362
3363#ifndef UNIT_BETA
3364 SCALE_BLOCK(1, float, bias_f, BETA);
3365#endif // UNIT_BIAS
3366
3367 // c = c + bias[broadcasted]
3368 ADD_BLOCK_BROADCAST(4, c, bias_f0);
3369
3370#else // defined(BROADCAST_BIAS)
3371 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id(
3372 2) * src2_stride_z;
3373
3374 LOAD_BLOCK(4, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
3375
3376 float8 bias_f0 = convert_float8(bias0);
3377 float8 bias_f1 = convert_float8(bias1);
3378 float8 bias_f2 = convert_float8(bias2);
3379 float8 bias_f3 = convert_float8(bias3);
3380
3381#ifndef UNIT_BETA
3382 SCALE_BLOCK(4, float, bias_f, BETA);
3383#endif // UNIT_BIAS
3384
3385 // c = c + bias
3386 ADD_BLOCK(4, c, bias_f);
3387
3388#endif // defined(BROADCAST_BIAS)
3389#endif // defined(BETA)
3390
3391 half8 c_h0 = convert_half8(c0);
3392 half8 c_h1 = convert_half8(c1);
3393 half8 c_h2 = convert_half8(c2);
3394 half8 c_h3 = convert_half8(c3);
3395
3396#if defined(ACTIVATION_TYPE)
3397 ACTIVATION_BLOCK(4, ACTIVATION_TYPE, half, c_h, A_VAL, B_VAL);
3398#endif // defined(ACTIVATION_TYPE)
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003399
3400 // Store 4x8 block
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003401 vstore8(c_h0, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
3402 vstore8(c_h1, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
3403 vstore8(c_h2, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
3404 vstore8(c_h3, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003405}
3406
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003407/** This OpenCL kernel optimized for Bifrost architectures computes the matrix multiplication between matrix A reshaped (src0) and matrix B reshaped (src1)
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003408 *
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003409 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003410 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (e.g. -DMULT_TRANSPOSE1XW_WIDTH=2)
3411 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (e.g. -DMULT_INTERLEAVE4X4_HEIGHT=2)
3412 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
3413 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003414 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003415 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
3416 * The activation function is performed after the bias addition
3417 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003418 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
3419 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
3420 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
3421 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
3422 *
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003423 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
3424 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
3425 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3426 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
3427 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3428 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
3429 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
3430 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
3431 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3432 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
3433 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3434 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003435 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
3436 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
3437 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
3438 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
3439 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
3440 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003441 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
3442 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
3443 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
3444 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
3445 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
3446 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003447 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
3448 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
3449 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003450 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003451 */
3452__kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0),
3453 IMAGE_DECLARATION(src1),
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003454#if defined(BETA)
3455 IMAGE_DECLARATION(src2),
3456#endif // defined(BETA)
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003457 IMAGE_DECLARATION(dst),
3458 uint src0_stride_z,
3459 uint src1_stride_z,
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003460#if defined(BETA)
3461 uint src2_stride_z,
3462#endif //defined(BETA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003463 uint dst_stride_z
3464#if defined(REINTERPRET_OUTPUT_AS_3D)
3465 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003466 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003467#endif // REINTERPRET_OUTPUT_AS_3D
3468 )
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003469{
3470 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
3471 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
3472 int z = get_global_id(2);
3473
3474 // Offset
3475 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
3476 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
3477
3478 // src_addr_a = address of matrix A
3479 // src_addr_b = address of matrix B
3480 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
3481 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
3482
3483#if defined(MATRIX_B_DEPTH)
3484 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
3485 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
3486#else // defined(MATRIX_B_DEPTH)
3487 src1_addr_in_bytes += z * src1_stride_z;
3488#endif // defined(MATRIX_B_DEPTH)
3489
3490 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
3491 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
3492
3493 // Compute end row address for matrix B
3494 __global half *src_end_addr_b = src_addr_b + COLS_B;
3495
3496 src_addr_a += offset_row_a;
3497 src_addr_b += offset_row_b;
3498
3499 // Reset accumulators
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003500 half8 c0 = 0.0f;
3501 half8 c1 = 0.0f;
3502 half8 c2 = 0.0f;
3503 half8 c3 = 0.0f;
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003504
3505#define COLS_MTX_B (COLS_B / (8 * MULT_TRANSPOSE1XW_WIDTH))
3506
3507 int i = 0;
3508 for(; i <= (int)(COLS_MTX_B - 4); i += 4)
3509 {
3510#if MULT_INTERLEAVE4X4_HEIGHT == 1
3511 // Load values from matrix A (interleaved) and matrix B (transposed)
3512 half8 a0 = vload8(0, src_addr_a);
3513 half8 b0 = vload8(0, src_addr_b);
3514
3515 src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT;
3516 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3517
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003518 c0 = fma((half8)a0.s0, b0, c0);
3519 c1 = fma((half8)a0.s1, b0, c1);
3520 c2 = fma((half8)a0.s2, b0, c2);
3521 c3 = fma((half8)a0.s3, b0, c3);
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003522
3523 // Load values from matrix B (transposed)
3524 b0 = vload8(0, src_addr_b);
3525
3526 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3527
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003528 c0 = fma((half8)a0.s4, b0, c0);
3529 c1 = fma((half8)a0.s5, b0, c1);
3530 c2 = fma((half8)a0.s6, b0, c2);
3531 c3 = fma((half8)a0.s7, b0, c3);
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003532
3533 // Load values from matrix A (interleaved) and matrix B (transposed)
3534 a0 = vload8(0, src_addr_a);
3535 b0 = vload8(0, src_addr_b);
3536
3537 src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT;
3538 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3539
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003540 c0 = fma((half8)a0.s0, b0, c0);
3541 c1 = fma((half8)a0.s1, b0, c1);
3542 c2 = fma((half8)a0.s2, b0, c2);
3543 c3 = fma((half8)a0.s3, b0, c3);
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003544
3545 // Load values from matrix B (transposed)
3546 b0 = vload8(0, src_addr_b);
3547
3548 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3549
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003550 c0 = fma((half8)a0.s4, b0, c0);
3551 c1 = fma((half8)a0.s5, b0, c1);
3552 c2 = fma((half8)a0.s6, b0, c2);
3553 c3 = fma((half8)a0.s7, b0, c3);
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003554#else // MULT_INTERLEAVE4X4_HEIGHT == 1
3555 // Load values from matrix A (interleaved) and matrix B (transposed)
3556 half4 a0 = vload4(0, src_addr_a);
3557 half8 b0 = vload8(0, src_addr_b);
3558
3559 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3560 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3561
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003562 c0 = fma((half8)a0.s0, b0, c0);
3563 c1 = fma((half8)a0.s1, b0, c1);
3564 c2 = fma((half8)a0.s2, b0, c2);
3565 c3 = fma((half8)a0.s3, b0, c3);
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003566
3567 // Load values from matrix A (interleaved) and matrix B (transposed)
3568 a0 = vload4(0, src_addr_a);
3569 b0 = vload8(0, src_addr_b);
3570
3571 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3572 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3573
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003574 c0 = fma((half8)a0.s0, b0, c0);
3575 c1 = fma((half8)a0.s1, b0, c1);
3576 c2 = fma((half8)a0.s2, b0, c2);
3577 c3 = fma((half8)a0.s3, b0, c3);
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003578
3579 // Load values from matrix A (interleaved) and matrix B (transposed)
3580 a0 = vload4(0, src_addr_a);
3581 b0 = vload8(0, src_addr_b);
3582
3583 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3584 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3585
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003586 c0 = fma((half8)a0.s0, b0, c0);
3587 c1 = fma((half8)a0.s1, b0, c1);
3588 c2 = fma((half8)a0.s2, b0, c2);
3589 c3 = fma((half8)a0.s3, b0, c3);
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003590
3591 // Load values from matrix A (interleaved) and matrix B (transposed)
3592 a0 = vload4(0, src_addr_a);
3593 b0 = vload8(0, src_addr_b);
3594
3595 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3596 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3597
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003598 c0 = fma((half8)a0.s0, b0, c0);
3599 c1 = fma((half8)a0.s1, b0, c1);
3600 c2 = fma((half8)a0.s2, b0, c2);
3601 c3 = fma((half8)a0.s3, b0, c3);
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003602#endif // MULT_INTERLEAVE4X4_HEIGHT == 1
3603 }
3604
3605 for(; i < (int)(COLS_MTX_B); ++i)
3606 {
3607 // Load values from matrix A (interleaved) and matrix B (transposed)
3608 half4 a0 = vload4(0, src_addr_a);
3609 half8 b0 = vload8(0, src_addr_b);
3610
3611 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3612 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3613
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003614 c0 = fma((half8)a0.s0, b0, c0);
3615 c1 = fma((half8)a0.s1, b0, c1);
3616 c2 = fma((half8)a0.s2, b0, c2);
3617 c3 = fma((half8)a0.s3, b0, c3);
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003618 }
3619
3620 // Compute destination address
3621 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
3622
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003623 // Compute dst address
3624 __global uchar *dst_addr = offset(&dst, 0, 0);
3625
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003626 uint4 zout = 0;
3627
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003628#if defined(REINTERPRET_OUTPUT_AS_3D)
3629 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003630 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003631 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003632 // | |
3633 // | plane0 |
3634 // | |
3635 // |__________________|
3636 // |******************|
3637 // | cross_plane_pad |
3638 // |******************|
3639 // | |
3640 // | plane1 |
3641 // | |
3642 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003643
3644 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003645 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
3646 zout = min(DEPTH_GEMM3D - 1, zout);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003647
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003648 // Add offset due to the cross plane paddings
3649 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003650
3651 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3652 // multiply dst_stride_z by DEPTH_GEMM3D
3653 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003654#else // defined(REINTERPRET_OUTPUT_AS_3D)
3655 // Add offset for batched GEMM
3656 dst_addr += z * dst_stride_z;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003657#endif // defined(REINTERPRET_OUTPUT_AS_3D)
3658
3659 // Multiply by the weight of matrix-matrix product and store the result
3660#if defined(ALPHA)
3661 SCALE_BLOCK(4, half, c, ALPHA);
3662#endif // defined(ALPHA)
3663
3664 // Add beta*bias
3665#if defined(BETA)
3666 REPEAT_VAR_INIT_TO_CONST(4, uint, zero, 0);
3667
3668#if defined(BROADCAST_BIAS)
3669 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half));
3670
3671 LOAD_BLOCK(1, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
3672
3673#ifndef UNIT_BETA
3674 SCALE_BLOCK(1, half, bias, BETA);
3675#endif // UNIT_BIAS
3676
3677 // c = c + bias[broadcasted]
3678 ADD_BLOCK_BROADCAST(4, c, bias0);
3679
3680#else // defined(BROADCAST_BIAS)
3681 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id(
3682 2) * src2_stride_z;
3683
3684 LOAD_BLOCK(4, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
3685
3686#ifndef UNIT_BETA
3687 SCALE_BLOCK(4, half, bias, BETA);
3688#endif // UNIT_BIAS
3689
3690 // c = c + bias
3691 ADD_BLOCK(4, c, bias);
3692
3693#endif // defined(BROADCAST_BIAS)
3694#endif // defined(BETA)
3695
3696#if defined(ACTIVATION_TYPE)
3697 ACTIVATION_BLOCK(4, ACTIVATION_TYPE, half, c, A_VAL, B_VAL);
3698#endif // defined(ACTIVATION_TYPE)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003699
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003700 // Store 4x8 block
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003701 vstore8(c0, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
3702 vstore8(c1, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
3703 vstore8(c2, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
3704 vstore8(c3, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003705}
Georgios Pinitas84225582018-05-14 12:00:05 +01003706
3707// Undefine local defines
3708#undef COLS_MTX_B
3709
Matthew Bentham6f31f8c2017-10-27 11:50:06 +01003710#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003711
Gian Marco36a0a462018-01-12 10:21:40 +00003712#endif // defined(COLS_B) && defined(MULT_TRANSPOSE1XW_WIDTH) && defined(MULT_INTERLEAVE4X4_HEIGHT)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003713
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003714#if defined(COLS_A) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && (NUM_ELEMS_PROCESSED_PER_THREAD_Y)
3715#if defined(DATA_TYPE)
3716#define VECTOR_TYPE VEC_DATA_TYPE(DATA_TYPE, NUM_ELEMS_PROCESSED_PER_THREAD_X)
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003717/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped.
3718 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003719 * @note This OpenCL kernel works with floating point data types (F16/F32)
3720 * @note The floating point data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
3721 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003722 * @note The number of matrix A columns and the optional alpha's value need to be passed at compile time using -DCOLS_A and -DALPHA
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003723 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
3724 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003725 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003726 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
3727 * The activation function is performed after the bias addition
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003728 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
3729 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003730 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
3731 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
3732 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
3733 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
3734 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003735 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003736 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
3737 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3738 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
3739 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3740 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003741 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003742 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
3743 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3744 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
3745 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3746 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003747 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
3748 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
3749 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
3750 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
3751 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
3752 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003753 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003754 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
3755 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
3756 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
3757 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
3758 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003759 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
3760 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003761 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003762 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003763 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
3764 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements for the output tensor (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003765 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003766__kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
3767 IMAGE_DECLARATION(src1),
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003768#if defined(BETA)
3769 IMAGE_DECLARATION(src2),
3770#endif // defined(BETA)
Gian Marcoae2af742018-02-15 12:35:44 +00003771 IMAGE_DECLARATION(dst),
3772 uint src0_stride_z,
3773 uint src1_stride_z,
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003774#if defined(BETA)
3775 uint src2_stride_z,
3776#endif //defined(BETA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003777 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003778#if defined(REINTERPRET_INPUT_AS_3D)
3779 ,
3780 uint src_cross_plane_pad
3781#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003782#if defined(REINTERPRET_OUTPUT_AS_3D)
3783 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003784 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003785#endif // REINTERPRET_OUTPUT_AS_3D
3786 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003787{
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003788 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003789
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003790 // Compute starting address for matrix A and Matrix B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003791 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003792
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003793 // Update address for the matrix A
3794 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003795
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003796 // Update address for the matrix B
3797 src_addr.s1 += idx * sizeof(DATA_TYPE);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003798
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003799#if defined(REINTERPRET_INPUT_AS_3D)
3800 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
3801 // in order to take into account the presence of possible cross plane paddings
3802 //
3803 // | |
3804 // | plane0 |
3805 // | |
3806 // |__________________|
3807 // |******************|
3808 // | cross_plane_pad |
3809 // |******************|
3810 // | |
3811 // | plane1 |
3812 // | |
3813 // |__________________|
3814
3815 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
3816 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
3817 zin = min(DEPTH_GEMM3D - 1, zin);
3818
3819 // Add offset due to the cross plane paddings
3820 zin *= (src_cross_plane_pad * src0_stride_y);
3821
3822 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3823 // multiply src0_stride_z by DEPTH_GEMM3D
3824 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
3825
3826#else // defined(REINTERPRET_INPUT_AS_3D)
3827
Gian Marcoae2af742018-02-15 12:35:44 +00003828 // Add offset for batched GEMM
3829 src_addr.s0 += get_global_id(2) * src0_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003830
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003831#endif // defined(REINTERPRET_INPUT_AS_3D)
3832
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003833#if defined(MATRIX_B_DEPTH)
3834 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
3835 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
3836#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00003837 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003838#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00003839
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003840 int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(DATA_TYPE));
3841
3842 VECTOR_TYPE acc0 = 0.0f;
3843#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3844 VECTOR_TYPE acc1 = 0.0f;
3845#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3846#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3847 VECTOR_TYPE acc2 = 0.0f;
3848#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3849#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3850 VECTOR_TYPE acc3 = 0.0f;
3851#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3852
Georgios Pinitas96880cf2017-10-20 18:52:20 +01003853 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(DATA_TYPE)); src_addr += (int2)(2 * sizeof(DATA_TYPE), 2 * src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003854 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003855#if defined(REINTERPRET_INPUT_AS_3D)
3856 // Load values from matrix A
Usama Arif0681e3b2019-04-25 14:28:07 +01003857 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 2, DATA_TYPE, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s);
3858#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003859 // Load values from matrix A
3860 VEC_DATA_TYPE(DATA_TYPE, 2)
3861 a0 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
3862#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3863 VEC_DATA_TYPE(DATA_TYPE, 2)
3864 a1 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
3865#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3866#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3867 VEC_DATA_TYPE(DATA_TYPE, 2)
3868 a2 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
3869#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3870#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3871 VEC_DATA_TYPE(DATA_TYPE, 2)
3872 a3 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
3873#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003874#endif // defined(REINTERPRET_INPUT_AS_3D)
3875
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003876 // Load values from matrix B
3877 VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
3878 VECTOR_TYPE b1 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1 + src1_stride_y));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003879
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003880 // Accumulate
3881 acc0 += b0 * (VECTOR_TYPE)a0.s0;
3882 acc0 += b1 * (VECTOR_TYPE)a0.s1;
3883#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3884 acc1 += b0 * (VECTOR_TYPE)a1.s0;
3885 acc1 += b1 * (VECTOR_TYPE)a1.s1;
3886#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3887#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3888 acc2 += b0 * (VECTOR_TYPE)a2.s0;
3889 acc2 += b1 * (VECTOR_TYPE)a2.s1;
3890#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3891#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3892 acc3 += b0 * (VECTOR_TYPE)a3.s0;
3893 acc3 += b1 * (VECTOR_TYPE)a3.s1;
3894#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003895 }
3896
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003897 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(DATA_TYPE), src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003898 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003899#if defined(REINTERPRET_INPUT_AS_3D)
3900 // Load values from matrix A
3901 DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
3902#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3903 DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
3904#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3905#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3906 DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
3907#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3908#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3909 DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
3910#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3911#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003912 // Load values from matrix A
3913 DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
3914#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3915 DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
3916#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3917#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3918 DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
3919#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3920#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3921 DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
3922#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003923#endif // defined(REINTERPRET_INPUT_AS_3D)
3924
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003925 // Load values from matrix B
3926 VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003927
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003928 // Accumulate
3929 acc0 += b0 * (VECTOR_TYPE)a0;
3930#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3931 acc1 += b0 * (VECTOR_TYPE)a1;
3932#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3933#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3934 acc2 += b0 * (VECTOR_TYPE)a2;
3935#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3936#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3937 acc3 += b0 * (VECTOR_TYPE)a3;
3938#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003939 }
3940
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003941 int z = get_global_id(2);
3942
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003943 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003944 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
3945
Gian Marcoae2af742018-02-15 12:35:44 +00003946 // Compute dst address
3947 __global uchar *dst_addr = offset(&dst, 0, 0);
3948
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003949 uint4 zout = 0;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003950
3951#if defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003952
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003953 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003954 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003955 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003956 // | |
3957 // | plane0 |
3958 // | |
3959 // |__________________|
3960 // |******************|
3961 // | cross_plane_pad |
3962 // |******************|
3963 // | |
3964 // | plane1 |
3965 // | |
3966 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003967
3968 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003969 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
3970 zout = min(DEPTH_GEMM3D - 1, zout);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003971
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003972 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003973 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003974
3975 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3976 // multiply dst_stride_z by DEPTH_GEMM3D
3977 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003978#else // defined(REINTERPRET_OUTPUT_AS_3D)
3979 // Add offset for batched GEMM
3980 dst_addr += z * dst_stride_z;
3981#endif // defined(REINTERPRET_OUTPUT_AS_3D)
3982
3983 // Multiply by the weight of matrix-matrix product and store the result
3984#if defined(ALPHA)
3985 SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, DATA_TYPE, acc, ALPHA);
3986#endif // defined(ALPHA)
3987
3988 // Add beta*bias
3989#if defined(BETA)
3990 REPEAT_VAR_INIT_TO_CONST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, uint, zero, 0);
3991
3992#if defined(BROADCAST_BIAS)
3993 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)NUM_ELEMS_PROCESSED_PER_THREAD_X * sizeof(DATA_TYPE));
3994
3995 LOAD_BLOCK(1, NUM_ELEMS_PROCESSED_PER_THREAD_X, DATA_TYPE, bias, src2_addr, 0, src2_stride_y, zero);
3996
3997#ifndef UNIT_BETA
3998 SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
3999#endif // UNIT_BIAS
4000
4001 // c = c + bias[broadcasted]
4002 ADD_BLOCK_BROADCAST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias0);
4003
4004#else // defined(BROADCAST_BIAS)
4005 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)NUM_ELEMS_PROCESSED_PER_THREAD_X * sizeof(DATA_TYPE)) + (get_global_id(1) *
4006 (uint)NUM_ELEMS_PROCESSED_PER_THREAD_Y * src2_stride_y) + get_global_id(2) * src2_stride_z;
4007
4008 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, NUM_ELEMS_PROCESSED_PER_THREAD_X, DATA_TYPE, bias, src2_addr, 0, src2_stride_y, zero);
4009
4010#ifndef UNIT_BETA
4011 SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, DATA_TYPE, bias, BETA);
4012#endif // UNIT_BIAS
4013
4014 // c = c + bias
4015 ADD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias);
4016
4017#endif // defined(BROADCAST_BIAS)
4018#endif // defined(BETA)
4019
4020#if defined(ACTIVATION_TYPE)
4021 ACTIVATION_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, ACTIVATION_TYPE, DATA_TYPE, acc, A_VAL, B_VAL);
4022#endif // defined(ACTIVATION_TYPE)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004023
4024 // Store output block
Usama Arif0681e3b2019-04-25 14:28:07 +01004025 STORE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, NUM_ELEMS_PROCESSED_PER_THREAD_X, DATA_TYPE, acc, dst_addr, dst_stride_y, zout.s);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004026}
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004027#endif // defined(DATA_TYPE)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01004028
Michele Di Giorgiof6f08da2018-04-26 10:24:30 +01004029/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004030 *
4031 * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
4032 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
4033 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
4034 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
4035 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004036 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
4037 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004038 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004039 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
4040 * The activation function is performed after the bias addition
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004041 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
4042 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004043 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
4044 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
4045 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
4046 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
4047 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004048 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004049 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
4050 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4051 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
4052 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4053 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
4054 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
4055 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
4056 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4057 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
4058 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4059 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004060 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
4061 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
4062 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
4063 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
4064 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
4065 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004066 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
4067 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
4068 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
4069 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
4070 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
4071 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004072 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
4073 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004074 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004075 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004076 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
4077 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004078 */
4079__kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
4080 IMAGE_DECLARATION(src1),
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004081#if defined(BETA)
4082 IMAGE_DECLARATION(src2),
4083#endif // defined(BETA)
Gian Marcoae2af742018-02-15 12:35:44 +00004084 IMAGE_DECLARATION(dst),
4085 uint src0_stride_z,
4086 uint src1_stride_z,
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004087#if defined(BETA)
4088 uint src2_stride_z,
4089#endif //defined(BETA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004090 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004091#if defined(REINTERPRET_INPUT_AS_3D)
4092 ,
4093 uint src_cross_plane_pad
4094#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004095#if defined(REINTERPRET_OUTPUT_AS_3D)
4096 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004097 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004098#endif // REINTERPRET_OUTPUT_AS_3D
4099 )
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004100{
4101 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
4102
4103 // Compute starting address for matrix A and matrix B
4104 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
4105
4106 // Update address for matrix A
4107 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
4108
4109 // Update address for matrix B
4110 src_addr.s1 += idx * sizeof(float);
4111
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004112#if defined(REINTERPRET_INPUT_AS_3D)
4113 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
4114 // in order to take into account the presence of possible cross plane paddings
4115 //
4116 // | |
4117 // | plane0 |
4118 // | |
4119 // |__________________|
4120 // |******************|
4121 // | cross_plane_pad |
4122 // |******************|
4123 // | |
4124 // | plane1 |
4125 // | |
4126 // |__________________|
4127
4128 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
4129 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
4130 zin = min(DEPTH_GEMM3D - 1, zin);
4131
4132 // Add offset due to the cross plane paddings
4133 zin *= (src_cross_plane_pad * src0_stride_y);
4134
4135 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4136 // multiply src0_stride_z by DEPTH_GEMM3D
4137 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
4138
4139#else // defined(REINTERPRET_INPUT_AS_3D)
4140
Gian Marcoae2af742018-02-15 12:35:44 +00004141 // Add offset for batched GEMM
4142 src_addr.s0 += get_global_id(2) * src0_stride_z;
4143
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004144#endif // defined(REINTERPRET_INPUT_AS_3D)
4145
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00004146#if defined(MATRIX_B_DEPTH)
4147 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
4148 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
4149#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00004150 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00004151#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00004152
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004153 // Initialize accumulators
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004154 float4 acc0 = 0.0f;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004155
4156#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004157 float4 acc1 = 0.0f;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004158#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4159
4160#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004161 float4 acc2 = 0.0f;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004162#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4163
4164#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004165 float4 acc3 = 0.0f;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004166#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4167
4168 // A and B src indices get incremented at the same time.
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004169 int i = 0;
4170 for(; i <= ((int)COLS_A - 4); i += 4)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004171 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004172#if defined(REINTERPRET_INPUT_AS_3D)
4173 // Load values from matrix A and matrix B
Usama Arif0681e3b2019-04-25 14:28:07 +01004174 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 4, float, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s);
4175#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004176 // Load values from matrix A and matrix B
4177 float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004178#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004179 float4 a1 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004180#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4181#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004182 float4 a2 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004183#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4184#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004185 float4 a3 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004186#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004187#endif // defined(REINTERPRET_INPUT_AS_3D)
4188
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004189 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
4190 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004191
4192 // Multiply and accumulate
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004193 acc0.s0 = fma(a0.s0, b0.s0, acc0.s0);
4194 acc0.s1 = fma(a0.s0, b0.s1, acc0.s1);
4195 acc0.s2 = fma(a0.s0, b0.s2, acc0.s2);
4196 acc0.s3 = fma(a0.s0, b0.s3, acc0.s3);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004197
4198#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004199
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004200 acc1.s0 = fma(a1.s0, b0.s0, acc1.s0);
4201 acc1.s1 = fma(a1.s0, b0.s1, acc1.s1);
4202 acc1.s2 = fma(a1.s0, b0.s2, acc1.s2);
4203 acc1.s3 = fma(a1.s0, b0.s3, acc1.s3);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004204
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004205#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4206#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004207
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004208 acc2.s0 = fma(a2.s0, b0.s0, acc2.s0);
4209 acc2.s1 = fma(a2.s0, b0.s1, acc2.s1);
4210 acc2.s2 = fma(a2.s0, b0.s2, acc2.s2);
4211 acc2.s3 = fma(a2.s0, b0.s3, acc2.s3);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004212
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004213#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4214#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004215
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004216 acc3.s0 = fma(a3.s0, b0.s0, acc3.s0);
4217 acc3.s1 = fma(a3.s0, b0.s1, acc3.s1);
4218 acc3.s2 = fma(a3.s0, b0.s2, acc3.s2);
4219 acc3.s3 = fma(a3.s0, b0.s3, acc3.s3);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004220#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004221
4222 // Load values from matrix A and matrix B
4223 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
4224 src_addr.s1 += src1_stride_y;
4225
4226 // Multiply and accumulate
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004227 acc0.s0 = fma(a0.s1, b0.s0, acc0.s0);
4228 acc0.s1 = fma(a0.s1, b0.s1, acc0.s1);
4229 acc0.s2 = fma(a0.s1, b0.s2, acc0.s2);
4230 acc0.s3 = fma(a0.s1, b0.s3, acc0.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004231
4232#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4233
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004234 acc1.s0 = fma(a1.s1, b0.s0, acc1.s0);
4235 acc1.s1 = fma(a1.s1, b0.s1, acc1.s1);
4236 acc1.s2 = fma(a1.s1, b0.s2, acc1.s2);
4237 acc1.s3 = fma(a1.s1, b0.s3, acc1.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004238
4239#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4240#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4241
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004242 acc2.s0 = fma(a2.s1, b0.s0, acc2.s0);
4243 acc2.s1 = fma(a2.s1, b0.s1, acc2.s1);
4244 acc2.s2 = fma(a2.s1, b0.s2, acc2.s2);
4245 acc2.s3 = fma(a2.s1, b0.s3, acc2.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004246
4247#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4248#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4249
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004250 acc3.s0 = fma(a3.s1, b0.s0, acc3.s0);
4251 acc3.s1 = fma(a3.s1, b0.s1, acc3.s1);
4252 acc3.s2 = fma(a3.s1, b0.s2, acc3.s2);
4253 acc3.s3 = fma(a3.s1, b0.s3, acc3.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004254#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4255
4256 // Load values from matrix A and matrix B
4257 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
4258 src_addr.s1 += src1_stride_y;
4259
4260 // Multiply and accumulate
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004261 acc0.s0 = fma(a0.s2, b0.s0, acc0.s0);
4262 acc0.s1 = fma(a0.s2, b0.s1, acc0.s1);
4263 acc0.s2 = fma(a0.s2, b0.s2, acc0.s2);
4264 acc0.s3 = fma(a0.s2, b0.s3, acc0.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004265
4266#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4267
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004268 acc1.s0 = fma(a1.s2, b0.s0, acc1.s0);
4269 acc1.s1 = fma(a1.s2, b0.s1, acc1.s1);
4270 acc1.s2 = fma(a1.s2, b0.s2, acc1.s2);
4271 acc1.s3 = fma(a1.s2, b0.s3, acc1.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004272
4273#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4274#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4275
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004276 acc2.s0 = fma(a2.s2, b0.s0, acc2.s0);
4277 acc2.s1 = fma(a2.s2, b0.s1, acc2.s1);
4278 acc2.s2 = fma(a2.s2, b0.s2, acc2.s2);
4279 acc2.s3 = fma(a2.s2, b0.s3, acc2.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004280
4281#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4282#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4283
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004284 acc3.s0 = fma(a3.s2, b0.s0, acc3.s0);
4285 acc3.s1 = fma(a3.s2, b0.s1, acc3.s1);
4286 acc3.s2 = fma(a3.s2, b0.s2, acc3.s2);
4287 acc3.s3 = fma(a3.s2, b0.s3, acc3.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004288#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4289
4290 // Load values from matrix A and matrix B
4291 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
4292 src_addr.s1 += src1_stride_y;
4293
4294 // Multiply and accumulate
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004295 acc0.s0 = fma(a0.s3, b0.s0, acc0.s0);
4296 acc0.s1 = fma(a0.s3, b0.s1, acc0.s1);
4297 acc0.s2 = fma(a0.s3, b0.s2, acc0.s2);
4298 acc0.s3 = fma(a0.s3, b0.s3, acc0.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004299
4300#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4301
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004302 acc1.s0 = fma(a1.s3, b0.s0, acc1.s0);
4303 acc1.s1 = fma(a1.s3, b0.s1, acc1.s1);
4304 acc1.s2 = fma(a1.s3, b0.s2, acc1.s2);
4305 acc1.s3 = fma(a1.s3, b0.s3, acc1.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004306
4307#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4308#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4309
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004310 acc2.s0 = fma(a2.s3, b0.s0, acc2.s0);
4311 acc2.s1 = fma(a2.s3, b0.s1, acc2.s1);
4312 acc2.s2 = fma(a2.s3, b0.s2, acc2.s2);
4313 acc2.s3 = fma(a2.s3, b0.s3, acc2.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004314
4315#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4316#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4317
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004318 acc3.s0 = fma(a3.s3, b0.s0, acc3.s0);
4319 acc3.s1 = fma(a3.s3, b0.s1, acc3.s1);
4320 acc3.s2 = fma(a3.s3, b0.s2, acc3.s2);
4321 acc3.s3 = fma(a3.s3, b0.s3, acc3.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004322#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4323
4324 src_addr.s0 += 4 * sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004325 }
4326
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004327 for(; i < (int)COLS_A; ++i)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004328 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004329#if defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004330 // Load values from matrix A
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004331 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
4332#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4333 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
4334#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4335#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4336 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
4337#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4338#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4339 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
4340#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4341#else // defined(REINTERPRET_INPUT_AS_3D)
4342 // Load values from matrix A
4343 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004344#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4345 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
4346#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4347#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4348 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
4349#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4350#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4351 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
4352#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004353#endif // defined(REINTERPRET_INPUT_AS_3D)
4354
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004355 // Load values from matrix B
4356 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004357 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004358
4359 // Multiply and accumulate
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004360 acc0.s0 = fma(a0, b0.s0, acc0.s0);
4361 acc0.s1 = fma(a0, b0.s1, acc0.s1);
4362 acc0.s2 = fma(a0, b0.s2, acc0.s2);
4363 acc0.s3 = fma(a0, b0.s3, acc0.s3);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004364#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004365 acc1.s0 = fma(a1, b0.s0, acc1.s0);
4366 acc1.s1 = fma(a1, b0.s1, acc1.s1);
4367 acc1.s2 = fma(a1, b0.s2, acc1.s2);
4368 acc1.s3 = fma(a1, b0.s3, acc1.s3);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004369#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4370#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004371 acc2.s0 = fma(a2, b0.s0, acc2.s0);
4372 acc2.s1 = fma(a2, b0.s1, acc2.s1);
4373 acc2.s2 = fma(a2, b0.s2, acc2.s2);
4374 acc2.s3 = fma(a2, b0.s3, acc2.s3);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004375#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4376#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004377 acc3.s0 = fma(a3, b0.s0, acc3.s0);
4378 acc3.s1 = fma(a3, b0.s1, acc3.s1);
4379 acc3.s2 = fma(a3, b0.s2, acc3.s2);
4380 acc3.s3 = fma(a3, b0.s3, acc3.s3);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004381#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004382
4383 src_addr.s0 += sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004384 }
4385
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004386 int z = get_global_id(2);
4387
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004388 // Compute destination address
4389 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
4390
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004391 // Compute dst address
4392 __global uchar *dst_addr = offset(&dst, 0, 0);
4393
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004394 uint4 zout = 0;
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004395
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004396#if defined(REINTERPRET_OUTPUT_AS_3D)
4397 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004398 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004399 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004400 // | |
4401 // | plane0 |
4402 // | |
4403 // |__________________|
4404 // |******************|
4405 // | cross_plane_pad |
4406 // |******************|
4407 // | |
4408 // | plane1 |
4409 // | |
4410 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004411
4412 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004413 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
4414 zout = min(DEPTH_GEMM3D - 1, zout);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004415
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004416 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004417 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004418
4419 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4420 // multiply dst_stride_z by DEPTH_GEMM3D
4421 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004422#else // defined(REINTERPRET_OUTPUT_AS_3D)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004423 // Add offset for batched GEMM
4424 dst_addr += z * dst_stride_z;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004425#endif // defined(REINTERPRET_OUTPUT_AS_3D)
4426
4427 // Multiply by the weight of matrix-matrix product and store the result
4428#if defined(ALPHA)
4429 SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, float, acc, ALPHA);
4430#endif // defined(ALPHA)
4431
4432 // Add beta*bias
4433#if defined(BETA)
4434 REPEAT_VAR_INIT_TO_CONST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, uint, zero, 0);
4435
4436#if defined(BROADCAST_BIAS)
4437 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float));
4438
4439 LOAD_BLOCK(1, 4, float, bias, src2_addr, 0, src2_stride_y, zero);
4440
4441#ifndef UNIT_BETA
4442 SCALE_BLOCK(1, float, bias, BETA);
4443#endif // UNIT_BIAS
4444
4445 // acc = acc + bias[broadcasted]
4446 ADD_BLOCK_BROADCAST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias0);
4447
4448#else // defined(BROADCAST_BIAS)
4449 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float)) + (get_global_id(1) *
4450 (uint)NUM_ELEMS_PROCESSED_PER_THREAD_Y * src2_stride_y) + get_global_id(2) * src2_stride_z;
4451
4452 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 4, float, bias, src2_addr, 0, src2_stride_y, zero);
4453
4454#ifndef UNIT_BETA
4455 SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, float, bias, BETA);
4456#endif // UNIT_BIAS
4457
4458 // acc = acc + bias
4459 ADD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias);
4460
4461#endif // defined(BROADCAST_BIAS)
4462#endif // defined(BETA)
4463
4464#if defined(ACTIVATION_TYPE)
4465 ACTIVATION_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, ACTIVATION_TYPE, float, acc, A_VAL, B_VAL);
4466#endif // defined(ACTIVATION_TYPE)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004467
4468 // Store the output block
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004469 vstore4(acc0, 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004470#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004471 vstore4(acc1, 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004472#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4473#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004474 vstore4(acc2, 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004475#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4476#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004477 vstore4(acc3, 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004478#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004479}
4480
4481/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
4482 *
4483 * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
4484 * This OpenCL kernel is optimized for Bifrost when the number of matrix B columns is less or equal to 1000.
4485 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
4486 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=2.
4487 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
4488 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha if alpha!=1.0f.
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004489 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
4490 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004491 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004492 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
4493 * The activation function is performed after the bias addition
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004494 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
4495 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004496 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
4497 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
4498 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
4499 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
4500 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004501 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004502 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
4503 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4504 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
4505 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4506 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
4507 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
4508 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
4509 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4510 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
4511 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4512 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004513 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
4514 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
4515 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
4516 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
4517 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
4518 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004519 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
4520 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
4521 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
4522 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
4523 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
4524 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004525 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
4526 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004527 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004528 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004529 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
4530 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004531 */
4532__kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
4533 IMAGE_DECLARATION(src1),
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004534#if defined(BETA)
4535 IMAGE_DECLARATION(src2),
4536#endif // defined(BETA)
Gian Marcoae2af742018-02-15 12:35:44 +00004537 IMAGE_DECLARATION(dst),
4538 uint src0_stride_z,
4539 uint src1_stride_z,
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004540#if defined(BETA)
4541 uint src2_stride_z,
4542#endif //defined(BETA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004543 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004544#if defined(REINTERPRET_INPUT_AS_3D)
4545 ,
4546 uint src_cross_plane_pad
4547#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004548#if defined(REINTERPRET_OUTPUT_AS_3D)
4549 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004550 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004551#endif // REINTERPRET_OUTPUT_AS_3D
4552 )
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004553{
4554 // Requires 2 NUM_ELEMS_PROCESSED_PER_THREAD_X, C vect2, A vect4, B (2 vload2) // to fix for NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4555 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
4556
4557 // Compute starting address for matrix A and Matrix B
4558 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
4559
4560 // Update address for the matrix A
4561 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
4562
4563 // Update address for the matrix B
4564 src_addr.s1 += idx * sizeof(float);
4565
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004566#if defined(REINTERPRET_INPUT_AS_3D)
4567 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
4568 // in order to take into account the presence of possible cross plane paddings
4569 //
4570 // | |
4571 // | plane0 |
4572 // | |
4573 // |__________________|
4574 // |******************|
4575 // | cross_plane_pad |
4576 // |******************|
4577 // | |
4578 // | plane1 |
4579 // | |
4580 // |__________________|
4581
4582 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
4583 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
4584 zin = min(DEPTH_GEMM3D - 1, zin);
4585
4586 // Add offset due to the cross plane paddings
4587 zin *= (src_cross_plane_pad * src0_stride_y);
4588
4589 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4590 // multiply src0_stride_z by DEPTH_GEMM3D
4591 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
4592
4593#else // defined(REINTERPRET_INPUT_AS_3D)
4594
Gian Marcoae2af742018-02-15 12:35:44 +00004595 // Add offset for batched GEMM
4596 src_addr.s0 += get_global_id(2) * src0_stride_z;
4597
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004598#endif // defined(REINTERPRET_INPUT_AS_3D)
4599
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00004600#if defined(MATRIX_B_DEPTH)
4601 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
4602 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
4603#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00004604 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00004605#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00004606
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004607 // Initialize accumulators
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004608 float2 acc0 = 0.0f;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004609#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004610 float2 acc1 = 0.0f;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004611#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4612#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004613 float2 acc2 = 0.0f;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004614#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4615#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004616 float2 acc3 = 0.0f;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004617#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4618
4619 // A and B src indices get incremented at the same time.
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004620 int i = 0;
4621 for(; i <= ((int)COLS_A - 8); i += 8)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004622 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004623#if defined(REINTERPRET_INPUT_AS_3D)
4624 // Load values from matrix A
4625 float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + zin.s0));
4626#else // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004627 // Load values from matrix A
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004628 float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0));
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004629#endif // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004630
4631 // Load values from matrix B
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004632 float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
4633 src_addr.s1 += src1_stride_y;
4634 float2 b1 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
4635 src_addr.s1 += src1_stride_y;
4636 float2 b2 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
4637 src_addr.s1 += src1_stride_y;
4638 float2 b3 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
4639 src_addr.s1 += src1_stride_y;
4640 float2 b4 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
4641 src_addr.s1 += src1_stride_y;
4642 float2 b5 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
4643 src_addr.s1 += src1_stride_y;
4644 float2 b6 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
4645 src_addr.s1 += src1_stride_y;
4646 float2 b7 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
4647 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004648
4649 // Multiply and accumulate
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004650 acc0.s0 = fma(a0.s0, b0.s0, acc0.s0);
4651 acc0.s0 = fma(a0.s1, b1.s0, acc0.s0);
4652 acc0.s0 = fma(a0.s2, b2.s0, acc0.s0);
4653 acc0.s0 = fma(a0.s3, b3.s0, acc0.s0);
4654 acc0.s0 = fma(a0.s4, b4.s0, acc0.s0);
4655 acc0.s0 = fma(a0.s5, b5.s0, acc0.s0);
4656 acc0.s0 = fma(a0.s6, b6.s0, acc0.s0);
4657 acc0.s0 = fma(a0.s7, b7.s0, acc0.s0);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004658
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004659 acc0.s1 = fma(a0.s0, b0.s1, acc0.s1);
4660 acc0.s1 = fma(a0.s1, b1.s1, acc0.s1);
4661 acc0.s1 = fma(a0.s2, b2.s1, acc0.s1);
4662 acc0.s1 = fma(a0.s3, b3.s1, acc0.s1);
4663 acc0.s1 = fma(a0.s4, b4.s1, acc0.s1);
4664 acc0.s1 = fma(a0.s5, b5.s1, acc0.s1);
4665 acc0.s1 = fma(a0.s6, b6.s1, acc0.s1);
4666 acc0.s1 = fma(a0.s7, b7.s1, acc0.s1);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004667
4668#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004669#if defined(REINTERPRET_INPUT_AS_3D)
4670 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
4671#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004672 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004673#endif // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004674 acc1.s0 = fma(a0.s0, b0.s0, acc1.s0);
4675 acc1.s0 = fma(a0.s1, b1.s0, acc1.s0);
4676 acc1.s0 = fma(a0.s2, b2.s0, acc1.s0);
4677 acc1.s0 = fma(a0.s3, b3.s0, acc1.s0);
4678 acc1.s0 = fma(a0.s4, b4.s0, acc1.s0);
4679 acc1.s0 = fma(a0.s5, b5.s0, acc1.s0);
4680 acc1.s0 = fma(a0.s6, b6.s0, acc1.s0);
4681 acc1.s0 = fma(a0.s7, b7.s0, acc1.s0);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004682
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004683 acc1.s1 = fma(a0.s0, b0.s1, acc1.s1);
4684 acc1.s1 = fma(a0.s1, b1.s1, acc1.s1);
4685 acc1.s1 = fma(a0.s2, b2.s1, acc1.s1);
4686 acc1.s1 = fma(a0.s3, b3.s1, acc1.s1);
4687 acc1.s1 = fma(a0.s4, b4.s1, acc1.s1);
4688 acc1.s1 = fma(a0.s5, b5.s1, acc1.s1);
4689 acc1.s1 = fma(a0.s6, b6.s1, acc1.s1);
4690 acc1.s1 = fma(a0.s7, b7.s1, acc1.s1);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004691#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4692#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004693#if defined(REINTERPRET_INPUT_AS_3D)
4694 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
4695#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004696 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004697#endif // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004698 acc2.s0 = fma(a0.s0, b0.s0, acc2.s0);
4699 acc2.s0 = fma(a0.s1, b1.s0, acc2.s0);
4700 acc2.s0 = fma(a0.s2, b2.s0, acc2.s0);
4701 acc2.s0 = fma(a0.s3, b3.s0, acc2.s0);
4702 acc2.s0 = fma(a0.s4, b4.s0, acc2.s0);
4703 acc2.s0 = fma(a0.s5, b5.s0, acc2.s0);
4704 acc2.s0 = fma(a0.s6, b6.s0, acc2.s0);
4705 acc2.s0 = fma(a0.s7, b7.s0, acc2.s0);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004706
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004707 acc2.s1 = fma(a0.s0, b0.s1, acc2.s1);
4708 acc2.s1 = fma(a0.s1, b1.s1, acc2.s1);
4709 acc2.s1 = fma(a0.s2, b2.s1, acc2.s1);
4710 acc2.s1 = fma(a0.s3, b3.s1, acc2.s1);
4711 acc2.s1 = fma(a0.s4, b4.s1, acc2.s1);
4712 acc2.s1 = fma(a0.s5, b5.s1, acc2.s1);
4713 acc2.s1 = fma(a0.s6, b6.s1, acc2.s1);
4714 acc2.s1 = fma(a0.s7, b7.s1, acc2.s1);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004715#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4716#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004717#if defined(REINTERPRET_INPUT_AS_3D)
4718 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
4719#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004720 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004721#endif // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004722 acc3.s0 = fma(a0.s0, b0.s0, acc3.s0);
4723 acc3.s0 = fma(a0.s1, b1.s0, acc3.s0);
4724 acc3.s0 = fma(a0.s2, b2.s0, acc3.s0);
4725 acc3.s0 = fma(a0.s3, b3.s0, acc3.s0);
4726 acc3.s0 = fma(a0.s4, b4.s0, acc3.s0);
4727 acc3.s0 = fma(a0.s5, b5.s0, acc3.s0);
4728 acc3.s0 = fma(a0.s6, b6.s0, acc3.s0);
4729 acc3.s0 = fma(a0.s7, b7.s0, acc3.s0);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004730
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004731 acc3.s1 = fma(a0.s0, b0.s1, acc3.s1);
4732 acc3.s1 = fma(a0.s1, b1.s1, acc3.s1);
4733 acc3.s1 = fma(a0.s2, b2.s1, acc3.s1);
4734 acc3.s1 = fma(a0.s3, b3.s1, acc3.s1);
4735 acc3.s1 = fma(a0.s4, b4.s1, acc3.s1);
4736 acc3.s1 = fma(a0.s5, b5.s1, acc3.s1);
4737 acc3.s1 = fma(a0.s6, b6.s1, acc3.s1);
4738 acc3.s1 = fma(a0.s7, b7.s1, acc3.s1);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004739#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004740
4741 src_addr.s0 += sizeof(float) * 8;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004742 }
4743 // float size increment
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004744 for(; i < (int)COLS_A; ++i)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004745 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004746#if defined(REINTERPRET_INPUT_AS_3D)
4747 // Load values from matrix A
4748 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
4749#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4750 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
4751#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4752#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4753 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
4754#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4755#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4756 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
4757#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4758#else // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004759 // Load values from matrix A
4760 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
4761#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4762 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
4763#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4764#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4765 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
4766#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4767#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4768 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
4769#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004770#endif // defined(REINTERPRET_INPUT_AS_3D)
4771
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004772 // Load values from matrix B
4773 float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004774 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004775
4776 // Multiply and accumulate
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004777 acc0.s0 = fma(a0, b0.s0, acc0.s0);
4778 acc0.s1 = fma(a0, b0.s1, acc0.s1);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004779#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004780 acc1.s0 = fma(a1, b0.s0, acc1.s0);
4781 acc1.s1 = fma(a1, b0.s1, acc1.s1);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004782#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4783#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004784 acc2.s0 = fma(a2, b0.s0, acc2.s0);
4785 acc2.s1 = fma(a2, b0.s1, acc2.s1);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004786#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4787#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004788 acc3.s0 = fma(a3, b0.s0, acc3.s0);
4789 acc3.s1 = fma(a3, b0.s1, acc3.s1);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004790#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004791
4792 src_addr.s0 += sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004793 }
4794
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004795 int z = get_global_id(2);
4796
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004797 // Compute destination address
4798 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
4799
Gian Marcoae2af742018-02-15 12:35:44 +00004800 // Compute dst address
4801 __global uchar *dst_addr = offset(&dst, 0, 0);
4802
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004803 uint4 zout = 0;
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004804
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004805#if defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004806
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004807 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004808 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004809 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004810 // | |
4811 // | plane0 |
4812 // | |
4813 // |__________________|
4814 // |******************|
4815 // | cross_plane_pad |
4816 // |******************|
4817 // | |
4818 // | plane1 |
4819 // | |
4820 // |__________________|
Gian Marcoae2af742018-02-15 12:35:44 +00004821
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004822 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004823 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
4824 zout = min(DEPTH_GEMM3D - 1, zout);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004825
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004826 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004827 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004828
4829 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4830 // multiply dst_stride_z by DEPTH_GEMM3D
4831 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004832#else // defined(REINTERPRET_OUTPUT_AS_3D)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004833 // Add offset for batched GEMM
4834 dst_addr += z * dst_stride_z;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004835#endif // defined(REINTERPRET_OUTPUT_AS_3D)
4836
4837 // Multiply by the weight of matrix-matrix product and store the result
4838#if defined(ALPHA)
4839 SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, float, acc, ALPHA);
4840#endif // defined(ALPHA)
4841
4842 // Add beta*bias
4843#if defined(BETA)
4844 REPEAT_VAR_INIT_TO_CONST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, uint, zero, 0);
4845
4846#if defined(BROADCAST_BIAS)
4847 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)2 * sizeof(float));
4848
4849 LOAD_BLOCK(1, 2, float, bias, src2_addr, 0, src2_stride_y, zero);
4850
4851#ifndef UNIT_BETA
4852 SCALE_BLOCK(1, float, bias, BETA);
4853#endif // UNIT_BIAS
4854
4855 // acc = acc + bias[broadcasted]
4856 ADD_BLOCK_BROADCAST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias0);
4857
4858#else // defined(BROADCAST_BIAS)
4859 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)2 * sizeof(float)) + (get_global_id(1) *
4860 (uint)NUM_ELEMS_PROCESSED_PER_THREAD_Y * src2_stride_y) + get_global_id(2) * src2_stride_z;
4861
4862 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 2, float, bias, src2_addr, 0, src2_stride_y, zero);
4863
4864#ifndef UNIT_BETA
4865 SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, float, bias, BETA);
4866#endif // UNIT_BIAS
4867
4868 // acc = acc + bias
4869 ADD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias);
4870
4871#endif // defined(BROADCAST_BIAS)
4872#endif // defined(BETA)
4873
4874#if defined(ACTIVATION_TYPE)
4875 ACTIVATION_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, ACTIVATION_TYPE, float, acc, A_VAL, B_VAL);
4876#endif // defined(ACTIVATION_TYPE)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004877
4878 // Store the output block
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004879 vstore2(acc0, 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004880#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004881 vstore2(acc1, 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004882#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4883#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004884 vstore2(acc2, 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004885#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4886#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004887 vstore2(acc3, 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004888#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004889}
4890
Vidhya Sudhan Loganathanbdff4912018-05-22 15:03:09 +01004891#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01004892/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
4893 *
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00004894 * @note This OpenCL kernel works with the 16-bit floating point data type (half) and accumulating the result in a 32 floating point variable.
4895 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
4896 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
4897 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
4898 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004899 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
4900 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00004901 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004902 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
4903 * The activation function is performed after the bias addition
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00004904 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
4905 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
4906 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
4907 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
4908 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
4909 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
4910 *
4911 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
4912 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
4913 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4914 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
4915 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4916 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
4917 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
4918 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
4919 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4920 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
4921 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4922 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004923 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
4924 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
4925 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
4926 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
4927 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
4928 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00004929 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
4930 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
4931 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
4932 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
4933 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
4934 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
4935 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
4936 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004937 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00004938 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
4939 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
4940 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
4941 */
4942__kernel void gemm_mm_floating_point_f16_bifrost_acc32(IMAGE_DECLARATION(src0),
4943 IMAGE_DECLARATION(src1),
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004944#if defined(BETA)
4945 IMAGE_DECLARATION(src2),
4946#endif // defined(BETA)
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00004947 IMAGE_DECLARATION(dst),
4948 uint src0_stride_z,
4949 uint src1_stride_z,
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004950#if defined(BETA)
4951 uint src2_stride_z,
4952#endif //defined(BETA)
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00004953 uint dst_stride_z
4954#if defined(REINTERPRET_INPUT_AS_3D)
4955 ,
4956 uint src_cross_plane_pad
4957#endif // REINTERPRET_INPUT_AS_3D
4958#if defined(REINTERPRET_OUTPUT_AS_3D)
4959 ,
4960 uint dst_cross_plane_pad
4961#endif // REINTERPRET_OUTPUT_AS_3D
4962 )
4963{
4964 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
4965
4966 // Compute starting address for matrix A and Matrix B
4967 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
4968
4969 // Update address for the matrix A
4970 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
4971
4972 // Update address for the matrix B
4973 src_addr.s1 += idx * sizeof(half);
4974
4975#if defined(REINTERPRET_INPUT_AS_3D)
4976 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
4977 // in order to take into account the presence of possible cross plane paddings
4978 //
4979 // | |
4980 // | plane0 |
4981 // | |
4982 // |__________________|
4983 // |******************|
4984 // | cross_plane_pad |
4985 // |******************|
4986 // | |
4987 // | plane1 |
4988 // | |
4989 // |__________________|
4990
4991 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
4992 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
4993 zin = min(DEPTH_GEMM3D - 1, zin);
4994
4995 // Add offset due to the cross plane paddings
4996 zin *= (src_cross_plane_pad * src0_stride_y);
4997
4998 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4999 // multiply src0_stride_z by DEPTH_GEMM3D
5000 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
5001
5002#else // defined(REINTERPRET_INPUT_AS_3D)
5003
5004 // Add offset for batched GEMM
5005 src_addr.s0 += get_global_id(2) * src0_stride_z;
5006
5007#endif // defined(REINTERPRET_INPUT_AS_3D)
5008
5009#if defined(MATRIX_B_DEPTH)
5010 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
5011 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
5012#else // defined(MATRIX_B_DEPTH)
5013 src_addr.s1 += get_global_id(2) * src1_stride_z;
5014#endif // defined(MATRIX_B_DEPTH)
5015
5016 float8 acc0 = 0.0h;
5017#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5018 float8 acc1 = 0.0h;
5019#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5020#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5021 float8 acc2 = 0.0h;
5022#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5023#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5024 float8 acc3 = 0.0h;
5025#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5026
5027 int i = 0;
5028 for(; i <= ((int)COLS_A - 4); i += 4)
5029 {
5030#if defined(REINTERPRET_INPUT_AS_3D)
5031 // Load values from matrix A
Usama Arif0681e3b2019-04-25 14:28:07 +01005032 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 4, half, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s);
5033#else // defined(REINTERPRET_INPUT_AS_3D)
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005034 // Load values from matrix A
5035 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
5036#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5037 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
5038#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5039#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5040 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
5041#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5042#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5043 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
5044#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5045#endif // defined(REINTERPRET_INPUT_AS_3D)
5046
5047 // Load values from matrix B
5048 float8 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
5049 src_addr.s1 += src1_stride_y;
5050
5051 // Accumulate
5052 acc0 = fma(b0, (float8)a0.s0, acc0);
5053#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5054 acc1 = fma(b0, (float8)a1.s0, acc1);
5055#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5056#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5057 acc2 = fma(b0, (float8)a2.s0, acc2);
5058#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5059#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5060 acc3 = fma(b0, (float8)a3.s0, acc3);
5061#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5062
5063 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
5064 src_addr.s1 += src1_stride_y;
5065 acc0 = fma(b0, (float8)a0.s1, acc0);
5066#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5067 acc1 = fma(b0, (float8)a1.s1, acc1);
5068#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5069#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5070 acc2 = fma(b0, (float8)a2.s1, acc2);
5071#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5072#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5073 acc3 = fma(b0, (float8)a3.s1, acc3);
5074#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5075
5076 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
5077 src_addr.s1 += src1_stride_y;
5078 acc0 = fma(b0, (float8)a0.s2, acc0);
5079#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5080 acc1 = fma(b0, (float8)a1.s2, acc1);
5081#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5082#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5083 acc2 = fma(b0, (float8)a2.s2, acc2);
5084#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5085#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5086 acc3 = fma(b0, (float8)a3.s2, acc3);
5087#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5088
5089 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
5090 src_addr.s1 += src1_stride_y;
5091 acc0 = fma(b0, (float8)a0.s3, acc0);
5092#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5093 acc1 = fma(b0, (float8)a1.s3, acc1);
5094#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5095#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5096 acc2 = fma(b0, (float8)a2.s3, acc2);
5097#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5098#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5099 acc3 = fma(b0, (float8)a3.s3, acc3);
5100#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5101
5102 src_addr.s0 += 4 * sizeof(half);
5103 }
5104
5105 for(; i < (int)COLS_A; ++i)
5106 {
5107#if defined(REINTERPRET_INPUT_AS_3D)
5108 // Load values from matrix A
5109 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
5110#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5111 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
5112#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5113#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5114 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
5115#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5116#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5117 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
5118#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5119#else // defined(REINTERPRET_INPUT_AS_3D)
5120 // Load values from matrix A
5121 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
5122#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5123 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
5124#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5125#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5126 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
5127#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5128#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5129 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
5130#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5131#endif // defined(REINTERPRET_INPUT_AS_3D)
5132
5133 // Load values from matrix B
5134 float8 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
5135
5136 src_addr += (int2)(sizeof(half), src1_stride_y);
5137
5138 // Accumulate
5139 acc0 = fma(b0, (float8)a0, acc0); // b0 * (half8)a0;
5140#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5141 acc1 = fma(b0, (float8)a1, acc1); // b0 * (half8)a1;
5142#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5143#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5144 acc2 = fma(b0, (float8)a2, acc2); // b0 * (half8)a2;
5145#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5146#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5147 acc3 = fma(b0, (float8)a3, acc3); // b0 * (half8)a3;
5148#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5149 }
5150
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005151 int z = get_global_id(2);
5152
5153 // Compute destination address
5154 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
5155
5156 // Compute dst address
5157 __global uchar *dst_addr = offset(&dst, 0, 0);
5158
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005159 uint4 zout = 0;
5160
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005161#if defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005162
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005163 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
5164 // in order to take into account the presence of possible cross plane paddings
5165 //
5166 // | |
5167 // | plane0 |
5168 // | |
5169 // |__________________|
5170 // |******************|
5171 // | cross_plane_pad |
5172 // |******************|
5173 // | |
5174 // | plane1 |
5175 // | |
5176 // |__________________|
5177
5178 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005179 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
5180 zout = min(DEPTH_GEMM3D - 1, zout);
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005181
5182 // Add offset due to the cross plane paddings
5183 zout *= (dst_cross_plane_pad * dst_stride_y);
5184
5185 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
5186 // multiply dst_stride_z by DEPTH_GEMM3D
5187 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005188#else // defined(REINTERPRET_OUTPUT_AS_3D)
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005189 // Add offset for batched GEMM
5190 dst_addr += z * dst_stride_z;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005191#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005192
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005193 // Multiply by the weight of matrix-matrix product and store the result
5194#if defined(ALPHA)
5195 SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, float, acc, ALPHA);
5196#endif // defined(ALPHA)
5197
5198#if defined(BETA)
5199 REPEAT_VAR_INIT_TO_CONST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, uint, zero, 0);
5200
5201#if defined(BROADCAST_BIAS)
5202 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half));
5203
5204 LOAD_BLOCK(1, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
5205
5206 float8 bias_f0 = convert_float8(bias0);
5207
5208#ifndef UNIT_BETA
5209 SCALE_BLOCK(1, float, bias_f, BETA);
5210#endif // UNIT_BIAS
5211
5212 // acc = acc + bias[broadcasted]
5213 ADD_BLOCK_BROADCAST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias_f0);
5214
5215#else // defined(BROADCAST_BIAS)
5216 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (get_global_id(1) *
5217 (uint)NUM_ELEMS_PROCESSED_PER_THREAD_Y * src2_stride_y) + get_global_id(2) * src2_stride_z;
5218
5219 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
5220
5221 float8 bias_f0 = convert_float8(bias0);
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005222#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005223 float8 bias_f1 = convert_float8(bias1);
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005224#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5225#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005226 float8 bias_f2 = convert_float8(bias2);
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005227#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5228#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005229 float8 bias_f3 = convert_float8(bias3);
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005230#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005231
5232#ifndef UNIT_BETA
5233 SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, float, bias_f, BETA);
5234#endif // UNIT_BIAS
5235
5236 // acc = acc + bias
5237 ADD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias_f);
5238
5239#endif // defined(BROADCAST_BIAS)
5240#endif // defined(BETA)
5241
5242 half8 acc_h0 = convert_half8(acc0);
5243#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5244 half8 acc_h1 = convert_half8(acc1);
5245#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5246#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5247 half8 acc_h2 = convert_half8(acc2);
5248#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5249#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5250 half8 acc_h3 = convert_half8(acc3);
5251#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5252
5253#if defined(ACTIVATION_TYPE)
5254 ACTIVATION_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, ACTIVATION_TYPE, half, acc_h, A_VAL, B_VAL);
5255#endif // defined(ACTIVATION_TYPE)
5256
5257 // Store the output block
5258 STORE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 8, half, acc_h, dst_addr, dst_stride_y, zout.s);
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005259}
5260
5261/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
5262 *
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005263 * @note This OpenCL kernel works with the 16-bit floating point data type (half) and uses the fma units.
5264 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
5265 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
5266 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
5267 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005268 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
5269 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005270 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005271 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
5272 * The activation function is performed after the bias addition
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005273 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
5274 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005275 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
5276 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
5277 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
5278 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
5279 *
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005280 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
5281 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
5282 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5283 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
5284 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
5285 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
5286 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
5287 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
5288 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5289 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
5290 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
5291 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005292 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
5293 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
5294 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
5295 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
5296 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
5297 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005298 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
5299 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
5300 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
5301 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
5302 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
5303 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005304 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
5305 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005306 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005307 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005308 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
5309 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005310 */
5311__kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0),
5312 IMAGE_DECLARATION(src1),
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005313#if defined(BETA)
5314 IMAGE_DECLARATION(src2),
5315#endif // defined(BETA)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005316 IMAGE_DECLARATION(dst),
5317 uint src0_stride_z,
5318 uint src1_stride_z,
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005319#if defined(BETA)
5320 uint src2_stride_z,
5321#endif //defined(BETA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005322 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005323#if defined(REINTERPRET_INPUT_AS_3D)
5324 ,
5325 uint src_cross_plane_pad
5326#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005327#if defined(REINTERPRET_OUTPUT_AS_3D)
5328 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005329 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005330#endif // REINTERPRET_OUTPUT_AS_3D
5331 )
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005332{
5333 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
5334
5335 // Compute starting address for matrix A and Matrix B
5336 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
5337
5338 // Update address for the matrix A
5339 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
5340
5341 // Update address for the matrix B
5342 src_addr.s1 += idx * sizeof(half);
5343
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005344#if defined(REINTERPRET_INPUT_AS_3D)
5345 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
5346 // in order to take into account the presence of possible cross plane paddings
5347 //
5348 // | |
5349 // | plane0 |
5350 // | |
5351 // |__________________|
5352 // |******************|
5353 // | cross_plane_pad |
5354 // |******************|
5355 // | |
5356 // | plane1 |
5357 // | |
5358 // |__________________|
5359
5360 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
5361 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
5362 zin = min(DEPTH_GEMM3D - 1, zin);
5363
5364 // Add offset due to the cross plane paddings
5365 zin *= (src_cross_plane_pad * src0_stride_y);
5366
5367 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
5368 // multiply src0_stride_z by DEPTH_GEMM3D
5369 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
5370
5371#else // defined(REINTERPRET_INPUT_AS_3D)
5372
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005373 // Add offset for batched GEMM
5374 src_addr.s0 += get_global_id(2) * src0_stride_z;
5375
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005376#endif // defined(REINTERPRET_INPUT_AS_3D)
5377
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005378#if defined(MATRIX_B_DEPTH)
5379 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
5380 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
5381#else // defined(MATRIX_B_DEPTH)
5382 src_addr.s1 += get_global_id(2) * src1_stride_z;
5383#endif // defined(MATRIX_B_DEPTH)
5384
5385 half8 acc0 = 0.0h;
5386#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5387 half8 acc1 = 0.0h;
5388#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5389#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5390 half8 acc2 = 0.0h;
5391#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5392#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5393 half8 acc3 = 0.0h;
5394#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5395
5396 int i = 0;
5397 for(; i <= ((int)COLS_A - 4); i += 4)
5398 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005399#if defined(REINTERPRET_INPUT_AS_3D)
5400 // Load values from matrix A
Usama Arif0681e3b2019-04-25 14:28:07 +01005401 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 4, half, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s);
5402#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005403 // Load values from matrix A
5404 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
5405#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5406 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
5407#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5408#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5409 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
5410#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5411#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5412 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
5413#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005414#endif // defined(REINTERPRET_INPUT_AS_3D)
5415
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005416 // Load values from matrix B
5417 half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
5418 src_addr.s1 += src1_stride_y;
5419
5420 // Accumulate
5421 acc0 = fma(b0, (half8)a0.s0, acc0);
5422#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5423 acc1 = fma(b0, (half8)a1.s0, acc1);
5424#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5425#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5426 acc2 = fma(b0, (half8)a2.s0, acc2);
5427#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5428#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5429 acc3 = fma(b0, (half8)a3.s0, acc3);
5430#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5431
5432 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
5433 src_addr.s1 += src1_stride_y;
5434 acc0 = fma(b0, (half8)a0.s1, acc0);
5435#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5436 acc1 = fma(b0, (half8)a1.s1, acc1);
5437#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5438#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5439 acc2 = fma(b0, (half8)a2.s1, acc2);
5440#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5441#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5442 acc3 = fma(b0, (half8)a3.s1, acc3);
5443#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5444
5445 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
5446 src_addr.s1 += src1_stride_y;
5447 acc0 = fma(b0, (half8)a0.s2, acc0);
5448#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5449 acc1 = fma(b0, (half8)a1.s2, acc1);
5450#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5451#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5452 acc2 = fma(b0, (half8)a2.s2, acc2);
5453#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5454#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5455 acc3 = fma(b0, (half8)a3.s2, acc3);
5456#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5457
5458 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
5459 src_addr.s1 += src1_stride_y;
5460 acc0 = fma(b0, (half8)a0.s3, acc0);
5461#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5462 acc1 = fma(b0, (half8)a1.s3, acc1);
5463#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5464#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5465 acc2 = fma(b0, (half8)a2.s3, acc2);
5466#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5467#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5468 acc3 = fma(b0, (half8)a3.s3, acc3);
5469#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5470
5471 src_addr.s0 += 4 * sizeof(half);
5472 }
5473
5474 for(; i < (int)COLS_A; ++i)
5475 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005476#if defined(REINTERPRET_INPUT_AS_3D)
5477 // Load values from matrix A
5478 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
5479#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5480 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
5481#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5482#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5483 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
5484#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5485#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5486 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
5487#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5488#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005489 // Load values from matrix A
5490 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
5491#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5492 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
5493#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5494#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5495 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
5496#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5497#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5498 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
5499#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005500#endif // defined(REINTERPRET_INPUT_AS_3D)
5501
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005502 // Load values from matrix B
5503 half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
5504
5505 src_addr += (int2)(sizeof(half), src1_stride_y);
5506
5507 // Accumulate
5508 acc0 = fma(b0, (half8)a0, acc0); // b0 * (half8)a0;
5509#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5510 acc1 = fma(b0, (half8)a1, acc1); // b0 * (half8)a1;
5511#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5512#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5513 acc2 = fma(b0, (half8)a2, acc2); // b0 * (half8)a2;
5514#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5515#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5516 acc3 = fma(b0, (half8)a3, acc3); // b0 * (half8)a3;
5517#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5518 }
5519
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005520 int z = get_global_id(2);
5521
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005522 // Compute destination address
5523 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
5524
5525 // Compute dst address
5526 __global uchar *dst_addr = offset(&dst, 0, 0);
5527
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005528 uint4 zout = 0;
5529
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005530#if defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005531
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005532 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01005533 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005534 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01005535 // | |
5536 // | plane0 |
5537 // | |
5538 // |__________________|
5539 // |******************|
5540 // | cross_plane_pad |
5541 // |******************|
5542 // | |
5543 // | plane1 |
5544 // | |
5545 // |__________________|
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005546
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005547 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005548 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
5549 zout = min(DEPTH_GEMM3D - 1, zout);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005550
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01005551 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005552 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005553
5554 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
5555 // multiply dst_stride_z by DEPTH_GEMM3D
5556 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005557#else // defined(REINTERPRET_OUTPUT_AS_3D)
5558 // Add offset for batched GEMM
5559 dst_addr += z * dst_stride_z;
5560#endif // defined(REINTERPRET_OUTPUT_AS_3D)
5561
5562 // Multiply by the weight of matrix-matrix product and store the result
5563#if defined(ALPHA)
5564 SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, half, acc, ALPHA);
5565#endif // defined(ALPHA)
5566
5567 // Add beta*bias
5568#if defined(BETA)
5569 REPEAT_VAR_INIT_TO_CONST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, uint, zero, 0);
5570
5571#if defined(BROADCAST_BIAS)
5572 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half));
5573
5574 LOAD_BLOCK(1, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
5575
5576#ifndef UNIT_BETA
5577 SCALE_BLOCK(1, half, bias, BETA);
5578#endif // UNIT_BIAS
5579
5580 // acc = acc + bias[broadcasted]
5581 ADD_BLOCK_BROADCAST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias0);
5582
5583#else // defined(BROADCAST_BIAS)
5584 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (get_global_id(1) *
5585 (uint)NUM_ELEMS_PROCESSED_PER_THREAD_Y * src2_stride_y) + get_global_id(2) * src2_stride_z;
5586
5587 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
5588
5589#ifndef UNIT_BETA
5590 SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, half, bias, BETA);
5591#endif // UNIT_BIAS
5592
5593 // acc = acc + bias
5594 ADD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias);
5595
5596#endif // defined(BROADCAST_BIAS)
5597#endif // defined(BETA)
5598
5599#if defined(ACTIVATION_TYPE)
5600 ACTIVATION_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, ACTIVATION_TYPE, half, acc, A_VAL, B_VAL);
5601#endif // defined(ACTIVATION_TYPE)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005602
5603 // Store the output block
Usama Arif0681e3b2019-04-25 14:28:07 +01005604 STORE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 8, half, acc, dst_addr, dst_stride_y, zout.s);
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005605}
Vidhya Sudhan Loganathanbdff4912018-05-22 15:03:09 +01005606#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005607
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01005608#endif // defined(COLS_A) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && (NUM_ELEMS_PROCESSED_PER_THREAD_Y)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005609
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005610#if defined(BETA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005611/** This OpenCL kernel performs the in-place matrix addition between 2 matrices taking into account that the second matrix might be weighted by a scalar value beta:
5612 *
Gian Marco19835e52018-01-30 13:35:54 +00005613 * @note The beta's value need to be passed at compile time using -DBETA
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005614 *
5615 * @param[in] src_ptr Pointer to the source matrix. Supported data types: F32
5616 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
5617 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5618 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
5619 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005620 * @param[in] src_stride_z Stride of the destination tensor in Z dimension (in bytes)
5621 * @param[in] src_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005622 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01005623 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005624 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
5625 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
5626 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
5627 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005628 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
5629 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005630 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
5631 */
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005632__kernel void gemm_ma_f32(TENSOR3D_DECLARATION(src),
5633 TENSOR3D_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005634{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005635 // Compute source and destination addresses
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005636 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
5637 Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005638
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005639 // Load values from A x B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005640 float4 alpha_ab = vload4(0, (__global float *)dst.ptr);
5641
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005642 // Load values from Matrix C
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005643 float4 c = vload4(0, (__global float *)src.ptr);
5644
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005645 // Computes alpha * axb + beta * c
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005646 float4 out = alpha_ab + (float4)BETA * c;
5647
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005648 // Store final result in axb matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005649 vstore4(out, 0, (__global float *)dst.ptr);
5650}
5651
Vidhya Sudhan Loganathan76c85642018-05-25 13:53:02 +01005652#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005653/** This OpenCL kernel performs the in-place matrix addition between 2 matrices taking into account that the second matrix might be weighted by a scalar value beta:
5654 *
Gian Marco19835e52018-01-30 13:35:54 +00005655 * @note The beta's value need to be passed at compile time using -DBETA
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01005656 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005657 * @param[in] src_ptr Pointer to the source matrix. Supported data types: F16
5658 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
5659 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5660 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
5661 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005662 * @param[in] src_stride_z Stride of the destination tensor in Z dimension (in bytes)
5663 * @param[in] src_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005664 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01005665 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005666 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
5667 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
5668 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
5669 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005670 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
5671 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005672 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
5673 */
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005674__kernel void gemm_ma_f16(TENSOR3D_DECLARATION(src),
5675 TENSOR3D_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005676{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005677 // Compute source and destination addresses
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005678 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
5679 Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005680
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005681 // Load values from A x B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005682 half8 alpha_ab = vload8(0, (__global half *)dst.ptr);
5683
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005684 // Load values from Matrix C
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005685 half8 c = vload8(0, (__global half *)src.ptr);
5686
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005687 // Computes alpha * axb + beta * c
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005688 half8 out = alpha_ab + (half8)BETA * c;
5689
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005690 // Store final result in axb matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005691 vstore8(out, 0, (__global half *)dst.ptr);
5692}
Vidhya Sudhan Loganathan76c85642018-05-25 13:53:02 +01005693#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005694#endif // defined(BETA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005695
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005696#if defined(WIDTH_VECTOR_A)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005697/** This OpenCL kernel computes the vector by matrix multiplication between each row of A (src0) and matrix B (src1) used for locally connected layer
5698 *
Gian Marco19835e52018-01-30 13:35:54 +00005699 * @note The width of A need to be passed at compile time using -DWIDTH_VECTOR_A
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005700 *
Gian Marco19835e52018-01-30 13:35:54 +00005701 * @note The input A and matrix B must not be reshaped
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005702 *
5703 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
5704 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
5705 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5706 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
5707 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
5708 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01005709 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005710 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
5711 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5712 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
5713 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
5714 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
5715 * @param[in] src1_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
5716 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01005717 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005718 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
5719 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
5720 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
5721 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
5722 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
5723 */
5724__kernel void gemm_lc_vm_f32(IMAGE_DECLARATION(src0),
5725 TENSOR3D_DECLARATION(src1),
5726 IMAGE_DECLARATION(dst))
5727{
5728 int idx = get_global_id(0) * 4;
5729 int idy = get_global_id(1);
5730
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005731 // Compute the address for the vector A and matrix B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005732 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes + src0_stride_y * idy, src1_offset_first_element_in_bytes + src1_stride_z * idy));
5733 src_addr.s1 += idx * sizeof(float);
5734
5735 int end_row_vec_a = src_addr.s0 + (WIDTH_VECTOR_A * sizeof(float));
5736
5737 float4 acc = 0.0f;
5738
Georgios Pinitas96880cf2017-10-20 18:52:20 +01005739 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(float)); src_addr += (int2)(2 * sizeof(float), 2 * src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005740 {
5741 float2 a0 = vload2(0, (__global float *)(src0_ptr + src_addr.s0));
5742 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
5743 float4 b1 = vload4(0, (__global float *)(src1_ptr + src_addr.s1 + src1_stride_y));
5744
5745 acc += b0 * (float4)a0.s0;
5746 acc += b1 * (float4)a0.s1;
5747 }
5748
5749 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(float), src1_stride_y))
5750 {
5751 float a0 = *((__global float *)(src0_ptr + src_addr.s0));
5752 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
5753
5754 acc += b0 * (float4)a0;
5755 }
5756
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005757 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005758 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
5759
5760 vstore4(acc, 0, (__global float *)(offset(&dst, 0, 0)));
5761}
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005762#endif // defined(WIDTH_VECTOR_A)
5763
5764/** This kernel accumulates each row with the biases vector.
5765 *
5766 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=short.
5767 * @note The vector size must be passed at compile time using -DVECTOR_SIZE e.g. -DVECTOR_SIZE=16.
5768 *
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01005769 * @param[in, out] accum_ptr Pointer to the accumulate tensor. Supported data type: U8/S8/U16/S16/F16/U32/S32/F32
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005770 * @param[in] accum_stride_x Stride of the accmulate tensor in X dimension (in bytes)
5771 * @param[in] accum_step_x accum_stride_x * number of elements along X processed per workitem(in bytes)
5772 * @param[in] accum_stride_y Stride of the accumlulate tensor in Y dimension (in bytes)
5773 * @param[in] accum_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
5774 * @param[in] accum_offset_first_element_in_bytes The offset of the first element in the accumulate tensor
5775 * @param[in] biases_ptr Pointer to the biases vector. Same as @p accum_ptr
5776 * @param[in] biases_stride_x Stride of the destination tensor in X dimension (in bytes)
5777 * @param[in] biases_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
5778 * @param[in] biases_offset_first_element_in_bytes The offset of the first element in the destination tensor
5779 */
5780#if defined(DATA_TYPE) && defined(VECTOR_SIZE)
5781__kernel void gemm_accumulate_biases(
5782 IMAGE_DECLARATION(accum),
5783 VECTOR_DECLARATION(biases))
5784{
5785 Image accum = CONVERT_TO_IMAGE_STRUCT(accum);
5786 Vector biases = CONVERT_TO_VECTOR_STRUCT(biases);
5787
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005788 // Vector size, e.g. number of vector elements.
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005789 VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
5790 accum_value = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)accum.ptr);
5791 VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
5792 biases_value = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)biases.ptr);
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01005793 accum_value = biases_value + accum_value;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005794 // Store result in the accumulate buffer
5795 VSTORE(VECTOR_SIZE)
5796 (accum_value, 0, (__global DATA_TYPE *)accum.ptr);
5797}
5798#endif // defined(DATA_TYPE) && defined(VECTOR_SIZE)