blob: c35d1606894522b6e931bc1f1cef8c75bf30db2b [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +00002 * Copyright (c) 2017-2019 ARM Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Usama Arif0681e3b2019-04-25 14:28:07 +010024#include "gemm_helpers.h"
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +000025#include "repeat.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010026
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +000027#if defined(M0) && defined(K0) && defined(V0) && defined(DATA_TYPE) && defined(SRC_WIDTH)
28#define INC2 (VEC_DATA_TYPE(uint, 2))(0, 1)
29#define INC3 (VEC_DATA_TYPE(uint, 3))(0, 1, 2)
30#define INC4 (VEC_DATA_TYPE(uint, 4))(0, 1, 2, 3)
31#define INC8 (VEC_DATA_TYPE(uint, 8))(0, 1, 2, 3, 4, 5, 6, 7)
32#define INC16 (VEC_DATA_TYPE(uint, 16))(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)
33#define CONCAT_INC(K0) INC##K0
34#define INC(K0) CONCAT_INC(K0)
35
36#if(SRC_WIDTH % K0)
37#define BOUNDARY_CONDITION_X(x, a) \
38 ({ \
39 a = select(0, a, CONVERT(((x * (VEC_DATA_TYPE(uint, K0))K0 + INC(K0)) < (VEC_DATA_TYPE(uint, K0))SRC_WIDTH), VEC_DATA_TYPE(DATA_TYPE, K0))); \
40 })
41#else // (SRC_WIDTH % K0)
42#define BOUNDARY_CONDITION_X(x, a) \
43 ({})
44#endif // (SRC_WIDTH % K0)
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000045
46/** This OpenCL kernel reshapes the lhs input matrix. The kernel splits the input matrix in blocks of size M0xK0 and stores each one (not transposed) in
47 * the output matrix unrolling the values.
48 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +010049 * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
50 * @note The width of the input tensor must be passed at compile time using -DSRC_WIDTH (e.g. -DSRC_WIDTH=16)
51 * @note The block's dimensions (M0 and K0) must be passed at compile time using -DM0 and -DK0 (e.g. -DM0=2, -DK0=2).
52 * @note The number of M0xK0 vertical blocks to store on the same output row must be passed at compile time using -DV0 (e.g. -DV0=2)
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000053 * @note Only the following values for M0, K0 and V0 are supported:
54 * M0: 2,3,4,5,6,7,8
Gian Marco Iodicebacfec52019-01-11 11:30:55 +000055 * K0: 2,3,4,8,16
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000056 * V0: greater than 0
Gian Marco Iodiced1f54762019-07-19 09:54:47 +010057 * @note In case the input has to be reinterpreted as a 3D tensor (e.g. input of convolution layer 1x1), the following information must be passed at compile time:
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000058 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
59 * -# HEIGHT_GEMM3D: The height of the input in case it has to be reinterpreted as a 3D tensor.
60 * -# DEPTH_GEMM3D: The depth of the input in case it has to be reinterpreted as a 3D tensor
61 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
62 * @note If the M0xK0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
63 *
64 * @param[in] src_ptr Pointer to the source LHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
65 * @param[in] src_stride_x Stride of the source LHS tensor in X dimension (in bytes)
66 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
67 * @param[in] src_stride_y Stride of the source LHS tensor in Y dimension (in bytes)
68 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
69 * @param[in] src_stride_z Stride of the source LHS tensor in Z dimension (in bytes)
70 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
71 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source LHS tensor
72 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
73 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
74 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
75 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
76 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
77 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
78 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
79 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
80 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
81 */
82__kernel void gemm_reshape_lhs_matrix_nt(TENSOR3D_DECLARATION(src),
83 TENSOR3D_DECLARATION(dst)
84#if defined(REINTERPRET_INPUT_AS_3D)
85 ,
86 uint cross_plane_pad
87#endif // REINTERPRET_INPUT_AS_3D
88 )
89{
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +000090 // Block size
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000091#define BLOCK_SIZE ((M0) * (K0))
92
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +000093 // Output offset X
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000094#if defined(INTERLEAVE)
95#define OUTPUT_OFFSET_X (K0)
96#else // defined(INTERLEAVE)
97#define OUTPUT_OFFSET_X (BLOCK_SIZE)
98#endif // defined(INTERLEAVE)
99
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000100 // Output step X
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000101#if defined(INTERLEAVE)
102#define OUTPUT_STEP_X (K0) * (V0)
103#else // Do not interleave
104#define OUTPUT_STEP_X (K0)
105#endif // defined(INTERLEAVE)
106
107 // Compute source and destination addresses
108 uint x = get_global_id(0);
109 uint y = get_global_id(1);
110 uint z = get_global_id(2);
111
112 // ------------------ Compute input/output addresses ---------------------------
113
114 // Compute the input address
115 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)K0 * sizeof(DATA_TYPE) + y * (uint)M0 * src_stride_y;
116
117 // Compute the output address
118 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)BLOCK_SIZE * (uint)V0 * sizeof(DATA_TYPE)) + ((y / (uint)V0) * (uint)dst_stride_y) + ((y % V0) *
119 (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE));
120
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000121 // Create variables: uint zin0=0, zin1=0, zin2=0...zin(M0-1)=0;
122 REPEAT_VAR_INIT_TO_CONST(M0, uint, zin, 0);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000123
124#if defined(REINTERPRET_INPUT_AS_3D)
125 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
126 // multiply src_stride_z by DEPTH_GEMM3D
127
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000128 input_ptr += z * (uint)src_stride_z * DEPTH_GEMM3D;
129
130 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
Usama Arif0681e3b2019-04-25 14:28:07 +0100131 CALCULATE_Z_OFFSET(M0, uint, zin, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, cross_plane_pad, src_stride_y);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000132
133#else // defined(REINTERPRET_INPUT_AS_3D)
134
135 input_ptr += z * (uint)src_stride_z;
136
137#endif // defined(REINTERPRET_INPUT_AS_3D)
138
139 // Add offset for batched GEMM
140 output_ptr += z * (uint)dst_stride_z;
141
142 // ---------------------------Load input values --------------------------------
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000143 // Load values from the LHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +0100144 LOAD_BLOCK(M0, K0, DATA_TYPE, a, input_ptr, 0, src_stride_y, zin);
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000145 BOUNDARY_CONDITION_X(x, a0);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000146#if M0 > 1
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000147 BOUNDARY_CONDITION_X(x, a1);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000148#endif // M0 > 1
149#if M0 > 2
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000150 BOUNDARY_CONDITION_X(x, a2);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000151#endif // M0 > 2
152#if M0 > 3
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000153 BOUNDARY_CONDITION_X(x, a3);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000154#endif // M0 > 3
155#if M0 > 4
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000156 BOUNDARY_CONDITION_X(x, a4);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000157#endif // M0 > 4
158#if M0 > 5
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000159 BOUNDARY_CONDITION_X(x, a5);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000160#endif // M0 > 5
161#if M0 > 6
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000162 BOUNDARY_CONDITION_X(x, a6);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000163#endif // M0 > 6
164#if M0 > 7
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000165 BOUNDARY_CONDITION_X(x, a7);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000166#endif // M0 > 7
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000167 // ---------------------------Store output values ------------------------------
Usama Arif0681e3b2019-04-25 14:28:07 +0100168 REPEAT_VAR_INIT_TO_CONST(16, uint, zout, 0);
169 STORE_BLOCK(M0, K0, DATA_TYPE, a, output_ptr, OUTPUT_STEP_X * sizeof(DATA_TYPE), zout);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000170
171#undef BLOCK_SIZE
172#undef OUTPUT_OFFSET_X
173#undef OUTPUT_STEP_X
174}
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000175
176#if M0 == 2
177#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
178 ({ \
179 VEC_DATA_TYPE(DATA_TYPE, M0) \
180 res = (VEC_DATA_TYPE(DATA_TYPE, M0))(a0.s##i, a1.s##i); \
181 VSTORE(M0) \
182 (res, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
183 })
184#elif M0 == 3 // M0 == 3
185#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
186 ({ \
187 VEC_DATA_TYPE(DATA_TYPE, M0) \
188 res = (VEC_DATA_TYPE(DATA_TYPE, M0))(a0.s##i, a1.s##i, a2.s##i); \
189 VSTORE(M0) \
190 (res, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
191 })
192#elif M0 == 4 // M0 == 4
193#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
194 ({ \
195 VEC_DATA_TYPE(DATA_TYPE, M0) \
196 res = (VEC_DATA_TYPE(DATA_TYPE, M0))(a0.s##i, a1.s##i, a2.s##i, a3.s##i); \
197 VSTORE(M0) \
198 (res, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
199 })
200#elif M0 == 5 // M0 == 5
201#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
202 ({ \
203 VEC_DATA_TYPE(DATA_TYPE, 4) \
204 res0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s##i, a1.s##i, a2.s##i, a3.s##i); \
205 DATA_TYPE res1 = a4.s##i; \
206 VSTORE(4) \
207 (res0, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
208 *((__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE)) + 4) = res1; \
209 })
210#elif M0 == 6 // M0 == 6
211#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
212 ({ \
213 VEC_DATA_TYPE(DATA_TYPE, 4) \
214 res0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s##i, a1.s##i, a2.s##i, a3.s##i); \
215 VEC_DATA_TYPE(DATA_TYPE, 2) \
216 res1 = (VEC_DATA_TYPE(DATA_TYPE, 2))(a4.s##i, a5.s##i); \
217 VSTORE(4) \
218 (res0, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
219 VSTORE(2) \
220 (res1, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE)) + 4); \
221 })
222#elif M0 == 7 // M0 == 7
223#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
224 ({ \
225 VEC_DATA_TYPE(DATA_TYPE, 4) \
226 res0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s##i, a1.s##i, a2.s##i, a3.s##i); \
227 VEC_DATA_TYPE(DATA_TYPE, 3) \
228 res1 = (VEC_DATA_TYPE(DATA_TYPE, 3))(a4.s##i, a5.s##i, a6.s##i); \
229 VSTORE(4) \
230 (res0, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
231 VSTORE(3) \
232 (res1, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE)) + 4); \
233 })
234#elif M0 == 8 // M0 == 8
235#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
236 ({ \
237 VEC_DATA_TYPE(DATA_TYPE, M0) \
238 res = (VEC_DATA_TYPE(DATA_TYPE, M0))(a0.s##i, a1.s##i, a2.s##i, a3.s##i, a4.s##i, a5.s##i, a6.s##i, a7.s##i); \
239 VSTORE(M0) \
240 (res, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
241 })
242#else // M0 not supported
243#error "M0 value not supported"
244#endif // N0 conditions
245
246/** This OpenCL kernel reshapes the lhs input matrix. The kernel splits the input matrix in blocks of size M0xK0 and stores each one (transposed) in
247 * the output matrix unrolling the values.
248 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +0100249 * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
250 * @note The width of the input tensor must be passed at compile time using -DSRC_WIDTH (e.g. -DSRC_WIDTH=16)
251 * @note The block's dimensions (M0 and K0) must be passed at compile time using -DM0 and -DK0 (e.g. -DM0=2, -DK0=2).
252 * @note The number of M0xK0 vertical blocks to store on the same output row must be passed at compile time using -DV0 (e.g. -DV0=2)
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000253 * @note Only the following values for M0, K0 and V0 are supported:
254 * M0: 2,3,4,5,6,7,8
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000255 * K0: 2,3,4,8,16
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000256 * V0: greater than 0
Gian Marco Iodiced1f54762019-07-19 09:54:47 +0100257 * @note In case the input has to be reinterpreted as a 3D tensor (e.g. input of convolution layer 1x1), the following information must be passed at compile time:
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000258 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
259 * -# HEIGHT_GEMM3D: The height of the input in case it has to be reinterpreted as a 3D tensor.
260 * -# DEPTH_GEMM3D: The depth of the input in case it has to be reinterpreted as a 3D tensor
261 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
262 * @note If the M0xK0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
263 *
264 * @param[in] src_ptr Pointer to the source LHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
265 * @param[in] src_stride_x Stride of the source LHS tensor in X dimension (in bytes)
266 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
267 * @param[in] src_stride_y Stride of the source LHS tensor in Y dimension (in bytes)
268 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
269 * @param[in] src_stride_z Stride of the source LHS tensor in Z dimension (in bytes)
270 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
271 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source LHS tensor
272 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
273 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
274 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
275 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
276 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
277 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
278 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
279 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
280 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
281 */
282__kernel void gemm_reshape_lhs_matrix_t(TENSOR3D_DECLARATION(src),
283 TENSOR3D_DECLARATION(dst)
284#if defined(REINTERPRET_INPUT_AS_3D)
285 ,
286 uint cross_plane_pad
287#endif // REINTERPRET_INPUT_AS_3D
288 )
289{
290 // Block size
291#define BLOCK_SIZE ((M0) * (K0))
292
293 // Output offset X
294#if defined(INTERLEAVE)
295#define OUTPUT_OFFSET_X (M0)
296#else // defined(INTERLEAVE)
297#define OUTPUT_OFFSET_X (BLOCK_SIZE)
298#endif // defined(INTERLEAVE)
299
300 // Output step X
301#if defined(INTERLEAVE)
302#define OUTPUT_STEP_X (M0) * (V0)
303#else // Do not interleave
304#define OUTPUT_STEP_X (M0)
305#endif // defined(INTERLEAVE)
306
307 // Compute source and destination addresses
308 uint x = get_global_id(0);
309 uint y = get_global_id(1);
310 uint z = get_global_id(2);
311
312 // ------------------ Compute input/output addresses ---------------------------
313
314 // Compute the input address
315 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)K0 * sizeof(DATA_TYPE) + y * (uint)M0 * src_stride_y;
316
317 // Compute the output address
318 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)BLOCK_SIZE * (uint)V0 * sizeof(DATA_TYPE)) + ((y / (uint)V0) * (uint)dst_stride_y) + ((y % V0) *
319 (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE));
320
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000321 // Create variables: uint zin0=0, zin1=0, zin2=0...zin(M0-1)=0;
322 REPEAT_VAR_INIT_TO_CONST(M0, uint, zin, 0);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000323
324#if defined(REINTERPRET_INPUT_AS_3D)
325 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
326 // multiply src_stride_z by DEPTH_GEMM3D
327
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000328 input_ptr += z * (uint)src_stride_z * DEPTH_GEMM3D;
329
330 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
Usama Arif0681e3b2019-04-25 14:28:07 +0100331 CALCULATE_Z_OFFSET(M0, uint, zin, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, cross_plane_pad, src_stride_y);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000332
333#else // defined(REINTERPRET_INPUT_AS_3D)
334
335 input_ptr += z * (uint)src_stride_z;
336
337#endif // defined(REINTERPRET_INPUT_AS_3D)
338
339 // Add offset for batched GEMM
340 output_ptr += z * (uint)dst_stride_z;
341
342 // ---------------------------Load input values --------------------------------
343
344 // Load values from the LHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +0100345 LOAD_BLOCK(M0, K0, DATA_TYPE, a, input_ptr, 0, src_stride_y, zin);
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000346 BOUNDARY_CONDITION_X(x, a0);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000347#if M0 > 1
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000348 BOUNDARY_CONDITION_X(x, a1);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000349#endif // M0 > 1
350#if M0 > 2
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000351 BOUNDARY_CONDITION_X(x, a2);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000352#endif // M0 > 2
353#if M0 > 3
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000354 BOUNDARY_CONDITION_X(x, a3);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000355#endif // M0 > 3
356#if M0 > 4
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000357 BOUNDARY_CONDITION_X(x, a4);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000358#endif // M0 > 4
359#if M0 > 5
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000360 BOUNDARY_CONDITION_X(x, a5);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000361#endif // M0 > 5
362#if M0 > 6
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000363 BOUNDARY_CONDITION_X(x, a6);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000364#endif // M0 > 6
365#if M0 > 7
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000366 BOUNDARY_CONDITION_X(x, a7);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000367#endif // M0 > 7
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000368 // ---------------------------Transpose and store block -----------------------
369
370 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 0);
371 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 1);
372#if K0 > 2
373 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 2);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000374#endif // K0 > 2
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000375#if K0 > 3
376 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 3);
377#endif // K0 > 3
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000378#if K0 > 4
379 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 4);
380 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 5);
381 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 6);
382 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 7);
383#endif // K0 > 4
384#if K0 > 8
385 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 8);
386 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 9);
387 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, A);
388 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, B);
389 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, C);
390 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, D);
391 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, E);
392 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, F);
393#endif // K0 > 8
394
395#undef BLOCK_SIZE
396#undef OUTPUT_OFFSET_X
397#undef OUTPUT_STEP_X
398}
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000399#endif // defined(M0) && defined(K0) && defined(V0) && defined(DATA_TYPE) && defined(SRC_WIDTH)
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000400
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000401#if defined(K0) && defined(N0) && defined(H0) && defined(DATA_TYPE) && defined(SRC_HEIGHT)
402/** This OpenCL kernel reshapes the rhs input matrix. The kernel splits the input matrix in blocks of size K0xN0 and stores each one (not transposed) in
403 * the output matrix unrolling the values.
404 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +0100405 * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
406 * @note The height of the input tensor must be passed at compile time using -DSRC_HEIGHT (e.g. -DSRC_HEIGHT=16)
407 * @note The block's dimensions (K0 and N0) must be passed at compile time using -DK0 and -DN0 (e.g. -DK0=2, -DN0=2).
408 * @note The number of K0xN0 vertical blocks to store on the same output row must be passed at compile time using -DH0 (e.g. -DH0=2)
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000409 * @note If the K0xN0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
410 * @note Only the following values for K0, N0 and H0 are supported:
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000411 * N0: 2,3,4,8,16
412 * K0: 1,2,3,4,8,16
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000413 * H0: greater than 0
414 *
415 * @param[in] src_ptr Pointer to the source RHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
416 * @param[in] src_stride_x Stride of the source RHS tensor in X dimension (in bytes)
417 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
418 * @param[in] src_stride_y Stride of the source RHS tensor in Y dimension (in bytes)
419 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
420 * @param[in] src_stride_z Stride of the source RHS tensor in Z dimension (in bytes)
421 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
422 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source RHS tensor
423 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
424 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
425 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
426 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
427 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
428 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
429 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
430 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
431 */
432__kernel void gemm_reshape_rhs_matrix_nt(TENSOR3D_DECLARATION(src),
433 TENSOR3D_DECLARATION(dst))
434{
435 // Block size
436#define BLOCK_SIZE ((K0) * (N0))
437
438 // Output offset X
439#if defined(INTERLEAVE)
440#define OUTPUT_OFFSET_X (N0)
441#else // defined(INTERLEAVE)
442#define OUTPUT_OFFSET_X (BLOCK_SIZE)
443#endif // defined(INTERLEAVE)
444
445 // Output step X
446#if defined(INTERLEAVE)
447#define OUTPUT_STEP_X (N0) * (H0)
448#else // Do not interleave
449#define OUTPUT_STEP_X (N0)
450#endif // defined(INTERLEAVE)
451
452 // Compute source and destination addresses
453 uint x = get_global_id(0);
454 uint y = get_global_id(1);
455 uint z = get_global_id(2);
456
457 // ------------------ Compute input/output addresses ---------------------------
458
459 // Compute the input address
460 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)N0 * sizeof(DATA_TYPE) + y * (uint)K0 * src_stride_y + z * (uint)src_stride_z;
461
462 // Compute the output address
463 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (y * (uint)BLOCK_SIZE * (uint)H0 * sizeof(DATA_TYPE)) + ((x % (uint)H0) * (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE)) + ((
464 x / (uint)H0)
465 * (uint)dst_stride_y)
466 + z * (uint)dst_stride_z;
467
468 // ---------------------------Load input values --------------------------------
469
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +0000470 REPEAT_VAR_INIT_TO_CONST(K0, VEC_DATA_TYPE(DATA_TYPE, N0), a, 0); ////uint a0=0, a1=0, a2=0...a(M0-1)=0;
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000471
472 // Load values from the RHS matrix
473 a0 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y));
474#if K0 > 1
475 if(y * (uint)K0 + 1 < SRC_HEIGHT)
476 {
477 a1 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y));
478 }
479#endif // K0 > 1
480#if K0 > 2
481 if(y * (uint)K0 + 2 < SRC_HEIGHT)
482 {
483 a2 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y));
484 }
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000485#endif // K0 > 2
486#if K0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000487 if(y * (uint)K0 + 3 < SRC_HEIGHT)
488 {
489 a3 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y));
490 }
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000491#endif // K0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000492#if K0 > 4
493 if(y * (uint)K0 + 4 < SRC_HEIGHT)
494 {
495 a4 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 4 * src_stride_y));
496 }
497 if(y * (uint)K0 + 5 < SRC_HEIGHT)
498 {
499 a5 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 5 * src_stride_y));
500 }
501 if(y * (uint)K0 + 6 < SRC_HEIGHT)
502 {
503 a6 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 6 * src_stride_y));
504 }
505 if(y * (uint)K0 + 7 < SRC_HEIGHT)
506 {
507 a7 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 7 * src_stride_y));
508 }
509#endif // K0 > 4
510#if K0 > 8
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000511 if(y * (uint)K0 + 8 < SRC_HEIGHT)
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000512 {
513 a8 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 8 * src_stride_y));
514 }
515 if(y * (uint)K0 + 9 < SRC_HEIGHT)
516 {
517 a9 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 9 * src_stride_y));
518 }
519 if(y * (uint)K0 + 10 < SRC_HEIGHT)
520 {
521 aA = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 10 * src_stride_y));
522 }
523 if(y * (uint)K0 + 11 < SRC_HEIGHT)
524 {
525 aB = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 11 * src_stride_y));
526 }
527 if(y * (uint)K0 + 12 < SRC_HEIGHT)
528 {
529 aC = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 12 * src_stride_y));
530 }
531 if(y * (uint)K0 + 13 < SRC_HEIGHT)
532 {
533 aD = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 13 * src_stride_y));
534 }
535 if(y * (uint)K0 + 14 < SRC_HEIGHT)
536 {
537 aE = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 14 * src_stride_y));
538 }
539 if(y * (uint)K0 + 15 < SRC_HEIGHT)
540 {
541 aF = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 15 * src_stride_y));
542 }
543#endif // K0 > 8
544
545 // ---------------------------Store output values ------------------------------
Usama Arif0681e3b2019-04-25 14:28:07 +0100546 REPEAT_VAR_INIT_TO_CONST(16, uint, zout, 0);
547 STORE_BLOCK(K0, N0, DATA_TYPE, a, output_ptr, OUTPUT_STEP_X * sizeof(DATA_TYPE), zout);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000548
549#undef BLOCK_SIZE
550#undef OUTPUT_OFFSET_X
551#undef OUTPUT_STEP_X
552}
553
554#if defined(TRANSPOSE)
555/** This OpenCL kernel reshapes the rhs input matrix. The kernel splits the input matrix in blocks of size K0xN0 and stores each one (transposed) in
556 * the output matrix unrolling the values.
557 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +0100558 * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
559 * @note The height of the input tensor must be passed at compile time using -DSRC_HEIGHT (e.g. -DSRC_HEIGHT=16)
560 * @note The block's dimensions (K0 and N0) must be passed at compile time using -DK0 and -DN0 (e.g. -DK0=2, -DN0=2).
561 * @note The number of K0xN0 vertical blocks to store on the same output row must be passed at compile time using -DH0 (e.g. -DH0=2)
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000562 * @note If the K0xN0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
563 * @note The option -DTRANSPOSE must passed at compile time.
564 * @note Only the following values for K0, N0 and H0 are supported:
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000565 * N0: 2,3,4,8,16
566 * K0: 2,3,4,8,16
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000567 * H0: greater than 0
568 *
569 * @param[in] src_ptr Pointer to the source RHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
570 * @param[in] src_stride_x Stride of the source RHS tensor in X dimension (in bytes)
571 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
572 * @param[in] src_stride_y Stride of the source RHS tensor in Y dimension (in bytes)
573 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
574 * @param[in] src_stride_z Stride of the source RHS tensor in Z dimension (in bytes)
575 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
576 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source RHS tensor
577 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
578 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
579 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
580 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
581 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
582 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
583 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
584 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
585 */
586__kernel void gemm_reshape_rhs_matrix_t(TENSOR3D_DECLARATION(src),
587 TENSOR3D_DECLARATION(dst))
588{
589 // Block size
590#define BLOCK_SIZE ((K0) * (N0))
591
592 // Output offset X
593#if defined(INTERLEAVE)
594#define OUTPUT_OFFSET_X (K0)
595#else // defined(INTERLEAVE)
596#define OUTPUT_OFFSET_X (BLOCK_SIZE)
597#endif // defined(INTERLEAVE)
598
599 // Output step X
600#if defined(INTERLEAVE)
601#define OUTPUT_STEP_X (K0) * (H0)
602#else // Do not interleave
603#define OUTPUT_STEP_X (K0)
604#endif // defined(INTERLEAVE)
605
606 // Compute source and destination addresses
607 uint x = get_global_id(0);
608 uint y = get_global_id(1);
609 uint z = get_global_id(2);
610
611 // ------------------ Compute input/output addresses ---------------------------
612
613 // Compute the input address
614 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)N0 * sizeof(DATA_TYPE) + y * (uint)K0 * src_stride_y + z * (uint)src_stride_z;
615
616 // Compute the output address
617 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (y * (uint)BLOCK_SIZE * (uint)H0 * sizeof(DATA_TYPE)) + ((x % H0) * (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE)) + ((x /
618 (uint)H0) * (uint)dst_stride_y) + z * (uint)dst_stride_z;
619
620 // ---------------------------Load input values --------------------------------
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +0000621 REPEAT_VAR_INIT_TO_CONST(K0, VEC_DATA_TYPE(DATA_TYPE, N0), a, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) a0=0, a1=0, ... a(K0-1)=0;
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000622
623 // Load values from the RHS matrix
624 a0 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y));
625 if(y * (uint)K0 + 1 < SRC_HEIGHT)
626 {
627 a1 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y));
628 }
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000629#if K0 > 2
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000630 if(y * (uint)K0 + 2 < SRC_HEIGHT)
631 {
632 a2 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y));
633 }
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000634#endif // K0 > 2
635#if K0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000636 if(y * (uint)K0 + 3 < SRC_HEIGHT)
637 {
638 a3 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y));
639 }
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000640#endif // K0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000641#if K0 > 4
642 if(y * (uint)K0 + 4 < SRC_HEIGHT)
643 {
644 a4 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 4 * src_stride_y));
645 }
646 if(y * (uint)K0 + 5 < SRC_HEIGHT)
647 {
648 a5 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 5 * src_stride_y));
649 }
650 if(y * (uint)K0 + 6 < SRC_HEIGHT)
651 {
652 a6 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 6 * src_stride_y));
653 }
654 if(y * (uint)K0 + 7 < SRC_HEIGHT)
655 {
656 a7 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 7 * src_stride_y));
657 }
658#endif // K0 > 4
659#if K0 > 8
Gian Marco Iodice89124342018-12-19 14:17:22 +0000660 if(y * (uint)K0 + 8 < SRC_HEIGHT)
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000661 {
662 a8 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 8 * src_stride_y));
663 }
664 if(y * (uint)K0 + 9 < SRC_HEIGHT)
665 {
666 a9 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 9 * src_stride_y));
667 }
668 if(y * (uint)K0 + 10 < SRC_HEIGHT)
669 {
670 aA = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 10 * src_stride_y));
671 }
672 if(y * (uint)K0 + 11 < SRC_HEIGHT)
673 {
674 aB = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 11 * src_stride_y));
675 }
676 if(y * (uint)K0 + 12 < SRC_HEIGHT)
677 {
678 aC = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 12 * src_stride_y));
679 }
680 if(y * (uint)K0 + 13 < SRC_HEIGHT)
681 {
682 aD = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 13 * src_stride_y));
683 }
684 if(y * (uint)K0 + 14 < SRC_HEIGHT)
685 {
686 aE = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 14 * src_stride_y));
687 }
688 if(y * (uint)K0 + 15 < SRC_HEIGHT)
689 {
690 aF = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 15 * src_stride_y));
691 }
692#endif // K0 > 8
693
694 // ---------------------------Transpose the block ------------------------------
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +0000695 REPEAT_VAR_INIT_TO_CONST(N0, VEC_DATA_TYPE(DATA_TYPE, K0), res, 0); //VEC_DATA_TYPE(DATA_TYPE, K0) res0=0, res1=0, res2=0,... res(N0-1)=0;
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000696
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000697#if K0 == 2
698 // This part computes the following transpositions:
699 // 2x2 -> 2x2
700 // 2x4 -> 4x2
701 // 2x8 -> 8x2
702 // 2x16 -> 16x2
703 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0);
704 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1);
705#if N0 > 2
706 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2);
707#endif // N0 > 2
708#if N0 > 3
709 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3);
710#endif // N0 > 3
711#if N0 > 4
712 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4);
713 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5);
714 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6);
715 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7);
716#endif // N0 > 4
717#if N0 > 8
718 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8);
719 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9);
720 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA);
721 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB);
722 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC);
723 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD);
724 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE);
725 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF);
726#endif // N0 > 8
727
728#elif K0 == 3 // K0 == 2
729 // This part computes the following transpositions:
730 // 3x2 -> 2x3
731 // 3x4 -> 4x3
732 // 3x8 -> 8x3
733 // 3x16 -> 16x3
Georgios Pinitasb0f342e2019-05-21 13:32:43 +0100734 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0);
735 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000736#if N0 > 2
Georgios Pinitasb0f342e2019-05-21 13:32:43 +0100737 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000738#endif // N0 > 2
739#if N0 > 3
Georgios Pinitasb0f342e2019-05-21 13:32:43 +0100740 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000741#endif // N0 > 3
742#if N0 > 4
Georgios Pinitasb0f342e2019-05-21 13:32:43 +0100743 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4);
744 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5);
745 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6);
746 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000747#endif // N0 > 4
748#if N0 > 8
Georgios Pinitasb0f342e2019-05-21 13:32:43 +0100749 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8);
750 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9);
751 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA);
752 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB);
753 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC);
754 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD);
755 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE);
756 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000757#endif // N0 > 8
758
759#elif K0 == 4 // K0 == 4
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000760 // This part computes the following transpositions:
761 // 4x2 -> 2x4
762 // 4x4 -> 4x4
763 // 4x8 -> 8x4
764 // 4x16 -> 16x4
765 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0, a3.s0);
766 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1, a3.s1);
767#if N0 > 2
768 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2, a3.s2);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000769#endif // N0 > 2
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000770#if N0 > 3
771 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3, a3.s3);
772#endif // N0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000773#if N0 > 4
774 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4, a3.s4);
775 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5, a3.s5);
776 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6, a3.s6);
777 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7, a3.s7);
778#endif // N0 > 4
779#if N0 > 8
780 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8, a3.s8);
781 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9, a3.s9);
782 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA, a3.sA);
783 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB, a3.sB);
784 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC, a3.sC);
785 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD, a3.sD);
786 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE, a3.sE);
787 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF, a3.sF);
788#endif // N0 > 8
789
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000790#elif K0 == 8 // K0 == 8
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000791 // This part computes the following transpositions:
792 // 8x2 -> 2x8
793 // 8x4 -> 4x8
794 // 8x8 -> 8x8
795 // 8x16 -> 16x8
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000796 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0, a3.s0, a4.s0, a5.s0, a6.s0, a7.s0);
797 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1, a3.s1, a4.s1, a5.s1, a6.s1, a7.s1);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000798#if N0 > 2
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000799 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2, a3.s2, a4.s2, a5.s2, a6.s2, a7.s2);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000800#endif // N0 > 2
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000801#if N0 > 3
802 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3, a3.s3, a4.s3, a5.s3, a6.s3, a7.s3);
803#endif // N0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000804#if N0 > 4
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000805 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4, a3.s4, a4.s4, a5.s4, a6.s4, a7.s4);
806 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5, a3.s5, a4.s5, a5.s5, a6.s5, a7.s5);
807 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6, a3.s6, a4.s6, a5.s6, a6.s6, a7.s6);
808 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7, a3.s7, a4.s7, a5.s7, a6.s7, a7.s7);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000809#endif // N0 > 4
810#if N0 > 8
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000811 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8, a3.s8, a4.s8, a5.s8, a6.s8, a7.s8);
812 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9, a3.s9, a4.s9, a5.s9, a6.s9, a7.s9);
813 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA, a3.sA, a4.sA, a5.sA, a6.sA, a7.sA);
814 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB, a3.sB, a4.sB, a5.sB, a6.sB, a7.sB);
815 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC, a3.sC, a4.sC, a5.sC, a6.sC, a7.sC);
816 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD, a3.sD, a4.sD, a5.sD, a6.sD, a7.sD);
817 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE, a3.sE, a4.sE, a5.sE, a6.sE, a7.sE);
818 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF, a3.sF, a4.sF, a5.sF, a6.sF, a7.sF);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000819#endif // N0 > 8
820
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000821#elif K0 == 16 // K0 == 16
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000822
823 // This part computes the following transpositions:
824 // 16x2 -> 2x16
825 // 16x4 -> 4x16
826 // 16x8 -> 8x16
827 // 16x16 -> 16x16
828 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0, a3.s0, a4.s0, a5.s0, a6.s0, a7.s0,
829 a8.s0, a9.s0, aA.s0, aB.s0, aC.s0, aD.s0, aE.s0, aF.s0);
830 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1, a3.s1, a4.s1, a5.s1, a6.s1, a7.s1,
831 a8.s1, a9.s1, aA.s1, aB.s1, aC.s1, aD.s1, aE.s1, aF.s1);
832#if N0 > 2
833 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2, a3.s2, a4.s2, a5.s2, a6.s2, a7.s2,
834 a8.s2, a9.s2, aA.s2, aB.s2, aC.s2, aD.s2, aE.s2, aF.s2);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000835#endif // N0 > 2
836#if N0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000837 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3, a3.s3, a4.s3, a5.s3, a6.s3, a7.s3,
838 a8.s3, a9.s3, aA.s3, aB.s3, aC.s3, aD.s3, aE.s3, aF.s3);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000839#endif // N0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000840#if N0 > 4
841 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4, a3.s4, a4.s4, a5.s4, a6.s4, a7.s4,
842 a8.s4, a9.s4, aA.s4, aB.s4, aC.s4, aD.s4, aE.s4, aF.s4);
843 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5, a3.s5, a4.s5, a5.s5, a6.s5, a7.s5,
844 a8.s5, a9.s5, aA.s5, aB.s5, aC.s5, aD.s5, aE.s5, aF.s5);
845 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6, a3.s6, a4.s6, a5.s6, a6.s6, a7.s6,
846 a8.s6, a9.s6, aA.s6, aB.s6, aC.s6, aD.s6, aE.s6, aF.s6);
847 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7, a3.s7, a4.s7, a5.s7, a6.s7, a7.s7,
848 a8.s7, a9.s7, aA.s7, aB.s7, aC.s7, aD.s7, aE.s7, aF.s7);
849#endif // N0 > 4
850#if N0 > 8
851 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8, a3.s8, a4.s8, a5.s8, a6.s8, a7.s8,
852 a8.s8, a9.s8, aA.s8, aB.s8, aC.s8, aD.s8, aE.s8, aF.s8);
853 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9, a3.s9, a4.s9, a5.s9, a6.s9, a7.s9,
854 a8.s9, a9.s9, aA.s9, aB.s9, aC.s9, aD.s9, aE.s9, aF.s9);
855 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA, a3.sA, a4.sA, a5.sA, a6.sA, a7.sA,
856 a8.sA, a9.sA, aA.sA, aB.sA, aC.sA, aD.sA, aE.sA, aF.sA);
857 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB, a3.sB, a4.sB, a5.sB, a6.sB, a7.sB,
858 a8.sB, a9.sB, aA.sB, aB.sB, aC.sB, aD.sB, aE.sB, aF.sB);
859 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC, a3.sC, a4.sC, a5.sC, a6.sC, a7.sC,
860 a8.sC, a9.sC, aA.sC, aB.sC, aC.sC, aD.sC, aE.sC, aF.sC);
861 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD, a3.sD, a4.sD, a5.sD, a6.sD, a7.sD,
862 a8.sD, a9.sD, aA.sD, aB.sD, aC.sD, aD.sD, aE.sD, aF.sD);
863 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE, a3.sE, a4.sE, a5.sE, a6.sE, a7.sE,
864 a8.sE, a9.sE, aA.sE, aB.sE, aC.sE, aD.sE, aE.sE, aF.sE);
865 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF, a3.sF, a4.sF, a5.sF, a6.sF, a7.sF,
866 a8.sF, a9.sF, aA.sF, aB.sF, aC.sF, aD.sF, aE.sF, aF.sF);
867#endif // N0 > 8
868
869#else // N0 == 16
870#error "Not supported N0 value"
871#endif // N0 > 2
872
873 // ---------------------------Store the output values ------------------------------
Usama Arif0681e3b2019-04-25 14:28:07 +0100874 REPEAT_VAR_INIT_TO_CONST(16, uint, zout, 0);
875 STORE_BLOCK(N0, K0, DATA_TYPE, res, output_ptr, OUTPUT_STEP_X * sizeof(DATA_TYPE), zout);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000876
877#undef BLOCK_SIZE
878#undef OUTPUT_OFFSET_X
879#undef OUTPUT_STEP_X
880}
881#endif // defined(TRANSPOSE)
882#endif // defined(K0) && defined(N0) && defined(H0) && defined(DATA_TYPE) && defined(SRC_HEIGHT)
883
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +0000884#if defined(M0) && defined(N0) && defined(K0) && defined(H0) && defined(DATA_TYPE) && defined(M) && defined(N) && defined(K)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +0000885
886#define CONCAT(a, b) a##b
887
888#define ARM_DOT1(a, b, c) \
889 ({ \
890 c = fma(a, b, c); \
891 })
892#define ARM_DOT2(a, b, c) \
893 ({ \
894 c = fma(a.s0, b.s0, c); \
895 c = fma(a.s1, b.s1, c); \
896 })
897#define ARM_DOT3(a, b, c) \
898 ({ \
899 ARM_DOT2(a, b, c); \
900 c = fma((a.s2), (b.s2), c); \
901 })
902#define ARM_DOT4(a, b, c) \
903 ({ \
904 ARM_DOT3(a, b, c); \
905 c = fma((a.s3), (b.s3), c); \
906 })
907#define ARM_DOT8(a, b, c) \
908 ({ \
909 ARM_DOT4((a.lo), (b.lo), c); \
910 ARM_DOT4((a.hi), (b.hi), c); \
911 })
912#define ARM_DOT16(a, b, c) \
913 ({ \
914 ARM_DOT8((a.lo), (b.lo), c); \
915 ARM_DOT8((a.hi), (b.hi), c); \
916 })
917
918#if N0 == 2
919#define ARM_DOT_K0XN0(k0, a, b, c) \
920 ({ \
921 CONCAT(ARM_DOT, k0) \
922 ((a), (b##0), (c.s0)); \
923 CONCAT(ARM_DOT, k0) \
924 ((a), (b##1), (c.s1)); \
925 })
926#elif N0 == 3 // N0 == 3
927#define ARM_DOT_K0XN0(k0, a, b, c) \
928 ({ \
929 CONCAT(ARM_DOT, k0) \
930 ((a), (b##0), (c.s0)); \
931 CONCAT(ARM_DOT, k0) \
932 ((a), (b##1), (c.s1)); \
933 CONCAT(ARM_DOT, k0) \
934 ((a), (b##2), (c.s2)); \
935 })
936#elif N0 == 4 // N0 == 4
937#define ARM_DOT_K0XN0(k0, a, b, c) \
938 ({ \
939 CONCAT(ARM_DOT, k0) \
940 ((a), (b##0), (c.s0)); \
941 CONCAT(ARM_DOT, k0) \
942 ((a), (b##1), (c.s1)); \
943 CONCAT(ARM_DOT, k0) \
944 ((a), (b##2), (c.s2)); \
945 CONCAT(ARM_DOT, k0) \
946 ((a), (b##3), (c.s3)); \
947 })
948#elif N0 == 8 // N0 == 8
949#define ARM_DOT_K0XN0(k0, a, b, c) \
950 ({ \
951 CONCAT(ARM_DOT, k0) \
952 ((a), (b##0), (c.s0)); \
953 CONCAT(ARM_DOT, k0) \
954 ((a), (b##1), (c.s1)); \
955 CONCAT(ARM_DOT, k0) \
956 ((a), (b##2), (c.s2)); \
957 CONCAT(ARM_DOT, k0) \
958 ((a), (b##3), (c.s3)); \
959 CONCAT(ARM_DOT, k0) \
960 ((a), (b##4), (c.s4)); \
961 CONCAT(ARM_DOT, k0) \
962 ((a), (b##5), (c.s5)); \
963 CONCAT(ARM_DOT, k0) \
964 ((a), (b##6), (c.s6)); \
965 CONCAT(ARM_DOT, k0) \
966 ((a), (b##7), (c.s7)); \
967 })
968#elif N0 == 16 // N0 == 16
969#define ARM_DOT_K0XN0(k0, a, b, c) \
970 ({ \
971 CONCAT(ARM_DOT, k0) \
972 ((a), (b##0), (c.s0)); \
973 CONCAT(ARM_DOT, k0) \
974 ((a), (b##1), (c.s1)); \
975 CONCAT(ARM_DOT, k0) \
976 ((a), (b##2), (c.s2)); \
977 CONCAT(ARM_DOT, k0) \
978 ((a), (b##3), (c.s3)); \
979 CONCAT(ARM_DOT, k0) \
980 ((a), (b##4), (c.s4)); \
981 CONCAT(ARM_DOT, k0) \
982 ((a), (b##5), (c.s5)); \
983 CONCAT(ARM_DOT, k0) \
984 ((a), (b##6), (c.s6)); \
985 CONCAT(ARM_DOT, k0) \
986 ((a), (b##7), (c.s7)); \
987 CONCAT(ARM_DOT, k0) \
988 ((a), (b##8), (c.s8)); \
989 CONCAT(ARM_DOT, k0) \
990 ((a), (b##9), (c.s9)); \
991 CONCAT(ARM_DOT, k0) \
992 ((a), (b##A), (c.sA)); \
993 CONCAT(ARM_DOT, k0) \
994 ((a), (b##B), (c.sB)); \
995 CONCAT(ARM_DOT, k0) \
996 ((a), (b##C), (c.sC)); \
997 CONCAT(ARM_DOT, k0) \
998 ((a), (b##D), (c.sD)); \
999 CONCAT(ARM_DOT, k0) \
1000 ((a), (b##E), (c.sE)); \
1001 CONCAT(ARM_DOT, k0) \
1002 ((a), (b##F), (c.sF)); \
1003 })
1004#else // N0 not supported
1005#error "N0 value not supported"
1006#endif // N0 conditions
1007
1008/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
1009 * The LHS matrix is NOT reshaped
1010 * The RHS is reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the block K0xN0 is transposed
1011 *
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001012 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01001013 * @note The GEMM's dimensions (M,N and K) must be passed at compile time using -DM, -DN and and -DK (e.g. -DM=52, -DN=30 and -DK=90)
1014 * @note The number of columns of LHS matrix must be passed at compile time using -DK (e.g. -DK=64)
1015 * @note The block's dimensions used for reshaping the RHS matrix (N0 and K0) must be passed at compile time using -DN0 and -DK0 (e.g. -DN0=8, -DK0=4).
1016 * @note The number of M0 rows to process must be passed at compile time using -DM0 (e.g. -DM0=2)
1017 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (e.g. -DH0=2)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001018 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
1019 * @note Only the following configurations of M0, N0 and K0 are currently supported:
1020 * - M0 = 1, 2, 3, 4, 5, 6, 7, 8
1021 * - N0 = 2, 3, 4, 8, 16
1022 * - K0 = 2, 3, 4, 8, 16
Gian Marco Iodice62251f72019-03-11 16:07:12 +00001023 * - H0 >= 1
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001024 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01001025 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +01001026 * The activation function is performed after the bias addition
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001027 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
1028 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
1029 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1030 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1031 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1032 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix
1033 *
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001034 * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: F16/F32
1035 * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes)
1036 * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1037 * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes)
1038 * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1039 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
1040 * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
1041 * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes)
1042 * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1043 * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
1044 * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1045 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001046 * @param[in] bias_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
1047 * @param[in] bias_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
1048 * @param[in] bias_step_x (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes)
1049 * @param[in] bias_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
1050 * @param[in] bias_step_y (Optional) bias_stride_y * number of elements along Y processed per workitem(in bytes)
1051 * @param[in] bias_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001052 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
1053 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1054 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1055 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1056 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1057 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1058 * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes)
1059 * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001060 * @param[in] bias_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001061 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1062 * @param[in] lhs_cross_plane_pad (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
1063 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001064 */
1065__kernel void gemm_mm_reshaped_only_rhs_t(IMAGE_DECLARATION(lhs),
1066 IMAGE_DECLARATION(rhs),
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001067#if defined(BETA)
1068 IMAGE_DECLARATION(bias),
1069#endif // defined(BETA)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001070 IMAGE_DECLARATION(dst),
1071 uint lhs_stride_z,
1072 uint rhs_stride_z,
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001073#if defined(BETA)
1074 uint bias_stride_z,
1075#endif //defined(BETA)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001076 uint dst_stride_z
1077#if defined(REINTERPRET_INPUT_AS_3D)
1078 ,
1079 uint lhs_cross_plane_pad
1080#endif // REINTERPRET_INPUT_AS_3D
1081#if defined(REINTERPRET_OUTPUT_AS_3D)
1082 ,
1083 uint dst_cross_plane_pad
1084#endif // REINTERPRET_OUTPUT_AS_3D
1085 )
1086{
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001087 // Block size
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001088#define RHS_BLOCK_SIZE ((K0) * (N0))
1089
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001090 // RHS offset and step X
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001091#if defined(RHS_INTERLEAVE)
1092#define RHS_OFFSET_X (K0)
1093#define RHS_STEP_X ((K0) * (H0))
1094#define RHS_STEP_LOOP (1)
1095#else // defined(RHS_INTERLEAVE)
1096#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
1097#define RHS_STEP_X (K0)
1098#define RHS_STEP_LOOP (H0)
1099#endif // defined(RHS_INTERLEAVE)
1100
1101 uint x = get_global_id(0);
1102 uint y = get_global_id(1);
1103 uint z = get_global_id(2);
1104
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001105#if defined(DUMMY_WORK_ITEMS)
1106 if((x * N0 >= N) || (y * M0 >= M))
1107 {
1108 return;
1109 }
1110#endif // defined(DUMMY_WORK_ITEMS)
1111
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001112 // Compute LHS matrix address
1113 uint lhs_offset = lhs_offset_first_element_in_bytes + y * M0 * (uint)lhs_stride_y;
1114
1115 // Compute RHS matrix address
1116 uint rhs_offset = rhs_offset_first_element_in_bytes + (x % H0) * (uint)RHS_OFFSET_X * sizeof(DATA_TYPE) + (x / (uint)H0) * rhs_stride_y;
1117
1118#if defined(MATRIX_B_DEPTH)
1119 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1120 rhs_offset += (z % MATRIX_B_DEPTH) * rhs_stride_z;
1121#else // defined(MATRIX_B_DEPTH)
1122 rhs_offset += z * rhs_stride_z;
1123#endif // defined(MATRIX_B_DEPTH)
1124
Usama Arif0681e3b2019-04-25 14:28:07 +01001125 REPEAT_VAR_INIT_TO_CONST(8, uint, zlhs, 0); //uint zlhs0=0,zlhs1=0,zlhs2=0,... zlhs7=0;
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001126 REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001127
1128#if defined(REINTERPRET_INPUT_AS_3D)
Usama Arif0681e3b2019-04-25 14:28:07 +01001129 // The plane (zlhs) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
1130 CALCULATE_Z_OFFSET(M0, uint, zlhs, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001131
1132 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1133 // multiply lhs_stride_z by DEPTH_GEMM3D
1134 lhs_offset += z * lhs_stride_z * DEPTH_GEMM3D;
1135
1136#else // defined(REINTERPRET_INPUT_AS_3D)
1137
1138 // Add offset for batched GEMM
1139 lhs_offset += z * lhs_stride_z;
1140
1141#endif // defined(REINTERPRET_INPUT_AS_3D)
1142
1143 // Initialize the accumulators
1144 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(M0-1)=0;
1145
1146 int i = 0;
1147 for(; i <= (K - K0); i += K0)
1148 {
1149 // Supported cases (M0, K0):
1150 // 1,2 - 1,3 - 1,4 - 1,8 - 1,16
1151 // 2,2 - 2,3 - 2,4 - 2,8 - 2,16
1152 // 3,2 - 3,3 - 3,4 - 3,8 - 3,16
1153 // 4,2 - 4,3 - 4,4 - 4,8 - 4,16
1154 // 5,2 - 5,3 - 5,4 - 5,8 - 5,16
1155 // 6,2 - 6,3 - 6,4 - 6,8 - 6,16
1156 // 7,2 - 7,3 - 7,4 - 7,8 - 7,16
1157 // 8,2 - 8,3 - 8,4 - 8,8 - 8,16
1158 // Load values from LHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +01001159 LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001160
1161 // Load values from RHS matrix
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001162 LOAD_BLOCK(N0, K0, DATA_TYPE, b, rhs_ptr, rhs_offset, RHS_STEP_X * sizeof(DATA_TYPE), zero);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001163
1164 // Accumulate
1165 ARM_DOT_K0XN0(K0, a0, b, c0);
1166#if M0 > 1
1167 ARM_DOT_K0XN0(K0, a1, b, c1);
1168#endif // M0 > 1
1169#if M0 > 2
1170 ARM_DOT_K0XN0(K0, a2, b, c2);
1171#endif // M0 > 2
1172#if M0 > 3
1173 ARM_DOT_K0XN0(K0, a3, b, c3);
1174#endif // M0 > 3
1175#if M0 > 4
1176 ARM_DOT_K0XN0(K0, a4, b, c4);
1177#endif // M0 > 4
1178#if M0 > 5
1179 ARM_DOT_K0XN0(K0, a5, b, c5);
1180#endif // M0 > 5
1181#if M0 > 6
1182 ARM_DOT_K0XN0(K0, a6, b, c6);
1183#endif // M0 > 6
1184#if M0 > 7
1185 ARM_DOT_K0XN0(K0, a7, b, c7);
1186#endif // M0 > 7
1187
1188 lhs_offset += K0 * sizeof(DATA_TYPE);
1189 rhs_offset += (N0 * RHS_STEP_X * RHS_STEP_LOOP) * sizeof(DATA_TYPE);
1190 }
1191
1192 // Left-over accumulations
1193 for(; i < K; ++i)
1194 {
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001195 // Load values from LHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +01001196 LOAD_BLOCK(M0, 1, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001197
1198 // Load values from RHS matrix
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001199 LOAD_BLOCK(N0, 1, DATA_TYPE, b, rhs_ptr, rhs_offset, RHS_STEP_X * sizeof(DATA_TYPE), zero);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001200
1201 // Accumulate
1202 ARM_DOT_K0XN0(1, a0, b, c0);
1203#if M0 > 1
1204 ARM_DOT_K0XN0(1, a1, b, c1);
1205#endif // M0 > 1
1206#if M0 > 2
1207 ARM_DOT_K0XN0(1, a2, b, c2);
1208#endif // M0 > 2
1209#if M0 > 3
1210 ARM_DOT_K0XN0(1, a3, b, c3);
1211#endif // M0 > 3
1212#if M0 > 4
1213 ARM_DOT_K0XN0(1, a4, b, c4);
1214#endif // M0 > 4
1215#if M0 > 5
1216 ARM_DOT_K0XN0(1, a5, b, c5);
1217#endif // M0 > 5
1218#if M0 > 6
1219 ARM_DOT_K0XN0(1, a6, b, c6);
1220#endif // M0 > 6
1221#if M0 > 7
1222 ARM_DOT_K0XN0(1, a7, b, c7);
1223#endif // M0 > 7
1224
1225 lhs_offset += sizeof(DATA_TYPE);
1226 rhs_offset += sizeof(DATA_TYPE);
1227 }
1228
1229 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * dst_stride_y);
1230
1231 REPEAT_VAR_INIT_TO_CONST(8, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
1232
1233#if defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001234
1235 // The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
Usama Arif0681e3b2019-04-25 14:28:07 +01001236 CALCULATE_Z_OFFSET(M0, uint, zout, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001237
1238 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1239 // multiply dst_stride_z by DEPTH_GEMM3D
1240 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1241
1242#else // defined(REINTERPRET_OUTPUT_AS_3D)
1243
1244 // Add offset for batched GEMM
1245 dst_addr += z * dst_stride_z;
1246
1247#endif // defined(REINTERPRET_OUTPUT_AS_3D)
1248
1249 // Multiply by the weight of matrix-matrix product and store the result
1250#if defined(ALPHA)
Usama Arif0681e3b2019-04-25 14:28:07 +01001251 SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001252#endif // defined(ALPHA)
1253
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001254 // Add beta*bias
1255#if defined(BETA)
1256#if defined(BROADCAST_BIAS)
1257 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
1258
1259 LOAD_BLOCK(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
1260
1261#ifndef UNIT_BETA
1262 SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
1263#endif // UNIT_BIAS
1264
1265 // c = c + bias[broadcasted]
1266 ADD_BLOCK_BROADCAST(M0, c, bias0);
1267
1268#else // defined(BROADCAST_BIAS)
1269 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id(
1270 2) * bias_stride_z;
1271
1272 LOAD_BLOCK(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
1273
1274#ifndef UNIT_BETA
1275 SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
1276#endif // UNIT_BIAS
1277
1278 // c = c + bias
1279 ADD_BLOCK(M0, c, bias);
1280
1281#endif // defined(BROADCAST_BIAS)
1282#endif // defined(BETA)
1283
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +01001284#if defined(ACTIVATION_TYPE)
1285 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE, c, A_VAL, B_VAL);
1286#endif // defined(ACTIVATION_TYPE)
1287
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001288 // Store output block
Usama Arif0681e3b2019-04-25 14:28:07 +01001289 STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001290
1291#undef RHS_BLOCK_SIZE
1292#undef RHS_OFFSET_X
1293#undef RHS_STEP_X
1294}
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001295
1296#define VFMA(a, b, c) \
1297 ({ \
1298 c = fma(a, b, c); \
1299 })
1300
1301#if M0 == 1
1302#define LD_RHS_VFMA_M0xN0(i, a, c) \
1303 ({ \
1304 VEC_DATA_TYPE(DATA_TYPE, N0) \
1305 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1306 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1307 })
1308#elif M0 == 2 // M0 == 2
1309#define LD_RHS_VFMA_M0xN0(i, a, c) \
1310 ({ \
1311 VEC_DATA_TYPE(DATA_TYPE, N0) \
1312 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1313 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1314 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1315 })
1316#elif M0 == 3 // M0 == 3
1317#define LD_RHS_VFMA_M0xN0(i, a, c) \
1318 ({ \
1319 VEC_DATA_TYPE(DATA_TYPE, N0) \
1320 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1321 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1322 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1323 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1324 })
1325#elif M0 == 4 // M0 == 4
1326#define LD_RHS_VFMA_M0xN0(i, a, c) \
1327 ({ \
1328 VEC_DATA_TYPE(DATA_TYPE, N0) \
1329 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1330 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1331 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1332 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1333 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1334 })
1335#elif M0 == 5 // M0 == 5
1336#define LD_RHS_VFMA_M0xN0(i, a, c) \
1337 ({ \
1338 VEC_DATA_TYPE(DATA_TYPE, N0) \
1339 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1340 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1341 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1342 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1343 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1344 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
1345 })
1346#elif M0 == 6 // M0 == 6
1347#define LD_RHS_VFMA_M0xN0(i, a, c) \
1348 ({ \
1349 VEC_DATA_TYPE(DATA_TYPE, N0) \
1350 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1351 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1352 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1353 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1354 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1355 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
1356 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
1357 })
1358#elif M0 == 7 // M0 == 7
1359#define LD_RHS_VFMA_M0xN0(i, a, c) \
1360 ({ \
1361 VEC_DATA_TYPE(DATA_TYPE, N0) \
1362 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1363 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1364 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1365 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1366 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1367 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
1368 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
1369 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##6).s##i), b, (c##6)); \
1370 })
1371#elif M0 == 8 // M0 == 8
1372#define LD_RHS_VFMA_M0xN0(i, a, c) \
1373 ({ \
1374 VEC_DATA_TYPE(DATA_TYPE, N0) \
1375 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1376 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1377 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1378 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1379 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1380 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
1381 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
1382 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##6).s##i), b, (c##6)); \
1383 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##7).s##i), b, (c##7)); \
1384 })
1385#else // M0 not supported
1386#error "M0 not supported"
1387#endif // M0 not supported
1388
1389/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
1390 * The LHS matrix is NOT reshaped
1391 * The RHS is reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the block K0xN0 is NOT transposed
1392 *
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001393 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01001394 * @note The GEMM's dimensions (M,N and K) must be passed at compile time using -DM, -DN and and -DK (e.g. -DM=52, -DN=30 and -DK=90).
1395 * @note The block's dimensions used for reshaping the RHS matrix (N0 and K0) must be passed at compile time using -DN0 and -DK0 (e.g. -DN0=8, -DK0=4).
1396 * @note The number of M0 rows to process must be passed at compile time using -DM0 (e.g. -DM0=2)
1397 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (e.g. -DH0=2)
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001398 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
1399 * @note Only the following configurations of M0, N0 and K0 are currently supported:
1400 * - M0 = 1, 2, 3, 4, 5, 6, 7, 8
1401 * - N0 = 2, 3, 4, 8, 16
1402 * - K0 = 2, 3, 4, 8, 16
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001403 * - H0 >= 1
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001404 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01001405 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +01001406 * The activation function is performed after the bias addition
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001407 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
1408 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
1409 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1410 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1411 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1412 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix
1413 *
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001414 * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: F16/F32
1415 * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes)
1416 * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1417 * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes)
1418 * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1419 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
1420 * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
1421 * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes)
1422 * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1423 * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
1424 * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1425 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001426 * @param[in] bias_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
1427 * @param[in] bias_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001428 * @param[in] bias_step_x (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes)
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001429 * @param[in] bias_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001430 * @param[in] bias_step_y (Optional) bias_stride_y * number of elements along Y processed per workitem(in bytes)
1431 * @param[in] bias_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
1432 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
1433 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1434 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1435 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1436 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1437 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1438 * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes)
1439 * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001440 * @param[in] bias_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001441 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1442 * @param[in] lhs_cross_plane_pad (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
1443 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001444 */
1445__kernel void gemm_mm_reshaped_only_rhs_nt(IMAGE_DECLARATION(lhs),
1446 IMAGE_DECLARATION(rhs),
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001447#if defined(BETA)
1448 IMAGE_DECLARATION(bias),
1449#endif // defined(BETA)
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001450 IMAGE_DECLARATION(dst),
1451 uint lhs_stride_z,
1452 uint rhs_stride_z,
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001453#if defined(BETA)
1454 uint bias_stride_z,
1455#endif //defined(BETA)
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001456 uint dst_stride_z
1457#if defined(REINTERPRET_INPUT_AS_3D)
1458 ,
1459 uint lhs_cross_plane_pad
1460#endif // REINTERPRET_INPUT_AS_3D
1461#if defined(REINTERPRET_OUTPUT_AS_3D)
1462 ,
1463 uint dst_cross_plane_pad
1464#endif // REINTERPRET_OUTPUT_AS_3D
1465 )
1466{
1467 // Block size
1468#define RHS_BLOCK_SIZE ((K0) * (N0))
1469
1470 // RHS offset and step X
1471#if defined(RHS_INTERLEAVE)
1472#define RHS_OFFSET_X (N0)
1473#define RHS_STEP_X ((N0) * (H0))
1474#define RHS_STEP_LOOP (1)
1475#else // defined(RHS_INTERLEAVE)
1476#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
1477#define RHS_STEP_X (N0)
1478#define RHS_STEP_LOOP (H0)
1479#endif // defined(RHS_INTERLEAVE)
1480
1481 uint x = get_global_id(0);
1482 uint y = get_global_id(1);
1483 uint z = get_global_id(2);
1484
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001485#if defined(DUMMY_WORK_ITEMS)
1486 if((x * N0 >= N) || (y * M0 >= M))
1487 {
1488 return;
1489 }
1490#endif // defined(DUMMY_WORK_ITEMS)
1491
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001492 // Compute LHS matrix address
1493 uint lhs_offset = lhs_offset_first_element_in_bytes + y * M0 * (uint)lhs_stride_y;
1494
1495 // Compute RHS matrix address
1496 uint rhs_offset = rhs_offset_first_element_in_bytes + (x % H0) * (uint)RHS_OFFSET_X * sizeof(DATA_TYPE) + (x / (uint)H0) * rhs_stride_y;
1497
1498#if defined(MATRIX_B_DEPTH)
1499 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1500 rhs_offset += (z % MATRIX_B_DEPTH) * rhs_stride_z;
1501#else // defined(MATRIX_B_DEPTH)
1502 rhs_offset += z * rhs_stride_z;
1503#endif // defined(MATRIX_B_DEPTH)
1504
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001505 REPEAT_VAR_INIT_TO_CONST(8, uint, zin, 0); //uint zin0=0,zin1=0,zin2=0,... zin7=0;
1506 REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0); //uint zero0=0,zero1=0,zero2=0,... zero7=0;
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001507
1508#if defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001509
1510 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
Usama Arif0681e3b2019-04-25 14:28:07 +01001511 CALCULATE_Z_OFFSET(M0, uint, zin, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001512
1513 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1514 // multiply lhs_stride_z by DEPTH_GEMM3D
1515 lhs_offset += z * lhs_stride_z * DEPTH_GEMM3D;
1516
1517#else // defined(REINTERPRET_INPUT_AS_3D)
1518
1519 // Add offset for batched GEMM
1520 lhs_offset += z * lhs_stride_z;
1521
1522#endif // defined(REINTERPRET_INPUT_AS_3D)
1523
1524 // Initialize the accumulators
1525 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(N0-1)=0;
1526
1527 int i = 0;
1528 for(; i <= (K - K0); i += K0)
1529 {
1530 // Supported cases (M0, K0):
1531 // 1,2 - 1,3 - 1,4 - 1,8 - 1,16
1532 // 2,2 - 2,3 - 2,4 - 2,8 - 2,16
1533 // 3,2 - 3,3 - 3,4 - 3,8 - 3,16
1534 // 4,2 - 4,3 - 4,4 - 4,8 - 4,16
1535 // 5,2 - 5,3 - 5,4 - 5,8 - 5,16
1536 // 6,2 - 6,3 - 6,4 - 6,8 - 6,16
1537 // 7,2 - 7,3 - 7,4 - 7,8 - 7,16
1538 // 8,2 - 8,3 - 8,4 - 8,8 - 8,16
1539 // Load values from LHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +01001540 LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zin);
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001541
1542 LD_RHS_VFMA_M0xN0(0, a, c);
1543 LD_RHS_VFMA_M0xN0(1, a, c);
1544#if K0 > 2
1545 LD_RHS_VFMA_M0xN0(2, a, c);
1546#endif // K0 > 2
1547#if K0 > 3
1548 LD_RHS_VFMA_M0xN0(3, a, c);
1549#endif // K0 > 3
1550#if K0 > 4
1551 LD_RHS_VFMA_M0xN0(4, a, c);
1552 LD_RHS_VFMA_M0xN0(5, a, c);
1553 LD_RHS_VFMA_M0xN0(6, a, c);
1554 LD_RHS_VFMA_M0xN0(7, a, c);
1555#endif // K0 > 4
1556#if K0 > 8
1557 LD_RHS_VFMA_M0xN0(8, a, c);
1558 LD_RHS_VFMA_M0xN0(9, a, c);
1559 LD_RHS_VFMA_M0xN0(A, a, c);
1560 LD_RHS_VFMA_M0xN0(B, a, c);
1561 LD_RHS_VFMA_M0xN0(C, a, c);
1562 LD_RHS_VFMA_M0xN0(D, a, c);
1563 LD_RHS_VFMA_M0xN0(E, a, c);
1564 LD_RHS_VFMA_M0xN0(F, a, c);
1565#endif // K0 > 8
1566
1567 lhs_offset += K0 * sizeof(DATA_TYPE);
1568 rhs_offset += K0 * RHS_STEP_X * RHS_STEP_LOOP * sizeof(DATA_TYPE);
1569 }
1570
1571 // Left-over accumulations
1572 for(; i < K; ++i)
1573 {
1574 // Load values from LHS matrix
1575 VEC_DATA_TYPE(DATA_TYPE, 2)
1576 a0 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 0 * lhs_stride_y + zin0));
1577#if M0 > 1
1578 VEC_DATA_TYPE(DATA_TYPE, 2)
1579 a1 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 1 * lhs_stride_y + zin1));
1580#endif // M0 > 1
1581#if M0 > 2
1582 VEC_DATA_TYPE(DATA_TYPE, 2)
1583 a2 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 2 * lhs_stride_y + zin2));
1584#endif // M0 > 2
1585#if M0 > 3
1586 VEC_DATA_TYPE(DATA_TYPE, 2)
1587 a3 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 3 * lhs_stride_y + zin3));
1588#endif // M0 > 3
1589#if M0 > 4
1590 VEC_DATA_TYPE(DATA_TYPE, 2)
1591 a4 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 4 * lhs_stride_y + zin4));
1592#endif // M0 > 4
1593#if M0 > 5
1594 VEC_DATA_TYPE(DATA_TYPE, 2)
1595 a5 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 5 * lhs_stride_y + zin5));
1596#endif // M0 > 5
1597#if M0 > 6
1598 VEC_DATA_TYPE(DATA_TYPE, 2)
1599 a6 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 6 * lhs_stride_y + zin6));
1600#endif // M0 > 6
1601#if M0 > 7
1602 VEC_DATA_TYPE(DATA_TYPE, 2)
giuros01b3204e72019-04-01 13:50:22 +01001603 a7 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 7 * lhs_stride_y + zin7));
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001604#endif // M0 > 7
1605
1606 LD_RHS_VFMA_M0xN0(0, a, c);
1607
1608 lhs_offset += sizeof(DATA_TYPE);
1609 rhs_offset += RHS_STEP_X * sizeof(DATA_TYPE);
1610 }
1611
1612 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * dst_stride_y);
1613
1614 REPEAT_VAR_INIT_TO_CONST(8, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
1615
1616#if defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001617 // The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
Usama Arif0681e3b2019-04-25 14:28:07 +01001618 CALCULATE_Z_OFFSET(M0, uint, zout, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001619
1620 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1621 // multiply dst_stride_z by DEPTH_GEMM3D
1622 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1623
1624#else // defined(REINTERPRET_OUTPUT_AS_3D)
1625
1626 // Add offset for batched GEMM
1627 dst_addr += z * dst_stride_z;
1628
1629#endif // defined(REINTERPRET_OUTPUT_AS_3D)
1630
1631 // Multiply by the weight of matrix-matrix product and store the result
1632#if defined(ALPHA)
Usama Arif0681e3b2019-04-25 14:28:07 +01001633 SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001634#endif // defined(ALPHA)
1635
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001636 // Add beta*bias
1637#if defined(BETA)
1638#if defined(BROADCAST_BIAS)
1639 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
1640
1641 LOAD_BLOCK(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
1642
1643#ifndef UNIT_BETA
1644 SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
1645#endif // UNIT_BIAS
1646
1647 // c = c + bias[broadcasted]
1648 ADD_BLOCK_BROADCAST(M0, c, bias0);
1649
1650#else // defined(BROADCAST_BIAS)
1651 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id(
1652 2) * bias_stride_z;
1653
1654 LOAD_BLOCK(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
1655
1656#ifndef UNIT_BETA
1657 SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
1658#endif // UNIT_BIAS
1659
1660 // c = c + bias
1661 ADD_BLOCK(M0, c, bias);
1662
1663#endif // defined(BROADCAST_BIAS)
1664#endif // defined(BETA)
1665
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +01001666#if defined(ACTIVATION_TYPE)
1667 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE, c, A_VAL, B_VAL);
1668#endif // defined(ACTIVATION_TYPE)
1669
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001670 // Store output block
Usama Arif0681e3b2019-04-25 14:28:07 +01001671 STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001672
1673#undef RHS_BLOCK_SIZE
1674#undef RHS_OFFSET_X
1675#undef RHS_STEP_X
1676}
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001677#endif // defined(M0) && defined(N0) && defined(K0) && defined(H0) && defined(DATA_TYPE) && defined(M) && defined(N) && defined(K)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001678
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001679#if defined(M0) && defined(N0) && defined(K0) && defined(V0) && defined(H0) && defined(DATA_TYPE) && defined(M) && defined(N)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001680
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001681#if K0 == 2
1682#define ARM_DOT_K0(a, b, c) \
1683 ({ \
1684 c = fma(a.s0, b.s0, c); \
1685 c = fma(a.s1, b.s1, c); \
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001686 })
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001687#elif K0 == 3 // K0 == 3
1688#define ARM_DOT_K0(a, b, c) \
1689 ({ \
1690 c = fma(a.s0, b.s0, c); \
1691 c = fma(a.s1, b.s1, c); \
1692 c = fma(a.s2, b.s2, c); \
1693 })
1694#elif K0 == 4 // K0 == 4
1695#define ARM_DOT_K0(a, b, c) \
1696 ({ \
1697 c = fma(a.s0, b.s0, c); \
1698 c = fma(a.s1, b.s1, c); \
1699 c = fma(a.s2, b.s2, c); \
1700 c = fma(a.s3, b.s3, c); \
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001701 })
1702#elif K0 == 8 // K0 == 8
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001703#define ARM_DOT_K0(a, b, c) \
1704 ({ \
1705 c = fma(a.s0, b.s0, c); \
1706 c = fma(a.s1, b.s1, c); \
1707 c = fma(a.s2, b.s2, c); \
1708 c = fma(a.s3, b.s3, c); \
1709 c = fma(a.s4, b.s4, c); \
1710 c = fma(a.s5, b.s5, c); \
1711 c = fma(a.s6, b.s6, c); \
1712 c = fma(a.s7, b.s7, c); \
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001713 })
1714#elif K0 == 16 // K0 == 16
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001715#define ARM_DOT_K0(a, b, c) \
1716 ({ \
1717 c = fma(a.s0, b.s0, c); \
1718 c = fma(a.s1, b.s1, c); \
1719 c = fma(a.s2, b.s2, c); \
1720 c = fma(a.s3, b.s3, c); \
1721 c = fma(a.s4, b.s4, c); \
1722 c = fma(a.s5, b.s5, c); \
1723 c = fma(a.s6, b.s6, c); \
1724 c = fma(a.s7, b.s7, c); \
1725 c = fma(a.s8, b.s8, c); \
1726 c = fma(a.s9, b.s9, c); \
1727 c = fma(a.sA, b.sA, c); \
1728 c = fma(a.sB, b.sB, c); \
1729 c = fma(a.sC, b.sC, c); \
1730 c = fma(a.sD, b.sD, c); \
1731 c = fma(a.sE, b.sE, c); \
1732 c = fma(a.sF, b.sF, c); \
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001733 })
1734#else // K0 not supported
1735#error "K0 value not supported"
1736#endif // K0 conditions
1737
1738#if N0 == 2
1739#define ARM_DOT_K0XN0(a, b, c) \
1740 ({ \
1741 ARM_DOT_K0((a), (b##0), (c.s0)); \
1742 ARM_DOT_K0((a), (b##1), (c.s1)); \
1743 })
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001744#elif N0 == 3 // N0 == 3
1745#define ARM_DOT_K0XN0(a, b, c) \
1746 ({ \
1747 ARM_DOT_K0((a), (b##0), (c.s0)); \
1748 ARM_DOT_K0((a), (b##1), (c.s1)); \
1749 ARM_DOT_K0((a), (b##2), (c.s2)); \
1750 })
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001751#elif N0 == 4 // N0 == 4
1752#define ARM_DOT_K0XN0(a, b, c) \
1753 ({ \
1754 ARM_DOT_K0((a), (b##0), (c.s0)); \
1755 ARM_DOT_K0((a), (b##1), (c.s1)); \
1756 ARM_DOT_K0((a), (b##2), (c.s2)); \
1757 ARM_DOT_K0((a), (b##3), (c.s3)); \
1758 })
1759#elif N0 == 8 // N0 == 8
1760#define ARM_DOT_K0XN0(a, b, c) \
1761 ({ \
1762 ARM_DOT_K0((a), (b##0), (c.s0)); \
1763 ARM_DOT_K0((a), (b##1), (c.s1)); \
1764 ARM_DOT_K0((a), (b##2), (c.s2)); \
1765 ARM_DOT_K0((a), (b##3), (c.s3)); \
1766 ARM_DOT_K0((a), (b##4), (c.s4)); \
1767 ARM_DOT_K0((a), (b##5), (c.s5)); \
1768 ARM_DOT_K0((a), (b##6), (c.s6)); \
1769 ARM_DOT_K0((a), (b##7), (c.s7)); \
1770 })
1771#elif N0 == 16 // N0 == 16
1772#define ARM_DOT_K0XN0(a, b, c) \
1773 ({ \
1774 ARM_DOT_K0((a), (b##0), (c.s0)); \
1775 ARM_DOT_K0((a), (b##1), (c.s1)); \
1776 ARM_DOT_K0((a), (b##2), (c.s2)); \
1777 ARM_DOT_K0((a), (b##3), (c.s3)); \
1778 ARM_DOT_K0((a), (b##4), (c.s4)); \
1779 ARM_DOT_K0((a), (b##5), (c.s5)); \
1780 ARM_DOT_K0((a), (b##6), (c.s6)); \
1781 ARM_DOT_K0((a), (b##7), (c.s7)); \
1782 ARM_DOT_K0((a), (b##8), (c.s8)); \
1783 ARM_DOT_K0((a), (b##9), (c.s9)); \
1784 ARM_DOT_K0((a), (b##A), (c.sA)); \
1785 ARM_DOT_K0((a), (b##B), (c.sB)); \
1786 ARM_DOT_K0((a), (b##C), (c.sC)); \
1787 ARM_DOT_K0((a), (b##D), (c.sD)); \
1788 ARM_DOT_K0((a), (b##E), (c.sE)); \
1789 ARM_DOT_K0((a), (b##F), (c.sF)); \
1790 })
1791#else // N0 not supported
1792#error "N0 value not supported"
1793#endif // N0 conditions
1794
1795/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
1796 * The LHS matrix must be reshaped with @ref CLGEMMReshapeLHSMatrixKernel and the M0xK0 must be NOT transposed
1797 * The RHS matrix must be reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the K0xN0 must be transposed
1798 *
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001799 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01001800 * @note The GEMM's dimensions M and N must be passed at compile time using -DM and -DN (e.g. -DM=52 and -DN=90).
1801 * @note The block's dimensions used for reshaping the LHS matrix and the RHS matrix (M0, N0 and K0) must be passed at compile time using -DM0, -DN0 and -DK0 (e.g. -DM0=4, -DN0=8, -DK0=4).
1802 * @note The number of M0xK0 vertical blocks stored on the same output row of the reshaped LHS matrix must be passed at compile time using -DV0 (e.g. -DV0=2)
1803 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (e.g. -DH0=2)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001804 * @note If the M0xK0 blocks in the reshaped LHS matrix have been interleaved, the option -DLHS_INTERLEAVE must passed at compile time.
1805 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
1806 * @note Only the following configurations of M0, N0 and K0 are currently supported:
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01001807 * - M0 = 2, 3, 4, 5, 6, 7, 8
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001808 * - N0 = 2, 3, 4, 8, 16
1809 * - K0 = 2, 3, 4, 8, 16
Gian Marco Iodice62251f72019-03-11 16:07:12 +00001810 * - V0 >= 1
1811 * - H0 >= 1
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001812 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01001813 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +01001814 * The activation function is performed after the bias addition
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01001815 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001816 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1817 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1818 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1819 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix NOT reshaped
1820 *
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001821 * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: F16/F32
1822 * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes)
1823 * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1824 * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes)
1825 * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1826 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
1827 * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
1828 * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes)
1829 * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1830 * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
1831 * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1832 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
1833 * @param[in] bias_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
1834 * @param[in] bias_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
1835 * @param[in] bias_step_x (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes)
1836 * @param[in] bias_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
1837 * @param[in] bias_step_y (Optional) bias_stride_y * number of elements along Y processed per workitem(in bytes)
1838 * @param[in] bias_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
1839 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
1840 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1841 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1842 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1843 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1844 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1845 * @param[in] k Number of columns in LHS matrix and rows in RHS matrix not reshaped.
1846 * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes)
1847 * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
1848 * @param[in] bias_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
1849 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1850 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001851 */
1852__kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs),
1853 IMAGE_DECLARATION(rhs),
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001854#if defined(BETA)
1855 IMAGE_DECLARATION(bias),
1856#endif // defined(BETA)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001857 IMAGE_DECLARATION(dst),
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001858 uint k,
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001859 uint lhs_stride_z,
1860 uint rhs_stride_z,
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001861#if defined(BETA)
1862 uint bias_stride_z,
1863#endif //defined(BETA)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001864 uint dst_stride_z
1865#if defined(REINTERPRET_OUTPUT_AS_3D)
1866 ,
1867 uint dst_cross_plane_pad
1868#endif // REINTERPRET_OUTPUT_AS_3D
1869 )
1870{
1871 // Block size
1872#define LHS_BLOCK_SIZE ((K0) * (M0))
1873
1874#if defined(LHS_INTERLEAVE)
1875#define LHS_OFFSET_X (K0)
1876#define LHS_STEP_X ((K0) * (V0))
1877#define LHS_STEP_LOOP (1)
1878#else // defined(INTERLEAVE)
1879#define LHS_OFFSET_X (LHS_BLOCK_SIZE)
1880#define LHS_STEP_X (K0)
1881#define LHS_STEP_LOOP (V0)
1882#endif // defined(INTERLEAVE)
1883
1884 // Block size
1885#define RHS_BLOCK_SIZE ((K0) * (N0))
1886
1887 // RHS offset and step X
1888#if defined(RHS_INTERLEAVE)
1889#define RHS_OFFSET_X (K0)
1890#define RHS_STEP_X ((K0) * (H0))
1891#define RHS_STEP_LOOP (1)
1892#else // defined(RHS_INTERLEAVE)
1893#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
1894#define RHS_STEP_X (K0)
1895#define RHS_STEP_LOOP (H0)
1896#endif // defined(RHS_INTERLEAVE)
1897
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001898#if defined(DUMMY_WORK_ITEMS)
1899 if((get_global_id(0) * N0 >= N) || (get_global_id(1) * M0 >= M))
1900 {
1901 return;
1902 }
1903#endif // defined(DUMMY_WORK_ITEMS)
1904
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001905 // Compute LHS matrix address
1906 __global uchar *lhs_addr = lhs_ptr + lhs_offset_first_element_in_bytes + (get_global_id(1) % V0) * (uint)LHS_OFFSET_X * sizeof(DATA_TYPE) + (get_global_id(1) / V0) * (uint)lhs_stride_y +
1907 (get_global_id(2) * lhs_stride_z);
1908
1909 // Compute RHS matrix address
1910 __global uchar *rhs_addr = rhs_ptr + rhs_offset_first_element_in_bytes + (get_global_id(0) % H0) * (uint)RHS_OFFSET_X * sizeof(DATA_TYPE) + (get_global_id(0) / (uint)H0) * rhs_stride_y;
1911
1912#if defined(MATRIX_B_DEPTH)
1913 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1914 rhs_addr += (get_global_id(2) % MATRIX_B_DEPTH) * rhs_stride_z;
1915#else // defined(MATRIX_B_DEPTH)
1916 rhs_addr += get_global_id(2) * rhs_stride_z;
1917#endif // defined(MATRIX_B_DEPTH)
1918
1919 // Initialize the accumulators
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +00001920 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(M0-1)=0;
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001921
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001922 REPEAT_VAR_INIT_TO_CONST(M0, uint, zlhs, 0); //uint zlhs0=0,zlhs1=0,zlhs2=0,... zlhs7=0;
1923 REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0);
Usama Arif0681e3b2019-04-25 14:28:07 +01001924
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001925 for(int i = 0; i < k; i += K0)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001926 {
1927 // Supported cases (M0, K0):
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001928 // 1,2 - 1,3 - 1,4 - 1,8 - 1,16
1929 // 2,2 - 2,3 - 2,4 - 2,8 - 2,16
1930 // 3,2 - 3,3 - 3,4 - 3,8 - 3,16
1931 // 4,2 - 4,3 - 4,4 - 4,8 - 4,16
1932 // 5,2 - 5,3 - 5,4 - 5,8 - 5,16
1933 // 6,2 - 6,3 - 6,4 - 6,8 - 6,16
1934 // 7,2 - 7,3 - 7,4 - 7,8 - 7,16
1935 // 8,2 - 8,3 - 8,4 - 8,8 - 8,16
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001936 // Load values from LHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +01001937 LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_addr, 0, LHS_STEP_X * sizeof(DATA_TYPE), zlhs);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001938
1939 // Load values from RHS matrix
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001940 LOAD_BLOCK(N0, K0, DATA_TYPE, b, rhs_addr, 0, RHS_STEP_X * sizeof(DATA_TYPE), zero);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001941
1942 // Accumulate
1943 ARM_DOT_K0XN0(a0, b, c0);
1944#if M0 > 1
1945 ARM_DOT_K0XN0(a1, b, c1);
1946#endif // M0 > 1
1947#if M0 > 2
1948 ARM_DOT_K0XN0(a2, b, c2);
1949#endif // M0 > 2
1950#if M0 > 3
1951 ARM_DOT_K0XN0(a3, b, c3);
1952#endif // M0 > 3
1953#if M0 > 4
1954 ARM_DOT_K0XN0(a4, b, c4);
1955#endif // M0 > 4
1956#if M0 > 5
1957 ARM_DOT_K0XN0(a5, b, c5);
1958#endif // M0 > 5
1959#if M0 > 6
1960 ARM_DOT_K0XN0(a6, b, c6);
1961#endif // M0 > 6
1962#if M0 > 7
1963 ARM_DOT_K0XN0(a7, b, c7);
1964#endif // M0 > 7
1965
1966 lhs_addr += (M0 * LHS_STEP_X * LHS_STEP_LOOP) * sizeof(DATA_TYPE);
1967 rhs_addr += (N0 * RHS_STEP_X * RHS_STEP_LOOP) * sizeof(DATA_TYPE);
1968 }
1969
1970 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * dst_stride_y);
1971
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001972 REPEAT_VAR_INIT_TO_CONST(M0, uint, zout, 0);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001973
1974#if defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001975
1976 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
Usama Arif0681e3b2019-04-25 14:28:07 +01001977 CALCULATE_Z_OFFSET(M0, uint, zout, get_global_id(1), HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001978 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1979 // multiply dst_stride_z by DEPTH_GEMM3D
1980 dst_addr += get_global_id(2) * dst_stride_z * DEPTH_GEMM3D;
1981
1982#else // defined(REINTERPRET_OUTPUT_AS_3D)
1983
1984 // Add offset for batched GEMM
1985 dst_addr += get_global_id(2) * dst_stride_z;
1986
1987#endif // defined(REINTERPRET_OUTPUT_AS_3D)
1988
1989 // Multiply by the weight of matrix-matrix product and store the result
1990#if defined(ALPHA)
Usama Arif0681e3b2019-04-25 14:28:07 +01001991 SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001992#endif // defined(ALPHA)
1993
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001994 // Add beta*bias
1995#if defined(BETA)
1996#if defined(BROADCAST_BIAS)
1997 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
1998
1999 LOAD_BLOCK(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
2000
2001#ifndef UNIT_BETA
2002 SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
2003#endif // UNIT_BIAS
2004
2005 // c = c + bias[broadcasted]
2006 ADD_BLOCK_BROADCAST(M0, c, bias0);
2007
2008#else // defined(BROADCAST_BIAS)
2009 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id(
2010 2) * bias_stride_z;
2011
2012 LOAD_BLOCK(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
2013
2014#ifndef UNIT_BETA
2015 SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
2016#endif // UNIT_BIAS
2017
2018 // c = c + bias
2019 ADD_BLOCK(M0, c, bias);
2020
2021#endif // defined(BROADCAST_BIAS)
2022#endif // defined(BETA)
2023
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +01002024#if defined(ACTIVATION_TYPE)
2025 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE, c, A_VAL, B_VAL);
2026#endif // defined(ACTIVATION_TYPE)
2027
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002028 // Store output block
Usama Arif0681e3b2019-04-25 14:28:07 +01002029 STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01002030
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002031#undef LHS_BLOCK_SIZE
2032#undef LHS_OFFSET_X
2033#undef LHS_STEP_X
2034#undef RHS_BLOCK_SIZE
2035#undef RHS_OFFSET_X
2036#undef RHS_STEP_X
2037}
giuros01b3204e72019-04-01 13:50:22 +01002038
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002039#if defined(LHS_TRANSPOSE)
2040
2041#define VTYPE(TYPE, SIZE) VEC_DATA_TYPE(TYPE, SIZE)
2042
2043#if GPU_ARCH == GPU_ARCH_MIDGARD
Gian Marco Iodice05639f62019-09-24 12:05:06 +01002044#define ARM_VFMA(a, b, c) c += (a) * (b);
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002045#else // GPU_ARCH == GPU_ARCH_MIDGARD
Gian Marco Iodice05639f62019-09-24 12:05:06 +01002046#define ARM_VFMA(a, b, c) c = fma((a), (b), (c));
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002047#endif // GPU_ARCH == GPU_ARCH_MIDGARD
2048
Gian Marco Iodice05639f62019-09-24 12:05:06 +01002049#define ARM_VVM_T_NT_1xN0x1(N0, TYPE, a, b, C) \
2050 ({ \
2051 ARM_VFMA((VTYPE(TYPE, N0))(a), b, (C##0)); \
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002052 })
Gian Marco Iodice05639f62019-09-24 12:05:06 +01002053#define ARM_VVM_T_NT_2xN0x1(N0, TYPE, a, b, C) \
2054 ({ \
2055 ARM_VFMA((VTYPE(TYPE, N0))(a.s0), b, (C##0)); \
2056 ARM_VFMA((VTYPE(TYPE, N0))(a.s1), b, (C##1)); \
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002057 })
Gian Marco Iodice05639f62019-09-24 12:05:06 +01002058#define ARM_VVM_T_NT_3xN0x1(N0, TYPE, a, b, C) \
2059 ({ \
2060 ARM_VVM_T_NT_2xN0x1(N0, TYPE, a, b, C); \
2061 ARM_VFMA((VTYPE(TYPE, N0))(a.s2), b, (C##2)); \
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002062 })
Gian Marco Iodice05639f62019-09-24 12:05:06 +01002063#define ARM_VVM_T_NT_4xN0x1(N0, TYPE, a, b, C) \
2064 ({ \
2065 ARM_VVM_T_NT_3xN0x1(N0, TYPE, a, b, C); \
2066 ARM_VFMA((VTYPE(TYPE, N0))(a.s3), b, (C##3)); \
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002067 })
Gian Marco Iodice05639f62019-09-24 12:05:06 +01002068#define ARM_VVM_T_NT_8xN0x1(N0, TYPE, a, b, C) \
2069 ({ \
2070 ARM_VVM_T_NT_4xN0x1(N0, TYPE, a, b, C); \
2071 ARM_VFMA((VTYPE(TYPE, N0))(a.s4), b, (C##4)); \
2072 ARM_VFMA((VTYPE(TYPE, N0))(a.s5), b, (C##5)); \
2073 ARM_VFMA((VTYPE(TYPE, N0))(a.s6), b, (C##6)); \
2074 ARM_VFMA((VTYPE(TYPE, N0))(a.s7), b, (C##7)); \
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002075 })
2076
2077// Factory macro for the column-vector (transposed) by row-vector (not transposed) multiplication. K0 = 1
2078// a is the column-vector (transposed)
2079// b is the row-vector (not transposed)
2080// C is the output matrix
2081// Lower case is a vector (a, b)
2082// Upper case is a matrix (C)
2083#define ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, a, b, C) ARM_VVM_T_NT_##M0##xN0x1(N0, TYPE, a, b, C)
2084
2085#define ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, A, B, C) \
2086 ({ \
2087 ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##0), (B##0), C); \
2088 })
2089#define ARM_MM_T_NT_M0xN0x2(M0, N0, TYPE, A, B, C) \
2090 ({ \
2091 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, A, B, C); \
2092 ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##1), (B##1), C); \
2093 })
2094#define ARM_MM_T_NT_M0xN0x3(M0, N0, TYPE, A, B, C) \
2095 ({ \
2096 ARM_MM_T_NT_M0xN0x2(M0, N0, TYPE, A, B, C); \
2097 ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##2), (B##2), C); \
2098 })
2099#define ARM_MM_T_NT_M0xN0x4(M0, N0, TYPE, A, B, C) \
2100 ({ \
2101 ARM_MM_T_NT_M0xN0x3(M0, N0, TYPE, A, B, C); \
2102 ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##3), (B##3), C); \
2103 })
2104#define ARM_MM_T_NT_M0xN0x8(M0, N0, TYPE, A, B, C) \
2105 ({ \
2106 ARM_MM_T_NT_M0xN0x4(M0, N0, TYPE, A, B, C); \
2107 ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##4), (B##4), C); \
2108 ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##5), (B##5), C); \
2109 ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##6), (B##6), C); \
2110 ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##7), (B##7), C); \
2111 })
2112#define ARM_MM_T_NT_M0xN0x16(M0, N0, TYPE, A, B, C) \
2113 ({ \
2114 ARM_MM_T_NT_M0xN0x8(M0, N0, TYPE, A, B, C); \
2115 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##8), (B##8), C); \
2116 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##9), (B##9), C); \
2117 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##A), (B##A), C); \
2118 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##B), (B##B), C); \
2119 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##C), (B##C), C); \
2120 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##D), (B##D), C); \
2121 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##E), (B##E), C); \
2122 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##F), (B##F), C); \
2123 })
2124
2125// Factory macro for the matrix (transposed) by matrix (not transposed) multiplication.
2126// The dimensions for this matrix multiplications are defined through M0, N0 and K0
2127// The dimensions supported are:
2128// M0: 1, 2, 3, 4, 8
2129// N0: 1, 2, 3, 4, 8, 16
2130// K0: 1, 2, 3, 4, 8, 16
2131// This macro calls the vector-by-matrix macro K0 times
2132// A, B and C are matrices
Gian Marco Iodice05639f62019-09-24 12:05:06 +01002133#define ARM_MM_T_NT(M0, N0, K0, TYPE, A, B, C) \
2134 CONCAT(ARM_MM_T_NT_M0xN0x, K0) \
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002135 (M0, N0, TYPE, A, B, C)
2136
2137/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
2138 * The LHS matrix must be reshaped with @ref CLGEMMReshapeLHSMatrixKernel and the M0xK0 must be transposed
2139 * The RHS matrix must be reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the K0xN0 must be NOT transposed
2140 *
2141 * @note LHS_TRANSPOSE should be passed at compile time in order to compile this OpenCL kernel (e.g. -DLHS_TRANSPOSE).
2142 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
2143 * @note The GEMM's dimensions M and N must be passed at compile time using -DM and -DN (e.g. -DM=52 and -DN=90).
2144 * @note The block's dimensions used for reshaping the LHS matrix and the RHS matrix (M0, N0 and K0) must be passed at compile time using -DM0, -DN0 and -DK0 (e.g. -DM0=4, -DN0=8, -DK0=4).
2145 * @note The number of M0xK0 vertical blocks stored on the same output row of the reshaped LHS matrix must be passed at compile time using -DV0 (e.g. -DV0=2)
2146 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (e.g. -DH0=2)
2147 * @note If the M0xK0 blocks in the reshaped LHS matrix have been interleaved, the option -DLHS_INTERLEAVE must passed at compile time.
2148 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
2149 * @note Only the following configurations of M0, N0 and K0 are currently supported:
2150 * - M0 = 2, 3, 4, 8
2151 * - N0 = 2, 3, 4, 8, 16
2152 * - K0 = 2, 3, 4, 8, 16
2153 * - V0 >= 1
2154 * - H0 >= 1
2155 *
2156 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
2157 * The activation function is performed after the bias addition
2158 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
2159 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2160 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2161 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2162 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix NOT reshaped
2163 *
2164 * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: F16/F32
2165 * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes)
2166 * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2167 * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes)
2168 * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2169 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
2170 * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
2171 * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes)
2172 * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2173 * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
2174 * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2175 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
2176 * @param[in] bias_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
2177 * @param[in] bias_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
2178 * @param[in] bias_step_x (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes)
2179 * @param[in] bias_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
2180 * @param[in] bias_step_y (Optional) bias_stride_y * number of elements along Y processed per workitem(in bytes)
2181 * @param[in] bias_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
2182 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
2183 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2184 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2185 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2186 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2187 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2188 * @param[in] k Number of columns in LHS matrix and rows in RHS matrix not reshaped.
2189 * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes)
2190 * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
2191 * @param[in] bias_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
2192 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2193 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
2194 */
2195__kernel void gemm_mm_reshaped_lhs_t_rhs_nt(IMAGE_DECLARATION(lhs),
2196 IMAGE_DECLARATION(rhs),
2197#if defined(BETA)
2198 IMAGE_DECLARATION(bias),
2199#endif // defined(BETA)
2200 IMAGE_DECLARATION(dst),
2201 uint k,
2202 uint lhs_stride_z,
2203 uint rhs_stride_z,
2204#if defined(BETA)
2205 uint bias_stride_z,
2206#endif //defined(BETA)
2207 uint dst_stride_z
2208#if defined(REINTERPRET_OUTPUT_AS_3D)
2209 ,
2210 uint dst_cross_plane_pad
2211#endif // REINTERPRET_OUTPUT_AS_3D
2212 )
2213{
2214 // Block size
2215#define LHS_BLOCK_SIZE ((K0) * (M0))
2216
2217#if defined(LHS_INTERLEAVE)
2218#define LHS_OFFSET_X (M0)
2219#define LHS_STEP_X ((M0) * (V0))
2220#define LHS_STEP_LOOP (1)
2221#else // defined(INTERLEAVE)
2222#define LHS_OFFSET_X (LHS_BLOCK_SIZE)
2223#define LHS_STEP_X (M0)
2224#define LHS_STEP_LOOP (V0)
2225#endif // defined(INTERLEAVE)
2226
2227 // Block size
2228#define RHS_BLOCK_SIZE ((K0) * (N0))
2229
2230 // RHS offset and step X
2231#if defined(RHS_INTERLEAVE)
2232#define RHS_OFFSET_X (N0)
2233#define RHS_STEP_X ((N0) * (H0))
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002234#else // defined(RHS_INTERLEAVE)
2235#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
2236#define RHS_STEP_X (N0)
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002237#endif // defined(RHS_INTERLEAVE)
2238
2239 const uint x = get_global_id(0);
2240 const uint y = get_global_id(1);
2241 const uint z = get_global_id(2);
2242
2243#if defined(DUMMY_WORK_ITEMS)
2244 if((x * N0 >= N) || (y * M0 >= M))
2245 {
2246 return;
2247 }
2248#endif // defined(DUMMY_WORK_ITEMS)
2249
2250 // Compute LHS matrix address
2251 __global uchar *lhs_addr = lhs_ptr + lhs_offset_first_element_in_bytes + (y % V0) * (uint)LHS_OFFSET_X * sizeof(DATA_TYPE) + (y / V0) * (uint)lhs_stride_y + (z * lhs_stride_z);
2252
2253 // Compute RHS matrix address
2254 __global uchar *rhs_addr = rhs_ptr + rhs_offset_first_element_in_bytes + (x % H0) * (uint)RHS_OFFSET_X * sizeof(DATA_TYPE) + (x / (uint)H0) * rhs_stride_y;
2255
2256#if defined(MATRIX_B_DEPTH)
2257 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2258 rhs_addr += (z % MATRIX_B_DEPTH) * rhs_stride_z;
2259#else // defined(MATRIX_B_DEPTH)
2260 rhs_addr += z * rhs_stride_z;
2261#endif // defined(MATRIX_B_DEPTH)
2262
2263 // Initialize the accumulators
2264 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(M0-1)=0;
2265
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002266 REPEAT_VAR_INIT_TO_CONST(M0, uint, zero, 0);
2267
Gian Marco Iodice05639f62019-09-24 12:05:06 +01002268 __global DATA_TYPE *lhs = (__global DATA_TYPE *)(lhs_addr);
2269 __global DATA_TYPE *rhs = (__global DATA_TYPE *)(rhs_addr);
2270
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002271 for(int i = 0; i < k; i += K0)
2272 {
Gian Marco Iodice05639f62019-09-24 12:05:06 +01002273 VEC_DATA_TYPE(DATA_TYPE, M0)
2274 a0 = VLOAD(M0)(0, lhs);
2275 VEC_DATA_TYPE(DATA_TYPE, N0)
2276 b0 = VLOAD(N0)(0, rhs);
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002277
Gian Marco Iodice05639f62019-09-24 12:05:06 +01002278 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002279
Gian Marco Iodice05639f62019-09-24 12:05:06 +01002280 lhs += LHS_STEP_X;
2281 rhs += RHS_STEP_X;
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002282
Gian Marco Iodice05639f62019-09-24 12:05:06 +01002283#if K0 > 1
2284 a0 = VLOAD(M0)(0, lhs);
2285 b0 = VLOAD(N0)(0, rhs);
2286
2287 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2288
2289 lhs += LHS_STEP_X;
2290 rhs += RHS_STEP_X;
2291#endif // K0 > 1
2292
2293#if K0 > 2
2294 a0 = VLOAD(M0)(0, lhs);
2295 b0 = VLOAD(N0)(0, rhs);
2296
2297 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2298
2299 lhs += LHS_STEP_X;
2300 rhs += RHS_STEP_X;
2301#endif // K0 > 2
2302
2303#if K0 > 3
2304 a0 = VLOAD(M0)(0, lhs);
2305 b0 = VLOAD(N0)(0, rhs);
2306
2307 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2308
2309 lhs += LHS_STEP_X;
2310 rhs += RHS_STEP_X;
2311#endif // K0 > 3
2312
2313#if K0 > 4
2314 a0 = VLOAD(M0)(0, lhs);
2315 b0 = VLOAD(N0)(0, rhs);
2316
2317 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2318
2319 lhs += LHS_STEP_X;
2320 rhs += RHS_STEP_X;
2321
2322 a0 = VLOAD(M0)(0, lhs);
2323 b0 = VLOAD(N0)(0, rhs);
2324
2325 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2326
2327 lhs += LHS_STEP_X;
2328 rhs += RHS_STEP_X;
2329
2330 a0 = VLOAD(M0)(0, lhs);
2331 b0 = VLOAD(N0)(0, rhs);
2332
2333 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2334
2335 lhs += LHS_STEP_X;
2336 rhs += RHS_STEP_X;
2337
2338 a0 = VLOAD(M0)(0, lhs);
2339 b0 = VLOAD(N0)(0, rhs);
2340
2341 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2342
2343 lhs += LHS_STEP_X;
2344 rhs += RHS_STEP_X;
2345#endif // K0 > 4
2346
2347#if K0 > 8
2348 a0 = VLOAD(M0)(0, lhs);
2349 b0 = VLOAD(N0)(0, rhs);
2350
2351 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2352
2353 lhs += LHS_STEP_X;
2354 rhs += RHS_STEP_X;
2355
2356 a0 = VLOAD(M0)(0, lhs);
2357 b0 = VLOAD(N0)(0, rhs);
2358
2359 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2360
2361 lhs += LHS_STEP_X;
2362 rhs += RHS_STEP_X;
2363
2364 a0 = VLOAD(M0)(0, lhs);
2365 b0 = VLOAD(N0)(0, rhs);
2366
2367 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2368
2369 lhs += LHS_STEP_X;
2370 rhs += RHS_STEP_X;
2371
2372 a0 = VLOAD(M0)(0, lhs);
2373 b0 = VLOAD(N0)(0, rhs);
2374
2375 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2376
2377 lhs += LHS_STEP_X;
2378 rhs += RHS_STEP_X;
2379
2380 a0 = VLOAD(M0)(0, lhs);
2381 b0 = VLOAD(N0)(0, rhs);
2382
2383 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2384
2385 lhs += LHS_STEP_X;
2386 rhs += RHS_STEP_X;
2387
2388 a0 = VLOAD(M0)(0, lhs);
2389 b0 = VLOAD(N0)(0, rhs);
2390
2391 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2392
2393 lhs += LHS_STEP_X;
2394 rhs += RHS_STEP_X;
2395
2396 a0 = VLOAD(M0)(0, lhs);
2397 b0 = VLOAD(N0)(0, rhs);
2398
2399 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2400
2401 lhs += LHS_STEP_X;
2402 rhs += RHS_STEP_X;
2403
2404 a0 = VLOAD(M0)(0, lhs);
2405 b0 = VLOAD(N0)(0, rhs);
2406
2407 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2408
2409 lhs += LHS_STEP_X;
2410 rhs += RHS_STEP_X;
2411#endif // K0 > 8
2412
2413#ifndef LHS_INTERLEAVE
2414 lhs += (M0 * K0 * (V0 - 1));
2415#endif // LHS_INTERLEAVE
2416
2417#ifndef RHS_INTERLEAVE
2418 rhs += (N0 * K0 * (H0 - 1));
2419#endif // RHS_INTERLEAVE
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002420 }
2421
2422 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * dst_stride_y);
2423
2424 REPEAT_VAR_INIT_TO_CONST(M0, uint, zout, 0);
2425
2426#if defined(REINTERPRET_OUTPUT_AS_3D)
2427
2428 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
2429 CALCULATE_Z_OFFSET(M0, uint, zout, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
2430 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2431 // multiply dst_stride_z by DEPTH_GEMM3D
2432 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
2433
2434#else // defined(REINTERPRET_OUTPUT_AS_3D)
2435
2436 // Add offset for batched GEMM
2437 dst_addr += z * dst_stride_z;
2438
2439#endif // defined(REINTERPRET_OUTPUT_AS_3D)
2440
2441 // Multiply by the weight of matrix-matrix product and store the result
2442#if defined(ALPHA)
2443 SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
2444#endif // defined(ALPHA)
2445
2446 // Add beta*bias
2447#if defined(BETA)
2448#if defined(BROADCAST_BIAS)
2449 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE));
2450
2451 LOAD_BLOCK(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
2452
2453#ifndef UNIT_BETA
2454 SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
2455#endif // UNIT_BIAS
2456
2457 // c = c + bias[broadcasted]
2458 ADD_BLOCK_BROADCAST(M0, c, bias0);
2459
2460#else // defined(BROADCAST_BIAS)
2461 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * bias_stride_y) + z * bias_stride_z;
2462
2463 LOAD_BLOCK(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
2464
2465#ifndef UNIT_BETA
2466 SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
2467#endif // UNIT_BIAS
2468
2469 // c = c + bias
2470 ADD_BLOCK(M0, c, bias);
2471
2472#endif // defined(BROADCAST_BIAS)
2473#endif // defined(BETA)
2474
2475#if defined(ACTIVATION_TYPE)
2476 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE, c, A_VAL, B_VAL);
2477#endif // defined(ACTIVATION_TYPE)
2478
2479 // Store output block
2480 STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
2481
2482#undef LHS_BLOCK_SIZE
2483#undef LHS_OFFSET_X
2484#undef LHS_STEP_X
2485#undef RHS_BLOCK_SIZE
2486#undef RHS_OFFSET_X
2487#undef RHS_STEP_X
2488}
2489
2490#endif // defined(LHS_TRANSPOSE)
2491
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002492#endif // defined(M0) && defined(N0) && defined(K0) && defined(V0) && defined(H0) && defined(K) && defined(DATA_TYPE)
2493
giuros01b3204e72019-04-01 13:50:22 +01002494#if defined(M0) && defined(N0) && defined(K0) && defined(K) && defined(DATA_TYPE)
2495
2496#define VFMA(a, b, c) \
2497 ({ \
2498 c = fma(a, b, c); \
2499 })
2500
2501#if M0 == 1
2502#define RHS_VFMA_M0xN0(i, a, b, c) \
2503 ({ \
2504 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
2505 })
2506#elif M0 == 2 // M0 == 2
2507#define RHS_VFMA_M0xN0(i, a, b, c) \
2508 ({ \
2509 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
2510 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
2511 })
2512#elif M0 == 3 // M0 == 3
2513#define RHS_VFMA_M0xN0(i, a, b, c) \
2514 ({ \
2515 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
2516 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
2517 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
2518 })
2519#elif M0 == 4 // M0 == 4
2520#define RHS_VFMA_M0xN0(i, a, b, c) \
2521 ({ \
2522 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
2523 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
2524 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
2525 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
2526 })
2527#elif M0 == 5 // M0 == 5
2528#define RHS_VFMA_M0xN0(i, a, b, c) \
2529 ({ \
2530 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
2531 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
2532 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
2533 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
2534 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
2535 })
2536#elif M0 == 6 // M0 == 6
2537#define RHS_VFMA_M0xN0(i, a, b, c) \
2538 ({ \
2539 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
2540 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
2541 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
2542 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
2543 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
2544 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
2545 })
2546#elif M0 == 7 // M0 == 7
2547#define RHS_VFMA_M0xN0(i, a, b, c) \
2548 ({ \
2549 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
2550 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
2551 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
2552 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
2553 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
2554 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
2555 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##6).s##i), b, (c##6)); \
2556 })
2557#elif M0 == 8 // M0 == 8
2558#define RHS_VFMA_M0xN0(i, a, b, c) \
2559 ({ \
2560 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
2561 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
2562 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
2563 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
2564 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
2565 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
2566 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##6).s##i), b, (c##6)); \
2567 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##7).s##i), b, (c##7)); \
2568 })
2569#else // M0 not supported
2570#error "M0 not supported"
2571#endif // M0 not supported
2572
2573/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
2574 * The LHS matrix is NOT reshaped
2575 * The RHS matrix is NOT reshaped
2576 *
2577 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002578 * @note The GEMM's dimensions (M,N and K) must be passed at compile time using -DM, -DN and and -DK (e.g. -DM=52, -DN=30 and -DK=90)
2579 * @note The number of columns of LHS matrix must be passed at compile time using -DK (e.g. -DK=64)
2580 * @note The number of M0 rows to process must be passed at compile time using -DM0 (e.g. -DM0=2)
2581 * @note The number of K0 partial accumulations must be passed at compile time using -DK0 (e.g., -DK0=2)
2582 * @note The number of N0 columns to process must be passed at compile time using -DN0 (e.g. -DN0=2)
giuros01b3204e72019-04-01 13:50:22 +01002583 * @note Only the following configurations of M0, N0 and K0 are currently supported:
2584 * - M0 = 1, 2, 3, 4, 5, 6, 7, 8
2585 * - N0 = 2, 3, 4, 8, 16
2586 * - K0 = 2, 3, 4, 8, 16
2587 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002588 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +01002589 * The activation function is performed after the bias addition
giuros01b3204e72019-04-01 13:50:22 +01002590 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
2591 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
2592 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2593 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2594 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2595 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix
2596 *
Gian Marco Iodice944170e2019-06-24 14:40:30 +01002597 * @param[in] lhs_ptr Pointer to the LHS matrix. Supported data type: F16/F32
2598 * @param[in] lhs_stride_x Stride of the LHS matrix in X dimension (in bytes)
2599 * @param[in] lhs_step_x lhs_stride_x * number of elements along X processed per workitem(in bytes)
2600 * @param[in] lhs_stride_y Stride of the LHS matrix in Y dimension (in bytes)
2601 * @param[in] lhs_step_y lhs_stride_y * number of elements along Y processed per workitem(in bytes)
2602 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS matrix
2603 * @param[in] rhs_ptr Pointer to the RHS matrix. Supported data type: same as @p lhs_ptr
2604 * @param[in] rhs_stride_x Stride of the RHS matrix in X dimension (in bytes)
2605 * @param[in] rhs_step_x rhs_stride_x * number of elements along X processed per workitem(in bytes)
2606 * @param[in] rhs_stride_y Stride of the RHS matrix in Y dimension (in bytes)
2607 * @param[in] rhs_step_y rhs_stride_y * number of elements along Y processed per workitem(in bytes)
2608 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS matrix
Gian Marco Iodice944170e2019-06-24 14:40:30 +01002609 * @param[in] bias_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
2610 * @param[in] bias_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
2611 * @param[in] bias_step_x (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes)
2612 * @param[in] bias_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
2613 * @param[in] bias_step_y (Optional) bias_stride_y * number of elements along Y processed per workitem(in bytes)
2614 * @param[in] bias_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
2615 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
2616 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2617 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2618 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2619 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2620 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2621 * @param[in] lhs_stride_z Stride of the LHS matrix in Z dimension (in bytes)
2622 * @param[in] rhs_stride_z Stride of the RHS matrix in Z dimension (in bytes)
2623 * @param[in] bias_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
2624 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2625 * @param[in] lhs_cross_plane_pad (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
2626 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
giuros01b3204e72019-04-01 13:50:22 +01002627 */
2628__kernel void gemm_mm_native(IMAGE_DECLARATION(lhs),
2629 IMAGE_DECLARATION(rhs),
Gian Marco Iodice944170e2019-06-24 14:40:30 +01002630#if defined(BETA)
2631 IMAGE_DECLARATION(bias),
2632#endif // defined(BETA)
giuros01b3204e72019-04-01 13:50:22 +01002633 IMAGE_DECLARATION(dst),
2634 uint lhs_stride_z,
2635 uint rhs_stride_z,
Gian Marco Iodice944170e2019-06-24 14:40:30 +01002636#if defined(BETA)
2637 uint bias_stride_z,
2638#endif //defined(BETA)
giuros01b3204e72019-04-01 13:50:22 +01002639 uint dst_stride_z
2640#if defined(REINTERPRET_INPUT_AS_3D)
2641 ,
2642 uint lhs_cross_plane_pad
2643#endif // REINTERPRET_INPUT_AS_3D
2644#if defined(REINTERPRET_OUTPUT_AS_3D)
2645 ,
2646 uint dst_cross_plane_pad
2647#endif // REINTERPRET_OUTPUT_AS_3D
2648 )
2649{
2650 // Block size
2651#define RHS_BLOCK_SIZE ((K0) * (N0))
2652
2653 // RHS offset and step X
2654#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
2655
2656 uint x = get_global_id(0);
2657 uint y = get_global_id(1);
2658 uint z = get_global_id(2);
2659
2660#if defined(DUMMY_WORK_ITEMS)
2661 if((x * N0 >= N) || (y * M0 >= M))
2662 {
2663 return;
2664 }
2665#endif // defined(DUMMY_WORK_ITEMS)
2666
2667 // Compute LHS matrix address
2668 uint lhs_offset = lhs_offset_first_element_in_bytes + y * M0 * (uint)lhs_stride_y;
2669
2670 // Compute RHS matrix address
2671 uint rhs_offset = rhs_offset_first_element_in_bytes + x * N0 * sizeof(DATA_TYPE);
2672
2673#if defined(MATRIX_B_DEPTH)
2674 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2675 rhs_offset += (z % MATRIX_B_DEPTH) * rhs_stride_z;
2676#else // defined(MATRIX_B_DEPTH)
2677 rhs_offset += z * rhs_stride_z;
2678#endif // defined(MATRIX_B_DEPTH)
2679
Gian Marco Iodice944170e2019-06-24 14:40:30 +01002680 REPEAT_VAR_INIT_TO_CONST(M0, uint, zlhs, 0);
2681 REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0);
giuros01b3204e72019-04-01 13:50:22 +01002682
2683#if defined(REINTERPRET_INPUT_AS_3D)
2684 // The plane (zlhs) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
2685 CALCULATE_Z_OFFSET(M0, uint, zlhs, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
2686
2687 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2688 // multiply lhs_stride_z by DEPTH_GEMM3D
2689 lhs_offset += z * lhs_stride_z * DEPTH_GEMM3D;
2690
2691#else // defined(REINTERPRET_INPUT_AS_3D)
2692
2693 // Add offset for batched GEMM
2694 lhs_offset += z * lhs_stride_z;
2695
2696#endif // defined(REINTERPRET_INPUT_AS_3D)
2697
2698 // Initialize the accumulators
Gian Marco Iodice944170e2019-06-24 14:40:30 +01002699 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(M0-1)=0;
giuros01b3204e72019-04-01 13:50:22 +01002700
2701 int i = 0;
2702 for(; i <= (K - K0); i += K0)
2703 {
2704 // Supported cases (M0, K0):
2705 // 1,2 - 1,3 - 1,4 - 1,8 - 1,16
2706 // 2,2 - 2,3 - 2,4 - 2,8 - 2,16
2707 // 3,2 - 3,3 - 3,4 - 3,8 - 3,16
2708 // 4,2 - 4,3 - 4,4 - 4,8 - 4,16
2709 // 5,2 - 5,3 - 5,4 - 5,8 - 5,16
2710 // 6,2 - 6,3 - 6,4 - 6,8 - 6,16
2711 // 7,2 - 7,3 - 7,4 - 7,8 - 7,16
2712 // 8,2 - 8,3 - 8,4 - 8,8 - 8,16
2713 // Load values from LHS matrix
2714 LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs);
2715
2716 // Load values from RHS matrix
Gian Marco Iodice944170e2019-06-24 14:40:30 +01002717 LOAD_BLOCK(K0, N0, DATA_TYPE, b, rhs_ptr, rhs_offset, rhs_stride_y, zero);
giuros01b3204e72019-04-01 13:50:22 +01002718
2719 RHS_VFMA_M0xN0(0, a, b0, c);
2720 RHS_VFMA_M0xN0(1, a, b1, c);
2721#if K0 > 2
2722 RHS_VFMA_M0xN0(2, a, b2, c);
2723#endif // K0 > 2
2724#if K0 > 3
2725 RHS_VFMA_M0xN0(3, a, b3, c);
2726#endif // K0 > 3
2727#if K0 > 4
2728 RHS_VFMA_M0xN0(4, a, b4, c);
2729 RHS_VFMA_M0xN0(5, a, b5, c);
2730 RHS_VFMA_M0xN0(6, a, b6, c);
2731 RHS_VFMA_M0xN0(7, a, b7, c);
2732#endif // K0 > 4
2733#if K0 > 8
2734 RHS_VFMA_M0xN0(8, a, b8, c);
2735 RHS_VFMA_M0xN0(9, a, b9, c);
Gian Marco Iodice7b9d7ca2019-09-19 16:37:39 +01002736 RHS_VFMA_M0xN0(A, a, bA, c);
2737 RHS_VFMA_M0xN0(B, a, bB, c);
2738 RHS_VFMA_M0xN0(C, a, bC, c);
2739 RHS_VFMA_M0xN0(D, a, bD, c);
2740 RHS_VFMA_M0xN0(E, a, bE, c);
2741 RHS_VFMA_M0xN0(F, a, bF, c);
giuros01b3204e72019-04-01 13:50:22 +01002742#endif // K0 > 8
2743
2744 lhs_offset += K0 * sizeof(DATA_TYPE);
2745 rhs_offset += K0 * rhs_stride_y;
2746 }
2747
2748 // Left-over accumulations
2749 for(; i < K; ++i)
2750 {
2751 // Load values from LHS matrix
2752 VEC_DATA_TYPE(DATA_TYPE, 2)
2753 a0 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 0 * lhs_stride_y + zlhs0));
2754#if M0 > 1
2755 VEC_DATA_TYPE(DATA_TYPE, 2)
2756 a1 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 1 * lhs_stride_y + zlhs1));
2757#endif // M0 > 1
2758#if M0 > 2
2759 VEC_DATA_TYPE(DATA_TYPE, 2)
2760 a2 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 2 * lhs_stride_y + zlhs2));
2761#endif // M0 > 2
2762#if M0 > 3
2763 VEC_DATA_TYPE(DATA_TYPE, 2)
2764 a3 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 3 * lhs_stride_y + zlhs3));
2765#endif // M0 > 3
2766#if M0 > 4
2767 VEC_DATA_TYPE(DATA_TYPE, 2)
2768 a4 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 4 * lhs_stride_y + zlhs4));
2769#endif // M0 > 4
2770#if M0 > 5
2771 VEC_DATA_TYPE(DATA_TYPE, 2)
2772 a5 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 5 * lhs_stride_y + zlhs5));
2773#endif // M0 > 5
2774#if M0 > 6
2775 VEC_DATA_TYPE(DATA_TYPE, 2)
2776 a6 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 6 * lhs_stride_y + zlhs6));
2777#endif // M0 > 6
2778#if M0 > 7
2779 VEC_DATA_TYPE(DATA_TYPE, 2)
2780 a7 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 7 * lhs_stride_y + zlhs7));
2781#endif // M0 > 7
2782
2783 VEC_DATA_TYPE(DATA_TYPE, N0)
2784 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0 * rhs_stride_y));
2785 RHS_VFMA_M0xN0(0, a, b, c);
2786
2787 lhs_offset += sizeof(DATA_TYPE);
2788 rhs_offset += rhs_stride_y;
2789 }
2790
2791 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * dst_stride_y);
2792
Gian Marco Iodice944170e2019-06-24 14:40:30 +01002793 REPEAT_VAR_INIT_TO_CONST(M0, uint, zout, 0);
giuros01b3204e72019-04-01 13:50:22 +01002794
2795#if defined(REINTERPRET_OUTPUT_AS_3D)
2796 // The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
2797 CALCULATE_Z_OFFSET(M0, uint, zout, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
2798
2799 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2800 // multiply dst_stride_z by DEPTH_GEMM3D
2801 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
2802
2803#else // defined(REINTERPRET_OUTPUT_AS_3D)
2804
2805 // Add offset for batched GEMM
2806 dst_addr += z * dst_stride_z;
2807
2808#endif // defined(REINTERPRET_OUTPUT_AS_3D)
2809
2810 // Multiply by the weight of matrix-matrix product and store the result
giuros01b3204e72019-04-01 13:50:22 +01002811#if defined(ALPHA)
2812 SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
2813#endif // defined(ALPHA)
2814
Gian Marco Iodice944170e2019-06-24 14:40:30 +01002815 // Add beta*bias
2816#if defined(BETA)
2817#if defined(BROADCAST_BIAS)
2818 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
2819
2820 LOAD_BLOCK(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
2821
2822#ifndef UNIT_BETA
2823 SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
2824#endif // UNIT_BIAS
2825
2826 // c = c + bias[broadcasted]
2827 ADD_BLOCK_BROADCAST(M0, c, bias0);
2828
2829#else // defined(BROADCAST_BIAS)
2830 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id(
2831 2) * bias_stride_z;
2832
2833 LOAD_BLOCK(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
2834
2835#ifndef UNIT_BETA
2836 SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
2837#endif // UNIT_BIAS
2838
2839 // c = c + bias
2840 ADD_BLOCK(M0, c, bias);
2841
2842#endif // defined(BROADCAST_BIAS)
2843#endif // defined(BETA)
2844
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +01002845#if defined(ACTIVATION_TYPE)
2846 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE, c, A_VAL, B_VAL);
2847#endif // defined(ACTIVATION_TYPE)
2848
giuros01b3204e72019-04-01 13:50:22 +01002849 // Store output block
2850 STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
2851
2852#undef RHS_BLOCK_SIZE
2853#undef RHS_OFFSET_X
2854#undef RHS_STEP_X
2855}
2856#endif // defined(M0) && defined(N0) && defined(K0) && defined(K) && defined(DATA_TYPE)
2857
Gian Marco36a0a462018-01-12 10:21:40 +00002858#if defined(COLS_B) && defined(MULT_TRANSPOSE1XW_WIDTH) && defined(MULT_INTERLEAVE4X4_HEIGHT)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002859/** This OpenCL kernel is optimised for Midgard. It computes the matrix multiplication between matrix A reshaped (src0) and matrix B reshaped (src1)
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002860 *
Gian Marco19835e52018-01-30 13:35:54 +00002861 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002862 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (e.g. -DMULT_TRANSPOSE1XW_WIDTH=2)
2863 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (e.g. -DMULT_INTERLEAVE4X4_HEIGHT=2)
2864 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
2865 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002866 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002867 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
2868 * The activation function is performed after the bias addition
2869 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002870 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2871 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2872 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2873 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
2874 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002875 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
2876 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
2877 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2878 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
2879 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2880 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002881 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002882 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
2883 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2884 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
2885 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2886 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002887 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
2888 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
2889 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
2890 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
2891 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
2892 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002893 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002894 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00002895 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002896 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00002897 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002898 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002899 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
2900 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002901 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002902 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002903 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002904 */
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01002905__kernel void gemm_mm_interleaved_transposed_f32(IMAGE_DECLARATION(src0),
2906 IMAGE_DECLARATION(src1),
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002907#if defined(BETA)
2908 IMAGE_DECLARATION(src2),
2909#endif // defined(BETA)
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01002910 IMAGE_DECLARATION(dst),
2911 uint src0_stride_z,
2912 uint src1_stride_z,
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002913#if defined(BETA)
2914 uint src2_stride_z,
2915#endif //defined(BETA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002916 uint dst_stride_z
2917#if defined(REINTERPRET_OUTPUT_AS_3D)
2918 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002919 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002920#endif // REINTERPRET_OUTPUT_AS_3D
2921 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002922{
Gian Marco36a0a462018-01-12 10:21:40 +00002923 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
2924 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +00002925 int z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002926
Gian Marco36a0a462018-01-12 10:21:40 +00002927 // Offset
2928 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
2929 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 4;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002930
Gian Marco36a0a462018-01-12 10:21:40 +00002931 // src_addr_a = address of matrix A
2932 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002933 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
2934 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
2935
2936#if defined(MATRIX_B_DEPTH)
2937 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2938 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
2939#else // defined(MATRIX_B_DEPTH)
2940 src1_addr_in_bytes += z * src1_stride_z;
2941#endif // defined(MATRIX_B_DEPTH)
2942
2943 __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes);
2944 __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002945
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002946 // Compute end row address for matrix B
Gian Marco36a0a462018-01-12 10:21:40 +00002947 __global float *src_end_addr_b = src_addr_b + COLS_B;
2948
2949 src_addr_a += offset_row_a;
2950 src_addr_b += offset_row_b;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002951
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002952 // Reset accumulators
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002953 float4 c0 = 0.0f;
2954 float4 c1 = 0.0f;
2955 float4 c2 = 0.0f;
2956 float4 c3 = 0.0f;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002957
Gian Marco36a0a462018-01-12 10:21:40 +00002958 for(; src_addr_b <= (src_end_addr_b - (int)(8 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002959 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002960 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00002961 float4 a0 = vload4(0, src_addr_a);
2962 float4 b0 = vload4(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002963
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002964 c0 += (float4)a0.s0 * b0;
2965 c1 += (float4)a0.s1 * b0;
2966 c2 += (float4)a0.s2 * b0;
2967 c3 += (float4)a0.s3 * b0;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002968
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002969 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00002970 a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
2971 b0 = vload4(0, src_addr_b + 4 * MULT_TRANSPOSE1XW_WIDTH);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002972
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002973 c0 += (float4)a0.s0 * b0;
2974 c1 += (float4)a0.s1 * b0;
2975 c2 += (float4)a0.s2 * b0;
2976 c3 += (float4)a0.s3 * b0;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002977 }
2978
Gian Marco36a0a462018-01-12 10:21:40 +00002979 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002980 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002981 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00002982 float4 a0 = vload4(0, src_addr_a);
2983 float4 b0 = vload4(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002984
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002985 c0 += (float4)a0.s0 * b0;
2986 c1 += (float4)a0.s1 * b0;
2987 c2 += (float4)a0.s2 * b0;
2988 c3 += (float4)a0.s3 * b0;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002989 }
2990
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002991 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002992 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2993
Gian Marcoae2af742018-02-15 12:35:44 +00002994 // Compute dst address
2995 __global uchar *dst_addr = offset(&dst, 0, 0);
2996
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002997 uint4 zout = 0;
2998
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002999#if defined(REINTERPRET_OUTPUT_AS_3D)
3000 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003001 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003002 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003003 // | |
3004 // | plane0 |
3005 // | |
3006 // |__________________|
3007 // |******************|
3008 // | cross_plane_pad |
3009 // |******************|
3010 // | |
3011 // | plane1 |
3012 // | |
3013 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003014
3015 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003016 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
3017 zout = min(DEPTH_GEMM3D - 1, zout);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003018
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003019 // Add offset due to the cross plane paddings
3020 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003021
3022 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3023 // multiply dst_stride_z by DEPTH_GEMM3D
3024 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003025#else // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +00003026 // Add offset for batched GEMM
3027 dst_addr += z * dst_stride_z;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003028#endif // defined(REINTERPRET_OUTPUT_AS_3D)
3029
3030 // Multiply by the weight of matrix-matrix product and store the result
3031#if defined(ALPHA)
3032 SCALE_BLOCK(4, float, c, ALPHA);
3033#endif // defined(ALPHA)
3034
3035 // Add beta*bias
3036#if defined(BETA)
3037 REPEAT_VAR_INIT_TO_CONST(4, uint, zero, 0);
3038
3039#if defined(BROADCAST_BIAS)
3040 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float));
3041
3042 LOAD_BLOCK(1, 4, float, bias, src2_addr, 0, src2_stride_y, zero);
3043
3044#ifndef UNIT_BETA
3045 SCALE_BLOCK(1, float, bias, BETA);
3046#endif // UNIT_BIAS
3047
3048 // c = c + bias[broadcasted]
3049 ADD_BLOCK_BROADCAST(4, c, bias0);
3050
3051#else // defined(BROADCAST_BIAS)
3052 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id(
3053 2) * src2_stride_z;
3054
3055 LOAD_BLOCK(4, 4, float, bias, src2_addr, 0, src2_stride_y, zero);
3056
3057#ifndef UNIT_BETA
3058 SCALE_BLOCK(4, float, bias, BETA);
3059#endif // UNIT_BIAS
3060
3061 // c = c + bias
3062 ADD_BLOCK(4, c, bias);
3063
3064#endif // defined(BROADCAST_BIAS)
3065#endif // defined(BETA)
3066
3067#if defined(ACTIVATION_TYPE)
3068 ACTIVATION_BLOCK(4, ACTIVATION_TYPE, float, c, A_VAL, B_VAL);
3069#endif // defined(ACTIVATION_TYPE)
Gian Marcoae2af742018-02-15 12:35:44 +00003070
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003071 // Store 4x4 block
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003072 vstore4(c0, 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
3073 vstore4(c1, 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
3074 vstore4(c2, 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
3075 vstore4(c3, 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003076}
3077
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003078/** This OpenCL kernel is optimized for Bifrost and tt computes the matrix multiplication between matrix A reshaped (src0) and matrix B reshaped (src1)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003079 *
Gian Marco19835e52018-01-30 13:35:54 +00003080 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003081 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (e.g. -DMULT_TRANSPOSE1XW_WIDTH=2)
3082 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (e.g. -DMULT_INTERLEAVE4X4_HEIGHT=2)
3083 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (e.g. -DMULT_INTERLEAVE4X4_HEIGHT=2)
3084 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
3085 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003086 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003087 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
3088 * The activation function is performed after the bias addition
3089 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003090 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
3091 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
3092 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
3093 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
3094 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003095 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
3096 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
3097 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3098 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
3099 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3100 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003101 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003102 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
3103 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3104 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
3105 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3106 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003107 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
3108 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
3109 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
3110 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
3111 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
3112 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003113 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003114 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00003115 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003116 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00003117 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003118 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003119 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
3120 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003121 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003122 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003123 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003124 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003125__kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0),
3126 IMAGE_DECLARATION(src1),
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003127#if defined(BETA)
3128 IMAGE_DECLARATION(src2),
3129#endif // defined(BETA)
Gian Marcoae2af742018-02-15 12:35:44 +00003130 IMAGE_DECLARATION(dst),
3131 uint src0_stride_z,
3132 uint src1_stride_z,
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003133#if defined(BETA)
3134 uint src2_stride_z,
3135#endif //defined(BETA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003136 uint dst_stride_z
3137#if defined(REINTERPRET_OUTPUT_AS_3D)
3138 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003139 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003140#endif // REINTERPRET_OUTPUT_AS_3D
3141 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003142{
Gian Marco36a0a462018-01-12 10:21:40 +00003143 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
3144 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +00003145 int z = get_global_id(2);
Gian Marco36a0a462018-01-12 10:21:40 +00003146
3147 // Offset
3148 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
3149 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 4;
3150
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003151 // src_addr_a = address of matrix A
3152 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003153 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
3154 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
3155
3156#if defined(MATRIX_B_DEPTH)
3157 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
3158 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
3159#else // defined(MATRIX_B_DEPTH)
3160 src1_addr_in_bytes += z * src1_stride_z;
3161#endif // defined(MATRIX_B_DEPTH)
3162
3163 __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes);
3164 __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003165
Gian Marco36a0a462018-01-12 10:21:40 +00003166 src_addr_a += offset_row_a;
3167 src_addr_b += offset_row_b;
3168
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003169 // Reset accumulators
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003170 float4 c0 = 0.0f;
3171 float4 c1 = 0.0f;
3172 float4 c2 = 0.0f;
3173 float4 c3 = 0.0f;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003174
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003175#define COLS_MTX_B (COLS_B / (4 * MULT_TRANSPOSE1XW_WIDTH))
3176
3177 int i = 0;
3178 for(; i <= (int)(COLS_MTX_B - 4); i += 4)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003179 {
3180 // Load values from matrix A (interleaved) and matrix B (transposed)
3181 float4 a0 = vload4(0, src_addr_a);
3182 float4 b0 = vload4(0, src_addr_b);
3183
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003184 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3185 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003186
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003187 c0.s0 = fma(a0.s0, b0.s0, c0.s0);
3188 c0.s1 = fma(a0.s0, b0.s1, c0.s1);
3189 c0.s2 = fma(a0.s0, b0.s2, c0.s2);
3190 c0.s3 = fma(a0.s0, b0.s3, c0.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003191
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003192 c1.s0 = fma(a0.s1, b0.s0, c1.s0);
3193 c1.s1 = fma(a0.s1, b0.s1, c1.s1);
3194 c1.s2 = fma(a0.s1, b0.s2, c1.s2);
3195 c1.s3 = fma(a0.s1, b0.s3, c1.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003196
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003197 c2.s0 = fma(a0.s2, b0.s0, c2.s0);
3198 c2.s1 = fma(a0.s2, b0.s1, c2.s1);
3199 c2.s2 = fma(a0.s2, b0.s2, c2.s2);
3200 c2.s3 = fma(a0.s2, b0.s3, c2.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003201
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003202 c3.s0 = fma(a0.s3, b0.s0, c3.s0);
3203 c3.s1 = fma(a0.s3, b0.s1, c3.s1);
3204 c3.s2 = fma(a0.s3, b0.s2, c3.s2);
3205 c3.s3 = fma(a0.s3, b0.s3, c3.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003206
3207 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003208 a0 = vload4(0, src_addr_a);
3209 b0 = vload4(0, src_addr_b);
3210
3211 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3212 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003213
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003214 c0.s0 = fma(a0.s0, b0.s0, c0.s0);
3215 c0.s1 = fma(a0.s0, b0.s1, c0.s1);
3216 c0.s2 = fma(a0.s0, b0.s2, c0.s2);
3217 c0.s3 = fma(a0.s0, b0.s3, c0.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003218
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003219 c1.s0 = fma(a0.s1, b0.s0, c1.s0);
3220 c1.s1 = fma(a0.s1, b0.s1, c1.s1);
3221 c1.s2 = fma(a0.s1, b0.s2, c1.s2);
3222 c1.s3 = fma(a0.s1, b0.s3, c1.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003223
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003224 c2.s0 = fma(a0.s2, b0.s0, c2.s0);
3225 c2.s1 = fma(a0.s2, b0.s1, c2.s1);
3226 c2.s2 = fma(a0.s2, b0.s2, c2.s2);
3227 c2.s3 = fma(a0.s2, b0.s3, c2.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003228
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003229 c3.s0 = fma(a0.s3, b0.s0, c3.s0);
3230 c3.s1 = fma(a0.s3, b0.s1, c3.s1);
3231 c3.s2 = fma(a0.s3, b0.s2, c3.s2);
3232 c3.s3 = fma(a0.s3, b0.s3, c3.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003233
3234 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003235 a0 = vload4(0, src_addr_a);
3236 b0 = vload4(0, src_addr_b);
3237
3238 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3239 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
3240
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003241 c0.s0 = fma(a0.s0, b0.s0, c0.s0);
3242 c0.s1 = fma(a0.s0, b0.s1, c0.s1);
3243 c0.s2 = fma(a0.s0, b0.s2, c0.s2);
3244 c0.s3 = fma(a0.s0, b0.s3, c0.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003245
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003246 c1.s0 = fma(a0.s1, b0.s0, c1.s0);
3247 c1.s1 = fma(a0.s1, b0.s1, c1.s1);
3248 c1.s2 = fma(a0.s1, b0.s2, c1.s2);
3249 c1.s3 = fma(a0.s1, b0.s3, c1.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003250
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003251 c2.s0 = fma(a0.s2, b0.s0, c2.s0);
3252 c2.s1 = fma(a0.s2, b0.s1, c2.s1);
3253 c2.s2 = fma(a0.s2, b0.s2, c2.s2);
3254 c2.s3 = fma(a0.s2, b0.s3, c2.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003255
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003256 c3.s0 = fma(a0.s3, b0.s0, c3.s0);
3257 c3.s1 = fma(a0.s3, b0.s1, c3.s1);
3258 c3.s2 = fma(a0.s3, b0.s2, c3.s2);
3259 c3.s3 = fma(a0.s3, b0.s3, c3.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003260
3261 // Load values from matrix A (interleaved) and matrix B (transposed)
3262 a0 = vload4(0, src_addr_a);
3263 b0 = vload4(0, src_addr_b);
3264
3265 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3266 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003267
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003268 c0.s0 = fma(a0.s0, b0.s0, c0.s0);
3269 c0.s1 = fma(a0.s0, b0.s1, c0.s1);
3270 c0.s2 = fma(a0.s0, b0.s2, c0.s2);
3271 c0.s3 = fma(a0.s0, b0.s3, c0.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003272
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003273 c1.s0 = fma(a0.s1, b0.s0, c1.s0);
3274 c1.s1 = fma(a0.s1, b0.s1, c1.s1);
3275 c1.s2 = fma(a0.s1, b0.s2, c1.s2);
3276 c1.s3 = fma(a0.s1, b0.s3, c1.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003277
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003278 c2.s0 = fma(a0.s2, b0.s0, c2.s0);
3279 c2.s1 = fma(a0.s2, b0.s1, c2.s1);
3280 c2.s2 = fma(a0.s2, b0.s2, c2.s2);
3281 c2.s3 = fma(a0.s2, b0.s3, c2.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003282
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003283 c3.s0 = fma(a0.s3, b0.s0, c3.s0);
3284 c3.s1 = fma(a0.s3, b0.s1, c3.s1);
3285 c3.s2 = fma(a0.s3, b0.s2, c3.s2);
3286 c3.s3 = fma(a0.s3, b0.s3, c3.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003287 }
3288
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003289 for(; i < (int)(COLS_MTX_B); ++i)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003290 {
3291 // Load values from matrix A (interleaved) and matrix B (transposed)
3292 float4 a0 = vload4(0, src_addr_a);
3293 float4 b0 = vload4(0, src_addr_b);
3294
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003295 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3296 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
3297
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003298 c0.s0 = fma(a0.s0, b0.s0, c0.s0);
3299 c0.s1 = fma(a0.s0, b0.s1, c0.s1);
3300 c0.s2 = fma(a0.s0, b0.s2, c0.s2);
3301 c0.s3 = fma(a0.s0, b0.s3, c0.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003302
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003303 c1.s0 = fma(a0.s1, b0.s0, c1.s0);
3304 c1.s1 = fma(a0.s1, b0.s1, c1.s1);
3305 c1.s2 = fma(a0.s1, b0.s2, c1.s2);
3306 c1.s3 = fma(a0.s1, b0.s3, c1.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003307
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003308 c2.s0 = fma(a0.s2, b0.s0, c2.s0);
3309 c2.s1 = fma(a0.s2, b0.s1, c2.s1);
3310 c2.s2 = fma(a0.s2, b0.s2, c2.s2);
3311 c2.s3 = fma(a0.s2, b0.s3, c2.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003312
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003313 c3.s0 = fma(a0.s3, b0.s0, c3.s0);
3314 c3.s1 = fma(a0.s3, b0.s1, c3.s1);
3315 c3.s2 = fma(a0.s3, b0.s2, c3.s2);
3316 c3.s3 = fma(a0.s3, b0.s3, c3.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003317 }
3318
3319 // Compute destination address
3320 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
3321
Gian Marcoae2af742018-02-15 12:35:44 +00003322 // Compute dst address
3323 __global uchar *dst_addr = offset(&dst, 0, 0);
3324
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003325 uint4 zout = 0;
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003326
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003327#if defined(REINTERPRET_OUTPUT_AS_3D)
3328 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003329 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003330 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003331 // | |
3332 // | plane0 |
3333 // | |
3334 // |__________________|
3335 // |******************|
3336 // | cross_plane_pad |
3337 // |******************|
3338 // | |
3339 // | plane1 |
3340 // | |
3341 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003342
3343 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003344 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
3345 zout = min(DEPTH_GEMM3D - 1, zout);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003346
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003347 // Add offset due to the cross plane paddings
3348 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003349
3350 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3351 // multiply dst_stride_z by DEPTH_GEMM3D
3352 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003353#else // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +00003354 // Add offset for batched GEMM
3355 dst_addr += z * dst_stride_z;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003356#endif // defined(REINTERPRET_OUTPUT_AS_3D)
3357
3358 // Multiply by the weight of matrix-matrix product and store the result
3359#if defined(ALPHA)
3360 SCALE_BLOCK(4, float, c, ALPHA);
3361#endif // defined(ALPHA)
3362
3363 // Add beta*bias
3364#if defined(BETA)
3365 REPEAT_VAR_INIT_TO_CONST(4, uint, zero, 0);
3366
3367#if defined(BROADCAST_BIAS)
3368 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float));
3369
3370 LOAD_BLOCK(1, 4, float, bias, src2_addr, 0, src2_stride_y, zero);
3371
3372#ifndef UNIT_BETA
3373 SCALE_BLOCK(1, float, bias, BETA);
3374#endif // UNIT_BIAS
3375
3376 // c = c + bias[broadcasted]
3377 ADD_BLOCK_BROADCAST(4, c, bias0);
3378
3379#else // defined(BROADCAST_BIAS)
3380 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id(
3381 2) * src2_stride_z;
3382
3383 LOAD_BLOCK(4, 4, float, bias, src2_addr, 0, src2_stride_y, zero);
3384
3385#ifndef UNIT_BETA
3386 SCALE_BLOCK(4, float, bias, BETA);
3387#endif // UNIT_BIAS
3388
3389 // c = c + bias
3390 ADD_BLOCK(4, c, bias);
3391
3392#endif // defined(BROADCAST_BIAS)
3393#endif // defined(BETA)
3394
3395#if defined(ACTIVATION_TYPE)
3396 ACTIVATION_BLOCK(4, ACTIVATION_TYPE, float, c, A_VAL, B_VAL);
3397#endif // defined(ACTIVATION_TYPE)
Gian Marcoae2af742018-02-15 12:35:44 +00003398
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003399 // Store 4x4 block
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003400 vstore4(c0, 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
3401 vstore4(c1, 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
3402 vstore4(c2, 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
3403 vstore4(c3, 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003404}
3405
Georgios Pinitas84225582018-05-14 12:00:05 +01003406// Undefine local defines
3407#undef COLS_MTX_B
3408
Matthew Bentham6f31f8c2017-10-27 11:50:06 +01003409#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003410/** This OpenCL kernel computes the matrix multiplication between matrix A reshaped (src0) and matrix B reshaped (src1)
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003411 *
Gian Marco19835e52018-01-30 13:35:54 +00003412 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003413 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (e.g. -DMULT_TRANSPOSE1XW_WIDTH=2)
3414 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (e.g. -DMULT_INTERLEAVE4X4_HEIGHT=2)
3415 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
3416 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003417 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003418 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
3419 * The activation function is performed after the bias addition
3420 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003421 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
3422 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
3423 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
3424 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
3425 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003426 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
3427 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
3428 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3429 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
3430 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3431 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003432 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003433 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
3434 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3435 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
3436 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3437 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003438 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
3439 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
3440 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
3441 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
3442 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
3443 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003444 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003445 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00003446 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003447 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00003448 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003449 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003450 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
3451 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003452 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003453 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003454 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003455 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003456__kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0),
3457 IMAGE_DECLARATION(src1),
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003458#if defined(BETA)
3459 IMAGE_DECLARATION(src2),
3460#endif // defined(BETA)
Gian Marcoae2af742018-02-15 12:35:44 +00003461 IMAGE_DECLARATION(dst),
3462 uint src0_stride_z,
3463 uint src1_stride_z,
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003464#if defined(BETA)
3465 uint src2_stride_z,
3466#endif //defined(BETA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003467 uint dst_stride_z
3468#if defined(REINTERPRET_OUTPUT_AS_3D)
3469 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003470 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003471#endif // REINTERPRET_OUTPUT_AS_3D
3472 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003473{
Gian Marco36a0a462018-01-12 10:21:40 +00003474 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
3475 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +00003476 int z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003477
Gian Marco36a0a462018-01-12 10:21:40 +00003478 // Offset
3479 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
3480 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003481
Gian Marco36a0a462018-01-12 10:21:40 +00003482 // src_addr_a = address of matrix A
3483 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003484 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
3485 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
3486
3487#if defined(MATRIX_B_DEPTH)
3488 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
3489 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
3490#else // defined(MATRIX_B_DEPTH)
3491 src1_addr_in_bytes += z * src1_stride_z;
3492#endif // defined(MATRIX_B_DEPTH)
3493
3494 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
3495 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003496
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003497 // Compute end row address for matrix B
Gian Marco36a0a462018-01-12 10:21:40 +00003498 __global half *src_end_addr_b = src_addr_b + COLS_B;
3499
3500 src_addr_a += offset_row_a;
3501 src_addr_b += offset_row_b;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003502
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003503 // Reset accumulators
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003504 half8 c0 = 0.0f;
3505 half8 c1 = 0.0f;
3506 half8 c2 = 0.0f;
3507 half8 c3 = 0.0f;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003508
Gian Marco36a0a462018-01-12 10:21:40 +00003509 for(; src_addr_b <= (src_end_addr_b - (int)(16 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 16 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003510 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003511 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00003512 half4 a0 = vload4(0, src_addr_a);
3513 half8 b0 = vload8(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003514
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003515 c0 += (half8)a0.s0 * b0;
3516 c1 += (half8)a0.s1 * b0;
3517 c2 += (half8)a0.s2 * b0;
3518 c3 += (half8)a0.s3 * b0;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003519
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003520 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00003521 a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
3522 b0 = vload8(0, src_addr_b + 8 * MULT_TRANSPOSE1XW_WIDTH);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003523
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003524 c0 += (half8)a0.s0 * b0;
3525 c1 += (half8)a0.s1 * b0;
3526 c2 += (half8)a0.s2 * b0;
3527 c3 += (half8)a0.s3 * b0;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003528 }
3529
Gian Marco36a0a462018-01-12 10:21:40 +00003530 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003531 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003532 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00003533 half4 a0 = vload4(0, src_addr_a);
3534 half8 b0 = vload8(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003535
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003536 c0 += (half8)a0.s0 * b0;
3537 c1 += (half8)a0.s1 * b0;
3538 c2 += (half8)a0.s2 * b0;
3539 c3 += (half8)a0.s3 * b0;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003540 }
3541
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003542 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003543 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
3544
Gian Marcoae2af742018-02-15 12:35:44 +00003545 // Compute dst address
3546 __global uchar *dst_addr = offset(&dst, 0, 0);
3547
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003548 uint4 zout = 0;
3549
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003550#if defined(REINTERPRET_OUTPUT_AS_3D)
3551 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003552 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003553 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003554 // | |
3555 // | plane0 |
3556 // | |
3557 // |__________________|
3558 // |******************|
3559 // | cross_plane_pad |
3560 // |******************|
3561 // | |
3562 // | plane1 |
3563 // | |
3564 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003565
3566 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003567 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
3568 zout = min(DEPTH_GEMM3D - 1, zout);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003569
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003570 // Add offset due to the cross plane paddings
3571 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003572
3573 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3574 // multiply dst_stride_z by DEPTH_GEMM3D
3575 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003576#else // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +00003577 // Add offset for batched GEMM
3578 dst_addr += z * dst_stride_z;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003579#endif // defined(REINTERPRET_OUTPUT_AS_3D)
3580
3581 // Multiply by the weight of matrix-matrix product and store the result
3582#if defined(ALPHA)
3583 SCALE_BLOCK(4, half, c, ALPHA);
3584#endif // defined(ALPHA)
3585
3586 // Add beta*bias
3587#if defined(BETA)
3588 REPEAT_VAR_INIT_TO_CONST(4, uint, zero, 0);
3589
3590#if defined(BROADCAST_BIAS)
3591 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half));
3592
3593 LOAD_BLOCK(1, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
3594
3595#ifndef UNIT_BETA
3596 SCALE_BLOCK(1, half, bias, BETA);
3597#endif // UNIT_BIAS
3598
3599 // c = c + bias[broadcasted]
3600 ADD_BLOCK_BROADCAST(4, c, bias0);
3601
3602#else // defined(BROADCAST_BIAS)
3603
3604 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id(
3605 2) * src2_stride_z;
3606
3607 LOAD_BLOCK(4, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
3608
3609#ifndef UNIT_BETA
3610 SCALE_BLOCK(4, half, bias, BETA);
3611#endif // UNIT_BIAS
3612
3613 // c = c + bias
3614 ADD_BLOCK(4, c, bias);
3615
3616#endif // defined(BROADCAST_BIAS)
3617#endif // defined(BETA)
3618
3619#if defined(ACTIVATION_TYPE)
3620 ACTIVATION_BLOCK(4, ACTIVATION_TYPE, half, c, A_VAL, B_VAL);
3621#endif // defined(ACTIVATION_TYPE)
Gian Marcoae2af742018-02-15 12:35:44 +00003622
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003623 // Store 4x8 block
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003624 vstore8(c0, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
3625 vstore8(c1, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
3626 vstore8(c2, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
3627 vstore8(c3, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003628}
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003629
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003630/** This OpenCL kernel computes the matrix multiplication between matrix A reshaped (src0) and matrix B reshaped (src1) while accumulating the result in a 32 floating point variable.
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003631 *
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003632 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003633 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (e.g. -DMULT_TRANSPOSE1XW_WIDTH=2)
3634 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (e.g. -DMULT_INTERLEAVE4X4_HEIGHT=2)
3635 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
3636 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003637 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003638 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
3639 * The activation function is performed after the bias addition
3640 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003641 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
3642 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
3643 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
3644 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
3645 *
3646 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
3647 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
3648 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3649 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
3650 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3651 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
3652 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
3653 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
3654 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3655 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
3656 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3657 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003658 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
3659 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
3660 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
3661 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
3662 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
3663 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003664 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
3665 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
3666 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
3667 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
3668 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
3669 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
3670 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
3671 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003672 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003673 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
3674 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
3675 */
3676__kernel void gemm_mm_interleaved_transposed_f16_acc32(IMAGE_DECLARATION(src0),
3677 IMAGE_DECLARATION(src1),
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003678#if defined(BETA)
3679 IMAGE_DECLARATION(src2),
3680#endif // defined(BETA)
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003681 IMAGE_DECLARATION(dst),
3682 uint src0_stride_z,
3683 uint src1_stride_z,
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003684#if defined(BETA)
3685 uint src2_stride_z,
3686#endif //defined(BETA)
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003687 uint dst_stride_z
3688#if defined(REINTERPRET_OUTPUT_AS_3D)
3689 ,
3690 uint cross_plane_pad
3691#endif // REINTERPRET_OUTPUT_AS_3D
3692 )
3693{
3694 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
3695 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
3696 int z = get_global_id(2);
3697
3698 // Offset
3699 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
3700 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
3701
3702 // src_addr_a = address of matrix A
3703 // src_addr_b = address of matrix B
3704 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
3705 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
3706
3707#if defined(MATRIX_B_DEPTH)
3708 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
3709 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
3710#else // defined(MATRIX_B_DEPTH)
3711 src1_addr_in_bytes += z * src1_stride_z;
3712#endif // defined(MATRIX_B_DEPTH)
3713
3714 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
3715 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
3716
3717 // Compute end row address for matrix B
3718 __global half *src_end_addr_b = src_addr_b + COLS_B;
3719
3720 src_addr_a += offset_row_a;
3721 src_addr_b += offset_row_b;
3722
3723 // Reset accumulators
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003724 float8 c0 = 0.0f;
3725 float8 c1 = 0.0f;
3726 float8 c2 = 0.0f;
3727 float8 c3 = 0.0f;
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003728
3729 for(; src_addr_b <= (src_end_addr_b - (int)(16 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 16 * MULT_TRANSPOSE1XW_WIDTH)
3730 {
3731 // Load values from matrix A (interleaved) and matrix B (transposed)
3732 float4 a0 = convert_float4(vload4(0, src_addr_a));
3733 float8 b0 = convert_float8(vload8(0, src_addr_b));
3734
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003735 c0 += (float8)a0.s0 * b0;
3736 c1 += (float8)a0.s1 * b0;
3737 c2 += (float8)a0.s2 * b0;
3738 c3 += (float8)a0.s3 * b0;
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003739
3740 // Load values from matrix A (interleaved) and matrix B (transposed)
3741 a0 = convert_float4(vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT));
3742 b0 = convert_float8(vload8(0, src_addr_b + 8 * MULT_TRANSPOSE1XW_WIDTH));
3743
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003744 c0 += (float8)a0.s0 * b0;
3745 c1 += (float8)a0.s1 * b0;
3746 c2 += (float8)a0.s2 * b0;
3747 c3 += (float8)a0.s3 * b0;
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003748 }
3749
3750 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
3751 {
3752 // Load values from matrix A (interleaved) and matrix B (transposed)
3753 float4 a0 = convert_float4(vload4(0, src_addr_a));
3754 float8 b0 = convert_float8(vload8(0, src_addr_b));
3755
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003756 c0 += (float8)a0.s0 * b0;
3757 c1 += (float8)a0.s1 * b0;
3758 c2 += (float8)a0.s2 * b0;
3759 c3 += (float8)a0.s3 * b0;
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003760 }
3761
3762 // Compute destination address
3763 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
3764
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003765 // Compute dst address
3766 __global uchar *dst_addr = offset(&dst, 0, 0);
3767
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003768 uint4 zout = 0;
3769
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003770#if defined(REINTERPRET_OUTPUT_AS_3D)
3771 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
3772 // in order to take into account the presence of possible cross plane paddings
3773 //
3774 // | |
3775 // | plane0 |
3776 // | |
3777 // |__________________|
3778 // |******************|
3779 // | cross_plane_pad |
3780 // |******************|
3781 // | |
3782 // | plane1 |
3783 // | |
3784 // |__________________|
3785
3786 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003787 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
3788 zout = min(DEPTH_GEMM3D - 1, zout);
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003789
3790 // Add offset due to the cross plane paddings
3791 zout *= (cross_plane_pad * dst_stride_y);
3792
3793 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3794 // multiply dst_stride_z by DEPTH_GEMM3D
3795 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003796#else // defined(REINTERPRET_OUTPUT_AS_3D)
3797 // Add offset for batched GEMM
3798 dst_addr += z * dst_stride_z;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003799#endif // defined(REINTERPRET_OUTPUT_AS_3D)
3800
3801 // Multiply by the weight of matrix-matrix product and store the result
3802#if defined(ALPHA)
3803 SCALE_BLOCK(4, float, c, ALPHA);
3804#endif // defined(ALPHA)
3805
3806#if defined(BETA)
3807 REPEAT_VAR_INIT_TO_CONST(4, uint, zero, 0);
3808
3809#if defined(BROADCAST_BIAS)
3810 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half));
3811
3812 LOAD_BLOCK(1, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
3813
3814 float8 bias_f0 = convert_float8(bias0);
3815
3816#ifndef UNIT_BETA
3817 SCALE_BLOCK(1, float, bias_f, BETA);
3818#endif // UNIT_BIAS
3819
3820 // c = c + bias[broadcasted]
3821 ADD_BLOCK_BROADCAST(4, c, bias_f0);
3822
3823#else // defined(BROADCAST_BIAS)
3824 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id(
3825 2) * src2_stride_z;
3826
3827 LOAD_BLOCK(4, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
3828
3829 float8 bias_f0 = convert_float8(bias0);
3830 float8 bias_f1 = convert_float8(bias1);
3831 float8 bias_f2 = convert_float8(bias2);
3832 float8 bias_f3 = convert_float8(bias3);
3833
3834#ifndef UNIT_BETA
3835 SCALE_BLOCK(4, float, bias_f, BETA);
3836#endif // UNIT_BIAS
3837
3838 // c = c + bias
3839 ADD_BLOCK(4, c, bias_f);
3840
3841#endif // defined(BROADCAST_BIAS)
3842#endif // defined(BETA)
3843
3844 half8 c_h0 = convert_half8(c0);
3845 half8 c_h1 = convert_half8(c1);
3846 half8 c_h2 = convert_half8(c2);
3847 half8 c_h3 = convert_half8(c3);
3848
3849#if defined(ACTIVATION_TYPE)
3850 ACTIVATION_BLOCK(4, ACTIVATION_TYPE, half, c_h, A_VAL, B_VAL);
3851#endif // defined(ACTIVATION_TYPE)
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003852
3853 // Store 4x8 block
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003854 vstore8(c_h0, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
3855 vstore8(c_h1, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
3856 vstore8(c_h2, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
3857 vstore8(c_h3, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003858}
3859
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003860/** This OpenCL kernel optimized for Bifrost architectures computes the matrix multiplication between matrix A reshaped (src0) and matrix B reshaped (src1)
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003861 *
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003862 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003863 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (e.g. -DMULT_TRANSPOSE1XW_WIDTH=2)
3864 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (e.g. -DMULT_INTERLEAVE4X4_HEIGHT=2)
3865 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
3866 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003867 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003868 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
3869 * The activation function is performed after the bias addition
3870 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003871 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
3872 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
3873 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
3874 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
3875 *
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003876 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
3877 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
3878 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3879 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
3880 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3881 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
3882 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
3883 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
3884 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3885 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
3886 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3887 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003888 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
3889 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
3890 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
3891 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
3892 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
3893 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003894 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
3895 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
3896 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
3897 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
3898 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
3899 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003900 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
3901 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
3902 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003903 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003904 */
3905__kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0),
3906 IMAGE_DECLARATION(src1),
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003907#if defined(BETA)
3908 IMAGE_DECLARATION(src2),
3909#endif // defined(BETA)
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003910 IMAGE_DECLARATION(dst),
3911 uint src0_stride_z,
3912 uint src1_stride_z,
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003913#if defined(BETA)
3914 uint src2_stride_z,
3915#endif //defined(BETA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003916 uint dst_stride_z
3917#if defined(REINTERPRET_OUTPUT_AS_3D)
3918 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003919 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003920#endif // REINTERPRET_OUTPUT_AS_3D
3921 )
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003922{
3923 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
3924 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
3925 int z = get_global_id(2);
3926
3927 // Offset
3928 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
3929 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
3930
3931 // src_addr_a = address of matrix A
3932 // src_addr_b = address of matrix B
3933 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
3934 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
3935
3936#if defined(MATRIX_B_DEPTH)
3937 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
3938 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
3939#else // defined(MATRIX_B_DEPTH)
3940 src1_addr_in_bytes += z * src1_stride_z;
3941#endif // defined(MATRIX_B_DEPTH)
3942
3943 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
3944 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
3945
3946 // Compute end row address for matrix B
3947 __global half *src_end_addr_b = src_addr_b + COLS_B;
3948
3949 src_addr_a += offset_row_a;
3950 src_addr_b += offset_row_b;
3951
3952 // Reset accumulators
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003953 half8 c0 = 0.0f;
3954 half8 c1 = 0.0f;
3955 half8 c2 = 0.0f;
3956 half8 c3 = 0.0f;
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003957
3958#define COLS_MTX_B (COLS_B / (8 * MULT_TRANSPOSE1XW_WIDTH))
3959
3960 int i = 0;
3961 for(; i <= (int)(COLS_MTX_B - 4); i += 4)
3962 {
3963#if MULT_INTERLEAVE4X4_HEIGHT == 1
3964 // Load values from matrix A (interleaved) and matrix B (transposed)
3965 half8 a0 = vload8(0, src_addr_a);
3966 half8 b0 = vload8(0, src_addr_b);
3967
3968 src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT;
3969 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3970
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003971 c0 = fma((half8)a0.s0, b0, c0);
3972 c1 = fma((half8)a0.s1, b0, c1);
3973 c2 = fma((half8)a0.s2, b0, c2);
3974 c3 = fma((half8)a0.s3, b0, c3);
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003975
3976 // Load values from matrix B (transposed)
3977 b0 = vload8(0, src_addr_b);
3978
3979 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3980
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003981 c0 = fma((half8)a0.s4, b0, c0);
3982 c1 = fma((half8)a0.s5, b0, c1);
3983 c2 = fma((half8)a0.s6, b0, c2);
3984 c3 = fma((half8)a0.s7, b0, c3);
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003985
3986 // Load values from matrix A (interleaved) and matrix B (transposed)
3987 a0 = vload8(0, src_addr_a);
3988 b0 = vload8(0, src_addr_b);
3989
3990 src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT;
3991 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3992
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003993 c0 = fma((half8)a0.s0, b0, c0);
3994 c1 = fma((half8)a0.s1, b0, c1);
3995 c2 = fma((half8)a0.s2, b0, c2);
3996 c3 = fma((half8)a0.s3, b0, c3);
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003997
3998 // Load values from matrix B (transposed)
3999 b0 = vload8(0, src_addr_b);
4000
4001 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
4002
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004003 c0 = fma((half8)a0.s4, b0, c0);
4004 c1 = fma((half8)a0.s5, b0, c1);
4005 c2 = fma((half8)a0.s6, b0, c2);
4006 c3 = fma((half8)a0.s7, b0, c3);
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004007#else // MULT_INTERLEAVE4X4_HEIGHT == 1
4008 // Load values from matrix A (interleaved) and matrix B (transposed)
4009 half4 a0 = vload4(0, src_addr_a);
4010 half8 b0 = vload8(0, src_addr_b);
4011
4012 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
4013 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
4014
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004015 c0 = fma((half8)a0.s0, b0, c0);
4016 c1 = fma((half8)a0.s1, b0, c1);
4017 c2 = fma((half8)a0.s2, b0, c2);
4018 c3 = fma((half8)a0.s3, b0, c3);
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004019
4020 // Load values from matrix A (interleaved) and matrix B (transposed)
4021 a0 = vload4(0, src_addr_a);
4022 b0 = vload8(0, src_addr_b);
4023
4024 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
4025 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
4026
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004027 c0 = fma((half8)a0.s0, b0, c0);
4028 c1 = fma((half8)a0.s1, b0, c1);
4029 c2 = fma((half8)a0.s2, b0, c2);
4030 c3 = fma((half8)a0.s3, b0, c3);
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004031
4032 // Load values from matrix A (interleaved) and matrix B (transposed)
4033 a0 = vload4(0, src_addr_a);
4034 b0 = vload8(0, src_addr_b);
4035
4036 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
4037 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
4038
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004039 c0 = fma((half8)a0.s0, b0, c0);
4040 c1 = fma((half8)a0.s1, b0, c1);
4041 c2 = fma((half8)a0.s2, b0, c2);
4042 c3 = fma((half8)a0.s3, b0, c3);
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004043
4044 // Load values from matrix A (interleaved) and matrix B (transposed)
4045 a0 = vload4(0, src_addr_a);
4046 b0 = vload8(0, src_addr_b);
4047
4048 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
4049 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
4050
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004051 c0 = fma((half8)a0.s0, b0, c0);
4052 c1 = fma((half8)a0.s1, b0, c1);
4053 c2 = fma((half8)a0.s2, b0, c2);
4054 c3 = fma((half8)a0.s3, b0, c3);
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004055#endif // MULT_INTERLEAVE4X4_HEIGHT == 1
4056 }
4057
4058 for(; i < (int)(COLS_MTX_B); ++i)
4059 {
4060 // Load values from matrix A (interleaved) and matrix B (transposed)
4061 half4 a0 = vload4(0, src_addr_a);
4062 half8 b0 = vload8(0, src_addr_b);
4063
4064 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
4065 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
4066
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004067 c0 = fma((half8)a0.s0, b0, c0);
4068 c1 = fma((half8)a0.s1, b0, c1);
4069 c2 = fma((half8)a0.s2, b0, c2);
4070 c3 = fma((half8)a0.s3, b0, c3);
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004071 }
4072
4073 // Compute destination address
4074 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
4075
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004076 // Compute dst address
4077 __global uchar *dst_addr = offset(&dst, 0, 0);
4078
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004079 uint4 zout = 0;
4080
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004081#if defined(REINTERPRET_OUTPUT_AS_3D)
4082 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004083 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004084 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004085 // | |
4086 // | plane0 |
4087 // | |
4088 // |__________________|
4089 // |******************|
4090 // | cross_plane_pad |
4091 // |******************|
4092 // | |
4093 // | plane1 |
4094 // | |
4095 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004096
4097 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004098 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
4099 zout = min(DEPTH_GEMM3D - 1, zout);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004100
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004101 // Add offset due to the cross plane paddings
4102 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004103
4104 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4105 // multiply dst_stride_z by DEPTH_GEMM3D
4106 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004107#else // defined(REINTERPRET_OUTPUT_AS_3D)
4108 // Add offset for batched GEMM
4109 dst_addr += z * dst_stride_z;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004110#endif // defined(REINTERPRET_OUTPUT_AS_3D)
4111
4112 // Multiply by the weight of matrix-matrix product and store the result
4113#if defined(ALPHA)
4114 SCALE_BLOCK(4, half, c, ALPHA);
4115#endif // defined(ALPHA)
4116
4117 // Add beta*bias
4118#if defined(BETA)
4119 REPEAT_VAR_INIT_TO_CONST(4, uint, zero, 0);
4120
4121#if defined(BROADCAST_BIAS)
4122 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half));
4123
4124 LOAD_BLOCK(1, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
4125
4126#ifndef UNIT_BETA
4127 SCALE_BLOCK(1, half, bias, BETA);
4128#endif // UNIT_BIAS
4129
4130 // c = c + bias[broadcasted]
4131 ADD_BLOCK_BROADCAST(4, c, bias0);
4132
4133#else // defined(BROADCAST_BIAS)
4134 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id(
4135 2) * src2_stride_z;
4136
4137 LOAD_BLOCK(4, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
4138
4139#ifndef UNIT_BETA
4140 SCALE_BLOCK(4, half, bias, BETA);
4141#endif // UNIT_BIAS
4142
4143 // c = c + bias
4144 ADD_BLOCK(4, c, bias);
4145
4146#endif // defined(BROADCAST_BIAS)
4147#endif // defined(BETA)
4148
4149#if defined(ACTIVATION_TYPE)
4150 ACTIVATION_BLOCK(4, ACTIVATION_TYPE, half, c, A_VAL, B_VAL);
4151#endif // defined(ACTIVATION_TYPE)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004152
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004153 // Store 4x8 block
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004154 vstore8(c0, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
4155 vstore8(c1, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
4156 vstore8(c2, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
4157 vstore8(c3, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004158}
Georgios Pinitas84225582018-05-14 12:00:05 +01004159
4160// Undefine local defines
4161#undef COLS_MTX_B
4162
Matthew Bentham6f31f8c2017-10-27 11:50:06 +01004163#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004164
Gian Marco36a0a462018-01-12 10:21:40 +00004165#endif // defined(COLS_B) && defined(MULT_TRANSPOSE1XW_WIDTH) && defined(MULT_INTERLEAVE4X4_HEIGHT)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01004166
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004167#if defined(COLS_A) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && (NUM_ELEMS_PROCESSED_PER_THREAD_Y)
4168#if defined(DATA_TYPE)
4169#define VECTOR_TYPE VEC_DATA_TYPE(DATA_TYPE, NUM_ELEMS_PROCESSED_PER_THREAD_X)
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004170/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped.
4171 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004172 * @note This OpenCL kernel works with floating point data types (F16/F32)
4173 * @note The floating point data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
4174 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004175 * @note The number of matrix A columns and the optional alpha's value need to be passed at compile time using -DCOLS_A and -DALPHA
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004176 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
4177 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004178 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004179 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
4180 * The activation function is performed after the bias addition
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004181 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
4182 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004183 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
4184 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
4185 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
4186 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
4187 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004188 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004189 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
4190 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4191 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
4192 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4193 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01004194 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004195 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
4196 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4197 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
4198 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4199 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004200 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
4201 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
4202 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
4203 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
4204 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
4205 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01004206 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004207 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
4208 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
4209 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
4210 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
4211 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004212 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
4213 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004214 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004215 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004216 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
4217 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements for the output tensor (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004218 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004219__kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
4220 IMAGE_DECLARATION(src1),
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004221#if defined(BETA)
4222 IMAGE_DECLARATION(src2),
4223#endif // defined(BETA)
Gian Marcoae2af742018-02-15 12:35:44 +00004224 IMAGE_DECLARATION(dst),
4225 uint src0_stride_z,
4226 uint src1_stride_z,
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004227#if defined(BETA)
4228 uint src2_stride_z,
4229#endif //defined(BETA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004230 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004231#if defined(REINTERPRET_INPUT_AS_3D)
4232 ,
4233 uint src_cross_plane_pad
4234#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004235#if defined(REINTERPRET_OUTPUT_AS_3D)
4236 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004237 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004238#endif // REINTERPRET_OUTPUT_AS_3D
4239 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004240{
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004241 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004242
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004243 // Compute starting address for matrix A and Matrix B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004244 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004245
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004246 // Update address for the matrix A
4247 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004248
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004249 // Update address for the matrix B
4250 src_addr.s1 += idx * sizeof(DATA_TYPE);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004251
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004252#if defined(REINTERPRET_INPUT_AS_3D)
4253 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
4254 // in order to take into account the presence of possible cross plane paddings
4255 //
4256 // | |
4257 // | plane0 |
4258 // | |
4259 // |__________________|
4260 // |******************|
4261 // | cross_plane_pad |
4262 // |******************|
4263 // | |
4264 // | plane1 |
4265 // | |
4266 // |__________________|
4267
4268 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
4269 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
4270 zin = min(DEPTH_GEMM3D - 1, zin);
4271
4272 // Add offset due to the cross plane paddings
4273 zin *= (src_cross_plane_pad * src0_stride_y);
4274
4275 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4276 // multiply src0_stride_z by DEPTH_GEMM3D
4277 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
4278
4279#else // defined(REINTERPRET_INPUT_AS_3D)
4280
Gian Marcoae2af742018-02-15 12:35:44 +00004281 // Add offset for batched GEMM
4282 src_addr.s0 += get_global_id(2) * src0_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00004283
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004284#endif // defined(REINTERPRET_INPUT_AS_3D)
4285
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00004286#if defined(MATRIX_B_DEPTH)
4287 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
4288 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
4289#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00004290 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00004291#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00004292
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004293 int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(DATA_TYPE));
4294
4295 VECTOR_TYPE acc0 = 0.0f;
4296#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4297 VECTOR_TYPE acc1 = 0.0f;
4298#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4299#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4300 VECTOR_TYPE acc2 = 0.0f;
4301#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4302#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4303 VECTOR_TYPE acc3 = 0.0f;
4304#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4305
Georgios Pinitas96880cf2017-10-20 18:52:20 +01004306 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(DATA_TYPE)); src_addr += (int2)(2 * sizeof(DATA_TYPE), 2 * src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004307 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004308#if defined(REINTERPRET_INPUT_AS_3D)
4309 // Load values from matrix A
Usama Arif0681e3b2019-04-25 14:28:07 +01004310 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 2, DATA_TYPE, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s);
4311#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004312 // Load values from matrix A
4313 VEC_DATA_TYPE(DATA_TYPE, 2)
4314 a0 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
4315#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4316 VEC_DATA_TYPE(DATA_TYPE, 2)
4317 a1 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
4318#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4319#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4320 VEC_DATA_TYPE(DATA_TYPE, 2)
4321 a2 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
4322#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4323#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4324 VEC_DATA_TYPE(DATA_TYPE, 2)
4325 a3 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
4326#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004327#endif // defined(REINTERPRET_INPUT_AS_3D)
4328
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004329 // Load values from matrix B
4330 VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
4331 VECTOR_TYPE b1 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1 + src1_stride_y));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004332
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004333 // Accumulate
4334 acc0 += b0 * (VECTOR_TYPE)a0.s0;
4335 acc0 += b1 * (VECTOR_TYPE)a0.s1;
4336#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4337 acc1 += b0 * (VECTOR_TYPE)a1.s0;
4338 acc1 += b1 * (VECTOR_TYPE)a1.s1;
4339#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4340#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4341 acc2 += b0 * (VECTOR_TYPE)a2.s0;
4342 acc2 += b1 * (VECTOR_TYPE)a2.s1;
4343#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4344#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4345 acc3 += b0 * (VECTOR_TYPE)a3.s0;
4346 acc3 += b1 * (VECTOR_TYPE)a3.s1;
4347#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004348 }
4349
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004350 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(DATA_TYPE), src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004351 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004352#if defined(REINTERPRET_INPUT_AS_3D)
4353 // Load values from matrix A
4354 DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
4355#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4356 DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
4357#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4358#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4359 DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
4360#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4361#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4362 DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
4363#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4364#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004365 // Load values from matrix A
4366 DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
4367#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4368 DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
4369#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4370#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4371 DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
4372#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4373#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4374 DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
4375#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004376#endif // defined(REINTERPRET_INPUT_AS_3D)
4377
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004378 // Load values from matrix B
4379 VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004380
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004381 // Accumulate
4382 acc0 += b0 * (VECTOR_TYPE)a0;
4383#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4384 acc1 += b0 * (VECTOR_TYPE)a1;
4385#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4386#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4387 acc2 += b0 * (VECTOR_TYPE)a2;
4388#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4389#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4390 acc3 += b0 * (VECTOR_TYPE)a3;
4391#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004392 }
4393
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004394 int z = get_global_id(2);
4395
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004396 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004397 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
4398
Gian Marcoae2af742018-02-15 12:35:44 +00004399 // Compute dst address
4400 __global uchar *dst_addr = offset(&dst, 0, 0);
4401
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004402 uint4 zout = 0;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004403
4404#if defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004405
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004406 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004407 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004408 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004409 // | |
4410 // | plane0 |
4411 // | |
4412 // |__________________|
4413 // |******************|
4414 // | cross_plane_pad |
4415 // |******************|
4416 // | |
4417 // | plane1 |
4418 // | |
4419 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004420
4421 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004422 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
4423 zout = min(DEPTH_GEMM3D - 1, zout);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004424
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004425 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004426 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004427
4428 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4429 // multiply dst_stride_z by DEPTH_GEMM3D
4430 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004431#else // defined(REINTERPRET_OUTPUT_AS_3D)
4432 // Add offset for batched GEMM
4433 dst_addr += z * dst_stride_z;
4434#endif // defined(REINTERPRET_OUTPUT_AS_3D)
4435
4436 // Multiply by the weight of matrix-matrix product and store the result
4437#if defined(ALPHA)
4438 SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, DATA_TYPE, acc, ALPHA);
4439#endif // defined(ALPHA)
4440
4441 // Add beta*bias
4442#if defined(BETA)
4443 REPEAT_VAR_INIT_TO_CONST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, uint, zero, 0);
4444
4445#if defined(BROADCAST_BIAS)
4446 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)NUM_ELEMS_PROCESSED_PER_THREAD_X * sizeof(DATA_TYPE));
4447
4448 LOAD_BLOCK(1, NUM_ELEMS_PROCESSED_PER_THREAD_X, DATA_TYPE, bias, src2_addr, 0, src2_stride_y, zero);
4449
4450#ifndef UNIT_BETA
4451 SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
4452#endif // UNIT_BIAS
4453
4454 // c = c + bias[broadcasted]
4455 ADD_BLOCK_BROADCAST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias0);
4456
4457#else // defined(BROADCAST_BIAS)
4458 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)NUM_ELEMS_PROCESSED_PER_THREAD_X * sizeof(DATA_TYPE)) + (get_global_id(1) *
4459 (uint)NUM_ELEMS_PROCESSED_PER_THREAD_Y * src2_stride_y) + get_global_id(2) * src2_stride_z;
4460
4461 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, NUM_ELEMS_PROCESSED_PER_THREAD_X, DATA_TYPE, bias, src2_addr, 0, src2_stride_y, zero);
4462
4463#ifndef UNIT_BETA
4464 SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, DATA_TYPE, bias, BETA);
4465#endif // UNIT_BIAS
4466
4467 // c = c + bias
4468 ADD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias);
4469
4470#endif // defined(BROADCAST_BIAS)
4471#endif // defined(BETA)
4472
4473#if defined(ACTIVATION_TYPE)
4474 ACTIVATION_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, ACTIVATION_TYPE, DATA_TYPE, acc, A_VAL, B_VAL);
4475#endif // defined(ACTIVATION_TYPE)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004476
4477 // Store output block
Usama Arif0681e3b2019-04-25 14:28:07 +01004478 STORE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, NUM_ELEMS_PROCESSED_PER_THREAD_X, DATA_TYPE, acc, dst_addr, dst_stride_y, zout.s);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004479}
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004480#endif // defined(DATA_TYPE)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01004481
Michele Di Giorgiof6f08da2018-04-26 10:24:30 +01004482/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004483 *
4484 * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
4485 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
4486 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
4487 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
4488 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004489 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
4490 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004491 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004492 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
4493 * The activation function is performed after the bias addition
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004494 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
4495 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004496 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
4497 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
4498 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
4499 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
4500 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004501 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004502 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
4503 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4504 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
4505 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4506 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
4507 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
4508 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
4509 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4510 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
4511 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4512 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004513 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
4514 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
4515 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
4516 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
4517 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
4518 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004519 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
4520 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
4521 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
4522 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
4523 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
4524 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004525 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
4526 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004527 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004528 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004529 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
4530 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004531 */
4532__kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
4533 IMAGE_DECLARATION(src1),
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004534#if defined(BETA)
4535 IMAGE_DECLARATION(src2),
4536#endif // defined(BETA)
Gian Marcoae2af742018-02-15 12:35:44 +00004537 IMAGE_DECLARATION(dst),
4538 uint src0_stride_z,
4539 uint src1_stride_z,
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004540#if defined(BETA)
4541 uint src2_stride_z,
4542#endif //defined(BETA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004543 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004544#if defined(REINTERPRET_INPUT_AS_3D)
4545 ,
4546 uint src_cross_plane_pad
4547#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004548#if defined(REINTERPRET_OUTPUT_AS_3D)
4549 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004550 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004551#endif // REINTERPRET_OUTPUT_AS_3D
4552 )
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004553{
4554 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
4555
4556 // Compute starting address for matrix A and matrix B
4557 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
4558
4559 // Update address for matrix A
4560 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
4561
4562 // Update address for matrix B
4563 src_addr.s1 += idx * sizeof(float);
4564
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004565#if defined(REINTERPRET_INPUT_AS_3D)
4566 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
4567 // in order to take into account the presence of possible cross plane paddings
4568 //
4569 // | |
4570 // | plane0 |
4571 // | |
4572 // |__________________|
4573 // |******************|
4574 // | cross_plane_pad |
4575 // |******************|
4576 // | |
4577 // | plane1 |
4578 // | |
4579 // |__________________|
4580
4581 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
4582 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
4583 zin = min(DEPTH_GEMM3D - 1, zin);
4584
4585 // Add offset due to the cross plane paddings
4586 zin *= (src_cross_plane_pad * src0_stride_y);
4587
4588 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4589 // multiply src0_stride_z by DEPTH_GEMM3D
4590 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
4591
4592#else // defined(REINTERPRET_INPUT_AS_3D)
4593
Gian Marcoae2af742018-02-15 12:35:44 +00004594 // Add offset for batched GEMM
4595 src_addr.s0 += get_global_id(2) * src0_stride_z;
4596
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004597#endif // defined(REINTERPRET_INPUT_AS_3D)
4598
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00004599#if defined(MATRIX_B_DEPTH)
4600 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
4601 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
4602#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00004603 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00004604#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00004605
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004606 // Initialize accumulators
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004607 float4 acc0 = 0.0f;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004608
4609#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004610 float4 acc1 = 0.0f;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004611#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4612
4613#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004614 float4 acc2 = 0.0f;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004615#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4616
4617#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004618 float4 acc3 = 0.0f;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004619#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4620
4621 // A and B src indices get incremented at the same time.
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004622 int i = 0;
4623 for(; i <= ((int)COLS_A - 4); i += 4)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004624 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004625#if defined(REINTERPRET_INPUT_AS_3D)
4626 // Load values from matrix A and matrix B
Usama Arif0681e3b2019-04-25 14:28:07 +01004627 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 4, float, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s);
4628#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004629 // Load values from matrix A and matrix B
4630 float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004631#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004632 float4 a1 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004633#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4634#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004635 float4 a2 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004636#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4637#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004638 float4 a3 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004639#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004640#endif // defined(REINTERPRET_INPUT_AS_3D)
4641
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004642 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
4643 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004644
4645 // Multiply and accumulate
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004646 acc0.s0 = fma(a0.s0, b0.s0, acc0.s0);
4647 acc0.s1 = fma(a0.s0, b0.s1, acc0.s1);
4648 acc0.s2 = fma(a0.s0, b0.s2, acc0.s2);
4649 acc0.s3 = fma(a0.s0, b0.s3, acc0.s3);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004650
4651#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004652
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004653 acc1.s0 = fma(a1.s0, b0.s0, acc1.s0);
4654 acc1.s1 = fma(a1.s0, b0.s1, acc1.s1);
4655 acc1.s2 = fma(a1.s0, b0.s2, acc1.s2);
4656 acc1.s3 = fma(a1.s0, b0.s3, acc1.s3);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004657
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004658#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4659#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004660
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004661 acc2.s0 = fma(a2.s0, b0.s0, acc2.s0);
4662 acc2.s1 = fma(a2.s0, b0.s1, acc2.s1);
4663 acc2.s2 = fma(a2.s0, b0.s2, acc2.s2);
4664 acc2.s3 = fma(a2.s0, b0.s3, acc2.s3);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004665
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004666#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4667#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004668
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004669 acc3.s0 = fma(a3.s0, b0.s0, acc3.s0);
4670 acc3.s1 = fma(a3.s0, b0.s1, acc3.s1);
4671 acc3.s2 = fma(a3.s0, b0.s2, acc3.s2);
4672 acc3.s3 = fma(a3.s0, b0.s3, acc3.s3);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004673#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004674
4675 // Load values from matrix A and matrix B
4676 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
4677 src_addr.s1 += src1_stride_y;
4678
4679 // Multiply and accumulate
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004680 acc0.s0 = fma(a0.s1, b0.s0, acc0.s0);
4681 acc0.s1 = fma(a0.s1, b0.s1, acc0.s1);
4682 acc0.s2 = fma(a0.s1, b0.s2, acc0.s2);
4683 acc0.s3 = fma(a0.s1, b0.s3, acc0.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004684
4685#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4686
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004687 acc1.s0 = fma(a1.s1, b0.s0, acc1.s0);
4688 acc1.s1 = fma(a1.s1, b0.s1, acc1.s1);
4689 acc1.s2 = fma(a1.s1, b0.s2, acc1.s2);
4690 acc1.s3 = fma(a1.s1, b0.s3, acc1.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004691
4692#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4693#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4694
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004695 acc2.s0 = fma(a2.s1, b0.s0, acc2.s0);
4696 acc2.s1 = fma(a2.s1, b0.s1, acc2.s1);
4697 acc2.s2 = fma(a2.s1, b0.s2, acc2.s2);
4698 acc2.s3 = fma(a2.s1, b0.s3, acc2.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004699
4700#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4701#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4702
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004703 acc3.s0 = fma(a3.s1, b0.s0, acc3.s0);
4704 acc3.s1 = fma(a3.s1, b0.s1, acc3.s1);
4705 acc3.s2 = fma(a3.s1, b0.s2, acc3.s2);
4706 acc3.s3 = fma(a3.s1, b0.s3, acc3.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004707#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4708
4709 // Load values from matrix A and matrix B
4710 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
4711 src_addr.s1 += src1_stride_y;
4712
4713 // Multiply and accumulate
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004714 acc0.s0 = fma(a0.s2, b0.s0, acc0.s0);
4715 acc0.s1 = fma(a0.s2, b0.s1, acc0.s1);
4716 acc0.s2 = fma(a0.s2, b0.s2, acc0.s2);
4717 acc0.s3 = fma(a0.s2, b0.s3, acc0.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004718
4719#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4720
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004721 acc1.s0 = fma(a1.s2, b0.s0, acc1.s0);
4722 acc1.s1 = fma(a1.s2, b0.s1, acc1.s1);
4723 acc1.s2 = fma(a1.s2, b0.s2, acc1.s2);
4724 acc1.s3 = fma(a1.s2, b0.s3, acc1.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004725
4726#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4727#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4728
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004729 acc2.s0 = fma(a2.s2, b0.s0, acc2.s0);
4730 acc2.s1 = fma(a2.s2, b0.s1, acc2.s1);
4731 acc2.s2 = fma(a2.s2, b0.s2, acc2.s2);
4732 acc2.s3 = fma(a2.s2, b0.s3, acc2.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004733
4734#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4735#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4736
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004737 acc3.s0 = fma(a3.s2, b0.s0, acc3.s0);
4738 acc3.s1 = fma(a3.s2, b0.s1, acc3.s1);
4739 acc3.s2 = fma(a3.s2, b0.s2, acc3.s2);
4740 acc3.s3 = fma(a3.s2, b0.s3, acc3.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004741#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4742
4743 // Load values from matrix A and matrix B
4744 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
4745 src_addr.s1 += src1_stride_y;
4746
4747 // Multiply and accumulate
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004748 acc0.s0 = fma(a0.s3, b0.s0, acc0.s0);
4749 acc0.s1 = fma(a0.s3, b0.s1, acc0.s1);
4750 acc0.s2 = fma(a0.s3, b0.s2, acc0.s2);
4751 acc0.s3 = fma(a0.s3, b0.s3, acc0.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004752
4753#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4754
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004755 acc1.s0 = fma(a1.s3, b0.s0, acc1.s0);
4756 acc1.s1 = fma(a1.s3, b0.s1, acc1.s1);
4757 acc1.s2 = fma(a1.s3, b0.s2, acc1.s2);
4758 acc1.s3 = fma(a1.s3, b0.s3, acc1.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004759
4760#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4761#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4762
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004763 acc2.s0 = fma(a2.s3, b0.s0, acc2.s0);
4764 acc2.s1 = fma(a2.s3, b0.s1, acc2.s1);
4765 acc2.s2 = fma(a2.s3, b0.s2, acc2.s2);
4766 acc2.s3 = fma(a2.s3, b0.s3, acc2.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004767
4768#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4769#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4770
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004771 acc3.s0 = fma(a3.s3, b0.s0, acc3.s0);
4772 acc3.s1 = fma(a3.s3, b0.s1, acc3.s1);
4773 acc3.s2 = fma(a3.s3, b0.s2, acc3.s2);
4774 acc3.s3 = fma(a3.s3, b0.s3, acc3.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004775#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4776
4777 src_addr.s0 += 4 * sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004778 }
4779
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004780 for(; i < (int)COLS_A; ++i)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004781 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004782#if defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004783 // Load values from matrix A
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004784 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
4785#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4786 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
4787#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4788#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4789 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
4790#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4791#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4792 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
4793#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4794#else // defined(REINTERPRET_INPUT_AS_3D)
4795 // Load values from matrix A
4796 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004797#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4798 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
4799#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4800#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4801 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
4802#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4803#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4804 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
4805#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004806#endif // defined(REINTERPRET_INPUT_AS_3D)
4807
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004808 // Load values from matrix B
4809 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004810 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004811
4812 // Multiply and accumulate
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004813 acc0.s0 = fma(a0, b0.s0, acc0.s0);
4814 acc0.s1 = fma(a0, b0.s1, acc0.s1);
4815 acc0.s2 = fma(a0, b0.s2, acc0.s2);
4816 acc0.s3 = fma(a0, b0.s3, acc0.s3);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004817#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004818 acc1.s0 = fma(a1, b0.s0, acc1.s0);
4819 acc1.s1 = fma(a1, b0.s1, acc1.s1);
4820 acc1.s2 = fma(a1, b0.s2, acc1.s2);
4821 acc1.s3 = fma(a1, b0.s3, acc1.s3);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004822#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4823#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004824 acc2.s0 = fma(a2, b0.s0, acc2.s0);
4825 acc2.s1 = fma(a2, b0.s1, acc2.s1);
4826 acc2.s2 = fma(a2, b0.s2, acc2.s2);
4827 acc2.s3 = fma(a2, b0.s3, acc2.s3);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004828#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4829#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004830 acc3.s0 = fma(a3, b0.s0, acc3.s0);
4831 acc3.s1 = fma(a3, b0.s1, acc3.s1);
4832 acc3.s2 = fma(a3, b0.s2, acc3.s2);
4833 acc3.s3 = fma(a3, b0.s3, acc3.s3);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004834#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004835
4836 src_addr.s0 += sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004837 }
4838
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004839 int z = get_global_id(2);
4840
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004841 // Compute destination address
4842 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
4843
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004844 // Compute dst address
4845 __global uchar *dst_addr = offset(&dst, 0, 0);
4846
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004847 uint4 zout = 0;
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004848
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004849#if defined(REINTERPRET_OUTPUT_AS_3D)
4850 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004851 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004852 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004853 // | |
4854 // | plane0 |
4855 // | |
4856 // |__________________|
4857 // |******************|
4858 // | cross_plane_pad |
4859 // |******************|
4860 // | |
4861 // | plane1 |
4862 // | |
4863 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004864
4865 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004866 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
4867 zout = min(DEPTH_GEMM3D - 1, zout);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004868
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004869 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004870 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004871
4872 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4873 // multiply dst_stride_z by DEPTH_GEMM3D
4874 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004875#else // defined(REINTERPRET_OUTPUT_AS_3D)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004876 // Add offset for batched GEMM
4877 dst_addr += z * dst_stride_z;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004878#endif // defined(REINTERPRET_OUTPUT_AS_3D)
4879
4880 // Multiply by the weight of matrix-matrix product and store the result
4881#if defined(ALPHA)
4882 SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, float, acc, ALPHA);
4883#endif // defined(ALPHA)
4884
4885 // Add beta*bias
4886#if defined(BETA)
4887 REPEAT_VAR_INIT_TO_CONST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, uint, zero, 0);
4888
4889#if defined(BROADCAST_BIAS)
4890 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float));
4891
4892 LOAD_BLOCK(1, 4, float, bias, src2_addr, 0, src2_stride_y, zero);
4893
4894#ifndef UNIT_BETA
4895 SCALE_BLOCK(1, float, bias, BETA);
4896#endif // UNIT_BIAS
4897
4898 // acc = acc + bias[broadcasted]
4899 ADD_BLOCK_BROADCAST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias0);
4900
4901#else // defined(BROADCAST_BIAS)
4902 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float)) + (get_global_id(1) *
4903 (uint)NUM_ELEMS_PROCESSED_PER_THREAD_Y * src2_stride_y) + get_global_id(2) * src2_stride_z;
4904
4905 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 4, float, bias, src2_addr, 0, src2_stride_y, zero);
4906
4907#ifndef UNIT_BETA
4908 SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, float, bias, BETA);
4909#endif // UNIT_BIAS
4910
4911 // acc = acc + bias
4912 ADD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias);
4913
4914#endif // defined(BROADCAST_BIAS)
4915#endif // defined(BETA)
4916
4917#if defined(ACTIVATION_TYPE)
4918 ACTIVATION_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, ACTIVATION_TYPE, float, acc, A_VAL, B_VAL);
4919#endif // defined(ACTIVATION_TYPE)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004920
4921 // Store the output block
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004922 vstore4(acc0, 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004923#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004924 vstore4(acc1, 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004925#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4926#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004927 vstore4(acc2, 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004928#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4929#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004930 vstore4(acc3, 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004931#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004932}
4933
4934/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
4935 *
4936 * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
4937 * This OpenCL kernel is optimized for Bifrost when the number of matrix B columns is less or equal to 1000.
4938 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
4939 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=2.
4940 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
4941 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha if alpha!=1.0f.
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004942 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
4943 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004944 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004945 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
4946 * The activation function is performed after the bias addition
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004947 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
4948 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004949 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
4950 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
4951 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
4952 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
4953 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004954 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004955 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
4956 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4957 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
4958 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4959 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
4960 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
4961 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
4962 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4963 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
4964 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4965 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004966 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
4967 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
4968 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
4969 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
4970 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
4971 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004972 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
4973 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
4974 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
4975 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
4976 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
4977 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004978 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
4979 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004980 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004981 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004982 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
4983 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004984 */
4985__kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
4986 IMAGE_DECLARATION(src1),
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004987#if defined(BETA)
4988 IMAGE_DECLARATION(src2),
4989#endif // defined(BETA)
Gian Marcoae2af742018-02-15 12:35:44 +00004990 IMAGE_DECLARATION(dst),
4991 uint src0_stride_z,
4992 uint src1_stride_z,
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004993#if defined(BETA)
4994 uint src2_stride_z,
4995#endif //defined(BETA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004996 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004997#if defined(REINTERPRET_INPUT_AS_3D)
4998 ,
4999 uint src_cross_plane_pad
5000#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005001#if defined(REINTERPRET_OUTPUT_AS_3D)
5002 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005003 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005004#endif // REINTERPRET_OUTPUT_AS_3D
5005 )
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005006{
5007 // Requires 2 NUM_ELEMS_PROCESSED_PER_THREAD_X, C vect2, A vect4, B (2 vload2) // to fix for NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5008 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
5009
5010 // Compute starting address for matrix A and Matrix B
5011 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
5012
5013 // Update address for the matrix A
5014 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
5015
5016 // Update address for the matrix B
5017 src_addr.s1 += idx * sizeof(float);
5018
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005019#if defined(REINTERPRET_INPUT_AS_3D)
5020 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
5021 // in order to take into account the presence of possible cross plane paddings
5022 //
5023 // | |
5024 // | plane0 |
5025 // | |
5026 // |__________________|
5027 // |******************|
5028 // | cross_plane_pad |
5029 // |******************|
5030 // | |
5031 // | plane1 |
5032 // | |
5033 // |__________________|
5034
5035 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
5036 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
5037 zin = min(DEPTH_GEMM3D - 1, zin);
5038
5039 // Add offset due to the cross plane paddings
5040 zin *= (src_cross_plane_pad * src0_stride_y);
5041
5042 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
5043 // multiply src0_stride_z by DEPTH_GEMM3D
5044 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
5045
5046#else // defined(REINTERPRET_INPUT_AS_3D)
5047
Gian Marcoae2af742018-02-15 12:35:44 +00005048 // Add offset for batched GEMM
5049 src_addr.s0 += get_global_id(2) * src0_stride_z;
5050
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005051#endif // defined(REINTERPRET_INPUT_AS_3D)
5052
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00005053#if defined(MATRIX_B_DEPTH)
5054 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
5055 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
5056#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00005057 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00005058#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00005059
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005060 // Initialize accumulators
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005061 float2 acc0 = 0.0f;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005062#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005063 float2 acc1 = 0.0f;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005064#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5065#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005066 float2 acc2 = 0.0f;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005067#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5068#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005069 float2 acc3 = 0.0f;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005070#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5071
5072 // A and B src indices get incremented at the same time.
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005073 int i = 0;
5074 for(; i <= ((int)COLS_A - 8); i += 8)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005075 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005076#if defined(REINTERPRET_INPUT_AS_3D)
5077 // Load values from matrix A
5078 float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + zin.s0));
5079#else // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005080 // Load values from matrix A
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005081 float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0));
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005082#endif // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005083
5084 // Load values from matrix B
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005085 float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
5086 src_addr.s1 += src1_stride_y;
5087 float2 b1 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
5088 src_addr.s1 += src1_stride_y;
5089 float2 b2 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
5090 src_addr.s1 += src1_stride_y;
5091 float2 b3 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
5092 src_addr.s1 += src1_stride_y;
5093 float2 b4 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
5094 src_addr.s1 += src1_stride_y;
5095 float2 b5 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
5096 src_addr.s1 += src1_stride_y;
5097 float2 b6 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
5098 src_addr.s1 += src1_stride_y;
5099 float2 b7 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
5100 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005101
5102 // Multiply and accumulate
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005103 acc0.s0 = fma(a0.s0, b0.s0, acc0.s0);
5104 acc0.s0 = fma(a0.s1, b1.s0, acc0.s0);
5105 acc0.s0 = fma(a0.s2, b2.s0, acc0.s0);
5106 acc0.s0 = fma(a0.s3, b3.s0, acc0.s0);
5107 acc0.s0 = fma(a0.s4, b4.s0, acc0.s0);
5108 acc0.s0 = fma(a0.s5, b5.s0, acc0.s0);
5109 acc0.s0 = fma(a0.s6, b6.s0, acc0.s0);
5110 acc0.s0 = fma(a0.s7, b7.s0, acc0.s0);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005111
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005112 acc0.s1 = fma(a0.s0, b0.s1, acc0.s1);
5113 acc0.s1 = fma(a0.s1, b1.s1, acc0.s1);
5114 acc0.s1 = fma(a0.s2, b2.s1, acc0.s1);
5115 acc0.s1 = fma(a0.s3, b3.s1, acc0.s1);
5116 acc0.s1 = fma(a0.s4, b4.s1, acc0.s1);
5117 acc0.s1 = fma(a0.s5, b5.s1, acc0.s1);
5118 acc0.s1 = fma(a0.s6, b6.s1, acc0.s1);
5119 acc0.s1 = fma(a0.s7, b7.s1, acc0.s1);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005120
5121#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005122#if defined(REINTERPRET_INPUT_AS_3D)
5123 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
5124#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005125 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005126#endif // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005127 acc1.s0 = fma(a0.s0, b0.s0, acc1.s0);
5128 acc1.s0 = fma(a0.s1, b1.s0, acc1.s0);
5129 acc1.s0 = fma(a0.s2, b2.s0, acc1.s0);
5130 acc1.s0 = fma(a0.s3, b3.s0, acc1.s0);
5131 acc1.s0 = fma(a0.s4, b4.s0, acc1.s0);
5132 acc1.s0 = fma(a0.s5, b5.s0, acc1.s0);
5133 acc1.s0 = fma(a0.s6, b6.s0, acc1.s0);
5134 acc1.s0 = fma(a0.s7, b7.s0, acc1.s0);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005135
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005136 acc1.s1 = fma(a0.s0, b0.s1, acc1.s1);
5137 acc1.s1 = fma(a0.s1, b1.s1, acc1.s1);
5138 acc1.s1 = fma(a0.s2, b2.s1, acc1.s1);
5139 acc1.s1 = fma(a0.s3, b3.s1, acc1.s1);
5140 acc1.s1 = fma(a0.s4, b4.s1, acc1.s1);
5141 acc1.s1 = fma(a0.s5, b5.s1, acc1.s1);
5142 acc1.s1 = fma(a0.s6, b6.s1, acc1.s1);
5143 acc1.s1 = fma(a0.s7, b7.s1, acc1.s1);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005144#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5145#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005146#if defined(REINTERPRET_INPUT_AS_3D)
5147 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
5148#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005149 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005150#endif // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005151 acc2.s0 = fma(a0.s0, b0.s0, acc2.s0);
5152 acc2.s0 = fma(a0.s1, b1.s0, acc2.s0);
5153 acc2.s0 = fma(a0.s2, b2.s0, acc2.s0);
5154 acc2.s0 = fma(a0.s3, b3.s0, acc2.s0);
5155 acc2.s0 = fma(a0.s4, b4.s0, acc2.s0);
5156 acc2.s0 = fma(a0.s5, b5.s0, acc2.s0);
5157 acc2.s0 = fma(a0.s6, b6.s0, acc2.s0);
5158 acc2.s0 = fma(a0.s7, b7.s0, acc2.s0);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005159
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005160 acc2.s1 = fma(a0.s0, b0.s1, acc2.s1);
5161 acc2.s1 = fma(a0.s1, b1.s1, acc2.s1);
5162 acc2.s1 = fma(a0.s2, b2.s1, acc2.s1);
5163 acc2.s1 = fma(a0.s3, b3.s1, acc2.s1);
5164 acc2.s1 = fma(a0.s4, b4.s1, acc2.s1);
5165 acc2.s1 = fma(a0.s5, b5.s1, acc2.s1);
5166 acc2.s1 = fma(a0.s6, b6.s1, acc2.s1);
5167 acc2.s1 = fma(a0.s7, b7.s1, acc2.s1);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005168#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5169#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005170#if defined(REINTERPRET_INPUT_AS_3D)
5171 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
5172#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005173 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005174#endif // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005175 acc3.s0 = fma(a0.s0, b0.s0, acc3.s0);
5176 acc3.s0 = fma(a0.s1, b1.s0, acc3.s0);
5177 acc3.s0 = fma(a0.s2, b2.s0, acc3.s0);
5178 acc3.s0 = fma(a0.s3, b3.s0, acc3.s0);
5179 acc3.s0 = fma(a0.s4, b4.s0, acc3.s0);
5180 acc3.s0 = fma(a0.s5, b5.s0, acc3.s0);
5181 acc3.s0 = fma(a0.s6, b6.s0, acc3.s0);
5182 acc3.s0 = fma(a0.s7, b7.s0, acc3.s0);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005183
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005184 acc3.s1 = fma(a0.s0, b0.s1, acc3.s1);
5185 acc3.s1 = fma(a0.s1, b1.s1, acc3.s1);
5186 acc3.s1 = fma(a0.s2, b2.s1, acc3.s1);
5187 acc3.s1 = fma(a0.s3, b3.s1, acc3.s1);
5188 acc3.s1 = fma(a0.s4, b4.s1, acc3.s1);
5189 acc3.s1 = fma(a0.s5, b5.s1, acc3.s1);
5190 acc3.s1 = fma(a0.s6, b6.s1, acc3.s1);
5191 acc3.s1 = fma(a0.s7, b7.s1, acc3.s1);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005192#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005193
5194 src_addr.s0 += sizeof(float) * 8;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005195 }
5196 // float size increment
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005197 for(; i < (int)COLS_A; ++i)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005198 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005199#if defined(REINTERPRET_INPUT_AS_3D)
5200 // Load values from matrix A
5201 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
5202#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5203 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
5204#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5205#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5206 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
5207#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5208#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5209 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
5210#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5211#else // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005212 // Load values from matrix A
5213 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
5214#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5215 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
5216#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5217#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5218 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
5219#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5220#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5221 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
5222#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005223#endif // defined(REINTERPRET_INPUT_AS_3D)
5224
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005225 // Load values from matrix B
5226 float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005227 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005228
5229 // Multiply and accumulate
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005230 acc0.s0 = fma(a0, b0.s0, acc0.s0);
5231 acc0.s1 = fma(a0, b0.s1, acc0.s1);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005232#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005233 acc1.s0 = fma(a1, b0.s0, acc1.s0);
5234 acc1.s1 = fma(a1, b0.s1, acc1.s1);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005235#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5236#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005237 acc2.s0 = fma(a2, b0.s0, acc2.s0);
5238 acc2.s1 = fma(a2, b0.s1, acc2.s1);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005239#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5240#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005241 acc3.s0 = fma(a3, b0.s0, acc3.s0);
5242 acc3.s1 = fma(a3, b0.s1, acc3.s1);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005243#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005244
5245 src_addr.s0 += sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005246 }
5247
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005248 int z = get_global_id(2);
5249
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005250 // Compute destination address
5251 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
5252
Gian Marcoae2af742018-02-15 12:35:44 +00005253 // Compute dst address
5254 __global uchar *dst_addr = offset(&dst, 0, 0);
5255
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005256 uint4 zout = 0;
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00005257
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005258#if defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005259
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005260 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01005261 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005262 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01005263 // | |
5264 // | plane0 |
5265 // | |
5266 // |__________________|
5267 // |******************|
5268 // | cross_plane_pad |
5269 // |******************|
5270 // | |
5271 // | plane1 |
5272 // | |
5273 // |__________________|
Gian Marcoae2af742018-02-15 12:35:44 +00005274
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005275 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005276 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
5277 zout = min(DEPTH_GEMM3D - 1, zout);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005278
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01005279 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005280 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005281
5282 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
5283 // multiply dst_stride_z by DEPTH_GEMM3D
5284 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005285#else // defined(REINTERPRET_OUTPUT_AS_3D)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005286 // Add offset for batched GEMM
5287 dst_addr += z * dst_stride_z;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005288#endif // defined(REINTERPRET_OUTPUT_AS_3D)
5289
5290 // Multiply by the weight of matrix-matrix product and store the result
5291#if defined(ALPHA)
5292 SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, float, acc, ALPHA);
5293#endif // defined(ALPHA)
5294
5295 // Add beta*bias
5296#if defined(BETA)
5297 REPEAT_VAR_INIT_TO_CONST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, uint, zero, 0);
5298
5299#if defined(BROADCAST_BIAS)
5300 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)2 * sizeof(float));
5301
5302 LOAD_BLOCK(1, 2, float, bias, src2_addr, 0, src2_stride_y, zero);
5303
5304#ifndef UNIT_BETA
5305 SCALE_BLOCK(1, float, bias, BETA);
5306#endif // UNIT_BIAS
5307
5308 // acc = acc + bias[broadcasted]
5309 ADD_BLOCK_BROADCAST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias0);
5310
5311#else // defined(BROADCAST_BIAS)
5312 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)2 * sizeof(float)) + (get_global_id(1) *
5313 (uint)NUM_ELEMS_PROCESSED_PER_THREAD_Y * src2_stride_y) + get_global_id(2) * src2_stride_z;
5314
5315 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 2, float, bias, src2_addr, 0, src2_stride_y, zero);
5316
5317#ifndef UNIT_BETA
5318 SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, float, bias, BETA);
5319#endif // UNIT_BIAS
5320
5321 // acc = acc + bias
5322 ADD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias);
5323
5324#endif // defined(BROADCAST_BIAS)
5325#endif // defined(BETA)
5326
5327#if defined(ACTIVATION_TYPE)
5328 ACTIVATION_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, ACTIVATION_TYPE, float, acc, A_VAL, B_VAL);
5329#endif // defined(ACTIVATION_TYPE)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005330
5331 // Store the output block
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005332 vstore2(acc0, 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005333#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005334 vstore2(acc1, 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005335#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5336#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005337 vstore2(acc2, 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005338#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5339#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005340 vstore2(acc3, 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005341#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005342}
5343
Vidhya Sudhan Loganathanbdff4912018-05-22 15:03:09 +01005344#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005345/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
5346 *
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005347 * @note This OpenCL kernel works with the 16-bit floating point data type (half) and accumulating the result in a 32 floating point variable.
5348 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
5349 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
5350 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
5351 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005352 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
5353 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005354 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005355 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
5356 * The activation function is performed after the bias addition
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005357 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
5358 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
5359 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
5360 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
5361 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
5362 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
5363 *
5364 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
5365 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
5366 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5367 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
5368 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
5369 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
5370 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
5371 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
5372 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5373 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
5374 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
5375 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005376 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
5377 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
5378 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
5379 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
5380 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
5381 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005382 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
5383 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
5384 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
5385 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
5386 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
5387 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
5388 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
5389 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005390 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005391 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
5392 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
5393 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
5394 */
5395__kernel void gemm_mm_floating_point_f16_bifrost_acc32(IMAGE_DECLARATION(src0),
5396 IMAGE_DECLARATION(src1),
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005397#if defined(BETA)
5398 IMAGE_DECLARATION(src2),
5399#endif // defined(BETA)
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005400 IMAGE_DECLARATION(dst),
5401 uint src0_stride_z,
5402 uint src1_stride_z,
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005403#if defined(BETA)
5404 uint src2_stride_z,
5405#endif //defined(BETA)
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005406 uint dst_stride_z
5407#if defined(REINTERPRET_INPUT_AS_3D)
5408 ,
5409 uint src_cross_plane_pad
5410#endif // REINTERPRET_INPUT_AS_3D
5411#if defined(REINTERPRET_OUTPUT_AS_3D)
5412 ,
5413 uint dst_cross_plane_pad
5414#endif // REINTERPRET_OUTPUT_AS_3D
5415 )
5416{
5417 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
5418
5419 // Compute starting address for matrix A and Matrix B
5420 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
5421
5422 // Update address for the matrix A
5423 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
5424
5425 // Update address for the matrix B
5426 src_addr.s1 += idx * sizeof(half);
5427
5428#if defined(REINTERPRET_INPUT_AS_3D)
5429 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
5430 // in order to take into account the presence of possible cross plane paddings
5431 //
5432 // | |
5433 // | plane0 |
5434 // | |
5435 // |__________________|
5436 // |******************|
5437 // | cross_plane_pad |
5438 // |******************|
5439 // | |
5440 // | plane1 |
5441 // | |
5442 // |__________________|
5443
5444 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
5445 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
5446 zin = min(DEPTH_GEMM3D - 1, zin);
5447
5448 // Add offset due to the cross plane paddings
5449 zin *= (src_cross_plane_pad * src0_stride_y);
5450
5451 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
5452 // multiply src0_stride_z by DEPTH_GEMM3D
5453 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
5454
5455#else // defined(REINTERPRET_INPUT_AS_3D)
5456
5457 // Add offset for batched GEMM
5458 src_addr.s0 += get_global_id(2) * src0_stride_z;
5459
5460#endif // defined(REINTERPRET_INPUT_AS_3D)
5461
5462#if defined(MATRIX_B_DEPTH)
5463 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
5464 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
5465#else // defined(MATRIX_B_DEPTH)
5466 src_addr.s1 += get_global_id(2) * src1_stride_z;
5467#endif // defined(MATRIX_B_DEPTH)
5468
5469 float8 acc0 = 0.0h;
5470#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5471 float8 acc1 = 0.0h;
5472#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5473#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5474 float8 acc2 = 0.0h;
5475#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5476#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5477 float8 acc3 = 0.0h;
5478#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5479
5480 int i = 0;
5481 for(; i <= ((int)COLS_A - 4); i += 4)
5482 {
5483#if defined(REINTERPRET_INPUT_AS_3D)
5484 // Load values from matrix A
Usama Arif0681e3b2019-04-25 14:28:07 +01005485 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 4, half, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s);
5486#else // defined(REINTERPRET_INPUT_AS_3D)
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005487 // Load values from matrix A
5488 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
5489#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5490 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
5491#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5492#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5493 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
5494#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5495#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5496 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
5497#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5498#endif // defined(REINTERPRET_INPUT_AS_3D)
5499
5500 // Load values from matrix B
5501 float8 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
5502 src_addr.s1 += src1_stride_y;
5503
5504 // Accumulate
5505 acc0 = fma(b0, (float8)a0.s0, acc0);
5506#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5507 acc1 = fma(b0, (float8)a1.s0, acc1);
5508#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5509#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5510 acc2 = fma(b0, (float8)a2.s0, acc2);
5511#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5512#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5513 acc3 = fma(b0, (float8)a3.s0, acc3);
5514#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5515
5516 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
5517 src_addr.s1 += src1_stride_y;
5518 acc0 = fma(b0, (float8)a0.s1, acc0);
5519#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5520 acc1 = fma(b0, (float8)a1.s1, acc1);
5521#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5522#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5523 acc2 = fma(b0, (float8)a2.s1, acc2);
5524#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5525#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5526 acc3 = fma(b0, (float8)a3.s1, acc3);
5527#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5528
5529 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
5530 src_addr.s1 += src1_stride_y;
5531 acc0 = fma(b0, (float8)a0.s2, acc0);
5532#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5533 acc1 = fma(b0, (float8)a1.s2, acc1);
5534#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5535#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5536 acc2 = fma(b0, (float8)a2.s2, acc2);
5537#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5538#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5539 acc3 = fma(b0, (float8)a3.s2, acc3);
5540#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5541
5542 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
5543 src_addr.s1 += src1_stride_y;
5544 acc0 = fma(b0, (float8)a0.s3, acc0);
5545#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5546 acc1 = fma(b0, (float8)a1.s3, acc1);
5547#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5548#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5549 acc2 = fma(b0, (float8)a2.s3, acc2);
5550#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5551#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5552 acc3 = fma(b0, (float8)a3.s3, acc3);
5553#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5554
5555 src_addr.s0 += 4 * sizeof(half);
5556 }
5557
5558 for(; i < (int)COLS_A; ++i)
5559 {
5560#if defined(REINTERPRET_INPUT_AS_3D)
5561 // Load values from matrix A
5562 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
5563#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5564 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
5565#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5566#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5567 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
5568#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5569#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5570 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
5571#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5572#else // defined(REINTERPRET_INPUT_AS_3D)
5573 // Load values from matrix A
5574 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
5575#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5576 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
5577#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5578#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5579 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
5580#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5581#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5582 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
5583#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5584#endif // defined(REINTERPRET_INPUT_AS_3D)
5585
5586 // Load values from matrix B
5587 float8 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
5588
5589 src_addr += (int2)(sizeof(half), src1_stride_y);
5590
5591 // Accumulate
5592 acc0 = fma(b0, (float8)a0, acc0); // b0 * (half8)a0;
5593#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5594 acc1 = fma(b0, (float8)a1, acc1); // b0 * (half8)a1;
5595#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5596#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5597 acc2 = fma(b0, (float8)a2, acc2); // b0 * (half8)a2;
5598#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5599#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5600 acc3 = fma(b0, (float8)a3, acc3); // b0 * (half8)a3;
5601#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5602 }
5603
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005604 int z = get_global_id(2);
5605
5606 // Compute destination address
5607 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
5608
5609 // Compute dst address
5610 __global uchar *dst_addr = offset(&dst, 0, 0);
5611
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005612 uint4 zout = 0;
5613
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005614#if defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005615
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005616 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
5617 // in order to take into account the presence of possible cross plane paddings
5618 //
5619 // | |
5620 // | plane0 |
5621 // | |
5622 // |__________________|
5623 // |******************|
5624 // | cross_plane_pad |
5625 // |******************|
5626 // | |
5627 // | plane1 |
5628 // | |
5629 // |__________________|
5630
5631 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005632 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
5633 zout = min(DEPTH_GEMM3D - 1, zout);
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005634
5635 // Add offset due to the cross plane paddings
5636 zout *= (dst_cross_plane_pad * dst_stride_y);
5637
5638 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
5639 // multiply dst_stride_z by DEPTH_GEMM3D
5640 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005641#else // defined(REINTERPRET_OUTPUT_AS_3D)
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005642 // Add offset for batched GEMM
5643 dst_addr += z * dst_stride_z;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005644#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005645
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005646 // Multiply by the weight of matrix-matrix product and store the result
5647#if defined(ALPHA)
5648 SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, float, acc, ALPHA);
5649#endif // defined(ALPHA)
5650
5651#if defined(BETA)
5652 REPEAT_VAR_INIT_TO_CONST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, uint, zero, 0);
5653
5654#if defined(BROADCAST_BIAS)
5655 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half));
5656
5657 LOAD_BLOCK(1, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
5658
5659 float8 bias_f0 = convert_float8(bias0);
5660
5661#ifndef UNIT_BETA
5662 SCALE_BLOCK(1, float, bias_f, BETA);
5663#endif // UNIT_BIAS
5664
5665 // acc = acc + bias[broadcasted]
5666 ADD_BLOCK_BROADCAST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias_f0);
5667
5668#else // defined(BROADCAST_BIAS)
5669 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (get_global_id(1) *
5670 (uint)NUM_ELEMS_PROCESSED_PER_THREAD_Y * src2_stride_y) + get_global_id(2) * src2_stride_z;
5671
5672 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
5673
5674 float8 bias_f0 = convert_float8(bias0);
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005675#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005676 float8 bias_f1 = convert_float8(bias1);
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005677#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5678#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005679 float8 bias_f2 = convert_float8(bias2);
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005680#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5681#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005682 float8 bias_f3 = convert_float8(bias3);
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005683#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005684
5685#ifndef UNIT_BETA
5686 SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, float, bias_f, BETA);
5687#endif // UNIT_BIAS
5688
5689 // acc = acc + bias
5690 ADD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias_f);
5691
5692#endif // defined(BROADCAST_BIAS)
5693#endif // defined(BETA)
5694
5695 half8 acc_h0 = convert_half8(acc0);
5696#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5697 half8 acc_h1 = convert_half8(acc1);
5698#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5699#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5700 half8 acc_h2 = convert_half8(acc2);
5701#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5702#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5703 half8 acc_h3 = convert_half8(acc3);
5704#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5705
5706#if defined(ACTIVATION_TYPE)
5707 ACTIVATION_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, ACTIVATION_TYPE, half, acc_h, A_VAL, B_VAL);
5708#endif // defined(ACTIVATION_TYPE)
5709
5710 // Store the output block
5711 STORE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 8, half, acc_h, dst_addr, dst_stride_y, zout.s);
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005712}
5713
5714/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
5715 *
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005716 * @note This OpenCL kernel works with the 16-bit floating point data type (half) and uses the fma units.
5717 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
5718 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
5719 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
5720 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005721 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
5722 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005723 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005724 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
5725 * The activation function is performed after the bias addition
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005726 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
5727 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005728 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
5729 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
5730 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
5731 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
5732 *
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005733 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
5734 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
5735 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5736 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
5737 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
5738 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
5739 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
5740 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
5741 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5742 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
5743 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
5744 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005745 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
5746 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
5747 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
5748 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
5749 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
5750 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005751 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
5752 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
5753 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
5754 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
5755 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
5756 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005757 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
5758 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005759 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005760 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005761 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
5762 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005763 */
5764__kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0),
5765 IMAGE_DECLARATION(src1),
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005766#if defined(BETA)
5767 IMAGE_DECLARATION(src2),
5768#endif // defined(BETA)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005769 IMAGE_DECLARATION(dst),
5770 uint src0_stride_z,
5771 uint src1_stride_z,
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005772#if defined(BETA)
5773 uint src2_stride_z,
5774#endif //defined(BETA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005775 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005776#if defined(REINTERPRET_INPUT_AS_3D)
5777 ,
5778 uint src_cross_plane_pad
5779#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005780#if defined(REINTERPRET_OUTPUT_AS_3D)
5781 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005782 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005783#endif // REINTERPRET_OUTPUT_AS_3D
5784 )
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005785{
5786 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
5787
5788 // Compute starting address for matrix A and Matrix B
5789 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
5790
5791 // Update address for the matrix A
5792 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
5793
5794 // Update address for the matrix B
5795 src_addr.s1 += idx * sizeof(half);
5796
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005797#if defined(REINTERPRET_INPUT_AS_3D)
5798 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
5799 // in order to take into account the presence of possible cross plane paddings
5800 //
5801 // | |
5802 // | plane0 |
5803 // | |
5804 // |__________________|
5805 // |******************|
5806 // | cross_plane_pad |
5807 // |******************|
5808 // | |
5809 // | plane1 |
5810 // | |
5811 // |__________________|
5812
5813 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
5814 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
5815 zin = min(DEPTH_GEMM3D - 1, zin);
5816
5817 // Add offset due to the cross plane paddings
5818 zin *= (src_cross_plane_pad * src0_stride_y);
5819
5820 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
5821 // multiply src0_stride_z by DEPTH_GEMM3D
5822 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
5823
5824#else // defined(REINTERPRET_INPUT_AS_3D)
5825
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005826 // Add offset for batched GEMM
5827 src_addr.s0 += get_global_id(2) * src0_stride_z;
5828
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005829#endif // defined(REINTERPRET_INPUT_AS_3D)
5830
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005831#if defined(MATRIX_B_DEPTH)
5832 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
5833 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
5834#else // defined(MATRIX_B_DEPTH)
5835 src_addr.s1 += get_global_id(2) * src1_stride_z;
5836#endif // defined(MATRIX_B_DEPTH)
5837
5838 half8 acc0 = 0.0h;
5839#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5840 half8 acc1 = 0.0h;
5841#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5842#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5843 half8 acc2 = 0.0h;
5844#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5845#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5846 half8 acc3 = 0.0h;
5847#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5848
5849 int i = 0;
5850 for(; i <= ((int)COLS_A - 4); i += 4)
5851 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005852#if defined(REINTERPRET_INPUT_AS_3D)
5853 // Load values from matrix A
Usama Arif0681e3b2019-04-25 14:28:07 +01005854 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 4, half, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s);
5855#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005856 // Load values from matrix A
5857 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
5858#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5859 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
5860#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5861#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5862 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
5863#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5864#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5865 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
5866#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005867#endif // defined(REINTERPRET_INPUT_AS_3D)
5868
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005869 // Load values from matrix B
5870 half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
5871 src_addr.s1 += src1_stride_y;
5872
5873 // Accumulate
5874 acc0 = fma(b0, (half8)a0.s0, acc0);
5875#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5876 acc1 = fma(b0, (half8)a1.s0, acc1);
5877#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5878#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5879 acc2 = fma(b0, (half8)a2.s0, acc2);
5880#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5881#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5882 acc3 = fma(b0, (half8)a3.s0, acc3);
5883#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5884
5885 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
5886 src_addr.s1 += src1_stride_y;
5887 acc0 = fma(b0, (half8)a0.s1, acc0);
5888#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5889 acc1 = fma(b0, (half8)a1.s1, acc1);
5890#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5891#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5892 acc2 = fma(b0, (half8)a2.s1, acc2);
5893#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5894#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5895 acc3 = fma(b0, (half8)a3.s1, acc3);
5896#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5897
5898 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
5899 src_addr.s1 += src1_stride_y;
5900 acc0 = fma(b0, (half8)a0.s2, acc0);
5901#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5902 acc1 = fma(b0, (half8)a1.s2, acc1);
5903#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5904#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5905 acc2 = fma(b0, (half8)a2.s2, acc2);
5906#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5907#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5908 acc3 = fma(b0, (half8)a3.s2, acc3);
5909#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5910
5911 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
5912 src_addr.s1 += src1_stride_y;
5913 acc0 = fma(b0, (half8)a0.s3, acc0);
5914#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5915 acc1 = fma(b0, (half8)a1.s3, acc1);
5916#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5917#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5918 acc2 = fma(b0, (half8)a2.s3, acc2);
5919#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5920#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5921 acc3 = fma(b0, (half8)a3.s3, acc3);
5922#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5923
5924 src_addr.s0 += 4 * sizeof(half);
5925 }
5926
5927 for(; i < (int)COLS_A; ++i)
5928 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005929#if defined(REINTERPRET_INPUT_AS_3D)
5930 // Load values from matrix A
5931 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
5932#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5933 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
5934#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5935#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5936 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
5937#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5938#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5939 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
5940#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5941#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005942 // Load values from matrix A
5943 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
5944#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5945 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
5946#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5947#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5948 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
5949#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5950#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5951 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
5952#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005953#endif // defined(REINTERPRET_INPUT_AS_3D)
5954
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005955 // Load values from matrix B
5956 half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
5957
5958 src_addr += (int2)(sizeof(half), src1_stride_y);
5959
5960 // Accumulate
5961 acc0 = fma(b0, (half8)a0, acc0); // b0 * (half8)a0;
5962#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5963 acc1 = fma(b0, (half8)a1, acc1); // b0 * (half8)a1;
5964#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5965#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5966 acc2 = fma(b0, (half8)a2, acc2); // b0 * (half8)a2;
5967#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5968#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5969 acc3 = fma(b0, (half8)a3, acc3); // b0 * (half8)a3;
5970#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5971 }
5972
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005973 int z = get_global_id(2);
5974
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005975 // Compute destination address
5976 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
5977
5978 // Compute dst address
5979 __global uchar *dst_addr = offset(&dst, 0, 0);
5980
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005981 uint4 zout = 0;
5982
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005983#if defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005984
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005985 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01005986 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005987 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01005988 // | |
5989 // | plane0 |
5990 // | |
5991 // |__________________|
5992 // |******************|
5993 // | cross_plane_pad |
5994 // |******************|
5995 // | |
5996 // | plane1 |
5997 // | |
5998 // |__________________|
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005999
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006000 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006001 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
6002 zout = min(DEPTH_GEMM3D - 1, zout);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006003
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01006004 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01006005 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006006
6007 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
6008 // multiply dst_stride_z by DEPTH_GEMM3D
6009 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006010#else // defined(REINTERPRET_OUTPUT_AS_3D)
6011 // Add offset for batched GEMM
6012 dst_addr += z * dst_stride_z;
6013#endif // defined(REINTERPRET_OUTPUT_AS_3D)
6014
6015 // Multiply by the weight of matrix-matrix product and store the result
6016#if defined(ALPHA)
6017 SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, half, acc, ALPHA);
6018#endif // defined(ALPHA)
6019
6020 // Add beta*bias
6021#if defined(BETA)
6022 REPEAT_VAR_INIT_TO_CONST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, uint, zero, 0);
6023
6024#if defined(BROADCAST_BIAS)
6025 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half));
6026
6027 LOAD_BLOCK(1, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
6028
6029#ifndef UNIT_BETA
6030 SCALE_BLOCK(1, half, bias, BETA);
6031#endif // UNIT_BIAS
6032
6033 // acc = acc + bias[broadcasted]
6034 ADD_BLOCK_BROADCAST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias0);
6035
6036#else // defined(BROADCAST_BIAS)
6037 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (get_global_id(1) *
6038 (uint)NUM_ELEMS_PROCESSED_PER_THREAD_Y * src2_stride_y) + get_global_id(2) * src2_stride_z;
6039
6040 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
6041
6042#ifndef UNIT_BETA
6043 SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, half, bias, BETA);
6044#endif // UNIT_BIAS
6045
6046 // acc = acc + bias
6047 ADD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias);
6048
6049#endif // defined(BROADCAST_BIAS)
6050#endif // defined(BETA)
6051
6052#if defined(ACTIVATION_TYPE)
6053 ACTIVATION_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, ACTIVATION_TYPE, half, acc, A_VAL, B_VAL);
6054#endif // defined(ACTIVATION_TYPE)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006055
6056 // Store the output block
Usama Arif0681e3b2019-04-25 14:28:07 +01006057 STORE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 8, half, acc, dst_addr, dst_stride_y, zout.s);
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006058}
Vidhya Sudhan Loganathanbdff4912018-05-22 15:03:09 +01006059#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006060
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01006061#endif // defined(COLS_A) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && (NUM_ELEMS_PROCESSED_PER_THREAD_Y)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006062
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006063#if defined(BETA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006064/** This OpenCL kernel performs the in-place matrix addition between 2 matrices taking into account that the second matrix might be weighted by a scalar value beta:
6065 *
Gian Marco19835e52018-01-30 13:35:54 +00006066 * @note The beta's value need to be passed at compile time using -DBETA
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006067 *
6068 * @param[in] src_ptr Pointer to the source matrix. Supported data types: F32
6069 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
6070 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
6071 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
6072 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006073 * @param[in] src_stride_z Stride of the destination tensor in Z dimension (in bytes)
6074 * @param[in] src_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006075 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01006076 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006077 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
6078 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
6079 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
6080 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006081 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
6082 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006083 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
6084 */
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006085__kernel void gemm_ma_f32(TENSOR3D_DECLARATION(src),
6086 TENSOR3D_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006087{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006088 // Compute source and destination addresses
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006089 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
6090 Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006091
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006092 // Load values from A x B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006093 float4 alpha_ab = vload4(0, (__global float *)dst.ptr);
6094
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006095 // Load values from Matrix C
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006096 float4 c = vload4(0, (__global float *)src.ptr);
6097
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006098 // Computes alpha * axb + beta * c
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006099 float4 out = alpha_ab + (float4)BETA * c;
6100
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006101 // Store final result in axb matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006102 vstore4(out, 0, (__global float *)dst.ptr);
6103}
6104
Vidhya Sudhan Loganathan76c85642018-05-25 13:53:02 +01006105#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006106/** This OpenCL kernel performs the in-place matrix addition between 2 matrices taking into account that the second matrix might be weighted by a scalar value beta:
6107 *
Gian Marco19835e52018-01-30 13:35:54 +00006108 * @note The beta's value need to be passed at compile time using -DBETA
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01006109 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006110 * @param[in] src_ptr Pointer to the source matrix. Supported data types: F16
6111 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
6112 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
6113 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
6114 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006115 * @param[in] src_stride_z Stride of the destination tensor in Z dimension (in bytes)
6116 * @param[in] src_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006117 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01006118 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006119 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
6120 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
6121 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
6122 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006123 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
6124 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006125 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
6126 */
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006127__kernel void gemm_ma_f16(TENSOR3D_DECLARATION(src),
6128 TENSOR3D_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006129{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006130 // Compute source and destination addresses
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006131 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
6132 Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006133
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006134 // Load values from A x B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006135 half8 alpha_ab = vload8(0, (__global half *)dst.ptr);
6136
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006137 // Load values from Matrix C
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006138 half8 c = vload8(0, (__global half *)src.ptr);
6139
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006140 // Computes alpha * axb + beta * c
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006141 half8 out = alpha_ab + (half8)BETA * c;
6142
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006143 // Store final result in axb matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006144 vstore8(out, 0, (__global half *)dst.ptr);
6145}
Vidhya Sudhan Loganathan76c85642018-05-25 13:53:02 +01006146#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006147#endif // defined(BETA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006148
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006149#if defined(WIDTH_VECTOR_A)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006150/** This OpenCL kernel computes the vector by matrix multiplication between each row of A (src0) and matrix B (src1) used for locally connected layer
6151 *
Gian Marco19835e52018-01-30 13:35:54 +00006152 * @note The width of A need to be passed at compile time using -DWIDTH_VECTOR_A
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006153 *
Gian Marco19835e52018-01-30 13:35:54 +00006154 * @note The input A and matrix B must not be reshaped
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006155 *
6156 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
6157 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
6158 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
6159 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
6160 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
6161 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01006162 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006163 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
6164 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
6165 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
6166 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
6167 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
6168 * @param[in] src1_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
6169 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01006170 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006171 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
6172 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
6173 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
6174 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
6175 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
6176 */
6177__kernel void gemm_lc_vm_f32(IMAGE_DECLARATION(src0),
6178 TENSOR3D_DECLARATION(src1),
6179 IMAGE_DECLARATION(dst))
6180{
6181 int idx = get_global_id(0) * 4;
6182 int idy = get_global_id(1);
6183
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006184 // Compute the address for the vector A and matrix B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006185 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes + src0_stride_y * idy, src1_offset_first_element_in_bytes + src1_stride_z * idy));
6186 src_addr.s1 += idx * sizeof(float);
6187
6188 int end_row_vec_a = src_addr.s0 + (WIDTH_VECTOR_A * sizeof(float));
6189
6190 float4 acc = 0.0f;
6191
Georgios Pinitas96880cf2017-10-20 18:52:20 +01006192 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(float)); src_addr += (int2)(2 * sizeof(float), 2 * src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006193 {
6194 float2 a0 = vload2(0, (__global float *)(src0_ptr + src_addr.s0));
6195 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
6196 float4 b1 = vload4(0, (__global float *)(src1_ptr + src_addr.s1 + src1_stride_y));
6197
6198 acc += b0 * (float4)a0.s0;
6199 acc += b1 * (float4)a0.s1;
6200 }
6201
6202 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(float), src1_stride_y))
6203 {
6204 float a0 = *((__global float *)(src0_ptr + src_addr.s0));
6205 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
6206
6207 acc += b0 * (float4)a0;
6208 }
6209
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006210 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006211 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
6212
6213 vstore4(acc, 0, (__global float *)(offset(&dst, 0, 0)));
6214}
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006215#endif // defined(WIDTH_VECTOR_A)
6216
6217/** This kernel accumulates each row with the biases vector.
6218 *
6219 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=short.
6220 * @note The vector size must be passed at compile time using -DVECTOR_SIZE e.g. -DVECTOR_SIZE=16.
6221 *
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01006222 * @param[in, out] accum_ptr Pointer to the accumulate tensor. Supported data type: U8/S8/U16/S16/F16/U32/S32/F32
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006223 * @param[in] accum_stride_x Stride of the accmulate tensor in X dimension (in bytes)
6224 * @param[in] accum_step_x accum_stride_x * number of elements along X processed per workitem(in bytes)
6225 * @param[in] accum_stride_y Stride of the accumlulate tensor in Y dimension (in bytes)
6226 * @param[in] accum_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
6227 * @param[in] accum_offset_first_element_in_bytes The offset of the first element in the accumulate tensor
6228 * @param[in] biases_ptr Pointer to the biases vector. Same as @p accum_ptr
6229 * @param[in] biases_stride_x Stride of the destination tensor in X dimension (in bytes)
6230 * @param[in] biases_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
6231 * @param[in] biases_offset_first_element_in_bytes The offset of the first element in the destination tensor
6232 */
6233#if defined(DATA_TYPE) && defined(VECTOR_SIZE)
6234__kernel void gemm_accumulate_biases(
6235 IMAGE_DECLARATION(accum),
6236 VECTOR_DECLARATION(biases))
6237{
6238 Image accum = CONVERT_TO_IMAGE_STRUCT(accum);
6239 Vector biases = CONVERT_TO_VECTOR_STRUCT(biases);
6240
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006241 // Vector size, e.g. number of vector elements.
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006242 VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
6243 accum_value = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)accum.ptr);
6244 VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
6245 biases_value = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)biases.ptr);
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01006246 accum_value = biases_value + accum_value;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006247 // Store result in the accumulate buffer
6248 VSTORE(VECTOR_SIZE)
6249 (accum_value, 0, (__global DATA_TYPE *)accum.ptr);
6250}
6251#endif // defined(DATA_TYPE) && defined(VECTOR_SIZE)