blob: 8e628e8d01523019938bff065e134e9620038f14 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +00002 * Copyright (c) 2017-2019 ARM Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Usama Arif0681e3b2019-04-25 14:28:07 +010024#include "gemm_helpers.h"
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +000025#include "repeat.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010026
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +000027#if defined(M0) && defined(K0) && defined(V0) && defined(DATA_TYPE) && defined(SRC_WIDTH)
28#define INC2 (VEC_DATA_TYPE(uint, 2))(0, 1)
29#define INC3 (VEC_DATA_TYPE(uint, 3))(0, 1, 2)
30#define INC4 (VEC_DATA_TYPE(uint, 4))(0, 1, 2, 3)
31#define INC8 (VEC_DATA_TYPE(uint, 8))(0, 1, 2, 3, 4, 5, 6, 7)
32#define INC16 (VEC_DATA_TYPE(uint, 16))(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)
33#define CONCAT_INC(K0) INC##K0
34#define INC(K0) CONCAT_INC(K0)
35
36#if(SRC_WIDTH % K0)
37#define BOUNDARY_CONDITION_X(x, a) \
38 ({ \
39 a = select(0, a, CONVERT(((x * (VEC_DATA_TYPE(uint, K0))K0 + INC(K0)) < (VEC_DATA_TYPE(uint, K0))SRC_WIDTH), VEC_DATA_TYPE(DATA_TYPE, K0))); \
40 })
41#else // (SRC_WIDTH % K0)
42#define BOUNDARY_CONDITION_X(x, a) \
43 ({})
44#endif // (SRC_WIDTH % K0)
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000045
46/** This OpenCL kernel reshapes the lhs input matrix. The kernel splits the input matrix in blocks of size M0xK0 and stores each one (not transposed) in
47 * the output matrix unrolling the values.
48 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +010049 * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
50 * @note The width of the input tensor must be passed at compile time using -DSRC_WIDTH (e.g. -DSRC_WIDTH=16)
51 * @note The block's dimensions (M0 and K0) must be passed at compile time using -DM0 and -DK0 (e.g. -DM0=2, -DK0=2).
52 * @note The number of M0xK0 vertical blocks to store on the same output row must be passed at compile time using -DV0 (e.g. -DV0=2)
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000053 * @note Only the following values for M0, K0 and V0 are supported:
54 * M0: 2,3,4,5,6,7,8
Gian Marco Iodicebacfec52019-01-11 11:30:55 +000055 * K0: 2,3,4,8,16
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000056 * V0: greater than 0
Gian Marco Iodiced1f54762019-07-19 09:54:47 +010057 * @note In case the input has to be reinterpreted as a 3D tensor (e.g. input of convolution layer 1x1), the following information must be passed at compile time:
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000058 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
59 * -# HEIGHT_GEMM3D: The height of the input in case it has to be reinterpreted as a 3D tensor.
60 * -# DEPTH_GEMM3D: The depth of the input in case it has to be reinterpreted as a 3D tensor
61 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
62 * @note If the M0xK0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
63 *
64 * @param[in] src_ptr Pointer to the source LHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
65 * @param[in] src_stride_x Stride of the source LHS tensor in X dimension (in bytes)
66 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
67 * @param[in] src_stride_y Stride of the source LHS tensor in Y dimension (in bytes)
68 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
69 * @param[in] src_stride_z Stride of the source LHS tensor in Z dimension (in bytes)
70 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
71 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source LHS tensor
72 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
73 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
74 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
75 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
76 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
77 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
78 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
79 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
80 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
81 */
82__kernel void gemm_reshape_lhs_matrix_nt(TENSOR3D_DECLARATION(src),
83 TENSOR3D_DECLARATION(dst)
84#if defined(REINTERPRET_INPUT_AS_3D)
85 ,
86 uint cross_plane_pad
87#endif // REINTERPRET_INPUT_AS_3D
88 )
89{
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +000090 // Block size
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000091#define BLOCK_SIZE ((M0) * (K0))
92
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +000093 // Output offset X
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000094#if defined(INTERLEAVE)
95#define OUTPUT_OFFSET_X (K0)
96#else // defined(INTERLEAVE)
97#define OUTPUT_OFFSET_X (BLOCK_SIZE)
98#endif // defined(INTERLEAVE)
99
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000100 // Output step X
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000101#if defined(INTERLEAVE)
102#define OUTPUT_STEP_X (K0) * (V0)
103#else // Do not interleave
104#define OUTPUT_STEP_X (K0)
105#endif // defined(INTERLEAVE)
106
107 // Compute source and destination addresses
108 uint x = get_global_id(0);
109 uint y = get_global_id(1);
110 uint z = get_global_id(2);
111
112 // ------------------ Compute input/output addresses ---------------------------
113
114 // Compute the input address
115 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)K0 * sizeof(DATA_TYPE) + y * (uint)M0 * src_stride_y;
116
117 // Compute the output address
118 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)BLOCK_SIZE * (uint)V0 * sizeof(DATA_TYPE)) + ((y / (uint)V0) * (uint)dst_stride_y) + ((y % V0) *
119 (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE));
120
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000121 // Create variables: uint zin0=0, zin1=0, zin2=0...zin(M0-1)=0;
122 REPEAT_VAR_INIT_TO_CONST(M0, uint, zin, 0);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000123
124#if defined(REINTERPRET_INPUT_AS_3D)
125 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
126 // multiply src_stride_z by DEPTH_GEMM3D
127
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000128 input_ptr += z * (uint)src_stride_z * DEPTH_GEMM3D;
129
130 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
Usama Arif0681e3b2019-04-25 14:28:07 +0100131 CALCULATE_Z_OFFSET(M0, uint, zin, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, cross_plane_pad, src_stride_y);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000132
133#else // defined(REINTERPRET_INPUT_AS_3D)
134
135 input_ptr += z * (uint)src_stride_z;
136
137#endif // defined(REINTERPRET_INPUT_AS_3D)
138
139 // Add offset for batched GEMM
140 output_ptr += z * (uint)dst_stride_z;
141
142 // ---------------------------Load input values --------------------------------
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000143 // Load values from the LHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +0100144 LOAD_BLOCK(M0, K0, DATA_TYPE, a, input_ptr, 0, src_stride_y, zin);
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000145 BOUNDARY_CONDITION_X(x, a0);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000146#if M0 > 1
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000147 BOUNDARY_CONDITION_X(x, a1);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000148#endif // M0 > 1
149#if M0 > 2
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000150 BOUNDARY_CONDITION_X(x, a2);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000151#endif // M0 > 2
152#if M0 > 3
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000153 BOUNDARY_CONDITION_X(x, a3);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000154#endif // M0 > 3
155#if M0 > 4
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000156 BOUNDARY_CONDITION_X(x, a4);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000157#endif // M0 > 4
158#if M0 > 5
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000159 BOUNDARY_CONDITION_X(x, a5);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000160#endif // M0 > 5
161#if M0 > 6
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000162 BOUNDARY_CONDITION_X(x, a6);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000163#endif // M0 > 6
164#if M0 > 7
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000165 BOUNDARY_CONDITION_X(x, a7);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000166#endif // M0 > 7
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000167 // ---------------------------Store output values ------------------------------
Usama Arif0681e3b2019-04-25 14:28:07 +0100168 REPEAT_VAR_INIT_TO_CONST(16, uint, zout, 0);
169 STORE_BLOCK(M0, K0, DATA_TYPE, a, output_ptr, OUTPUT_STEP_X * sizeof(DATA_TYPE), zout);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000170
171#undef BLOCK_SIZE
172#undef OUTPUT_OFFSET_X
173#undef OUTPUT_STEP_X
174}
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000175
176#if M0 == 2
177#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
178 ({ \
179 VEC_DATA_TYPE(DATA_TYPE, M0) \
180 res = (VEC_DATA_TYPE(DATA_TYPE, M0))(a0.s##i, a1.s##i); \
181 VSTORE(M0) \
182 (res, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
183 })
184#elif M0 == 3 // M0 == 3
185#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
186 ({ \
187 VEC_DATA_TYPE(DATA_TYPE, M0) \
188 res = (VEC_DATA_TYPE(DATA_TYPE, M0))(a0.s##i, a1.s##i, a2.s##i); \
189 VSTORE(M0) \
190 (res, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
191 })
192#elif M0 == 4 // M0 == 4
193#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
194 ({ \
195 VEC_DATA_TYPE(DATA_TYPE, M0) \
196 res = (VEC_DATA_TYPE(DATA_TYPE, M0))(a0.s##i, a1.s##i, a2.s##i, a3.s##i); \
197 VSTORE(M0) \
198 (res, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
199 })
200#elif M0 == 5 // M0 == 5
201#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
202 ({ \
203 VEC_DATA_TYPE(DATA_TYPE, 4) \
204 res0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s##i, a1.s##i, a2.s##i, a3.s##i); \
205 DATA_TYPE res1 = a4.s##i; \
206 VSTORE(4) \
207 (res0, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
208 *((__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE)) + 4) = res1; \
209 })
210#elif M0 == 6 // M0 == 6
211#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
212 ({ \
213 VEC_DATA_TYPE(DATA_TYPE, 4) \
214 res0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s##i, a1.s##i, a2.s##i, a3.s##i); \
215 VEC_DATA_TYPE(DATA_TYPE, 2) \
216 res1 = (VEC_DATA_TYPE(DATA_TYPE, 2))(a4.s##i, a5.s##i); \
217 VSTORE(4) \
218 (res0, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
219 VSTORE(2) \
220 (res1, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE)) + 4); \
221 })
222#elif M0 == 7 // M0 == 7
223#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
224 ({ \
225 VEC_DATA_TYPE(DATA_TYPE, 4) \
226 res0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s##i, a1.s##i, a2.s##i, a3.s##i); \
227 VEC_DATA_TYPE(DATA_TYPE, 3) \
228 res1 = (VEC_DATA_TYPE(DATA_TYPE, 3))(a4.s##i, a5.s##i, a6.s##i); \
229 VSTORE(4) \
230 (res0, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
231 VSTORE(3) \
232 (res1, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE)) + 4); \
233 })
234#elif M0 == 8 // M0 == 8
235#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
236 ({ \
237 VEC_DATA_TYPE(DATA_TYPE, M0) \
238 res = (VEC_DATA_TYPE(DATA_TYPE, M0))(a0.s##i, a1.s##i, a2.s##i, a3.s##i, a4.s##i, a5.s##i, a6.s##i, a7.s##i); \
239 VSTORE(M0) \
240 (res, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
241 })
242#else // M0 not supported
243#error "M0 value not supported"
244#endif // N0 conditions
245
246/** This OpenCL kernel reshapes the lhs input matrix. The kernel splits the input matrix in blocks of size M0xK0 and stores each one (transposed) in
247 * the output matrix unrolling the values.
248 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +0100249 * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
250 * @note The width of the input tensor must be passed at compile time using -DSRC_WIDTH (e.g. -DSRC_WIDTH=16)
251 * @note The block's dimensions (M0 and K0) must be passed at compile time using -DM0 and -DK0 (e.g. -DM0=2, -DK0=2).
252 * @note The number of M0xK0 vertical blocks to store on the same output row must be passed at compile time using -DV0 (e.g. -DV0=2)
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000253 * @note Only the following values for M0, K0 and V0 are supported:
254 * M0: 2,3,4,5,6,7,8
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000255 * K0: 2,3,4,8,16
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000256 * V0: greater than 0
Gian Marco Iodiced1f54762019-07-19 09:54:47 +0100257 * @note In case the input has to be reinterpreted as a 3D tensor (e.g. input of convolution layer 1x1), the following information must be passed at compile time:
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000258 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
259 * -# HEIGHT_GEMM3D: The height of the input in case it has to be reinterpreted as a 3D tensor.
260 * -# DEPTH_GEMM3D: The depth of the input in case it has to be reinterpreted as a 3D tensor
261 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
262 * @note If the M0xK0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
263 *
264 * @param[in] src_ptr Pointer to the source LHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
265 * @param[in] src_stride_x Stride of the source LHS tensor in X dimension (in bytes)
266 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
267 * @param[in] src_stride_y Stride of the source LHS tensor in Y dimension (in bytes)
268 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
269 * @param[in] src_stride_z Stride of the source LHS tensor in Z dimension (in bytes)
270 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
271 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source LHS tensor
272 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
273 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
274 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
275 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
276 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
277 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
278 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
279 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
280 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
281 */
282__kernel void gemm_reshape_lhs_matrix_t(TENSOR3D_DECLARATION(src),
283 TENSOR3D_DECLARATION(dst)
284#if defined(REINTERPRET_INPUT_AS_3D)
285 ,
286 uint cross_plane_pad
287#endif // REINTERPRET_INPUT_AS_3D
288 )
289{
290 // Block size
291#define BLOCK_SIZE ((M0) * (K0))
292
293 // Output offset X
294#if defined(INTERLEAVE)
295#define OUTPUT_OFFSET_X (M0)
296#else // defined(INTERLEAVE)
297#define OUTPUT_OFFSET_X (BLOCK_SIZE)
298#endif // defined(INTERLEAVE)
299
300 // Output step X
301#if defined(INTERLEAVE)
302#define OUTPUT_STEP_X (M0) * (V0)
303#else // Do not interleave
304#define OUTPUT_STEP_X (M0)
305#endif // defined(INTERLEAVE)
306
307 // Compute source and destination addresses
308 uint x = get_global_id(0);
309 uint y = get_global_id(1);
310 uint z = get_global_id(2);
311
312 // ------------------ Compute input/output addresses ---------------------------
313
314 // Compute the input address
315 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)K0 * sizeof(DATA_TYPE) + y * (uint)M0 * src_stride_y;
316
317 // Compute the output address
318 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)BLOCK_SIZE * (uint)V0 * sizeof(DATA_TYPE)) + ((y / (uint)V0) * (uint)dst_stride_y) + ((y % V0) *
319 (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE));
320
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000321 // Create variables: uint zin0=0, zin1=0, zin2=0...zin(M0-1)=0;
322 REPEAT_VAR_INIT_TO_CONST(M0, uint, zin, 0);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000323
324#if defined(REINTERPRET_INPUT_AS_3D)
325 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
326 // multiply src_stride_z by DEPTH_GEMM3D
327
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000328 input_ptr += z * (uint)src_stride_z * DEPTH_GEMM3D;
329
330 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
Usama Arif0681e3b2019-04-25 14:28:07 +0100331 CALCULATE_Z_OFFSET(M0, uint, zin, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, cross_plane_pad, src_stride_y);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000332
333#else // defined(REINTERPRET_INPUT_AS_3D)
334
335 input_ptr += z * (uint)src_stride_z;
336
337#endif // defined(REINTERPRET_INPUT_AS_3D)
338
339 // Add offset for batched GEMM
340 output_ptr += z * (uint)dst_stride_z;
341
342 // ---------------------------Load input values --------------------------------
343
344 // Load values from the LHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +0100345 LOAD_BLOCK(M0, K0, DATA_TYPE, a, input_ptr, 0, src_stride_y, zin);
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000346 BOUNDARY_CONDITION_X(x, a0);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000347#if M0 > 1
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000348 BOUNDARY_CONDITION_X(x, a1);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000349#endif // M0 > 1
350#if M0 > 2
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000351 BOUNDARY_CONDITION_X(x, a2);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000352#endif // M0 > 2
353#if M0 > 3
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000354 BOUNDARY_CONDITION_X(x, a3);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000355#endif // M0 > 3
356#if M0 > 4
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000357 BOUNDARY_CONDITION_X(x, a4);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000358#endif // M0 > 4
359#if M0 > 5
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000360 BOUNDARY_CONDITION_X(x, a5);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000361#endif // M0 > 5
362#if M0 > 6
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000363 BOUNDARY_CONDITION_X(x, a6);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000364#endif // M0 > 6
365#if M0 > 7
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000366 BOUNDARY_CONDITION_X(x, a7);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000367#endif // M0 > 7
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000368 // ---------------------------Transpose and store block -----------------------
369
370 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 0);
371 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 1);
372#if K0 > 2
373 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 2);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000374#endif // K0 > 2
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000375#if K0 > 3
376 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 3);
377#endif // K0 > 3
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000378#if K0 > 4
379 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 4);
380 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 5);
381 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 6);
382 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 7);
383#endif // K0 > 4
384#if K0 > 8
385 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 8);
386 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 9);
387 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, A);
388 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, B);
389 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, C);
390 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, D);
391 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, E);
392 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, F);
393#endif // K0 > 8
394
395#undef BLOCK_SIZE
396#undef OUTPUT_OFFSET_X
397#undef OUTPUT_STEP_X
398}
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000399#endif // defined(M0) && defined(K0) && defined(V0) && defined(DATA_TYPE) && defined(SRC_WIDTH)
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000400
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000401#if defined(K0) && defined(N0) && defined(H0) && defined(DATA_TYPE) && defined(SRC_HEIGHT)
402/** This OpenCL kernel reshapes the rhs input matrix. The kernel splits the input matrix in blocks of size K0xN0 and stores each one (not transposed) in
403 * the output matrix unrolling the values.
404 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +0100405 * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
406 * @note The height of the input tensor must be passed at compile time using -DSRC_HEIGHT (e.g. -DSRC_HEIGHT=16)
407 * @note The block's dimensions (K0 and N0) must be passed at compile time using -DK0 and -DN0 (e.g. -DK0=2, -DN0=2).
408 * @note The number of K0xN0 vertical blocks to store on the same output row must be passed at compile time using -DH0 (e.g. -DH0=2)
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000409 * @note If the K0xN0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
410 * @note Only the following values for K0, N0 and H0 are supported:
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000411 * N0: 2,3,4,8,16
412 * K0: 1,2,3,4,8,16
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000413 * H0: greater than 0
414 *
415 * @param[in] src_ptr Pointer to the source RHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
416 * @param[in] src_stride_x Stride of the source RHS tensor in X dimension (in bytes)
417 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
418 * @param[in] src_stride_y Stride of the source RHS tensor in Y dimension (in bytes)
419 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
420 * @param[in] src_stride_z Stride of the source RHS tensor in Z dimension (in bytes)
421 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
422 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source RHS tensor
423 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
424 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
425 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
426 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
427 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
428 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
429 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
430 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
431 */
432__kernel void gemm_reshape_rhs_matrix_nt(TENSOR3D_DECLARATION(src),
433 TENSOR3D_DECLARATION(dst))
434{
435 // Block size
436#define BLOCK_SIZE ((K0) * (N0))
437
438 // Output offset X
439#if defined(INTERLEAVE)
440#define OUTPUT_OFFSET_X (N0)
441#else // defined(INTERLEAVE)
442#define OUTPUT_OFFSET_X (BLOCK_SIZE)
443#endif // defined(INTERLEAVE)
444
445 // Output step X
446#if defined(INTERLEAVE)
447#define OUTPUT_STEP_X (N0) * (H0)
448#else // Do not interleave
449#define OUTPUT_STEP_X (N0)
450#endif // defined(INTERLEAVE)
451
452 // Compute source and destination addresses
453 uint x = get_global_id(0);
454 uint y = get_global_id(1);
455 uint z = get_global_id(2);
456
457 // ------------------ Compute input/output addresses ---------------------------
458
459 // Compute the input address
460 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)N0 * sizeof(DATA_TYPE) + y * (uint)K0 * src_stride_y + z * (uint)src_stride_z;
461
462 // Compute the output address
463 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (y * (uint)BLOCK_SIZE * (uint)H0 * sizeof(DATA_TYPE)) + ((x % (uint)H0) * (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE)) + ((
464 x / (uint)H0)
465 * (uint)dst_stride_y)
466 + z * (uint)dst_stride_z;
467
468 // ---------------------------Load input values --------------------------------
469
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +0000470 REPEAT_VAR_INIT_TO_CONST(K0, VEC_DATA_TYPE(DATA_TYPE, N0), a, 0); ////uint a0=0, a1=0, a2=0...a(M0-1)=0;
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000471
472 // Load values from the RHS matrix
473 a0 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y));
474#if K0 > 1
475 if(y * (uint)K0 + 1 < SRC_HEIGHT)
476 {
477 a1 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y));
478 }
479#endif // K0 > 1
480#if K0 > 2
481 if(y * (uint)K0 + 2 < SRC_HEIGHT)
482 {
483 a2 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y));
484 }
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000485#endif // K0 > 2
486#if K0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000487 if(y * (uint)K0 + 3 < SRC_HEIGHT)
488 {
489 a3 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y));
490 }
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000491#endif // K0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000492#if K0 > 4
493 if(y * (uint)K0 + 4 < SRC_HEIGHT)
494 {
495 a4 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 4 * src_stride_y));
496 }
497 if(y * (uint)K0 + 5 < SRC_HEIGHT)
498 {
499 a5 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 5 * src_stride_y));
500 }
501 if(y * (uint)K0 + 6 < SRC_HEIGHT)
502 {
503 a6 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 6 * src_stride_y));
504 }
505 if(y * (uint)K0 + 7 < SRC_HEIGHT)
506 {
507 a7 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 7 * src_stride_y));
508 }
509#endif // K0 > 4
510#if K0 > 8
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000511 if(y * (uint)K0 + 8 < SRC_HEIGHT)
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000512 {
513 a8 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 8 * src_stride_y));
514 }
515 if(y * (uint)K0 + 9 < SRC_HEIGHT)
516 {
517 a9 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 9 * src_stride_y));
518 }
519 if(y * (uint)K0 + 10 < SRC_HEIGHT)
520 {
521 aA = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 10 * src_stride_y));
522 }
523 if(y * (uint)K0 + 11 < SRC_HEIGHT)
524 {
525 aB = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 11 * src_stride_y));
526 }
527 if(y * (uint)K0 + 12 < SRC_HEIGHT)
528 {
529 aC = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 12 * src_stride_y));
530 }
531 if(y * (uint)K0 + 13 < SRC_HEIGHT)
532 {
533 aD = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 13 * src_stride_y));
534 }
535 if(y * (uint)K0 + 14 < SRC_HEIGHT)
536 {
537 aE = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 14 * src_stride_y));
538 }
539 if(y * (uint)K0 + 15 < SRC_HEIGHT)
540 {
541 aF = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 15 * src_stride_y));
542 }
543#endif // K0 > 8
544
545 // ---------------------------Store output values ------------------------------
Usama Arif0681e3b2019-04-25 14:28:07 +0100546 REPEAT_VAR_INIT_TO_CONST(16, uint, zout, 0);
547 STORE_BLOCK(K0, N0, DATA_TYPE, a, output_ptr, OUTPUT_STEP_X * sizeof(DATA_TYPE), zout);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000548
549#undef BLOCK_SIZE
550#undef OUTPUT_OFFSET_X
551#undef OUTPUT_STEP_X
552}
553
554#if defined(TRANSPOSE)
555/** This OpenCL kernel reshapes the rhs input matrix. The kernel splits the input matrix in blocks of size K0xN0 and stores each one (transposed) in
556 * the output matrix unrolling the values.
557 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +0100558 * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
559 * @note The height of the input tensor must be passed at compile time using -DSRC_HEIGHT (e.g. -DSRC_HEIGHT=16)
560 * @note The block's dimensions (K0 and N0) must be passed at compile time using -DK0 and -DN0 (e.g. -DK0=2, -DN0=2).
561 * @note The number of K0xN0 vertical blocks to store on the same output row must be passed at compile time using -DH0 (e.g. -DH0=2)
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000562 * @note If the K0xN0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
563 * @note The option -DTRANSPOSE must passed at compile time.
564 * @note Only the following values for K0, N0 and H0 are supported:
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000565 * N0: 2,3,4,8,16
566 * K0: 2,3,4,8,16
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000567 * H0: greater than 0
568 *
569 * @param[in] src_ptr Pointer to the source RHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
570 * @param[in] src_stride_x Stride of the source RHS tensor in X dimension (in bytes)
571 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
572 * @param[in] src_stride_y Stride of the source RHS tensor in Y dimension (in bytes)
573 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
574 * @param[in] src_stride_z Stride of the source RHS tensor in Z dimension (in bytes)
575 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
576 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source RHS tensor
577 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
578 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
579 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
580 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
581 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
582 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
583 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
584 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
585 */
586__kernel void gemm_reshape_rhs_matrix_t(TENSOR3D_DECLARATION(src),
587 TENSOR3D_DECLARATION(dst))
588{
589 // Block size
590#define BLOCK_SIZE ((K0) * (N0))
591
592 // Output offset X
593#if defined(INTERLEAVE)
594#define OUTPUT_OFFSET_X (K0)
595#else // defined(INTERLEAVE)
596#define OUTPUT_OFFSET_X (BLOCK_SIZE)
597#endif // defined(INTERLEAVE)
598
599 // Output step X
600#if defined(INTERLEAVE)
601#define OUTPUT_STEP_X (K0) * (H0)
602#else // Do not interleave
603#define OUTPUT_STEP_X (K0)
604#endif // defined(INTERLEAVE)
605
606 // Compute source and destination addresses
607 uint x = get_global_id(0);
608 uint y = get_global_id(1);
609 uint z = get_global_id(2);
610
611 // ------------------ Compute input/output addresses ---------------------------
612
613 // Compute the input address
614 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)N0 * sizeof(DATA_TYPE) + y * (uint)K0 * src_stride_y + z * (uint)src_stride_z;
615
616 // Compute the output address
617 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (y * (uint)BLOCK_SIZE * (uint)H0 * sizeof(DATA_TYPE)) + ((x % H0) * (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE)) + ((x /
618 (uint)H0) * (uint)dst_stride_y) + z * (uint)dst_stride_z;
619
620 // ---------------------------Load input values --------------------------------
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +0000621 REPEAT_VAR_INIT_TO_CONST(K0, VEC_DATA_TYPE(DATA_TYPE, N0), a, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) a0=0, a1=0, ... a(K0-1)=0;
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000622
623 // Load values from the RHS matrix
624 a0 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y));
625 if(y * (uint)K0 + 1 < SRC_HEIGHT)
626 {
627 a1 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y));
628 }
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000629#if K0 > 2
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000630 if(y * (uint)K0 + 2 < SRC_HEIGHT)
631 {
632 a2 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y));
633 }
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000634#endif // K0 > 2
635#if K0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000636 if(y * (uint)K0 + 3 < SRC_HEIGHT)
637 {
638 a3 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y));
639 }
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000640#endif // K0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000641#if K0 > 4
642 if(y * (uint)K0 + 4 < SRC_HEIGHT)
643 {
644 a4 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 4 * src_stride_y));
645 }
646 if(y * (uint)K0 + 5 < SRC_HEIGHT)
647 {
648 a5 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 5 * src_stride_y));
649 }
650 if(y * (uint)K0 + 6 < SRC_HEIGHT)
651 {
652 a6 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 6 * src_stride_y));
653 }
654 if(y * (uint)K0 + 7 < SRC_HEIGHT)
655 {
656 a7 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 7 * src_stride_y));
657 }
658#endif // K0 > 4
659#if K0 > 8
Gian Marco Iodice89124342018-12-19 14:17:22 +0000660 if(y * (uint)K0 + 8 < SRC_HEIGHT)
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000661 {
662 a8 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 8 * src_stride_y));
663 }
664 if(y * (uint)K0 + 9 < SRC_HEIGHT)
665 {
666 a9 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 9 * src_stride_y));
667 }
668 if(y * (uint)K0 + 10 < SRC_HEIGHT)
669 {
670 aA = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 10 * src_stride_y));
671 }
672 if(y * (uint)K0 + 11 < SRC_HEIGHT)
673 {
674 aB = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 11 * src_stride_y));
675 }
676 if(y * (uint)K0 + 12 < SRC_HEIGHT)
677 {
678 aC = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 12 * src_stride_y));
679 }
680 if(y * (uint)K0 + 13 < SRC_HEIGHT)
681 {
682 aD = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 13 * src_stride_y));
683 }
684 if(y * (uint)K0 + 14 < SRC_HEIGHT)
685 {
686 aE = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 14 * src_stride_y));
687 }
688 if(y * (uint)K0 + 15 < SRC_HEIGHT)
689 {
690 aF = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 15 * src_stride_y));
691 }
692#endif // K0 > 8
693
694 // ---------------------------Transpose the block ------------------------------
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +0000695 REPEAT_VAR_INIT_TO_CONST(N0, VEC_DATA_TYPE(DATA_TYPE, K0), res, 0); //VEC_DATA_TYPE(DATA_TYPE, K0) res0=0, res1=0, res2=0,... res(N0-1)=0;
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000696
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000697#if K0 == 2
698 // This part computes the following transpositions:
699 // 2x2 -> 2x2
700 // 2x4 -> 4x2
701 // 2x8 -> 8x2
702 // 2x16 -> 16x2
703 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0);
704 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1);
705#if N0 > 2
706 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2);
707#endif // N0 > 2
708#if N0 > 3
709 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3);
710#endif // N0 > 3
711#if N0 > 4
712 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4);
713 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5);
714 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6);
715 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7);
716#endif // N0 > 4
717#if N0 > 8
718 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8);
719 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9);
720 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA);
721 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB);
722 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC);
723 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD);
724 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE);
725 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF);
726#endif // N0 > 8
727
728#elif K0 == 3 // K0 == 2
729 // This part computes the following transpositions:
730 // 3x2 -> 2x3
731 // 3x4 -> 4x3
732 // 3x8 -> 8x3
733 // 3x16 -> 16x3
Georgios Pinitasb0f342e2019-05-21 13:32:43 +0100734 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0);
735 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000736#if N0 > 2
Georgios Pinitasb0f342e2019-05-21 13:32:43 +0100737 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000738#endif // N0 > 2
739#if N0 > 3
Georgios Pinitasb0f342e2019-05-21 13:32:43 +0100740 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000741#endif // N0 > 3
742#if N0 > 4
Georgios Pinitasb0f342e2019-05-21 13:32:43 +0100743 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4);
744 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5);
745 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6);
746 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000747#endif // N0 > 4
748#if N0 > 8
Georgios Pinitasb0f342e2019-05-21 13:32:43 +0100749 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8);
750 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9);
751 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA);
752 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB);
753 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC);
754 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD);
755 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE);
756 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000757#endif // N0 > 8
758
759#elif K0 == 4 // K0 == 4
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000760 // This part computes the following transpositions:
761 // 4x2 -> 2x4
762 // 4x4 -> 4x4
763 // 4x8 -> 8x4
764 // 4x16 -> 16x4
765 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0, a3.s0);
766 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1, a3.s1);
767#if N0 > 2
768 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2, a3.s2);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000769#endif // N0 > 2
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000770#if N0 > 3
771 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3, a3.s3);
772#endif // N0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000773#if N0 > 4
774 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4, a3.s4);
775 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5, a3.s5);
776 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6, a3.s6);
777 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7, a3.s7);
778#endif // N0 > 4
779#if N0 > 8
780 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8, a3.s8);
781 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9, a3.s9);
782 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA, a3.sA);
783 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB, a3.sB);
784 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC, a3.sC);
785 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD, a3.sD);
786 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE, a3.sE);
787 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF, a3.sF);
788#endif // N0 > 8
789
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000790#elif K0 == 8 // K0 == 8
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000791 // This part computes the following transpositions:
792 // 8x2 -> 2x8
793 // 8x4 -> 4x8
794 // 8x8 -> 8x8
795 // 8x16 -> 16x8
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000796 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0, a3.s0, a4.s0, a5.s0, a6.s0, a7.s0);
797 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1, a3.s1, a4.s1, a5.s1, a6.s1, a7.s1);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000798#if N0 > 2
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000799 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2, a3.s2, a4.s2, a5.s2, a6.s2, a7.s2);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000800#endif // N0 > 2
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000801#if N0 > 3
802 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3, a3.s3, a4.s3, a5.s3, a6.s3, a7.s3);
803#endif // N0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000804#if N0 > 4
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000805 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4, a3.s4, a4.s4, a5.s4, a6.s4, a7.s4);
806 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5, a3.s5, a4.s5, a5.s5, a6.s5, a7.s5);
807 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6, a3.s6, a4.s6, a5.s6, a6.s6, a7.s6);
808 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7, a3.s7, a4.s7, a5.s7, a6.s7, a7.s7);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000809#endif // N0 > 4
810#if N0 > 8
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000811 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8, a3.s8, a4.s8, a5.s8, a6.s8, a7.s8);
812 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9, a3.s9, a4.s9, a5.s9, a6.s9, a7.s9);
813 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA, a3.sA, a4.sA, a5.sA, a6.sA, a7.sA);
814 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB, a3.sB, a4.sB, a5.sB, a6.sB, a7.sB);
815 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC, a3.sC, a4.sC, a5.sC, a6.sC, a7.sC);
816 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD, a3.sD, a4.sD, a5.sD, a6.sD, a7.sD);
817 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE, a3.sE, a4.sE, a5.sE, a6.sE, a7.sE);
818 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF, a3.sF, a4.sF, a5.sF, a6.sF, a7.sF);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000819#endif // N0 > 8
820
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000821#elif K0 == 16 // K0 == 16
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000822
823 // This part computes the following transpositions:
824 // 16x2 -> 2x16
825 // 16x4 -> 4x16
826 // 16x8 -> 8x16
827 // 16x16 -> 16x16
828 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0, a3.s0, a4.s0, a5.s0, a6.s0, a7.s0,
829 a8.s0, a9.s0, aA.s0, aB.s0, aC.s0, aD.s0, aE.s0, aF.s0);
830 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1, a3.s1, a4.s1, a5.s1, a6.s1, a7.s1,
831 a8.s1, a9.s1, aA.s1, aB.s1, aC.s1, aD.s1, aE.s1, aF.s1);
832#if N0 > 2
833 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2, a3.s2, a4.s2, a5.s2, a6.s2, a7.s2,
834 a8.s2, a9.s2, aA.s2, aB.s2, aC.s2, aD.s2, aE.s2, aF.s2);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000835#endif // N0 > 2
836#if N0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000837 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3, a3.s3, a4.s3, a5.s3, a6.s3, a7.s3,
838 a8.s3, a9.s3, aA.s3, aB.s3, aC.s3, aD.s3, aE.s3, aF.s3);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000839#endif // N0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000840#if N0 > 4
841 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4, a3.s4, a4.s4, a5.s4, a6.s4, a7.s4,
842 a8.s4, a9.s4, aA.s4, aB.s4, aC.s4, aD.s4, aE.s4, aF.s4);
843 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5, a3.s5, a4.s5, a5.s5, a6.s5, a7.s5,
844 a8.s5, a9.s5, aA.s5, aB.s5, aC.s5, aD.s5, aE.s5, aF.s5);
845 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6, a3.s6, a4.s6, a5.s6, a6.s6, a7.s6,
846 a8.s6, a9.s6, aA.s6, aB.s6, aC.s6, aD.s6, aE.s6, aF.s6);
847 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7, a3.s7, a4.s7, a5.s7, a6.s7, a7.s7,
848 a8.s7, a9.s7, aA.s7, aB.s7, aC.s7, aD.s7, aE.s7, aF.s7);
849#endif // N0 > 4
850#if N0 > 8
851 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8, a3.s8, a4.s8, a5.s8, a6.s8, a7.s8,
852 a8.s8, a9.s8, aA.s8, aB.s8, aC.s8, aD.s8, aE.s8, aF.s8);
853 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9, a3.s9, a4.s9, a5.s9, a6.s9, a7.s9,
854 a8.s9, a9.s9, aA.s9, aB.s9, aC.s9, aD.s9, aE.s9, aF.s9);
855 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA, a3.sA, a4.sA, a5.sA, a6.sA, a7.sA,
856 a8.sA, a9.sA, aA.sA, aB.sA, aC.sA, aD.sA, aE.sA, aF.sA);
857 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB, a3.sB, a4.sB, a5.sB, a6.sB, a7.sB,
858 a8.sB, a9.sB, aA.sB, aB.sB, aC.sB, aD.sB, aE.sB, aF.sB);
859 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC, a3.sC, a4.sC, a5.sC, a6.sC, a7.sC,
860 a8.sC, a9.sC, aA.sC, aB.sC, aC.sC, aD.sC, aE.sC, aF.sC);
861 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD, a3.sD, a4.sD, a5.sD, a6.sD, a7.sD,
862 a8.sD, a9.sD, aA.sD, aB.sD, aC.sD, aD.sD, aE.sD, aF.sD);
863 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE, a3.sE, a4.sE, a5.sE, a6.sE, a7.sE,
864 a8.sE, a9.sE, aA.sE, aB.sE, aC.sE, aD.sE, aE.sE, aF.sE);
865 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF, a3.sF, a4.sF, a5.sF, a6.sF, a7.sF,
866 a8.sF, a9.sF, aA.sF, aB.sF, aC.sF, aD.sF, aE.sF, aF.sF);
867#endif // N0 > 8
868
869#else // N0 == 16
870#error "Not supported N0 value"
871#endif // N0 > 2
872
873 // ---------------------------Store the output values ------------------------------
Usama Arif0681e3b2019-04-25 14:28:07 +0100874 REPEAT_VAR_INIT_TO_CONST(16, uint, zout, 0);
875 STORE_BLOCK(N0, K0, DATA_TYPE, res, output_ptr, OUTPUT_STEP_X * sizeof(DATA_TYPE), zout);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000876
877#undef BLOCK_SIZE
878#undef OUTPUT_OFFSET_X
879#undef OUTPUT_STEP_X
880}
881#endif // defined(TRANSPOSE)
882#endif // defined(K0) && defined(N0) && defined(H0) && defined(DATA_TYPE) && defined(SRC_HEIGHT)
883
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +0000884#if defined(M0) && defined(N0) && defined(K0) && defined(H0) && defined(DATA_TYPE) && defined(M) && defined(N) && defined(K)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +0000885
886#define CONCAT(a, b) a##b
887
888#define ARM_DOT1(a, b, c) \
889 ({ \
890 c = fma(a, b, c); \
891 })
892#define ARM_DOT2(a, b, c) \
893 ({ \
894 c = fma(a.s0, b.s0, c); \
895 c = fma(a.s1, b.s1, c); \
896 })
897#define ARM_DOT3(a, b, c) \
898 ({ \
899 ARM_DOT2(a, b, c); \
900 c = fma((a.s2), (b.s2), c); \
901 })
902#define ARM_DOT4(a, b, c) \
903 ({ \
904 ARM_DOT3(a, b, c); \
905 c = fma((a.s3), (b.s3), c); \
906 })
907#define ARM_DOT8(a, b, c) \
908 ({ \
909 ARM_DOT4((a.lo), (b.lo), c); \
910 ARM_DOT4((a.hi), (b.hi), c); \
911 })
912#define ARM_DOT16(a, b, c) \
913 ({ \
914 ARM_DOT8((a.lo), (b.lo), c); \
915 ARM_DOT8((a.hi), (b.hi), c); \
916 })
917
918#if N0 == 2
919#define ARM_DOT_K0XN0(k0, a, b, c) \
920 ({ \
921 CONCAT(ARM_DOT, k0) \
922 ((a), (b##0), (c.s0)); \
923 CONCAT(ARM_DOT, k0) \
924 ((a), (b##1), (c.s1)); \
925 })
926#elif N0 == 3 // N0 == 3
927#define ARM_DOT_K0XN0(k0, a, b, c) \
928 ({ \
929 CONCAT(ARM_DOT, k0) \
930 ((a), (b##0), (c.s0)); \
931 CONCAT(ARM_DOT, k0) \
932 ((a), (b##1), (c.s1)); \
933 CONCAT(ARM_DOT, k0) \
934 ((a), (b##2), (c.s2)); \
935 })
936#elif N0 == 4 // N0 == 4
937#define ARM_DOT_K0XN0(k0, a, b, c) \
938 ({ \
939 CONCAT(ARM_DOT, k0) \
940 ((a), (b##0), (c.s0)); \
941 CONCAT(ARM_DOT, k0) \
942 ((a), (b##1), (c.s1)); \
943 CONCAT(ARM_DOT, k0) \
944 ((a), (b##2), (c.s2)); \
945 CONCAT(ARM_DOT, k0) \
946 ((a), (b##3), (c.s3)); \
947 })
948#elif N0 == 8 // N0 == 8
949#define ARM_DOT_K0XN0(k0, a, b, c) \
950 ({ \
951 CONCAT(ARM_DOT, k0) \
952 ((a), (b##0), (c.s0)); \
953 CONCAT(ARM_DOT, k0) \
954 ((a), (b##1), (c.s1)); \
955 CONCAT(ARM_DOT, k0) \
956 ((a), (b##2), (c.s2)); \
957 CONCAT(ARM_DOT, k0) \
958 ((a), (b##3), (c.s3)); \
959 CONCAT(ARM_DOT, k0) \
960 ((a), (b##4), (c.s4)); \
961 CONCAT(ARM_DOT, k0) \
962 ((a), (b##5), (c.s5)); \
963 CONCAT(ARM_DOT, k0) \
964 ((a), (b##6), (c.s6)); \
965 CONCAT(ARM_DOT, k0) \
966 ((a), (b##7), (c.s7)); \
967 })
968#elif N0 == 16 // N0 == 16
969#define ARM_DOT_K0XN0(k0, a, b, c) \
970 ({ \
971 CONCAT(ARM_DOT, k0) \
972 ((a), (b##0), (c.s0)); \
973 CONCAT(ARM_DOT, k0) \
974 ((a), (b##1), (c.s1)); \
975 CONCAT(ARM_DOT, k0) \
976 ((a), (b##2), (c.s2)); \
977 CONCAT(ARM_DOT, k0) \
978 ((a), (b##3), (c.s3)); \
979 CONCAT(ARM_DOT, k0) \
980 ((a), (b##4), (c.s4)); \
981 CONCAT(ARM_DOT, k0) \
982 ((a), (b##5), (c.s5)); \
983 CONCAT(ARM_DOT, k0) \
984 ((a), (b##6), (c.s6)); \
985 CONCAT(ARM_DOT, k0) \
986 ((a), (b##7), (c.s7)); \
987 CONCAT(ARM_DOT, k0) \
988 ((a), (b##8), (c.s8)); \
989 CONCAT(ARM_DOT, k0) \
990 ((a), (b##9), (c.s9)); \
991 CONCAT(ARM_DOT, k0) \
992 ((a), (b##A), (c.sA)); \
993 CONCAT(ARM_DOT, k0) \
994 ((a), (b##B), (c.sB)); \
995 CONCAT(ARM_DOT, k0) \
996 ((a), (b##C), (c.sC)); \
997 CONCAT(ARM_DOT, k0) \
998 ((a), (b##D), (c.sD)); \
999 CONCAT(ARM_DOT, k0) \
1000 ((a), (b##E), (c.sE)); \
1001 CONCAT(ARM_DOT, k0) \
1002 ((a), (b##F), (c.sF)); \
1003 })
1004#else // N0 not supported
1005#error "N0 value not supported"
1006#endif // N0 conditions
1007
1008/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
1009 * The LHS matrix is NOT reshaped
1010 * The RHS is reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the block K0xN0 is transposed
1011 *
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001012 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01001013 * @note The GEMM's dimensions (M,N and K) must be passed at compile time using -DM, -DN and and -DK (e.g. -DM=52, -DN=30 and -DK=90)
1014 * @note The number of columns of LHS matrix must be passed at compile time using -DK (e.g. -DK=64)
1015 * @note The block's dimensions used for reshaping the RHS matrix (N0 and K0) must be passed at compile time using -DN0 and -DK0 (e.g. -DN0=8, -DK0=4).
1016 * @note The number of M0 rows to process must be passed at compile time using -DM0 (e.g. -DM0=2)
1017 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (e.g. -DH0=2)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001018 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
1019 * @note Only the following configurations of M0, N0 and K0 are currently supported:
1020 * - M0 = 1, 2, 3, 4, 5, 6, 7, 8
1021 * - N0 = 2, 3, 4, 8, 16
1022 * - K0 = 2, 3, 4, 8, 16
Gian Marco Iodice62251f72019-03-11 16:07:12 +00001023 * - H0 >= 1
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001024 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01001025 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +01001026 * The activation function is performed after the bias addition
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001027 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
1028 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
1029 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1030 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1031 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1032 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix
1033 *
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001034 * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: F16/F32
1035 * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes)
1036 * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1037 * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes)
1038 * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1039 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
1040 * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
1041 * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes)
1042 * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1043 * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
1044 * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1045 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001046 * @param[in] bias_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
1047 * @param[in] bias_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
1048 * @param[in] bias_step_x (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes)
1049 * @param[in] bias_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
1050 * @param[in] bias_step_y (Optional) bias_stride_y * number of elements along Y processed per workitem(in bytes)
1051 * @param[in] bias_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001052 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
1053 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1054 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1055 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1056 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1057 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1058 * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes)
1059 * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001060 * @param[in] bias_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001061 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1062 * @param[in] lhs_cross_plane_pad (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
1063 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001064 */
1065__kernel void gemm_mm_reshaped_only_rhs_t(IMAGE_DECLARATION(lhs),
1066 IMAGE_DECLARATION(rhs),
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001067#if defined(BETA)
1068 IMAGE_DECLARATION(bias),
1069#endif // defined(BETA)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001070 IMAGE_DECLARATION(dst),
1071 uint lhs_stride_z,
1072 uint rhs_stride_z,
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001073#if defined(BETA)
1074 uint bias_stride_z,
1075#endif //defined(BETA)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001076 uint dst_stride_z
1077#if defined(REINTERPRET_INPUT_AS_3D)
1078 ,
1079 uint lhs_cross_plane_pad
1080#endif // REINTERPRET_INPUT_AS_3D
1081#if defined(REINTERPRET_OUTPUT_AS_3D)
1082 ,
1083 uint dst_cross_plane_pad
1084#endif // REINTERPRET_OUTPUT_AS_3D
1085 )
1086{
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001087 // Block size
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001088#define RHS_BLOCK_SIZE ((K0) * (N0))
1089
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001090 // RHS offset and step X
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001091#if defined(RHS_INTERLEAVE)
1092#define RHS_OFFSET_X (K0)
1093#define RHS_STEP_X ((K0) * (H0))
1094#define RHS_STEP_LOOP (1)
1095#else // defined(RHS_INTERLEAVE)
1096#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
1097#define RHS_STEP_X (K0)
1098#define RHS_STEP_LOOP (H0)
1099#endif // defined(RHS_INTERLEAVE)
1100
1101 uint x = get_global_id(0);
1102 uint y = get_global_id(1);
1103 uint z = get_global_id(2);
1104
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001105#if defined(DUMMY_WORK_ITEMS)
1106 if((x * N0 >= N) || (y * M0 >= M))
1107 {
1108 return;
1109 }
1110#endif // defined(DUMMY_WORK_ITEMS)
1111
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001112 // Compute LHS matrix address
1113 uint lhs_offset = lhs_offset_first_element_in_bytes + y * M0 * (uint)lhs_stride_y;
1114
1115 // Compute RHS matrix address
1116 uint rhs_offset = rhs_offset_first_element_in_bytes + (x % H0) * (uint)RHS_OFFSET_X * sizeof(DATA_TYPE) + (x / (uint)H0) * rhs_stride_y;
1117
1118#if defined(MATRIX_B_DEPTH)
1119 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1120 rhs_offset += (z % MATRIX_B_DEPTH) * rhs_stride_z;
1121#else // defined(MATRIX_B_DEPTH)
1122 rhs_offset += z * rhs_stride_z;
1123#endif // defined(MATRIX_B_DEPTH)
1124
Usama Arif0681e3b2019-04-25 14:28:07 +01001125 REPEAT_VAR_INIT_TO_CONST(8, uint, zlhs, 0); //uint zlhs0=0,zlhs1=0,zlhs2=0,... zlhs7=0;
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001126 REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001127
1128#if defined(REINTERPRET_INPUT_AS_3D)
Usama Arif0681e3b2019-04-25 14:28:07 +01001129 // The plane (zlhs) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
1130 CALCULATE_Z_OFFSET(M0, uint, zlhs, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001131
1132 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1133 // multiply lhs_stride_z by DEPTH_GEMM3D
1134 lhs_offset += z * lhs_stride_z * DEPTH_GEMM3D;
1135
1136#else // defined(REINTERPRET_INPUT_AS_3D)
1137
1138 // Add offset for batched GEMM
1139 lhs_offset += z * lhs_stride_z;
1140
1141#endif // defined(REINTERPRET_INPUT_AS_3D)
1142
1143 // Initialize the accumulators
1144 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(M0-1)=0;
1145
1146 int i = 0;
1147 for(; i <= (K - K0); i += K0)
1148 {
1149 // Supported cases (M0, K0):
1150 // 1,2 - 1,3 - 1,4 - 1,8 - 1,16
1151 // 2,2 - 2,3 - 2,4 - 2,8 - 2,16
1152 // 3,2 - 3,3 - 3,4 - 3,8 - 3,16
1153 // 4,2 - 4,3 - 4,4 - 4,8 - 4,16
1154 // 5,2 - 5,3 - 5,4 - 5,8 - 5,16
1155 // 6,2 - 6,3 - 6,4 - 6,8 - 6,16
1156 // 7,2 - 7,3 - 7,4 - 7,8 - 7,16
1157 // 8,2 - 8,3 - 8,4 - 8,8 - 8,16
1158 // Load values from LHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +01001159 LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001160
1161 // Load values from RHS matrix
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001162 LOAD_BLOCK(N0, K0, DATA_TYPE, b, rhs_ptr, rhs_offset, RHS_STEP_X * sizeof(DATA_TYPE), zero);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001163
1164 // Accumulate
1165 ARM_DOT_K0XN0(K0, a0, b, c0);
1166#if M0 > 1
1167 ARM_DOT_K0XN0(K0, a1, b, c1);
1168#endif // M0 > 1
1169#if M0 > 2
1170 ARM_DOT_K0XN0(K0, a2, b, c2);
1171#endif // M0 > 2
1172#if M0 > 3
1173 ARM_DOT_K0XN0(K0, a3, b, c3);
1174#endif // M0 > 3
1175#if M0 > 4
1176 ARM_DOT_K0XN0(K0, a4, b, c4);
1177#endif // M0 > 4
1178#if M0 > 5
1179 ARM_DOT_K0XN0(K0, a5, b, c5);
1180#endif // M0 > 5
1181#if M0 > 6
1182 ARM_DOT_K0XN0(K0, a6, b, c6);
1183#endif // M0 > 6
1184#if M0 > 7
1185 ARM_DOT_K0XN0(K0, a7, b, c7);
1186#endif // M0 > 7
1187
1188 lhs_offset += K0 * sizeof(DATA_TYPE);
1189 rhs_offset += (N0 * RHS_STEP_X * RHS_STEP_LOOP) * sizeof(DATA_TYPE);
1190 }
1191
1192 // Left-over accumulations
1193 for(; i < K; ++i)
1194 {
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001195 // Load values from LHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +01001196 LOAD_BLOCK(M0, 1, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001197
1198 // Load values from RHS matrix
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001199 LOAD_BLOCK(N0, 1, DATA_TYPE, b, rhs_ptr, rhs_offset, RHS_STEP_X * sizeof(DATA_TYPE), zero);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001200
1201 // Accumulate
1202 ARM_DOT_K0XN0(1, a0, b, c0);
1203#if M0 > 1
1204 ARM_DOT_K0XN0(1, a1, b, c1);
1205#endif // M0 > 1
1206#if M0 > 2
1207 ARM_DOT_K0XN0(1, a2, b, c2);
1208#endif // M0 > 2
1209#if M0 > 3
1210 ARM_DOT_K0XN0(1, a3, b, c3);
1211#endif // M0 > 3
1212#if M0 > 4
1213 ARM_DOT_K0XN0(1, a4, b, c4);
1214#endif // M0 > 4
1215#if M0 > 5
1216 ARM_DOT_K0XN0(1, a5, b, c5);
1217#endif // M0 > 5
1218#if M0 > 6
1219 ARM_DOT_K0XN0(1, a6, b, c6);
1220#endif // M0 > 6
1221#if M0 > 7
1222 ARM_DOT_K0XN0(1, a7, b, c7);
1223#endif // M0 > 7
1224
1225 lhs_offset += sizeof(DATA_TYPE);
1226 rhs_offset += sizeof(DATA_TYPE);
1227 }
1228
1229 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * dst_stride_y);
1230
1231 REPEAT_VAR_INIT_TO_CONST(8, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
1232
1233#if defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001234
1235 // The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
Usama Arif0681e3b2019-04-25 14:28:07 +01001236 CALCULATE_Z_OFFSET(M0, uint, zout, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001237
1238 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1239 // multiply dst_stride_z by DEPTH_GEMM3D
1240 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1241
1242#else // defined(REINTERPRET_OUTPUT_AS_3D)
1243
1244 // Add offset for batched GEMM
1245 dst_addr += z * dst_stride_z;
1246
1247#endif // defined(REINTERPRET_OUTPUT_AS_3D)
1248
1249 // Multiply by the weight of matrix-matrix product and store the result
1250#if defined(ALPHA)
Usama Arif0681e3b2019-04-25 14:28:07 +01001251 SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001252#endif // defined(ALPHA)
1253
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001254 // Add beta*bias
1255#if defined(BETA)
1256#if defined(BROADCAST_BIAS)
1257 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
1258
1259 LOAD_BLOCK(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
1260
1261#ifndef UNIT_BETA
1262 SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
1263#endif // UNIT_BIAS
1264
1265 // c = c + bias[broadcasted]
1266 ADD_BLOCK_BROADCAST(M0, c, bias0);
1267
1268#else // defined(BROADCAST_BIAS)
1269 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id(
1270 2) * bias_stride_z;
1271
1272 LOAD_BLOCK(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
1273
1274#ifndef UNIT_BETA
1275 SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
1276#endif // UNIT_BIAS
1277
1278 // c = c + bias
1279 ADD_BLOCK(M0, c, bias);
1280
1281#endif // defined(BROADCAST_BIAS)
1282#endif // defined(BETA)
1283
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +01001284#if defined(ACTIVATION_TYPE)
1285 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE, c, A_VAL, B_VAL);
1286#endif // defined(ACTIVATION_TYPE)
1287
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001288 // Store output block
Usama Arif0681e3b2019-04-25 14:28:07 +01001289 STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001290
1291#undef RHS_BLOCK_SIZE
1292#undef RHS_OFFSET_X
1293#undef RHS_STEP_X
1294}
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001295
1296#define VFMA(a, b, c) \
1297 ({ \
1298 c = fma(a, b, c); \
1299 })
1300
1301#if M0 == 1
1302#define LD_RHS_VFMA_M0xN0(i, a, c) \
1303 ({ \
1304 VEC_DATA_TYPE(DATA_TYPE, N0) \
1305 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1306 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1307 })
1308#elif M0 == 2 // M0 == 2
1309#define LD_RHS_VFMA_M0xN0(i, a, c) \
1310 ({ \
1311 VEC_DATA_TYPE(DATA_TYPE, N0) \
1312 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1313 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1314 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1315 })
1316#elif M0 == 3 // M0 == 3
1317#define LD_RHS_VFMA_M0xN0(i, a, c) \
1318 ({ \
1319 VEC_DATA_TYPE(DATA_TYPE, N0) \
1320 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1321 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1322 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1323 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1324 })
1325#elif M0 == 4 // M0 == 4
1326#define LD_RHS_VFMA_M0xN0(i, a, c) \
1327 ({ \
1328 VEC_DATA_TYPE(DATA_TYPE, N0) \
1329 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1330 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1331 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1332 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1333 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1334 })
1335#elif M0 == 5 // M0 == 5
1336#define LD_RHS_VFMA_M0xN0(i, a, c) \
1337 ({ \
1338 VEC_DATA_TYPE(DATA_TYPE, N0) \
1339 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1340 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1341 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1342 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1343 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1344 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
1345 })
1346#elif M0 == 6 // M0 == 6
1347#define LD_RHS_VFMA_M0xN0(i, a, c) \
1348 ({ \
1349 VEC_DATA_TYPE(DATA_TYPE, N0) \
1350 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1351 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1352 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1353 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1354 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1355 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
1356 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
1357 })
1358#elif M0 == 7 // M0 == 7
1359#define LD_RHS_VFMA_M0xN0(i, a, c) \
1360 ({ \
1361 VEC_DATA_TYPE(DATA_TYPE, N0) \
1362 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1363 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1364 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1365 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1366 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1367 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
1368 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
1369 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##6).s##i), b, (c##6)); \
1370 })
1371#elif M0 == 8 // M0 == 8
1372#define LD_RHS_VFMA_M0xN0(i, a, c) \
1373 ({ \
1374 VEC_DATA_TYPE(DATA_TYPE, N0) \
1375 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1376 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1377 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1378 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1379 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1380 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
1381 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
1382 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##6).s##i), b, (c##6)); \
1383 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##7).s##i), b, (c##7)); \
1384 })
1385#else // M0 not supported
1386#error "M0 not supported"
1387#endif // M0 not supported
1388
1389/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
1390 * The LHS matrix is NOT reshaped
1391 * The RHS is reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the block K0xN0 is NOT transposed
1392 *
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001393 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01001394 * @note The GEMM's dimensions (M,N and K) must be passed at compile time using -DM, -DN and and -DK (e.g. -DM=52, -DN=30 and -DK=90).
1395 * @note The block's dimensions used for reshaping the RHS matrix (N0 and K0) must be passed at compile time using -DN0 and -DK0 (e.g. -DN0=8, -DK0=4).
1396 * @note The number of M0 rows to process must be passed at compile time using -DM0 (e.g. -DM0=2)
1397 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (e.g. -DH0=2)
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001398 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
1399 * @note Only the following configurations of M0, N0 and K0 are currently supported:
1400 * - M0 = 1, 2, 3, 4, 5, 6, 7, 8
1401 * - N0 = 2, 3, 4, 8, 16
1402 * - K0 = 2, 3, 4, 8, 16
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001403 * - H0 >= 1
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001404 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01001405 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +01001406 * The activation function is performed after the bias addition
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001407 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
1408 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
1409 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1410 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1411 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1412 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix
1413 *
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001414 * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: F16/F32
1415 * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes)
1416 * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1417 * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes)
1418 * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1419 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
1420 * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
1421 * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes)
1422 * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1423 * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
1424 * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1425 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001426 * @param[in] bias_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
1427 * @param[in] bias_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001428 * @param[in] bias_step_x (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes)
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001429 * @param[in] bias_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001430 * @param[in] bias_step_y (Optional) bias_stride_y * number of elements along Y processed per workitem(in bytes)
1431 * @param[in] bias_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
1432 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
1433 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1434 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1435 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1436 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1437 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1438 * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes)
1439 * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001440 * @param[in] bias_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001441 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1442 * @param[in] lhs_cross_plane_pad (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
1443 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001444 */
1445__kernel void gemm_mm_reshaped_only_rhs_nt(IMAGE_DECLARATION(lhs),
1446 IMAGE_DECLARATION(rhs),
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001447#if defined(BETA)
1448 IMAGE_DECLARATION(bias),
1449#endif // defined(BETA)
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001450 IMAGE_DECLARATION(dst),
1451 uint lhs_stride_z,
1452 uint rhs_stride_z,
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001453#if defined(BETA)
1454 uint bias_stride_z,
1455#endif //defined(BETA)
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001456 uint dst_stride_z
1457#if defined(REINTERPRET_INPUT_AS_3D)
1458 ,
1459 uint lhs_cross_plane_pad
1460#endif // REINTERPRET_INPUT_AS_3D
1461#if defined(REINTERPRET_OUTPUT_AS_3D)
1462 ,
1463 uint dst_cross_plane_pad
1464#endif // REINTERPRET_OUTPUT_AS_3D
1465 )
1466{
1467 // Block size
1468#define RHS_BLOCK_SIZE ((K0) * (N0))
1469
1470 // RHS offset and step X
1471#if defined(RHS_INTERLEAVE)
1472#define RHS_OFFSET_X (N0)
1473#define RHS_STEP_X ((N0) * (H0))
1474#define RHS_STEP_LOOP (1)
1475#else // defined(RHS_INTERLEAVE)
1476#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
1477#define RHS_STEP_X (N0)
1478#define RHS_STEP_LOOP (H0)
1479#endif // defined(RHS_INTERLEAVE)
1480
1481 uint x = get_global_id(0);
1482 uint y = get_global_id(1);
1483 uint z = get_global_id(2);
1484
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001485#if defined(DUMMY_WORK_ITEMS)
1486 if((x * N0 >= N) || (y * M0 >= M))
1487 {
1488 return;
1489 }
1490#endif // defined(DUMMY_WORK_ITEMS)
1491
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001492 // Compute LHS matrix address
1493 uint lhs_offset = lhs_offset_first_element_in_bytes + y * M0 * (uint)lhs_stride_y;
1494
1495 // Compute RHS matrix address
1496 uint rhs_offset = rhs_offset_first_element_in_bytes + (x % H0) * (uint)RHS_OFFSET_X * sizeof(DATA_TYPE) + (x / (uint)H0) * rhs_stride_y;
1497
1498#if defined(MATRIX_B_DEPTH)
1499 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1500 rhs_offset += (z % MATRIX_B_DEPTH) * rhs_stride_z;
1501#else // defined(MATRIX_B_DEPTH)
1502 rhs_offset += z * rhs_stride_z;
1503#endif // defined(MATRIX_B_DEPTH)
1504
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001505 REPEAT_VAR_INIT_TO_CONST(8, uint, zin, 0); //uint zin0=0,zin1=0,zin2=0,... zin7=0;
1506 REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0); //uint zero0=0,zero1=0,zero2=0,... zero7=0;
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001507
1508#if defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001509
1510 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
Usama Arif0681e3b2019-04-25 14:28:07 +01001511 CALCULATE_Z_OFFSET(M0, uint, zin, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001512
1513 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1514 // multiply lhs_stride_z by DEPTH_GEMM3D
1515 lhs_offset += z * lhs_stride_z * DEPTH_GEMM3D;
1516
1517#else // defined(REINTERPRET_INPUT_AS_3D)
1518
1519 // Add offset for batched GEMM
1520 lhs_offset += z * lhs_stride_z;
1521
1522#endif // defined(REINTERPRET_INPUT_AS_3D)
1523
1524 // Initialize the accumulators
1525 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(N0-1)=0;
1526
1527 int i = 0;
1528 for(; i <= (K - K0); i += K0)
1529 {
1530 // Supported cases (M0, K0):
1531 // 1,2 - 1,3 - 1,4 - 1,8 - 1,16
1532 // 2,2 - 2,3 - 2,4 - 2,8 - 2,16
1533 // 3,2 - 3,3 - 3,4 - 3,8 - 3,16
1534 // 4,2 - 4,3 - 4,4 - 4,8 - 4,16
1535 // 5,2 - 5,3 - 5,4 - 5,8 - 5,16
1536 // 6,2 - 6,3 - 6,4 - 6,8 - 6,16
1537 // 7,2 - 7,3 - 7,4 - 7,8 - 7,16
1538 // 8,2 - 8,3 - 8,4 - 8,8 - 8,16
1539 // Load values from LHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +01001540 LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zin);
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001541
1542 LD_RHS_VFMA_M0xN0(0, a, c);
1543 LD_RHS_VFMA_M0xN0(1, a, c);
1544#if K0 > 2
1545 LD_RHS_VFMA_M0xN0(2, a, c);
1546#endif // K0 > 2
1547#if K0 > 3
1548 LD_RHS_VFMA_M0xN0(3, a, c);
1549#endif // K0 > 3
1550#if K0 > 4
1551 LD_RHS_VFMA_M0xN0(4, a, c);
1552 LD_RHS_VFMA_M0xN0(5, a, c);
1553 LD_RHS_VFMA_M0xN0(6, a, c);
1554 LD_RHS_VFMA_M0xN0(7, a, c);
1555#endif // K0 > 4
1556#if K0 > 8
1557 LD_RHS_VFMA_M0xN0(8, a, c);
1558 LD_RHS_VFMA_M0xN0(9, a, c);
1559 LD_RHS_VFMA_M0xN0(A, a, c);
1560 LD_RHS_VFMA_M0xN0(B, a, c);
1561 LD_RHS_VFMA_M0xN0(C, a, c);
1562 LD_RHS_VFMA_M0xN0(D, a, c);
1563 LD_RHS_VFMA_M0xN0(E, a, c);
1564 LD_RHS_VFMA_M0xN0(F, a, c);
1565#endif // K0 > 8
1566
1567 lhs_offset += K0 * sizeof(DATA_TYPE);
1568 rhs_offset += K0 * RHS_STEP_X * RHS_STEP_LOOP * sizeof(DATA_TYPE);
1569 }
1570
1571 // Left-over accumulations
1572 for(; i < K; ++i)
1573 {
1574 // Load values from LHS matrix
1575 VEC_DATA_TYPE(DATA_TYPE, 2)
1576 a0 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 0 * lhs_stride_y + zin0));
1577#if M0 > 1
1578 VEC_DATA_TYPE(DATA_TYPE, 2)
1579 a1 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 1 * lhs_stride_y + zin1));
1580#endif // M0 > 1
1581#if M0 > 2
1582 VEC_DATA_TYPE(DATA_TYPE, 2)
1583 a2 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 2 * lhs_stride_y + zin2));
1584#endif // M0 > 2
1585#if M0 > 3
1586 VEC_DATA_TYPE(DATA_TYPE, 2)
1587 a3 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 3 * lhs_stride_y + zin3));
1588#endif // M0 > 3
1589#if M0 > 4
1590 VEC_DATA_TYPE(DATA_TYPE, 2)
1591 a4 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 4 * lhs_stride_y + zin4));
1592#endif // M0 > 4
1593#if M0 > 5
1594 VEC_DATA_TYPE(DATA_TYPE, 2)
1595 a5 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 5 * lhs_stride_y + zin5));
1596#endif // M0 > 5
1597#if M0 > 6
1598 VEC_DATA_TYPE(DATA_TYPE, 2)
1599 a6 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 6 * lhs_stride_y + zin6));
1600#endif // M0 > 6
1601#if M0 > 7
1602 VEC_DATA_TYPE(DATA_TYPE, 2)
giuros01b3204e72019-04-01 13:50:22 +01001603 a7 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 7 * lhs_stride_y + zin7));
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001604#endif // M0 > 7
1605
1606 LD_RHS_VFMA_M0xN0(0, a, c);
1607
1608 lhs_offset += sizeof(DATA_TYPE);
1609 rhs_offset += RHS_STEP_X * sizeof(DATA_TYPE);
1610 }
1611
1612 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * dst_stride_y);
1613
1614 REPEAT_VAR_INIT_TO_CONST(8, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
1615
1616#if defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001617 // The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
Usama Arif0681e3b2019-04-25 14:28:07 +01001618 CALCULATE_Z_OFFSET(M0, uint, zout, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001619
1620 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1621 // multiply dst_stride_z by DEPTH_GEMM3D
1622 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1623
1624#else // defined(REINTERPRET_OUTPUT_AS_3D)
1625
1626 // Add offset for batched GEMM
1627 dst_addr += z * dst_stride_z;
1628
1629#endif // defined(REINTERPRET_OUTPUT_AS_3D)
1630
1631 // Multiply by the weight of matrix-matrix product and store the result
1632#if defined(ALPHA)
Usama Arif0681e3b2019-04-25 14:28:07 +01001633 SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001634#endif // defined(ALPHA)
1635
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001636 // Add beta*bias
1637#if defined(BETA)
1638#if defined(BROADCAST_BIAS)
1639 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
1640
1641 LOAD_BLOCK(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
1642
1643#ifndef UNIT_BETA
1644 SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
1645#endif // UNIT_BIAS
1646
1647 // c = c + bias[broadcasted]
1648 ADD_BLOCK_BROADCAST(M0, c, bias0);
1649
1650#else // defined(BROADCAST_BIAS)
1651 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id(
1652 2) * bias_stride_z;
1653
1654 LOAD_BLOCK(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
1655
1656#ifndef UNIT_BETA
1657 SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
1658#endif // UNIT_BIAS
1659
1660 // c = c + bias
1661 ADD_BLOCK(M0, c, bias);
1662
1663#endif // defined(BROADCAST_BIAS)
1664#endif // defined(BETA)
1665
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +01001666#if defined(ACTIVATION_TYPE)
1667 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE, c, A_VAL, B_VAL);
1668#endif // defined(ACTIVATION_TYPE)
1669
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001670 // Store output block
Usama Arif0681e3b2019-04-25 14:28:07 +01001671 STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001672
1673#undef RHS_BLOCK_SIZE
1674#undef RHS_OFFSET_X
1675#undef RHS_STEP_X
1676}
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001677#endif // defined(M0) && defined(N0) && defined(K0) && defined(H0) && defined(DATA_TYPE) && defined(M) && defined(N) && defined(K)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001678
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001679#if defined(M0) && defined(N0) && defined(K0) && defined(V0) && defined(H0) && defined(DATA_TYPE) && defined(M) && defined(N)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001680
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001681#if K0 == 2
1682#define ARM_DOT_K0(a, b, c) \
1683 ({ \
1684 c = fma(a.s0, b.s0, c); \
1685 c = fma(a.s1, b.s1, c); \
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001686 })
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001687#elif K0 == 3 // K0 == 3
1688#define ARM_DOT_K0(a, b, c) \
1689 ({ \
1690 c = fma(a.s0, b.s0, c); \
1691 c = fma(a.s1, b.s1, c); \
1692 c = fma(a.s2, b.s2, c); \
1693 })
1694#elif K0 == 4 // K0 == 4
1695#define ARM_DOT_K0(a, b, c) \
1696 ({ \
1697 c = fma(a.s0, b.s0, c); \
1698 c = fma(a.s1, b.s1, c); \
1699 c = fma(a.s2, b.s2, c); \
1700 c = fma(a.s3, b.s3, c); \
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001701 })
1702#elif K0 == 8 // K0 == 8
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001703#define ARM_DOT_K0(a, b, c) \
1704 ({ \
1705 c = fma(a.s0, b.s0, c); \
1706 c = fma(a.s1, b.s1, c); \
1707 c = fma(a.s2, b.s2, c); \
1708 c = fma(a.s3, b.s3, c); \
1709 c = fma(a.s4, b.s4, c); \
1710 c = fma(a.s5, b.s5, c); \
1711 c = fma(a.s6, b.s6, c); \
1712 c = fma(a.s7, b.s7, c); \
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001713 })
1714#elif K0 == 16 // K0 == 16
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001715#define ARM_DOT_K0(a, b, c) \
1716 ({ \
1717 c = fma(a.s0, b.s0, c); \
1718 c = fma(a.s1, b.s1, c); \
1719 c = fma(a.s2, b.s2, c); \
1720 c = fma(a.s3, b.s3, c); \
1721 c = fma(a.s4, b.s4, c); \
1722 c = fma(a.s5, b.s5, c); \
1723 c = fma(a.s6, b.s6, c); \
1724 c = fma(a.s7, b.s7, c); \
1725 c = fma(a.s8, b.s8, c); \
1726 c = fma(a.s9, b.s9, c); \
1727 c = fma(a.sA, b.sA, c); \
1728 c = fma(a.sB, b.sB, c); \
1729 c = fma(a.sC, b.sC, c); \
1730 c = fma(a.sD, b.sD, c); \
1731 c = fma(a.sE, b.sE, c); \
1732 c = fma(a.sF, b.sF, c); \
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001733 })
1734#else // K0 not supported
1735#error "K0 value not supported"
1736#endif // K0 conditions
1737
1738#if N0 == 2
1739#define ARM_DOT_K0XN0(a, b, c) \
1740 ({ \
1741 ARM_DOT_K0((a), (b##0), (c.s0)); \
1742 ARM_DOT_K0((a), (b##1), (c.s1)); \
1743 })
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001744#elif N0 == 3 // N0 == 3
1745#define ARM_DOT_K0XN0(a, b, c) \
1746 ({ \
1747 ARM_DOT_K0((a), (b##0), (c.s0)); \
1748 ARM_DOT_K0((a), (b##1), (c.s1)); \
1749 ARM_DOT_K0((a), (b##2), (c.s2)); \
1750 })
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001751#elif N0 == 4 // N0 == 4
1752#define ARM_DOT_K0XN0(a, b, c) \
1753 ({ \
1754 ARM_DOT_K0((a), (b##0), (c.s0)); \
1755 ARM_DOT_K0((a), (b##1), (c.s1)); \
1756 ARM_DOT_K0((a), (b##2), (c.s2)); \
1757 ARM_DOT_K0((a), (b##3), (c.s3)); \
1758 })
1759#elif N0 == 8 // N0 == 8
1760#define ARM_DOT_K0XN0(a, b, c) \
1761 ({ \
1762 ARM_DOT_K0((a), (b##0), (c.s0)); \
1763 ARM_DOT_K0((a), (b##1), (c.s1)); \
1764 ARM_DOT_K0((a), (b##2), (c.s2)); \
1765 ARM_DOT_K0((a), (b##3), (c.s3)); \
1766 ARM_DOT_K0((a), (b##4), (c.s4)); \
1767 ARM_DOT_K0((a), (b##5), (c.s5)); \
1768 ARM_DOT_K0((a), (b##6), (c.s6)); \
1769 ARM_DOT_K0((a), (b##7), (c.s7)); \
1770 })
1771#elif N0 == 16 // N0 == 16
1772#define ARM_DOT_K0XN0(a, b, c) \
1773 ({ \
1774 ARM_DOT_K0((a), (b##0), (c.s0)); \
1775 ARM_DOT_K0((a), (b##1), (c.s1)); \
1776 ARM_DOT_K0((a), (b##2), (c.s2)); \
1777 ARM_DOT_K0((a), (b##3), (c.s3)); \
1778 ARM_DOT_K0((a), (b##4), (c.s4)); \
1779 ARM_DOT_K0((a), (b##5), (c.s5)); \
1780 ARM_DOT_K0((a), (b##6), (c.s6)); \
1781 ARM_DOT_K0((a), (b##7), (c.s7)); \
1782 ARM_DOT_K0((a), (b##8), (c.s8)); \
1783 ARM_DOT_K0((a), (b##9), (c.s9)); \
1784 ARM_DOT_K0((a), (b##A), (c.sA)); \
1785 ARM_DOT_K0((a), (b##B), (c.sB)); \
1786 ARM_DOT_K0((a), (b##C), (c.sC)); \
1787 ARM_DOT_K0((a), (b##D), (c.sD)); \
1788 ARM_DOT_K0((a), (b##E), (c.sE)); \
1789 ARM_DOT_K0((a), (b##F), (c.sF)); \
1790 })
1791#else // N0 not supported
1792#error "N0 value not supported"
1793#endif // N0 conditions
1794
1795/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
1796 * The LHS matrix must be reshaped with @ref CLGEMMReshapeLHSMatrixKernel and the M0xK0 must be NOT transposed
1797 * The RHS matrix must be reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the K0xN0 must be transposed
1798 *
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001799 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01001800 * @note The GEMM's dimensions M and N must be passed at compile time using -DM and -DN (e.g. -DM=52 and -DN=90).
1801 * @note The block's dimensions used for reshaping the LHS matrix and the RHS matrix (M0, N0 and K0) must be passed at compile time using -DM0, -DN0 and -DK0 (e.g. -DM0=4, -DN0=8, -DK0=4).
1802 * @note The number of M0xK0 vertical blocks stored on the same output row of the reshaped LHS matrix must be passed at compile time using -DV0 (e.g. -DV0=2)
1803 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (e.g. -DH0=2)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001804 * @note If the M0xK0 blocks in the reshaped LHS matrix have been interleaved, the option -DLHS_INTERLEAVE must passed at compile time.
1805 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
1806 * @note Only the following configurations of M0, N0 and K0 are currently supported:
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01001807 * - M0 = 2, 3, 4, 5, 6, 7, 8
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001808 * - N0 = 2, 3, 4, 8, 16
1809 * - K0 = 2, 3, 4, 8, 16
Gian Marco Iodice62251f72019-03-11 16:07:12 +00001810 * - V0 >= 1
1811 * - H0 >= 1
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001812 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01001813 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +01001814 * The activation function is performed after the bias addition
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01001815 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001816 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1817 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1818 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1819 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix NOT reshaped
1820 *
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001821 * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: F16/F32
1822 * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes)
1823 * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1824 * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes)
1825 * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1826 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
1827 * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
1828 * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes)
1829 * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1830 * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
1831 * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1832 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
1833 * @param[in] bias_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
1834 * @param[in] bias_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
1835 * @param[in] bias_step_x (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes)
1836 * @param[in] bias_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
1837 * @param[in] bias_step_y (Optional) bias_stride_y * number of elements along Y processed per workitem(in bytes)
1838 * @param[in] bias_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
1839 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
1840 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1841 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1842 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1843 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1844 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1845 * @param[in] k Number of columns in LHS matrix and rows in RHS matrix not reshaped.
1846 * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes)
1847 * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
1848 * @param[in] bias_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
1849 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1850 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001851 */
1852__kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs),
1853 IMAGE_DECLARATION(rhs),
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001854#if defined(BETA)
1855 IMAGE_DECLARATION(bias),
1856#endif // defined(BETA)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001857 IMAGE_DECLARATION(dst),
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001858 uint k,
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001859 uint lhs_stride_z,
1860 uint rhs_stride_z,
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001861#if defined(BETA)
1862 uint bias_stride_z,
1863#endif //defined(BETA)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001864 uint dst_stride_z
1865#if defined(REINTERPRET_OUTPUT_AS_3D)
1866 ,
1867 uint dst_cross_plane_pad
1868#endif // REINTERPRET_OUTPUT_AS_3D
1869 )
1870{
1871 // Block size
1872#define LHS_BLOCK_SIZE ((K0) * (M0))
1873
1874#if defined(LHS_INTERLEAVE)
1875#define LHS_OFFSET_X (K0)
1876#define LHS_STEP_X ((K0) * (V0))
1877#define LHS_STEP_LOOP (1)
1878#else // defined(INTERLEAVE)
1879#define LHS_OFFSET_X (LHS_BLOCK_SIZE)
1880#define LHS_STEP_X (K0)
1881#define LHS_STEP_LOOP (V0)
1882#endif // defined(INTERLEAVE)
1883
1884 // Block size
1885#define RHS_BLOCK_SIZE ((K0) * (N0))
1886
1887 // RHS offset and step X
1888#if defined(RHS_INTERLEAVE)
1889#define RHS_OFFSET_X (K0)
1890#define RHS_STEP_X ((K0) * (H0))
1891#define RHS_STEP_LOOP (1)
1892#else // defined(RHS_INTERLEAVE)
1893#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
1894#define RHS_STEP_X (K0)
1895#define RHS_STEP_LOOP (H0)
1896#endif // defined(RHS_INTERLEAVE)
1897
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001898#if defined(DUMMY_WORK_ITEMS)
1899 if((get_global_id(0) * N0 >= N) || (get_global_id(1) * M0 >= M))
1900 {
1901 return;
1902 }
1903#endif // defined(DUMMY_WORK_ITEMS)
1904
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001905 // Compute LHS matrix address
1906 __global uchar *lhs_addr = lhs_ptr + lhs_offset_first_element_in_bytes + (get_global_id(1) % V0) * (uint)LHS_OFFSET_X * sizeof(DATA_TYPE) + (get_global_id(1) / V0) * (uint)lhs_stride_y +
1907 (get_global_id(2) * lhs_stride_z);
1908
1909 // Compute RHS matrix address
1910 __global uchar *rhs_addr = rhs_ptr + rhs_offset_first_element_in_bytes + (get_global_id(0) % H0) * (uint)RHS_OFFSET_X * sizeof(DATA_TYPE) + (get_global_id(0) / (uint)H0) * rhs_stride_y;
1911
1912#if defined(MATRIX_B_DEPTH)
1913 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1914 rhs_addr += (get_global_id(2) % MATRIX_B_DEPTH) * rhs_stride_z;
1915#else // defined(MATRIX_B_DEPTH)
1916 rhs_addr += get_global_id(2) * rhs_stride_z;
1917#endif // defined(MATRIX_B_DEPTH)
1918
1919 // Initialize the accumulators
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +00001920 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(M0-1)=0;
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001921
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001922 REPEAT_VAR_INIT_TO_CONST(M0, uint, zlhs, 0); //uint zlhs0=0,zlhs1=0,zlhs2=0,... zlhs7=0;
1923 REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0);
Usama Arif0681e3b2019-04-25 14:28:07 +01001924
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001925 for(int i = 0; i < k; i += K0)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001926 {
1927 // Supported cases (M0, K0):
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001928 // 1,2 - 1,3 - 1,4 - 1,8 - 1,16
1929 // 2,2 - 2,3 - 2,4 - 2,8 - 2,16
1930 // 3,2 - 3,3 - 3,4 - 3,8 - 3,16
1931 // 4,2 - 4,3 - 4,4 - 4,8 - 4,16
1932 // 5,2 - 5,3 - 5,4 - 5,8 - 5,16
1933 // 6,2 - 6,3 - 6,4 - 6,8 - 6,16
1934 // 7,2 - 7,3 - 7,4 - 7,8 - 7,16
1935 // 8,2 - 8,3 - 8,4 - 8,8 - 8,16
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001936 // Load values from LHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +01001937 LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_addr, 0, LHS_STEP_X * sizeof(DATA_TYPE), zlhs);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001938
1939 // Load values from RHS matrix
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001940 LOAD_BLOCK(N0, K0, DATA_TYPE, b, rhs_addr, 0, RHS_STEP_X * sizeof(DATA_TYPE), zero);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001941
1942 // Accumulate
1943 ARM_DOT_K0XN0(a0, b, c0);
1944#if M0 > 1
1945 ARM_DOT_K0XN0(a1, b, c1);
1946#endif // M0 > 1
1947#if M0 > 2
1948 ARM_DOT_K0XN0(a2, b, c2);
1949#endif // M0 > 2
1950#if M0 > 3
1951 ARM_DOT_K0XN0(a3, b, c3);
1952#endif // M0 > 3
1953#if M0 > 4
1954 ARM_DOT_K0XN0(a4, b, c4);
1955#endif // M0 > 4
1956#if M0 > 5
1957 ARM_DOT_K0XN0(a5, b, c5);
1958#endif // M0 > 5
1959#if M0 > 6
1960 ARM_DOT_K0XN0(a6, b, c6);
1961#endif // M0 > 6
1962#if M0 > 7
1963 ARM_DOT_K0XN0(a7, b, c7);
1964#endif // M0 > 7
1965
1966 lhs_addr += (M0 * LHS_STEP_X * LHS_STEP_LOOP) * sizeof(DATA_TYPE);
1967 rhs_addr += (N0 * RHS_STEP_X * RHS_STEP_LOOP) * sizeof(DATA_TYPE);
1968 }
1969
1970 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * dst_stride_y);
1971
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001972 REPEAT_VAR_INIT_TO_CONST(M0, uint, zout, 0);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001973
1974#if defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001975
1976 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
Usama Arif0681e3b2019-04-25 14:28:07 +01001977 CALCULATE_Z_OFFSET(M0, uint, zout, get_global_id(1), HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001978 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1979 // multiply dst_stride_z by DEPTH_GEMM3D
1980 dst_addr += get_global_id(2) * dst_stride_z * DEPTH_GEMM3D;
1981
1982#else // defined(REINTERPRET_OUTPUT_AS_3D)
1983
1984 // Add offset for batched GEMM
1985 dst_addr += get_global_id(2) * dst_stride_z;
1986
1987#endif // defined(REINTERPRET_OUTPUT_AS_3D)
1988
1989 // Multiply by the weight of matrix-matrix product and store the result
1990#if defined(ALPHA)
Usama Arif0681e3b2019-04-25 14:28:07 +01001991 SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001992#endif // defined(ALPHA)
1993
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001994 // Add beta*bias
1995#if defined(BETA)
1996#if defined(BROADCAST_BIAS)
1997 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
1998
1999 LOAD_BLOCK(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
2000
2001#ifndef UNIT_BETA
2002 SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
2003#endif // UNIT_BIAS
2004
2005 // c = c + bias[broadcasted]
2006 ADD_BLOCK_BROADCAST(M0, c, bias0);
2007
2008#else // defined(BROADCAST_BIAS)
2009 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id(
2010 2) * bias_stride_z;
2011
2012 LOAD_BLOCK(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
2013
2014#ifndef UNIT_BETA
2015 SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
2016#endif // UNIT_BIAS
2017
2018 // c = c + bias
2019 ADD_BLOCK(M0, c, bias);
2020
2021#endif // defined(BROADCAST_BIAS)
2022#endif // defined(BETA)
2023
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +01002024#if defined(ACTIVATION_TYPE)
2025 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE, c, A_VAL, B_VAL);
2026#endif // defined(ACTIVATION_TYPE)
2027
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002028 // Store output block
Usama Arif0681e3b2019-04-25 14:28:07 +01002029 STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01002030
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002031#undef LHS_BLOCK_SIZE
2032#undef LHS_OFFSET_X
2033#undef LHS_STEP_X
2034#undef RHS_BLOCK_SIZE
2035#undef RHS_OFFSET_X
2036#undef RHS_STEP_X
2037}
giuros01b3204e72019-04-01 13:50:22 +01002038
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002039#if defined(LHS_TRANSPOSE)
2040
2041#define VTYPE(TYPE, SIZE) VEC_DATA_TYPE(TYPE, SIZE)
2042
2043#if GPU_ARCH == GPU_ARCH_MIDGARD
2044#define ARM_VFMA(SIZE, a, b, c) c += (a) * (b);
2045#else // GPU_ARCH == GPU_ARCH_MIDGARD
2046#define ARM_VFMA_1(a, b, c) \
2047 ({ \
2048 c = fma((a), (b), (c)); \
2049 })
2050#define ARM_VFMA_2(a, b, c) \
2051 ({ \
2052 (c).s0 = fma((a).s0, (b).s0, (c).s0); \
2053 (c).s1 = fma((a).s1, (b).s1, (c).s1); \
2054 })
2055#define ARM_VFMA_3(a, b, c) \
2056 ({ \
2057 ARM_VFMA_2(a, b, c); \
2058 (c).s2 = fma((a).s2, (b).s2, (c).s2); \
2059 })
2060#define ARM_VFMA_4(a, b, c) \
2061 ({ \
2062 ARM_VFMA_3(a, b, c); \
2063 (c).s3 = fma((a).s3, (b).s3, (c).s3); \
2064 })
2065#define ARM_VFMA_8(a, b, c) \
2066 ({ \
2067 ARM_VFMA_4(a, b, c); \
2068 (c).s4 = fma((a).s4, (b).s4, (c).s4); \
2069 (c).s5 = fma((a).s5, (b).s5, (c).s5); \
2070 (c).s6 = fma((a).s6, (b).s6, (c).s6); \
2071 (c).s7 = fma((a).s7, (b).s7, (c).s7); \
2072 })
2073#define ARM_VFMA_16(a, b, c) \
2074 ({ \
2075 ARM_VFMA_8(a, b, c); \
2076 (c).s8 = fma((a).s8, (b).s8, (c).s8); \
2077 (c).s9 = fma((a).s9, (b).s9, (c).s9); \
2078 (c).sA = fma((a).sA, (b).sA, (c).sA); \
2079 (c).sB = fma((a).sB, (b).sB, (c).sB); \
2080 (c).sC = fma((a).sC, (b).sC, (c).sC); \
2081 (c).sD = fma((a).sD, (b).sD, (c).sD); \
2082 (c).sE = fma((a).sE, (b).sE, (c).sE); \
2083 (c).sF = fma((a).sF, (b).sF, (c).sF); \
2084 })
2085
2086// Factory macro for the vector FMA
2087#define ARM_VFMA(SIZE, a, b, c) ARM_VFMA_##SIZE((a), (b), (c))
2088
2089#endif // GPU_ARCH == GPU_ARCH_MIDGARD
2090
2091#define ARM_VVM_T_NT_1xN0x1(N0, TYPE, a, b, C) \
2092 ({ \
2093 ARM_VFMA(N0, (VTYPE(TYPE, N0))(a), b, (C##0)); \
2094 })
2095#define ARM_VVM_T_NT_2xN0x1(N0, TYPE, a, b, C) \
2096 ({ \
2097 ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s0), b, (C##0)); \
2098 ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s1), b, (C##1)); \
2099 })
2100#define ARM_VVM_T_NT_3xN0x1(N0, TYPE, a, b, C) \
2101 ({ \
2102 ARM_VVM_T_NT_2xN0x1(N0, TYPE, a, b, C); \
2103 ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s2), b, (C##2)); \
2104 })
2105#define ARM_VVM_T_NT_4xN0x1(N0, TYPE, a, b, C) \
2106 ({ \
2107 ARM_VVM_T_NT_3xN0x1(N0, TYPE, a, b, C); \
2108 ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s3), b, (C##3)); \
2109 })
2110#define ARM_VVM_T_NT_8xN0x1(N0, TYPE, a, b, C) \
2111 ({ \
2112 ARM_VVM_T_NT_4xN0x1(N0, TYPE, a, b, C); \
2113 ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s4), b, (C##4)); \
2114 ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s5), b, (C##5)); \
2115 ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s6), b, (C##6)); \
2116 ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s7), b, (C##7)); \
2117 })
2118
2119// Factory macro for the column-vector (transposed) by row-vector (not transposed) multiplication. K0 = 1
2120// a is the column-vector (transposed)
2121// b is the row-vector (not transposed)
2122// C is the output matrix
2123// Lower case is a vector (a, b)
2124// Upper case is a matrix (C)
2125#define ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, a, b, C) ARM_VVM_T_NT_##M0##xN0x1(N0, TYPE, a, b, C)
2126
2127#define ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, A, B, C) \
2128 ({ \
2129 ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##0), (B##0), C); \
2130 })
2131#define ARM_MM_T_NT_M0xN0x2(M0, N0, TYPE, A, B, C) \
2132 ({ \
2133 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, A, B, C); \
2134 ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##1), (B##1), C); \
2135 })
2136#define ARM_MM_T_NT_M0xN0x3(M0, N0, TYPE, A, B, C) \
2137 ({ \
2138 ARM_MM_T_NT_M0xN0x2(M0, N0, TYPE, A, B, C); \
2139 ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##2), (B##2), C); \
2140 })
2141#define ARM_MM_T_NT_M0xN0x4(M0, N0, TYPE, A, B, C) \
2142 ({ \
2143 ARM_MM_T_NT_M0xN0x3(M0, N0, TYPE, A, B, C); \
2144 ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##3), (B##3), C); \
2145 })
2146#define ARM_MM_T_NT_M0xN0x8(M0, N0, TYPE, A, B, C) \
2147 ({ \
2148 ARM_MM_T_NT_M0xN0x4(M0, N0, TYPE, A, B, C); \
2149 ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##4), (B##4), C); \
2150 ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##5), (B##5), C); \
2151 ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##6), (B##6), C); \
2152 ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##7), (B##7), C); \
2153 })
2154#define ARM_MM_T_NT_M0xN0x16(M0, N0, TYPE, A, B, C) \
2155 ({ \
2156 ARM_MM_T_NT_M0xN0x8(M0, N0, TYPE, A, B, C); \
2157 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##8), (B##8), C); \
2158 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##9), (B##9), C); \
2159 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##A), (B##A), C); \
2160 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##B), (B##B), C); \
2161 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##C), (B##C), C); \
2162 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##D), (B##D), C); \
2163 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##E), (B##E), C); \
2164 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##F), (B##F), C); \
2165 })
2166
2167// Factory macro for the matrix (transposed) by matrix (not transposed) multiplication.
2168// The dimensions for this matrix multiplications are defined through M0, N0 and K0
2169// The dimensions supported are:
2170// M0: 1, 2, 3, 4, 8
2171// N0: 1, 2, 3, 4, 8, 16
2172// K0: 1, 2, 3, 4, 8, 16
2173// This macro calls the vector-by-matrix macro K0 times
2174// A, B and C are matrices
2175#define ARM_MM_T_NT(M0, N0, K0, TYPE, A, B, C) CONCAT(ARM_MM_T_NT_M0xN0x, K0) \
2176 (M0, N0, TYPE, A, B, C)
2177
2178/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
2179 * The LHS matrix must be reshaped with @ref CLGEMMReshapeLHSMatrixKernel and the M0xK0 must be transposed
2180 * The RHS matrix must be reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the K0xN0 must be NOT transposed
2181 *
2182 * @note LHS_TRANSPOSE should be passed at compile time in order to compile this OpenCL kernel (e.g. -DLHS_TRANSPOSE).
2183 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
2184 * @note The GEMM's dimensions M and N must be passed at compile time using -DM and -DN (e.g. -DM=52 and -DN=90).
2185 * @note The block's dimensions used for reshaping the LHS matrix and the RHS matrix (M0, N0 and K0) must be passed at compile time using -DM0, -DN0 and -DK0 (e.g. -DM0=4, -DN0=8, -DK0=4).
2186 * @note The number of M0xK0 vertical blocks stored on the same output row of the reshaped LHS matrix must be passed at compile time using -DV0 (e.g. -DV0=2)
2187 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (e.g. -DH0=2)
2188 * @note If the M0xK0 blocks in the reshaped LHS matrix have been interleaved, the option -DLHS_INTERLEAVE must passed at compile time.
2189 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
2190 * @note Only the following configurations of M0, N0 and K0 are currently supported:
2191 * - M0 = 2, 3, 4, 8
2192 * - N0 = 2, 3, 4, 8, 16
2193 * - K0 = 2, 3, 4, 8, 16
2194 * - V0 >= 1
2195 * - H0 >= 1
2196 *
2197 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
2198 * The activation function is performed after the bias addition
2199 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
2200 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2201 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2202 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2203 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix NOT reshaped
2204 *
2205 * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: F16/F32
2206 * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes)
2207 * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2208 * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes)
2209 * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2210 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
2211 * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
2212 * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes)
2213 * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2214 * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
2215 * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2216 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
2217 * @param[in] bias_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
2218 * @param[in] bias_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
2219 * @param[in] bias_step_x (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes)
2220 * @param[in] bias_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
2221 * @param[in] bias_step_y (Optional) bias_stride_y * number of elements along Y processed per workitem(in bytes)
2222 * @param[in] bias_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
2223 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
2224 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2225 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2226 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2227 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2228 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2229 * @param[in] k Number of columns in LHS matrix and rows in RHS matrix not reshaped.
2230 * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes)
2231 * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
2232 * @param[in] bias_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
2233 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2234 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
2235 */
2236__kernel void gemm_mm_reshaped_lhs_t_rhs_nt(IMAGE_DECLARATION(lhs),
2237 IMAGE_DECLARATION(rhs),
2238#if defined(BETA)
2239 IMAGE_DECLARATION(bias),
2240#endif // defined(BETA)
2241 IMAGE_DECLARATION(dst),
2242 uint k,
2243 uint lhs_stride_z,
2244 uint rhs_stride_z,
2245#if defined(BETA)
2246 uint bias_stride_z,
2247#endif //defined(BETA)
2248 uint dst_stride_z
2249#if defined(REINTERPRET_OUTPUT_AS_3D)
2250 ,
2251 uint dst_cross_plane_pad
2252#endif // REINTERPRET_OUTPUT_AS_3D
2253 )
2254{
2255 // Block size
2256#define LHS_BLOCK_SIZE ((K0) * (M0))
2257
2258#if defined(LHS_INTERLEAVE)
2259#define LHS_OFFSET_X (M0)
2260#define LHS_STEP_X ((M0) * (V0))
2261#define LHS_STEP_LOOP (1)
2262#else // defined(INTERLEAVE)
2263#define LHS_OFFSET_X (LHS_BLOCK_SIZE)
2264#define LHS_STEP_X (M0)
2265#define LHS_STEP_LOOP (V0)
2266#endif // defined(INTERLEAVE)
2267
2268 // Block size
2269#define RHS_BLOCK_SIZE ((K0) * (N0))
2270
2271 // RHS offset and step X
2272#if defined(RHS_INTERLEAVE)
2273#define RHS_OFFSET_X (N0)
2274#define RHS_STEP_X ((N0) * (H0))
2275#define RHS_STEP_LOOP (1)
2276#else // defined(RHS_INTERLEAVE)
2277#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
2278#define RHS_STEP_X (N0)
2279#define RHS_STEP_LOOP (H0)
2280#endif // defined(RHS_INTERLEAVE)
2281
2282 const uint x = get_global_id(0);
2283 const uint y = get_global_id(1);
2284 const uint z = get_global_id(2);
2285
2286#if defined(DUMMY_WORK_ITEMS)
2287 if((x * N0 >= N) || (y * M0 >= M))
2288 {
2289 return;
2290 }
2291#endif // defined(DUMMY_WORK_ITEMS)
2292
2293 // Compute LHS matrix address
2294 __global uchar *lhs_addr = lhs_ptr + lhs_offset_first_element_in_bytes + (y % V0) * (uint)LHS_OFFSET_X * sizeof(DATA_TYPE) + (y / V0) * (uint)lhs_stride_y + (z * lhs_stride_z);
2295
2296 // Compute RHS matrix address
2297 __global uchar *rhs_addr = rhs_ptr + rhs_offset_first_element_in_bytes + (x % H0) * (uint)RHS_OFFSET_X * sizeof(DATA_TYPE) + (x / (uint)H0) * rhs_stride_y;
2298
2299#if defined(MATRIX_B_DEPTH)
2300 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2301 rhs_addr += (z % MATRIX_B_DEPTH) * rhs_stride_z;
2302#else // defined(MATRIX_B_DEPTH)
2303 rhs_addr += z * rhs_stride_z;
2304#endif // defined(MATRIX_B_DEPTH)
2305
2306 // Initialize the accumulators
2307 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(M0-1)=0;
2308
2309 REPEAT_VAR_INIT_TO_CONST(K0, uint, zlhs, 0); //uint zlhs0=0,zlhs1=0,zlhs2=0,... zlhs7=0;
2310 REPEAT_VAR_INIT_TO_CONST(M0, uint, zero, 0);
2311
2312 for(int i = 0; i < k; i += K0)
2313 {
2314 // Supported cases (K0, M0):
2315 // 1,2 - 2,2 - 3,2 - 4,2 - 5,2 - 6,2 - 7,2 - 8,2
2316 // 1,3 - 2,3 - 3,3 - 4,3 - 5,3 - 6,3 - 7,3 - 8,3
2317 // 1,4 - 2,4 - 3,4 - 4,4 - 5,4 - 6,4 - 7,4 - 8,4
2318 // 1,8 - 2,8 - 3,8 - 4,8 - 5,8 - 6,8 - 7,8 - 8,8
2319 // 1,16 - 2,16 - 3,16 - 4,16 - 5,16 - 6,16 - 7,16 - 8,16
2320 // Load values from LHS matrix
2321 LOAD_BLOCK(K0, M0, DATA_TYPE, a, lhs_addr, 0, LHS_STEP_X * sizeof(DATA_TYPE), zlhs);
2322
2323 // Load values from RHS matrix
2324 LOAD_BLOCK(K0, N0, DATA_TYPE, b, rhs_addr, 0, RHS_STEP_X * sizeof(DATA_TYPE), zlhs);
2325
2326 // Perform the partial matrix multiplication
2327 ARM_MM_T_NT(M0, N0, K0, DATA_TYPE, a, b, c);
2328
2329 lhs_addr += (K0 * LHS_STEP_X * LHS_STEP_LOOP) * sizeof(DATA_TYPE);
2330 rhs_addr += (K0 * RHS_STEP_X * RHS_STEP_LOOP) * sizeof(DATA_TYPE);
2331 }
2332
2333 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * dst_stride_y);
2334
2335 REPEAT_VAR_INIT_TO_CONST(M0, uint, zout, 0);
2336
2337#if defined(REINTERPRET_OUTPUT_AS_3D)
2338
2339 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
2340 CALCULATE_Z_OFFSET(M0, uint, zout, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
2341 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2342 // multiply dst_stride_z by DEPTH_GEMM3D
2343 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
2344
2345#else // defined(REINTERPRET_OUTPUT_AS_3D)
2346
2347 // Add offset for batched GEMM
2348 dst_addr += z * dst_stride_z;
2349
2350#endif // defined(REINTERPRET_OUTPUT_AS_3D)
2351
2352 // Multiply by the weight of matrix-matrix product and store the result
2353#if defined(ALPHA)
2354 SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
2355#endif // defined(ALPHA)
2356
2357 // Add beta*bias
2358#if defined(BETA)
2359#if defined(BROADCAST_BIAS)
2360 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE));
2361
2362 LOAD_BLOCK(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
2363
2364#ifndef UNIT_BETA
2365 SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
2366#endif // UNIT_BIAS
2367
2368 // c = c + bias[broadcasted]
2369 ADD_BLOCK_BROADCAST(M0, c, bias0);
2370
2371#else // defined(BROADCAST_BIAS)
2372 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * bias_stride_y) + z * bias_stride_z;
2373
2374 LOAD_BLOCK(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
2375
2376#ifndef UNIT_BETA
2377 SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
2378#endif // UNIT_BIAS
2379
2380 // c = c + bias
2381 ADD_BLOCK(M0, c, bias);
2382
2383#endif // defined(BROADCAST_BIAS)
2384#endif // defined(BETA)
2385
2386#if defined(ACTIVATION_TYPE)
2387 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE, c, A_VAL, B_VAL);
2388#endif // defined(ACTIVATION_TYPE)
2389
2390 // Store output block
2391 STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
2392
2393#undef LHS_BLOCK_SIZE
2394#undef LHS_OFFSET_X
2395#undef LHS_STEP_X
2396#undef RHS_BLOCK_SIZE
2397#undef RHS_OFFSET_X
2398#undef RHS_STEP_X
2399}
2400
2401#endif // defined(LHS_TRANSPOSE)
2402
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002403#endif // defined(M0) && defined(N0) && defined(K0) && defined(V0) && defined(H0) && defined(K) && defined(DATA_TYPE)
2404
giuros01b3204e72019-04-01 13:50:22 +01002405#if defined(M0) && defined(N0) && defined(K0) && defined(K) && defined(DATA_TYPE)
2406
2407#define VFMA(a, b, c) \
2408 ({ \
2409 c = fma(a, b, c); \
2410 })
2411
2412#if M0 == 1
2413#define RHS_VFMA_M0xN0(i, a, b, c) \
2414 ({ \
2415 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
2416 })
2417#elif M0 == 2 // M0 == 2
2418#define RHS_VFMA_M0xN0(i, a, b, c) \
2419 ({ \
2420 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
2421 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
2422 })
2423#elif M0 == 3 // M0 == 3
2424#define RHS_VFMA_M0xN0(i, a, b, c) \
2425 ({ \
2426 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
2427 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
2428 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
2429 })
2430#elif M0 == 4 // M0 == 4
2431#define RHS_VFMA_M0xN0(i, a, b, c) \
2432 ({ \
2433 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
2434 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
2435 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
2436 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
2437 })
2438#elif M0 == 5 // M0 == 5
2439#define RHS_VFMA_M0xN0(i, a, b, c) \
2440 ({ \
2441 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
2442 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
2443 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
2444 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
2445 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
2446 })
2447#elif M0 == 6 // M0 == 6
2448#define RHS_VFMA_M0xN0(i, a, b, c) \
2449 ({ \
2450 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
2451 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
2452 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
2453 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
2454 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
2455 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
2456 })
2457#elif M0 == 7 // M0 == 7
2458#define RHS_VFMA_M0xN0(i, a, b, c) \
2459 ({ \
2460 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
2461 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
2462 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
2463 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
2464 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
2465 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
2466 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##6).s##i), b, (c##6)); \
2467 })
2468#elif M0 == 8 // M0 == 8
2469#define RHS_VFMA_M0xN0(i, a, b, c) \
2470 ({ \
2471 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
2472 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
2473 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
2474 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
2475 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
2476 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
2477 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##6).s##i), b, (c##6)); \
2478 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##7).s##i), b, (c##7)); \
2479 })
2480#else // M0 not supported
2481#error "M0 not supported"
2482#endif // M0 not supported
2483
2484/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
2485 * The LHS matrix is NOT reshaped
2486 * The RHS matrix is NOT reshaped
2487 *
2488 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002489 * @note The GEMM's dimensions (M,N and K) must be passed at compile time using -DM, -DN and and -DK (e.g. -DM=52, -DN=30 and -DK=90)
2490 * @note The number of columns of LHS matrix must be passed at compile time using -DK (e.g. -DK=64)
2491 * @note The number of M0 rows to process must be passed at compile time using -DM0 (e.g. -DM0=2)
2492 * @note The number of K0 partial accumulations must be passed at compile time using -DK0 (e.g., -DK0=2)
2493 * @note The number of N0 columns to process must be passed at compile time using -DN0 (e.g. -DN0=2)
giuros01b3204e72019-04-01 13:50:22 +01002494 * @note Only the following configurations of M0, N0 and K0 are currently supported:
2495 * - M0 = 1, 2, 3, 4, 5, 6, 7, 8
2496 * - N0 = 2, 3, 4, 8, 16
2497 * - K0 = 2, 3, 4, 8, 16
2498 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002499 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +01002500 * The activation function is performed after the bias addition
giuros01b3204e72019-04-01 13:50:22 +01002501 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
2502 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
2503 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2504 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2505 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2506 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix
2507 *
Gian Marco Iodice944170e2019-06-24 14:40:30 +01002508 * @param[in] lhs_ptr Pointer to the LHS matrix. Supported data type: F16/F32
2509 * @param[in] lhs_stride_x Stride of the LHS matrix in X dimension (in bytes)
2510 * @param[in] lhs_step_x lhs_stride_x * number of elements along X processed per workitem(in bytes)
2511 * @param[in] lhs_stride_y Stride of the LHS matrix in Y dimension (in bytes)
2512 * @param[in] lhs_step_y lhs_stride_y * number of elements along Y processed per workitem(in bytes)
2513 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS matrix
2514 * @param[in] rhs_ptr Pointer to the RHS matrix. Supported data type: same as @p lhs_ptr
2515 * @param[in] rhs_stride_x Stride of the RHS matrix in X dimension (in bytes)
2516 * @param[in] rhs_step_x rhs_stride_x * number of elements along X processed per workitem(in bytes)
2517 * @param[in] rhs_stride_y Stride of the RHS matrix in Y dimension (in bytes)
2518 * @param[in] rhs_step_y rhs_stride_y * number of elements along Y processed per workitem(in bytes)
2519 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS matrix
Gian Marco Iodice944170e2019-06-24 14:40:30 +01002520 * @param[in] bias_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
2521 * @param[in] bias_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
2522 * @param[in] bias_step_x (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes)
2523 * @param[in] bias_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
2524 * @param[in] bias_step_y (Optional) bias_stride_y * number of elements along Y processed per workitem(in bytes)
2525 * @param[in] bias_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
2526 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
2527 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2528 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2529 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2530 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2531 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2532 * @param[in] lhs_stride_z Stride of the LHS matrix in Z dimension (in bytes)
2533 * @param[in] rhs_stride_z Stride of the RHS matrix in Z dimension (in bytes)
2534 * @param[in] bias_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
2535 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2536 * @param[in] lhs_cross_plane_pad (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
2537 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
giuros01b3204e72019-04-01 13:50:22 +01002538 */
2539__kernel void gemm_mm_native(IMAGE_DECLARATION(lhs),
2540 IMAGE_DECLARATION(rhs),
Gian Marco Iodice944170e2019-06-24 14:40:30 +01002541#if defined(BETA)
2542 IMAGE_DECLARATION(bias),
2543#endif // defined(BETA)
giuros01b3204e72019-04-01 13:50:22 +01002544 IMAGE_DECLARATION(dst),
2545 uint lhs_stride_z,
2546 uint rhs_stride_z,
Gian Marco Iodice944170e2019-06-24 14:40:30 +01002547#if defined(BETA)
2548 uint bias_stride_z,
2549#endif //defined(BETA)
giuros01b3204e72019-04-01 13:50:22 +01002550 uint dst_stride_z
2551#if defined(REINTERPRET_INPUT_AS_3D)
2552 ,
2553 uint lhs_cross_plane_pad
2554#endif // REINTERPRET_INPUT_AS_3D
2555#if defined(REINTERPRET_OUTPUT_AS_3D)
2556 ,
2557 uint dst_cross_plane_pad
2558#endif // REINTERPRET_OUTPUT_AS_3D
2559 )
2560{
2561 // Block size
2562#define RHS_BLOCK_SIZE ((K0) * (N0))
2563
2564 // RHS offset and step X
2565#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
2566
2567 uint x = get_global_id(0);
2568 uint y = get_global_id(1);
2569 uint z = get_global_id(2);
2570
2571#if defined(DUMMY_WORK_ITEMS)
2572 if((x * N0 >= N) || (y * M0 >= M))
2573 {
2574 return;
2575 }
2576#endif // defined(DUMMY_WORK_ITEMS)
2577
2578 // Compute LHS matrix address
2579 uint lhs_offset = lhs_offset_first_element_in_bytes + y * M0 * (uint)lhs_stride_y;
2580
2581 // Compute RHS matrix address
2582 uint rhs_offset = rhs_offset_first_element_in_bytes + x * N0 * sizeof(DATA_TYPE);
2583
2584#if defined(MATRIX_B_DEPTH)
2585 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2586 rhs_offset += (z % MATRIX_B_DEPTH) * rhs_stride_z;
2587#else // defined(MATRIX_B_DEPTH)
2588 rhs_offset += z * rhs_stride_z;
2589#endif // defined(MATRIX_B_DEPTH)
2590
Gian Marco Iodice944170e2019-06-24 14:40:30 +01002591 REPEAT_VAR_INIT_TO_CONST(M0, uint, zlhs, 0);
2592 REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0);
giuros01b3204e72019-04-01 13:50:22 +01002593
2594#if defined(REINTERPRET_INPUT_AS_3D)
2595 // The plane (zlhs) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
2596 CALCULATE_Z_OFFSET(M0, uint, zlhs, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
2597
2598 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2599 // multiply lhs_stride_z by DEPTH_GEMM3D
2600 lhs_offset += z * lhs_stride_z * DEPTH_GEMM3D;
2601
2602#else // defined(REINTERPRET_INPUT_AS_3D)
2603
2604 // Add offset for batched GEMM
2605 lhs_offset += z * lhs_stride_z;
2606
2607#endif // defined(REINTERPRET_INPUT_AS_3D)
2608
2609 // Initialize the accumulators
Gian Marco Iodice944170e2019-06-24 14:40:30 +01002610 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(M0-1)=0;
giuros01b3204e72019-04-01 13:50:22 +01002611
2612 int i = 0;
2613 for(; i <= (K - K0); i += K0)
2614 {
2615 // Supported cases (M0, K0):
2616 // 1,2 - 1,3 - 1,4 - 1,8 - 1,16
2617 // 2,2 - 2,3 - 2,4 - 2,8 - 2,16
2618 // 3,2 - 3,3 - 3,4 - 3,8 - 3,16
2619 // 4,2 - 4,3 - 4,4 - 4,8 - 4,16
2620 // 5,2 - 5,3 - 5,4 - 5,8 - 5,16
2621 // 6,2 - 6,3 - 6,4 - 6,8 - 6,16
2622 // 7,2 - 7,3 - 7,4 - 7,8 - 7,16
2623 // 8,2 - 8,3 - 8,4 - 8,8 - 8,16
2624 // Load values from LHS matrix
2625 LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs);
2626
2627 // Load values from RHS matrix
Gian Marco Iodice944170e2019-06-24 14:40:30 +01002628 LOAD_BLOCK(K0, N0, DATA_TYPE, b, rhs_ptr, rhs_offset, rhs_stride_y, zero);
giuros01b3204e72019-04-01 13:50:22 +01002629
2630 RHS_VFMA_M0xN0(0, a, b0, c);
2631 RHS_VFMA_M0xN0(1, a, b1, c);
2632#if K0 > 2
2633 RHS_VFMA_M0xN0(2, a, b2, c);
2634#endif // K0 > 2
2635#if K0 > 3
2636 RHS_VFMA_M0xN0(3, a, b3, c);
2637#endif // K0 > 3
2638#if K0 > 4
2639 RHS_VFMA_M0xN0(4, a, b4, c);
2640 RHS_VFMA_M0xN0(5, a, b5, c);
2641 RHS_VFMA_M0xN0(6, a, b6, c);
2642 RHS_VFMA_M0xN0(7, a, b7, c);
2643#endif // K0 > 4
2644#if K0 > 8
2645 RHS_VFMA_M0xN0(8, a, b8, c);
2646 RHS_VFMA_M0xN0(9, a, b9, c);
Gian Marco Iodice7b9d7ca2019-09-19 16:37:39 +01002647 RHS_VFMA_M0xN0(A, a, bA, c);
2648 RHS_VFMA_M0xN0(B, a, bB, c);
2649 RHS_VFMA_M0xN0(C, a, bC, c);
2650 RHS_VFMA_M0xN0(D, a, bD, c);
2651 RHS_VFMA_M0xN0(E, a, bE, c);
2652 RHS_VFMA_M0xN0(F, a, bF, c);
giuros01b3204e72019-04-01 13:50:22 +01002653#endif // K0 > 8
2654
2655 lhs_offset += K0 * sizeof(DATA_TYPE);
2656 rhs_offset += K0 * rhs_stride_y;
2657 }
2658
2659 // Left-over accumulations
2660 for(; i < K; ++i)
2661 {
2662 // Load values from LHS matrix
2663 VEC_DATA_TYPE(DATA_TYPE, 2)
2664 a0 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 0 * lhs_stride_y + zlhs0));
2665#if M0 > 1
2666 VEC_DATA_TYPE(DATA_TYPE, 2)
2667 a1 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 1 * lhs_stride_y + zlhs1));
2668#endif // M0 > 1
2669#if M0 > 2
2670 VEC_DATA_TYPE(DATA_TYPE, 2)
2671 a2 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 2 * lhs_stride_y + zlhs2));
2672#endif // M0 > 2
2673#if M0 > 3
2674 VEC_DATA_TYPE(DATA_TYPE, 2)
2675 a3 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 3 * lhs_stride_y + zlhs3));
2676#endif // M0 > 3
2677#if M0 > 4
2678 VEC_DATA_TYPE(DATA_TYPE, 2)
2679 a4 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 4 * lhs_stride_y + zlhs4));
2680#endif // M0 > 4
2681#if M0 > 5
2682 VEC_DATA_TYPE(DATA_TYPE, 2)
2683 a5 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 5 * lhs_stride_y + zlhs5));
2684#endif // M0 > 5
2685#if M0 > 6
2686 VEC_DATA_TYPE(DATA_TYPE, 2)
2687 a6 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 6 * lhs_stride_y + zlhs6));
2688#endif // M0 > 6
2689#if M0 > 7
2690 VEC_DATA_TYPE(DATA_TYPE, 2)
2691 a7 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 7 * lhs_stride_y + zlhs7));
2692#endif // M0 > 7
2693
2694 VEC_DATA_TYPE(DATA_TYPE, N0)
2695 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0 * rhs_stride_y));
2696 RHS_VFMA_M0xN0(0, a, b, c);
2697
2698 lhs_offset += sizeof(DATA_TYPE);
2699 rhs_offset += rhs_stride_y;
2700 }
2701
2702 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * dst_stride_y);
2703
Gian Marco Iodice944170e2019-06-24 14:40:30 +01002704 REPEAT_VAR_INIT_TO_CONST(M0, uint, zout, 0);
giuros01b3204e72019-04-01 13:50:22 +01002705
2706#if defined(REINTERPRET_OUTPUT_AS_3D)
2707 // The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
2708 CALCULATE_Z_OFFSET(M0, uint, zout, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
2709
2710 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2711 // multiply dst_stride_z by DEPTH_GEMM3D
2712 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
2713
2714#else // defined(REINTERPRET_OUTPUT_AS_3D)
2715
2716 // Add offset for batched GEMM
2717 dst_addr += z * dst_stride_z;
2718
2719#endif // defined(REINTERPRET_OUTPUT_AS_3D)
2720
2721 // Multiply by the weight of matrix-matrix product and store the result
giuros01b3204e72019-04-01 13:50:22 +01002722#if defined(ALPHA)
2723 SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
2724#endif // defined(ALPHA)
2725
Gian Marco Iodice944170e2019-06-24 14:40:30 +01002726 // Add beta*bias
2727#if defined(BETA)
2728#if defined(BROADCAST_BIAS)
2729 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
2730
2731 LOAD_BLOCK(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
2732
2733#ifndef UNIT_BETA
2734 SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
2735#endif // UNIT_BIAS
2736
2737 // c = c + bias[broadcasted]
2738 ADD_BLOCK_BROADCAST(M0, c, bias0);
2739
2740#else // defined(BROADCAST_BIAS)
2741 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id(
2742 2) * bias_stride_z;
2743
2744 LOAD_BLOCK(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
2745
2746#ifndef UNIT_BETA
2747 SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
2748#endif // UNIT_BIAS
2749
2750 // c = c + bias
2751 ADD_BLOCK(M0, c, bias);
2752
2753#endif // defined(BROADCAST_BIAS)
2754#endif // defined(BETA)
2755
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +01002756#if defined(ACTIVATION_TYPE)
2757 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE, c, A_VAL, B_VAL);
2758#endif // defined(ACTIVATION_TYPE)
2759
giuros01b3204e72019-04-01 13:50:22 +01002760 // Store output block
2761 STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
2762
2763#undef RHS_BLOCK_SIZE
2764#undef RHS_OFFSET_X
2765#undef RHS_STEP_X
2766}
2767#endif // defined(M0) && defined(N0) && defined(K0) && defined(K) && defined(DATA_TYPE)
2768
Gian Marco36a0a462018-01-12 10:21:40 +00002769#if defined(COLS_B) && defined(MULT_TRANSPOSE1XW_WIDTH) && defined(MULT_INTERLEAVE4X4_HEIGHT)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002770/** This OpenCL kernel is optimised for Midgard. It computes the matrix multiplication between matrix A reshaped (src0) and matrix B reshaped (src1)
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002771 *
Gian Marco19835e52018-01-30 13:35:54 +00002772 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002773 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (e.g. -DMULT_TRANSPOSE1XW_WIDTH=2)
2774 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (e.g. -DMULT_INTERLEAVE4X4_HEIGHT=2)
2775 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
2776 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002777 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002778 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
2779 * The activation function is performed after the bias addition
2780 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002781 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2782 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2783 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2784 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
2785 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002786 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
2787 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
2788 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2789 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
2790 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2791 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002792 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002793 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
2794 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2795 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
2796 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2797 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002798 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
2799 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
2800 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
2801 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
2802 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
2803 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002804 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002805 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00002806 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002807 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00002808 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002809 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002810 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
2811 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002812 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002813 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002814 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002815 */
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01002816__kernel void gemm_mm_interleaved_transposed_f32(IMAGE_DECLARATION(src0),
2817 IMAGE_DECLARATION(src1),
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002818#if defined(BETA)
2819 IMAGE_DECLARATION(src2),
2820#endif // defined(BETA)
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01002821 IMAGE_DECLARATION(dst),
2822 uint src0_stride_z,
2823 uint src1_stride_z,
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002824#if defined(BETA)
2825 uint src2_stride_z,
2826#endif //defined(BETA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002827 uint dst_stride_z
2828#if defined(REINTERPRET_OUTPUT_AS_3D)
2829 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002830 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002831#endif // REINTERPRET_OUTPUT_AS_3D
2832 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002833{
Gian Marco36a0a462018-01-12 10:21:40 +00002834 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
2835 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +00002836 int z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002837
Gian Marco36a0a462018-01-12 10:21:40 +00002838 // Offset
2839 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
2840 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 4;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002841
Gian Marco36a0a462018-01-12 10:21:40 +00002842 // src_addr_a = address of matrix A
2843 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002844 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
2845 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
2846
2847#if defined(MATRIX_B_DEPTH)
2848 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2849 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
2850#else // defined(MATRIX_B_DEPTH)
2851 src1_addr_in_bytes += z * src1_stride_z;
2852#endif // defined(MATRIX_B_DEPTH)
2853
2854 __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes);
2855 __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002856
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002857 // Compute end row address for matrix B
Gian Marco36a0a462018-01-12 10:21:40 +00002858 __global float *src_end_addr_b = src_addr_b + COLS_B;
2859
2860 src_addr_a += offset_row_a;
2861 src_addr_b += offset_row_b;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002862
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002863 // Reset accumulators
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002864 float4 c0 = 0.0f;
2865 float4 c1 = 0.0f;
2866 float4 c2 = 0.0f;
2867 float4 c3 = 0.0f;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002868
Gian Marco36a0a462018-01-12 10:21:40 +00002869 for(; src_addr_b <= (src_end_addr_b - (int)(8 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002870 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002871 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00002872 float4 a0 = vload4(0, src_addr_a);
2873 float4 b0 = vload4(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002874
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002875 c0 += (float4)a0.s0 * b0;
2876 c1 += (float4)a0.s1 * b0;
2877 c2 += (float4)a0.s2 * b0;
2878 c3 += (float4)a0.s3 * b0;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002879
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002880 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00002881 a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
2882 b0 = vload4(0, src_addr_b + 4 * MULT_TRANSPOSE1XW_WIDTH);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002883
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002884 c0 += (float4)a0.s0 * b0;
2885 c1 += (float4)a0.s1 * b0;
2886 c2 += (float4)a0.s2 * b0;
2887 c3 += (float4)a0.s3 * b0;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002888 }
2889
Gian Marco36a0a462018-01-12 10:21:40 +00002890 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002891 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002892 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00002893 float4 a0 = vload4(0, src_addr_a);
2894 float4 b0 = vload4(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002895
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002896 c0 += (float4)a0.s0 * b0;
2897 c1 += (float4)a0.s1 * b0;
2898 c2 += (float4)a0.s2 * b0;
2899 c3 += (float4)a0.s3 * b0;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002900 }
2901
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002902 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002903 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2904
Gian Marcoae2af742018-02-15 12:35:44 +00002905 // Compute dst address
2906 __global uchar *dst_addr = offset(&dst, 0, 0);
2907
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002908 uint4 zout = 0;
2909
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002910#if defined(REINTERPRET_OUTPUT_AS_3D)
2911 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002912 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002913 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002914 // | |
2915 // | plane0 |
2916 // | |
2917 // |__________________|
2918 // |******************|
2919 // | cross_plane_pad |
2920 // |******************|
2921 // | |
2922 // | plane1 |
2923 // | |
2924 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002925
2926 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002927 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
2928 zout = min(DEPTH_GEMM3D - 1, zout);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002929
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002930 // Add offset due to the cross plane paddings
2931 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002932
2933 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2934 // multiply dst_stride_z by DEPTH_GEMM3D
2935 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002936#else // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +00002937 // Add offset for batched GEMM
2938 dst_addr += z * dst_stride_z;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002939#endif // defined(REINTERPRET_OUTPUT_AS_3D)
2940
2941 // Multiply by the weight of matrix-matrix product and store the result
2942#if defined(ALPHA)
2943 SCALE_BLOCK(4, float, c, ALPHA);
2944#endif // defined(ALPHA)
2945
2946 // Add beta*bias
2947#if defined(BETA)
2948 REPEAT_VAR_INIT_TO_CONST(4, uint, zero, 0);
2949
2950#if defined(BROADCAST_BIAS)
2951 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float));
2952
2953 LOAD_BLOCK(1, 4, float, bias, src2_addr, 0, src2_stride_y, zero);
2954
2955#ifndef UNIT_BETA
2956 SCALE_BLOCK(1, float, bias, BETA);
2957#endif // UNIT_BIAS
2958
2959 // c = c + bias[broadcasted]
2960 ADD_BLOCK_BROADCAST(4, c, bias0);
2961
2962#else // defined(BROADCAST_BIAS)
2963 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id(
2964 2) * src2_stride_z;
2965
2966 LOAD_BLOCK(4, 4, float, bias, src2_addr, 0, src2_stride_y, zero);
2967
2968#ifndef UNIT_BETA
2969 SCALE_BLOCK(4, float, bias, BETA);
2970#endif // UNIT_BIAS
2971
2972 // c = c + bias
2973 ADD_BLOCK(4, c, bias);
2974
2975#endif // defined(BROADCAST_BIAS)
2976#endif // defined(BETA)
2977
2978#if defined(ACTIVATION_TYPE)
2979 ACTIVATION_BLOCK(4, ACTIVATION_TYPE, float, c, A_VAL, B_VAL);
2980#endif // defined(ACTIVATION_TYPE)
Gian Marcoae2af742018-02-15 12:35:44 +00002981
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002982 // Store 4x4 block
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002983 vstore4(c0, 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
2984 vstore4(c1, 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
2985 vstore4(c2, 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
2986 vstore4(c3, 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002987}
2988
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002989/** This OpenCL kernel is optimized for Bifrost and tt computes the matrix multiplication between matrix A reshaped (src0) and matrix B reshaped (src1)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002990 *
Gian Marco19835e52018-01-30 13:35:54 +00002991 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002992 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (e.g. -DMULT_TRANSPOSE1XW_WIDTH=2)
2993 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (e.g. -DMULT_INTERLEAVE4X4_HEIGHT=2)
2994 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (e.g. -DMULT_INTERLEAVE4X4_HEIGHT=2)
2995 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
2996 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002997 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002998 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
2999 * The activation function is performed after the bias addition
3000 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003001 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
3002 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
3003 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
3004 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
3005 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003006 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
3007 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
3008 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3009 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
3010 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3011 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003012 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003013 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
3014 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3015 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
3016 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3017 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003018 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
3019 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
3020 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
3021 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
3022 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
3023 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003024 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003025 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00003026 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003027 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00003028 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003029 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003030 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
3031 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003032 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003033 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003034 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003035 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003036__kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0),
3037 IMAGE_DECLARATION(src1),
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003038#if defined(BETA)
3039 IMAGE_DECLARATION(src2),
3040#endif // defined(BETA)
Gian Marcoae2af742018-02-15 12:35:44 +00003041 IMAGE_DECLARATION(dst),
3042 uint src0_stride_z,
3043 uint src1_stride_z,
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003044#if defined(BETA)
3045 uint src2_stride_z,
3046#endif //defined(BETA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003047 uint dst_stride_z
3048#if defined(REINTERPRET_OUTPUT_AS_3D)
3049 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003050 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003051#endif // REINTERPRET_OUTPUT_AS_3D
3052 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003053{
Gian Marco36a0a462018-01-12 10:21:40 +00003054 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
3055 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +00003056 int z = get_global_id(2);
Gian Marco36a0a462018-01-12 10:21:40 +00003057
3058 // Offset
3059 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
3060 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 4;
3061
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003062 // src_addr_a = address of matrix A
3063 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003064 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
3065 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
3066
3067#if defined(MATRIX_B_DEPTH)
3068 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
3069 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
3070#else // defined(MATRIX_B_DEPTH)
3071 src1_addr_in_bytes += z * src1_stride_z;
3072#endif // defined(MATRIX_B_DEPTH)
3073
3074 __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes);
3075 __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003076
Gian Marco36a0a462018-01-12 10:21:40 +00003077 src_addr_a += offset_row_a;
3078 src_addr_b += offset_row_b;
3079
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003080 // Reset accumulators
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003081 float4 c0 = 0.0f;
3082 float4 c1 = 0.0f;
3083 float4 c2 = 0.0f;
3084 float4 c3 = 0.0f;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003085
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003086#define COLS_MTX_B (COLS_B / (4 * MULT_TRANSPOSE1XW_WIDTH))
3087
3088 int i = 0;
3089 for(; i <= (int)(COLS_MTX_B - 4); i += 4)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003090 {
3091 // Load values from matrix A (interleaved) and matrix B (transposed)
3092 float4 a0 = vload4(0, src_addr_a);
3093 float4 b0 = vload4(0, src_addr_b);
3094
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003095 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3096 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003097
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003098 c0.s0 = fma(a0.s0, b0.s0, c0.s0);
3099 c0.s1 = fma(a0.s0, b0.s1, c0.s1);
3100 c0.s2 = fma(a0.s0, b0.s2, c0.s2);
3101 c0.s3 = fma(a0.s0, b0.s3, c0.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003102
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003103 c1.s0 = fma(a0.s1, b0.s0, c1.s0);
3104 c1.s1 = fma(a0.s1, b0.s1, c1.s1);
3105 c1.s2 = fma(a0.s1, b0.s2, c1.s2);
3106 c1.s3 = fma(a0.s1, b0.s3, c1.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003107
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003108 c2.s0 = fma(a0.s2, b0.s0, c2.s0);
3109 c2.s1 = fma(a0.s2, b0.s1, c2.s1);
3110 c2.s2 = fma(a0.s2, b0.s2, c2.s2);
3111 c2.s3 = fma(a0.s2, b0.s3, c2.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003112
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003113 c3.s0 = fma(a0.s3, b0.s0, c3.s0);
3114 c3.s1 = fma(a0.s3, b0.s1, c3.s1);
3115 c3.s2 = fma(a0.s3, b0.s2, c3.s2);
3116 c3.s3 = fma(a0.s3, b0.s3, c3.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003117
3118 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003119 a0 = vload4(0, src_addr_a);
3120 b0 = vload4(0, src_addr_b);
3121
3122 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3123 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003124
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003125 c0.s0 = fma(a0.s0, b0.s0, c0.s0);
3126 c0.s1 = fma(a0.s0, b0.s1, c0.s1);
3127 c0.s2 = fma(a0.s0, b0.s2, c0.s2);
3128 c0.s3 = fma(a0.s0, b0.s3, c0.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003129
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003130 c1.s0 = fma(a0.s1, b0.s0, c1.s0);
3131 c1.s1 = fma(a0.s1, b0.s1, c1.s1);
3132 c1.s2 = fma(a0.s1, b0.s2, c1.s2);
3133 c1.s3 = fma(a0.s1, b0.s3, c1.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003134
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003135 c2.s0 = fma(a0.s2, b0.s0, c2.s0);
3136 c2.s1 = fma(a0.s2, b0.s1, c2.s1);
3137 c2.s2 = fma(a0.s2, b0.s2, c2.s2);
3138 c2.s3 = fma(a0.s2, b0.s3, c2.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003139
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003140 c3.s0 = fma(a0.s3, b0.s0, c3.s0);
3141 c3.s1 = fma(a0.s3, b0.s1, c3.s1);
3142 c3.s2 = fma(a0.s3, b0.s2, c3.s2);
3143 c3.s3 = fma(a0.s3, b0.s3, c3.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003144
3145 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003146 a0 = vload4(0, src_addr_a);
3147 b0 = vload4(0, src_addr_b);
3148
3149 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3150 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
3151
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003152 c0.s0 = fma(a0.s0, b0.s0, c0.s0);
3153 c0.s1 = fma(a0.s0, b0.s1, c0.s1);
3154 c0.s2 = fma(a0.s0, b0.s2, c0.s2);
3155 c0.s3 = fma(a0.s0, b0.s3, c0.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003156
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003157 c1.s0 = fma(a0.s1, b0.s0, c1.s0);
3158 c1.s1 = fma(a0.s1, b0.s1, c1.s1);
3159 c1.s2 = fma(a0.s1, b0.s2, c1.s2);
3160 c1.s3 = fma(a0.s1, b0.s3, c1.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003161
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003162 c2.s0 = fma(a0.s2, b0.s0, c2.s0);
3163 c2.s1 = fma(a0.s2, b0.s1, c2.s1);
3164 c2.s2 = fma(a0.s2, b0.s2, c2.s2);
3165 c2.s3 = fma(a0.s2, b0.s3, c2.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003166
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003167 c3.s0 = fma(a0.s3, b0.s0, c3.s0);
3168 c3.s1 = fma(a0.s3, b0.s1, c3.s1);
3169 c3.s2 = fma(a0.s3, b0.s2, c3.s2);
3170 c3.s3 = fma(a0.s3, b0.s3, c3.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003171
3172 // Load values from matrix A (interleaved) and matrix B (transposed)
3173 a0 = vload4(0, src_addr_a);
3174 b0 = vload4(0, src_addr_b);
3175
3176 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3177 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003178
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003179 c0.s0 = fma(a0.s0, b0.s0, c0.s0);
3180 c0.s1 = fma(a0.s0, b0.s1, c0.s1);
3181 c0.s2 = fma(a0.s0, b0.s2, c0.s2);
3182 c0.s3 = fma(a0.s0, b0.s3, c0.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003183
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003184 c1.s0 = fma(a0.s1, b0.s0, c1.s0);
3185 c1.s1 = fma(a0.s1, b0.s1, c1.s1);
3186 c1.s2 = fma(a0.s1, b0.s2, c1.s2);
3187 c1.s3 = fma(a0.s1, b0.s3, c1.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003188
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003189 c2.s0 = fma(a0.s2, b0.s0, c2.s0);
3190 c2.s1 = fma(a0.s2, b0.s1, c2.s1);
3191 c2.s2 = fma(a0.s2, b0.s2, c2.s2);
3192 c2.s3 = fma(a0.s2, b0.s3, c2.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003193
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003194 c3.s0 = fma(a0.s3, b0.s0, c3.s0);
3195 c3.s1 = fma(a0.s3, b0.s1, c3.s1);
3196 c3.s2 = fma(a0.s3, b0.s2, c3.s2);
3197 c3.s3 = fma(a0.s3, b0.s3, c3.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003198 }
3199
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003200 for(; i < (int)(COLS_MTX_B); ++i)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003201 {
3202 // Load values from matrix A (interleaved) and matrix B (transposed)
3203 float4 a0 = vload4(0, src_addr_a);
3204 float4 b0 = vload4(0, src_addr_b);
3205
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003206 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3207 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
3208
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003209 c0.s0 = fma(a0.s0, b0.s0, c0.s0);
3210 c0.s1 = fma(a0.s0, b0.s1, c0.s1);
3211 c0.s2 = fma(a0.s0, b0.s2, c0.s2);
3212 c0.s3 = fma(a0.s0, b0.s3, c0.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003213
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003214 c1.s0 = fma(a0.s1, b0.s0, c1.s0);
3215 c1.s1 = fma(a0.s1, b0.s1, c1.s1);
3216 c1.s2 = fma(a0.s1, b0.s2, c1.s2);
3217 c1.s3 = fma(a0.s1, b0.s3, c1.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003218
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003219 c2.s0 = fma(a0.s2, b0.s0, c2.s0);
3220 c2.s1 = fma(a0.s2, b0.s1, c2.s1);
3221 c2.s2 = fma(a0.s2, b0.s2, c2.s2);
3222 c2.s3 = fma(a0.s2, b0.s3, c2.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003223
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003224 c3.s0 = fma(a0.s3, b0.s0, c3.s0);
3225 c3.s1 = fma(a0.s3, b0.s1, c3.s1);
3226 c3.s2 = fma(a0.s3, b0.s2, c3.s2);
3227 c3.s3 = fma(a0.s3, b0.s3, c3.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003228 }
3229
3230 // Compute destination address
3231 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
3232
Gian Marcoae2af742018-02-15 12:35:44 +00003233 // Compute dst address
3234 __global uchar *dst_addr = offset(&dst, 0, 0);
3235
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003236 uint4 zout = 0;
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003237
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003238#if defined(REINTERPRET_OUTPUT_AS_3D)
3239 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003240 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003241 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003242 // | |
3243 // | plane0 |
3244 // | |
3245 // |__________________|
3246 // |******************|
3247 // | cross_plane_pad |
3248 // |******************|
3249 // | |
3250 // | plane1 |
3251 // | |
3252 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003253
3254 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003255 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
3256 zout = min(DEPTH_GEMM3D - 1, zout);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003257
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003258 // Add offset due to the cross plane paddings
3259 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003260
3261 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3262 // multiply dst_stride_z by DEPTH_GEMM3D
3263 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003264#else // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +00003265 // Add offset for batched GEMM
3266 dst_addr += z * dst_stride_z;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003267#endif // defined(REINTERPRET_OUTPUT_AS_3D)
3268
3269 // Multiply by the weight of matrix-matrix product and store the result
3270#if defined(ALPHA)
3271 SCALE_BLOCK(4, float, c, ALPHA);
3272#endif // defined(ALPHA)
3273
3274 // Add beta*bias
3275#if defined(BETA)
3276 REPEAT_VAR_INIT_TO_CONST(4, uint, zero, 0);
3277
3278#if defined(BROADCAST_BIAS)
3279 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float));
3280
3281 LOAD_BLOCK(1, 4, float, bias, src2_addr, 0, src2_stride_y, zero);
3282
3283#ifndef UNIT_BETA
3284 SCALE_BLOCK(1, float, bias, BETA);
3285#endif // UNIT_BIAS
3286
3287 // c = c + bias[broadcasted]
3288 ADD_BLOCK_BROADCAST(4, c, bias0);
3289
3290#else // defined(BROADCAST_BIAS)
3291 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id(
3292 2) * src2_stride_z;
3293
3294 LOAD_BLOCK(4, 4, float, bias, src2_addr, 0, src2_stride_y, zero);
3295
3296#ifndef UNIT_BETA
3297 SCALE_BLOCK(4, float, bias, BETA);
3298#endif // UNIT_BIAS
3299
3300 // c = c + bias
3301 ADD_BLOCK(4, c, bias);
3302
3303#endif // defined(BROADCAST_BIAS)
3304#endif // defined(BETA)
3305
3306#if defined(ACTIVATION_TYPE)
3307 ACTIVATION_BLOCK(4, ACTIVATION_TYPE, float, c, A_VAL, B_VAL);
3308#endif // defined(ACTIVATION_TYPE)
Gian Marcoae2af742018-02-15 12:35:44 +00003309
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003310 // Store 4x4 block
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003311 vstore4(c0, 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
3312 vstore4(c1, 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
3313 vstore4(c2, 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
3314 vstore4(c3, 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003315}
3316
Georgios Pinitas84225582018-05-14 12:00:05 +01003317// Undefine local defines
3318#undef COLS_MTX_B
3319
Matthew Bentham6f31f8c2017-10-27 11:50:06 +01003320#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003321/** This OpenCL kernel computes the matrix multiplication between matrix A reshaped (src0) and matrix B reshaped (src1)
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003322 *
Gian Marco19835e52018-01-30 13:35:54 +00003323 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003324 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (e.g. -DMULT_TRANSPOSE1XW_WIDTH=2)
3325 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (e.g. -DMULT_INTERLEAVE4X4_HEIGHT=2)
3326 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
3327 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003328 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003329 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
3330 * The activation function is performed after the bias addition
3331 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003332 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
3333 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
3334 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
3335 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
3336 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003337 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
3338 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
3339 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3340 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
3341 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3342 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003343 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003344 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
3345 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3346 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
3347 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3348 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003349 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
3350 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
3351 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
3352 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
3353 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
3354 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003355 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003356 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00003357 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003358 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00003359 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003360 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003361 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
3362 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003363 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003364 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003365 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003366 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003367__kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0),
3368 IMAGE_DECLARATION(src1),
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003369#if defined(BETA)
3370 IMAGE_DECLARATION(src2),
3371#endif // defined(BETA)
Gian Marcoae2af742018-02-15 12:35:44 +00003372 IMAGE_DECLARATION(dst),
3373 uint src0_stride_z,
3374 uint src1_stride_z,
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003375#if defined(BETA)
3376 uint src2_stride_z,
3377#endif //defined(BETA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003378 uint dst_stride_z
3379#if defined(REINTERPRET_OUTPUT_AS_3D)
3380 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003381 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003382#endif // REINTERPRET_OUTPUT_AS_3D
3383 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003384{
Gian Marco36a0a462018-01-12 10:21:40 +00003385 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
3386 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +00003387 int z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003388
Gian Marco36a0a462018-01-12 10:21:40 +00003389 // Offset
3390 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
3391 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003392
Gian Marco36a0a462018-01-12 10:21:40 +00003393 // src_addr_a = address of matrix A
3394 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003395 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
3396 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
3397
3398#if defined(MATRIX_B_DEPTH)
3399 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
3400 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
3401#else // defined(MATRIX_B_DEPTH)
3402 src1_addr_in_bytes += z * src1_stride_z;
3403#endif // defined(MATRIX_B_DEPTH)
3404
3405 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
3406 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003407
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003408 // Compute end row address for matrix B
Gian Marco36a0a462018-01-12 10:21:40 +00003409 __global half *src_end_addr_b = src_addr_b + COLS_B;
3410
3411 src_addr_a += offset_row_a;
3412 src_addr_b += offset_row_b;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003413
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003414 // Reset accumulators
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003415 half8 c0 = 0.0f;
3416 half8 c1 = 0.0f;
3417 half8 c2 = 0.0f;
3418 half8 c3 = 0.0f;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003419
Gian Marco36a0a462018-01-12 10:21:40 +00003420 for(; src_addr_b <= (src_end_addr_b - (int)(16 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 16 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003421 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003422 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00003423 half4 a0 = vload4(0, src_addr_a);
3424 half8 b0 = vload8(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003425
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003426 c0 += (half8)a0.s0 * b0;
3427 c1 += (half8)a0.s1 * b0;
3428 c2 += (half8)a0.s2 * b0;
3429 c3 += (half8)a0.s3 * b0;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003430
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003431 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00003432 a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
3433 b0 = vload8(0, src_addr_b + 8 * MULT_TRANSPOSE1XW_WIDTH);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003434
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003435 c0 += (half8)a0.s0 * b0;
3436 c1 += (half8)a0.s1 * b0;
3437 c2 += (half8)a0.s2 * b0;
3438 c3 += (half8)a0.s3 * b0;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003439 }
3440
Gian Marco36a0a462018-01-12 10:21:40 +00003441 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003442 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003443 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00003444 half4 a0 = vload4(0, src_addr_a);
3445 half8 b0 = vload8(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003446
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003447 c0 += (half8)a0.s0 * b0;
3448 c1 += (half8)a0.s1 * b0;
3449 c2 += (half8)a0.s2 * b0;
3450 c3 += (half8)a0.s3 * b0;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003451 }
3452
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003453 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003454 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
3455
Gian Marcoae2af742018-02-15 12:35:44 +00003456 // Compute dst address
3457 __global uchar *dst_addr = offset(&dst, 0, 0);
3458
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003459 uint4 zout = 0;
3460
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003461#if defined(REINTERPRET_OUTPUT_AS_3D)
3462 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003463 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003464 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003465 // | |
3466 // | plane0 |
3467 // | |
3468 // |__________________|
3469 // |******************|
3470 // | cross_plane_pad |
3471 // |******************|
3472 // | |
3473 // | plane1 |
3474 // | |
3475 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003476
3477 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003478 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
3479 zout = min(DEPTH_GEMM3D - 1, zout);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003480
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003481 // Add offset due to the cross plane paddings
3482 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003483
3484 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3485 // multiply dst_stride_z by DEPTH_GEMM3D
3486 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003487#else // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +00003488 // Add offset for batched GEMM
3489 dst_addr += z * dst_stride_z;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003490#endif // defined(REINTERPRET_OUTPUT_AS_3D)
3491
3492 // Multiply by the weight of matrix-matrix product and store the result
3493#if defined(ALPHA)
3494 SCALE_BLOCK(4, half, c, ALPHA);
3495#endif // defined(ALPHA)
3496
3497 // Add beta*bias
3498#if defined(BETA)
3499 REPEAT_VAR_INIT_TO_CONST(4, uint, zero, 0);
3500
3501#if defined(BROADCAST_BIAS)
3502 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half));
3503
3504 LOAD_BLOCK(1, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
3505
3506#ifndef UNIT_BETA
3507 SCALE_BLOCK(1, half, bias, BETA);
3508#endif // UNIT_BIAS
3509
3510 // c = c + bias[broadcasted]
3511 ADD_BLOCK_BROADCAST(4, c, bias0);
3512
3513#else // defined(BROADCAST_BIAS)
3514
3515 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id(
3516 2) * src2_stride_z;
3517
3518 LOAD_BLOCK(4, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
3519
3520#ifndef UNIT_BETA
3521 SCALE_BLOCK(4, half, bias, BETA);
3522#endif // UNIT_BIAS
3523
3524 // c = c + bias
3525 ADD_BLOCK(4, c, bias);
3526
3527#endif // defined(BROADCAST_BIAS)
3528#endif // defined(BETA)
3529
3530#if defined(ACTIVATION_TYPE)
3531 ACTIVATION_BLOCK(4, ACTIVATION_TYPE, half, c, A_VAL, B_VAL);
3532#endif // defined(ACTIVATION_TYPE)
Gian Marcoae2af742018-02-15 12:35:44 +00003533
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003534 // Store 4x8 block
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003535 vstore8(c0, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
3536 vstore8(c1, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
3537 vstore8(c2, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
3538 vstore8(c3, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003539}
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003540
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003541/** This OpenCL kernel computes the matrix multiplication between matrix A reshaped (src0) and matrix B reshaped (src1) while accumulating the result in a 32 floating point variable.
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003542 *
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003543 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003544 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (e.g. -DMULT_TRANSPOSE1XW_WIDTH=2)
3545 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (e.g. -DMULT_INTERLEAVE4X4_HEIGHT=2)
3546 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
3547 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003548 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003549 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
3550 * The activation function is performed after the bias addition
3551 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003552 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
3553 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
3554 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
3555 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
3556 *
3557 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
3558 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
3559 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3560 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
3561 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3562 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
3563 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
3564 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
3565 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3566 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
3567 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3568 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003569 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
3570 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
3571 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
3572 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
3573 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
3574 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003575 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
3576 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
3577 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
3578 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
3579 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
3580 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
3581 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
3582 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003583 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003584 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
3585 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
3586 */
3587__kernel void gemm_mm_interleaved_transposed_f16_acc32(IMAGE_DECLARATION(src0),
3588 IMAGE_DECLARATION(src1),
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003589#if defined(BETA)
3590 IMAGE_DECLARATION(src2),
3591#endif // defined(BETA)
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003592 IMAGE_DECLARATION(dst),
3593 uint src0_stride_z,
3594 uint src1_stride_z,
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003595#if defined(BETA)
3596 uint src2_stride_z,
3597#endif //defined(BETA)
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003598 uint dst_stride_z
3599#if defined(REINTERPRET_OUTPUT_AS_3D)
3600 ,
3601 uint cross_plane_pad
3602#endif // REINTERPRET_OUTPUT_AS_3D
3603 )
3604{
3605 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
3606 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
3607 int z = get_global_id(2);
3608
3609 // Offset
3610 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
3611 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
3612
3613 // src_addr_a = address of matrix A
3614 // src_addr_b = address of matrix B
3615 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
3616 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
3617
3618#if defined(MATRIX_B_DEPTH)
3619 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
3620 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
3621#else // defined(MATRIX_B_DEPTH)
3622 src1_addr_in_bytes += z * src1_stride_z;
3623#endif // defined(MATRIX_B_DEPTH)
3624
3625 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
3626 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
3627
3628 // Compute end row address for matrix B
3629 __global half *src_end_addr_b = src_addr_b + COLS_B;
3630
3631 src_addr_a += offset_row_a;
3632 src_addr_b += offset_row_b;
3633
3634 // Reset accumulators
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003635 float8 c0 = 0.0f;
3636 float8 c1 = 0.0f;
3637 float8 c2 = 0.0f;
3638 float8 c3 = 0.0f;
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003639
3640 for(; src_addr_b <= (src_end_addr_b - (int)(16 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 16 * MULT_TRANSPOSE1XW_WIDTH)
3641 {
3642 // Load values from matrix A (interleaved) and matrix B (transposed)
3643 float4 a0 = convert_float4(vload4(0, src_addr_a));
3644 float8 b0 = convert_float8(vload8(0, src_addr_b));
3645
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003646 c0 += (float8)a0.s0 * b0;
3647 c1 += (float8)a0.s1 * b0;
3648 c2 += (float8)a0.s2 * b0;
3649 c3 += (float8)a0.s3 * b0;
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003650
3651 // Load values from matrix A (interleaved) and matrix B (transposed)
3652 a0 = convert_float4(vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT));
3653 b0 = convert_float8(vload8(0, src_addr_b + 8 * MULT_TRANSPOSE1XW_WIDTH));
3654
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003655 c0 += (float8)a0.s0 * b0;
3656 c1 += (float8)a0.s1 * b0;
3657 c2 += (float8)a0.s2 * b0;
3658 c3 += (float8)a0.s3 * b0;
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003659 }
3660
3661 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
3662 {
3663 // Load values from matrix A (interleaved) and matrix B (transposed)
3664 float4 a0 = convert_float4(vload4(0, src_addr_a));
3665 float8 b0 = convert_float8(vload8(0, src_addr_b));
3666
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003667 c0 += (float8)a0.s0 * b0;
3668 c1 += (float8)a0.s1 * b0;
3669 c2 += (float8)a0.s2 * b0;
3670 c3 += (float8)a0.s3 * b0;
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003671 }
3672
3673 // Compute destination address
3674 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
3675
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003676 // Compute dst address
3677 __global uchar *dst_addr = offset(&dst, 0, 0);
3678
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003679 uint4 zout = 0;
3680
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003681#if defined(REINTERPRET_OUTPUT_AS_3D)
3682 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
3683 // in order to take into account the presence of possible cross plane paddings
3684 //
3685 // | |
3686 // | plane0 |
3687 // | |
3688 // |__________________|
3689 // |******************|
3690 // | cross_plane_pad |
3691 // |******************|
3692 // | |
3693 // | plane1 |
3694 // | |
3695 // |__________________|
3696
3697 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003698 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
3699 zout = min(DEPTH_GEMM3D - 1, zout);
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003700
3701 // Add offset due to the cross plane paddings
3702 zout *= (cross_plane_pad * dst_stride_y);
3703
3704 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3705 // multiply dst_stride_z by DEPTH_GEMM3D
3706 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003707#else // defined(REINTERPRET_OUTPUT_AS_3D)
3708 // Add offset for batched GEMM
3709 dst_addr += z * dst_stride_z;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003710#endif // defined(REINTERPRET_OUTPUT_AS_3D)
3711
3712 // Multiply by the weight of matrix-matrix product and store the result
3713#if defined(ALPHA)
3714 SCALE_BLOCK(4, float, c, ALPHA);
3715#endif // defined(ALPHA)
3716
3717#if defined(BETA)
3718 REPEAT_VAR_INIT_TO_CONST(4, uint, zero, 0);
3719
3720#if defined(BROADCAST_BIAS)
3721 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half));
3722
3723 LOAD_BLOCK(1, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
3724
3725 float8 bias_f0 = convert_float8(bias0);
3726
3727#ifndef UNIT_BETA
3728 SCALE_BLOCK(1, float, bias_f, BETA);
3729#endif // UNIT_BIAS
3730
3731 // c = c + bias[broadcasted]
3732 ADD_BLOCK_BROADCAST(4, c, bias_f0);
3733
3734#else // defined(BROADCAST_BIAS)
3735 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id(
3736 2) * src2_stride_z;
3737
3738 LOAD_BLOCK(4, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
3739
3740 float8 bias_f0 = convert_float8(bias0);
3741 float8 bias_f1 = convert_float8(bias1);
3742 float8 bias_f2 = convert_float8(bias2);
3743 float8 bias_f3 = convert_float8(bias3);
3744
3745#ifndef UNIT_BETA
3746 SCALE_BLOCK(4, float, bias_f, BETA);
3747#endif // UNIT_BIAS
3748
3749 // c = c + bias
3750 ADD_BLOCK(4, c, bias_f);
3751
3752#endif // defined(BROADCAST_BIAS)
3753#endif // defined(BETA)
3754
3755 half8 c_h0 = convert_half8(c0);
3756 half8 c_h1 = convert_half8(c1);
3757 half8 c_h2 = convert_half8(c2);
3758 half8 c_h3 = convert_half8(c3);
3759
3760#if defined(ACTIVATION_TYPE)
3761 ACTIVATION_BLOCK(4, ACTIVATION_TYPE, half, c_h, A_VAL, B_VAL);
3762#endif // defined(ACTIVATION_TYPE)
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003763
3764 // Store 4x8 block
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003765 vstore8(c_h0, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
3766 vstore8(c_h1, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
3767 vstore8(c_h2, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
3768 vstore8(c_h3, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003769}
3770
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003771/** This OpenCL kernel optimized for Bifrost architectures computes the matrix multiplication between matrix A reshaped (src0) and matrix B reshaped (src1)
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003772 *
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003773 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003774 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (e.g. -DMULT_TRANSPOSE1XW_WIDTH=2)
3775 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (e.g. -DMULT_INTERLEAVE4X4_HEIGHT=2)
3776 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
3777 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003778 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003779 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
3780 * The activation function is performed after the bias addition
3781 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003782 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
3783 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
3784 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
3785 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
3786 *
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003787 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
3788 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
3789 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3790 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
3791 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3792 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
3793 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
3794 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
3795 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3796 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
3797 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3798 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003799 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
3800 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
3801 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
3802 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
3803 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
3804 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003805 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
3806 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
3807 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
3808 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
3809 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
3810 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003811 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
3812 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
3813 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003814 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003815 */
3816__kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0),
3817 IMAGE_DECLARATION(src1),
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003818#if defined(BETA)
3819 IMAGE_DECLARATION(src2),
3820#endif // defined(BETA)
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003821 IMAGE_DECLARATION(dst),
3822 uint src0_stride_z,
3823 uint src1_stride_z,
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003824#if defined(BETA)
3825 uint src2_stride_z,
3826#endif //defined(BETA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003827 uint dst_stride_z
3828#if defined(REINTERPRET_OUTPUT_AS_3D)
3829 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003830 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003831#endif // REINTERPRET_OUTPUT_AS_3D
3832 )
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003833{
3834 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
3835 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
3836 int z = get_global_id(2);
3837
3838 // Offset
3839 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
3840 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
3841
3842 // src_addr_a = address of matrix A
3843 // src_addr_b = address of matrix B
3844 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
3845 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
3846
3847#if defined(MATRIX_B_DEPTH)
3848 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
3849 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
3850#else // defined(MATRIX_B_DEPTH)
3851 src1_addr_in_bytes += z * src1_stride_z;
3852#endif // defined(MATRIX_B_DEPTH)
3853
3854 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
3855 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
3856
3857 // Compute end row address for matrix B
3858 __global half *src_end_addr_b = src_addr_b + COLS_B;
3859
3860 src_addr_a += offset_row_a;
3861 src_addr_b += offset_row_b;
3862
3863 // Reset accumulators
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003864 half8 c0 = 0.0f;
3865 half8 c1 = 0.0f;
3866 half8 c2 = 0.0f;
3867 half8 c3 = 0.0f;
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003868
3869#define COLS_MTX_B (COLS_B / (8 * MULT_TRANSPOSE1XW_WIDTH))
3870
3871 int i = 0;
3872 for(; i <= (int)(COLS_MTX_B - 4); i += 4)
3873 {
3874#if MULT_INTERLEAVE4X4_HEIGHT == 1
3875 // Load values from matrix A (interleaved) and matrix B (transposed)
3876 half8 a0 = vload8(0, src_addr_a);
3877 half8 b0 = vload8(0, src_addr_b);
3878
3879 src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT;
3880 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3881
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003882 c0 = fma((half8)a0.s0, b0, c0);
3883 c1 = fma((half8)a0.s1, b0, c1);
3884 c2 = fma((half8)a0.s2, b0, c2);
3885 c3 = fma((half8)a0.s3, b0, c3);
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003886
3887 // Load values from matrix B (transposed)
3888 b0 = vload8(0, src_addr_b);
3889
3890 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3891
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003892 c0 = fma((half8)a0.s4, b0, c0);
3893 c1 = fma((half8)a0.s5, b0, c1);
3894 c2 = fma((half8)a0.s6, b0, c2);
3895 c3 = fma((half8)a0.s7, b0, c3);
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003896
3897 // Load values from matrix A (interleaved) and matrix B (transposed)
3898 a0 = vload8(0, src_addr_a);
3899 b0 = vload8(0, src_addr_b);
3900
3901 src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT;
3902 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3903
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003904 c0 = fma((half8)a0.s0, b0, c0);
3905 c1 = fma((half8)a0.s1, b0, c1);
3906 c2 = fma((half8)a0.s2, b0, c2);
3907 c3 = fma((half8)a0.s3, b0, c3);
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003908
3909 // Load values from matrix B (transposed)
3910 b0 = vload8(0, src_addr_b);
3911
3912 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3913
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003914 c0 = fma((half8)a0.s4, b0, c0);
3915 c1 = fma((half8)a0.s5, b0, c1);
3916 c2 = fma((half8)a0.s6, b0, c2);
3917 c3 = fma((half8)a0.s7, b0, c3);
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003918#else // MULT_INTERLEAVE4X4_HEIGHT == 1
3919 // Load values from matrix A (interleaved) and matrix B (transposed)
3920 half4 a0 = vload4(0, src_addr_a);
3921 half8 b0 = vload8(0, src_addr_b);
3922
3923 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3924 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3925
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003926 c0 = fma((half8)a0.s0, b0, c0);
3927 c1 = fma((half8)a0.s1, b0, c1);
3928 c2 = fma((half8)a0.s2, b0, c2);
3929 c3 = fma((half8)a0.s3, b0, c3);
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003930
3931 // Load values from matrix A (interleaved) and matrix B (transposed)
3932 a0 = vload4(0, src_addr_a);
3933 b0 = vload8(0, src_addr_b);
3934
3935 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3936 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3937
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003938 c0 = fma((half8)a0.s0, b0, c0);
3939 c1 = fma((half8)a0.s1, b0, c1);
3940 c2 = fma((half8)a0.s2, b0, c2);
3941 c3 = fma((half8)a0.s3, b0, c3);
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003942
3943 // Load values from matrix A (interleaved) and matrix B (transposed)
3944 a0 = vload4(0, src_addr_a);
3945 b0 = vload8(0, src_addr_b);
3946
3947 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3948 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3949
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003950 c0 = fma((half8)a0.s0, b0, c0);
3951 c1 = fma((half8)a0.s1, b0, c1);
3952 c2 = fma((half8)a0.s2, b0, c2);
3953 c3 = fma((half8)a0.s3, b0, c3);
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003954
3955 // Load values from matrix A (interleaved) and matrix B (transposed)
3956 a0 = vload4(0, src_addr_a);
3957 b0 = vload8(0, src_addr_b);
3958
3959 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3960 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3961
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003962 c0 = fma((half8)a0.s0, b0, c0);
3963 c1 = fma((half8)a0.s1, b0, c1);
3964 c2 = fma((half8)a0.s2, b0, c2);
3965 c3 = fma((half8)a0.s3, b0, c3);
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003966#endif // MULT_INTERLEAVE4X4_HEIGHT == 1
3967 }
3968
3969 for(; i < (int)(COLS_MTX_B); ++i)
3970 {
3971 // Load values from matrix A (interleaved) and matrix B (transposed)
3972 half4 a0 = vload4(0, src_addr_a);
3973 half8 b0 = vload8(0, src_addr_b);
3974
3975 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3976 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3977
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003978 c0 = fma((half8)a0.s0, b0, c0);
3979 c1 = fma((half8)a0.s1, b0, c1);
3980 c2 = fma((half8)a0.s2, b0, c2);
3981 c3 = fma((half8)a0.s3, b0, c3);
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003982 }
3983
3984 // Compute destination address
3985 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
3986
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003987 // Compute dst address
3988 __global uchar *dst_addr = offset(&dst, 0, 0);
3989
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003990 uint4 zout = 0;
3991
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003992#if defined(REINTERPRET_OUTPUT_AS_3D)
3993 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003994 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003995 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003996 // | |
3997 // | plane0 |
3998 // | |
3999 // |__________________|
4000 // |******************|
4001 // | cross_plane_pad |
4002 // |******************|
4003 // | |
4004 // | plane1 |
4005 // | |
4006 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004007
4008 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004009 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
4010 zout = min(DEPTH_GEMM3D - 1, zout);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004011
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004012 // Add offset due to the cross plane paddings
4013 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004014
4015 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4016 // multiply dst_stride_z by DEPTH_GEMM3D
4017 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004018#else // defined(REINTERPRET_OUTPUT_AS_3D)
4019 // Add offset for batched GEMM
4020 dst_addr += z * dst_stride_z;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004021#endif // defined(REINTERPRET_OUTPUT_AS_3D)
4022
4023 // Multiply by the weight of matrix-matrix product and store the result
4024#if defined(ALPHA)
4025 SCALE_BLOCK(4, half, c, ALPHA);
4026#endif // defined(ALPHA)
4027
4028 // Add beta*bias
4029#if defined(BETA)
4030 REPEAT_VAR_INIT_TO_CONST(4, uint, zero, 0);
4031
4032#if defined(BROADCAST_BIAS)
4033 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half));
4034
4035 LOAD_BLOCK(1, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
4036
4037#ifndef UNIT_BETA
4038 SCALE_BLOCK(1, half, bias, BETA);
4039#endif // UNIT_BIAS
4040
4041 // c = c + bias[broadcasted]
4042 ADD_BLOCK_BROADCAST(4, c, bias0);
4043
4044#else // defined(BROADCAST_BIAS)
4045 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id(
4046 2) * src2_stride_z;
4047
4048 LOAD_BLOCK(4, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
4049
4050#ifndef UNIT_BETA
4051 SCALE_BLOCK(4, half, bias, BETA);
4052#endif // UNIT_BIAS
4053
4054 // c = c + bias
4055 ADD_BLOCK(4, c, bias);
4056
4057#endif // defined(BROADCAST_BIAS)
4058#endif // defined(BETA)
4059
4060#if defined(ACTIVATION_TYPE)
4061 ACTIVATION_BLOCK(4, ACTIVATION_TYPE, half, c, A_VAL, B_VAL);
4062#endif // defined(ACTIVATION_TYPE)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004063
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004064 // Store 4x8 block
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004065 vstore8(c0, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
4066 vstore8(c1, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
4067 vstore8(c2, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
4068 vstore8(c3, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004069}
Georgios Pinitas84225582018-05-14 12:00:05 +01004070
4071// Undefine local defines
4072#undef COLS_MTX_B
4073
Matthew Bentham6f31f8c2017-10-27 11:50:06 +01004074#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004075
Gian Marco36a0a462018-01-12 10:21:40 +00004076#endif // defined(COLS_B) && defined(MULT_TRANSPOSE1XW_WIDTH) && defined(MULT_INTERLEAVE4X4_HEIGHT)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01004077
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004078#if defined(COLS_A) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && (NUM_ELEMS_PROCESSED_PER_THREAD_Y)
4079#if defined(DATA_TYPE)
4080#define VECTOR_TYPE VEC_DATA_TYPE(DATA_TYPE, NUM_ELEMS_PROCESSED_PER_THREAD_X)
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004081/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped.
4082 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004083 * @note This OpenCL kernel works with floating point data types (F16/F32)
4084 * @note The floating point data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
4085 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004086 * @note The number of matrix A columns and the optional alpha's value need to be passed at compile time using -DCOLS_A and -DALPHA
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004087 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
4088 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004089 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004090 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
4091 * The activation function is performed after the bias addition
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004092 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
4093 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004094 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
4095 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
4096 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
4097 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
4098 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004099 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004100 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
4101 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4102 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
4103 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4104 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01004105 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004106 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
4107 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4108 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
4109 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4110 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004111 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
4112 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
4113 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
4114 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
4115 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
4116 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01004117 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004118 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
4119 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
4120 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
4121 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
4122 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004123 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
4124 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004125 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004126 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004127 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
4128 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements for the output tensor (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004129 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004130__kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
4131 IMAGE_DECLARATION(src1),
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004132#if defined(BETA)
4133 IMAGE_DECLARATION(src2),
4134#endif // defined(BETA)
Gian Marcoae2af742018-02-15 12:35:44 +00004135 IMAGE_DECLARATION(dst),
4136 uint src0_stride_z,
4137 uint src1_stride_z,
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004138#if defined(BETA)
4139 uint src2_stride_z,
4140#endif //defined(BETA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004141 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004142#if defined(REINTERPRET_INPUT_AS_3D)
4143 ,
4144 uint src_cross_plane_pad
4145#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004146#if defined(REINTERPRET_OUTPUT_AS_3D)
4147 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004148 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004149#endif // REINTERPRET_OUTPUT_AS_3D
4150 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004151{
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004152 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004153
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004154 // Compute starting address for matrix A and Matrix B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004155 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004156
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004157 // Update address for the matrix A
4158 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004159
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004160 // Update address for the matrix B
4161 src_addr.s1 += idx * sizeof(DATA_TYPE);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004162
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004163#if defined(REINTERPRET_INPUT_AS_3D)
4164 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
4165 // in order to take into account the presence of possible cross plane paddings
4166 //
4167 // | |
4168 // | plane0 |
4169 // | |
4170 // |__________________|
4171 // |******************|
4172 // | cross_plane_pad |
4173 // |******************|
4174 // | |
4175 // | plane1 |
4176 // | |
4177 // |__________________|
4178
4179 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
4180 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
4181 zin = min(DEPTH_GEMM3D - 1, zin);
4182
4183 // Add offset due to the cross plane paddings
4184 zin *= (src_cross_plane_pad * src0_stride_y);
4185
4186 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4187 // multiply src0_stride_z by DEPTH_GEMM3D
4188 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
4189
4190#else // defined(REINTERPRET_INPUT_AS_3D)
4191
Gian Marcoae2af742018-02-15 12:35:44 +00004192 // Add offset for batched GEMM
4193 src_addr.s0 += get_global_id(2) * src0_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00004194
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004195#endif // defined(REINTERPRET_INPUT_AS_3D)
4196
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00004197#if defined(MATRIX_B_DEPTH)
4198 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
4199 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
4200#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00004201 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00004202#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00004203
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004204 int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(DATA_TYPE));
4205
4206 VECTOR_TYPE acc0 = 0.0f;
4207#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4208 VECTOR_TYPE acc1 = 0.0f;
4209#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4210#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4211 VECTOR_TYPE acc2 = 0.0f;
4212#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4213#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4214 VECTOR_TYPE acc3 = 0.0f;
4215#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4216
Georgios Pinitas96880cf2017-10-20 18:52:20 +01004217 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(DATA_TYPE)); src_addr += (int2)(2 * sizeof(DATA_TYPE), 2 * src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004218 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004219#if defined(REINTERPRET_INPUT_AS_3D)
4220 // Load values from matrix A
Usama Arif0681e3b2019-04-25 14:28:07 +01004221 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 2, DATA_TYPE, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s);
4222#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004223 // Load values from matrix A
4224 VEC_DATA_TYPE(DATA_TYPE, 2)
4225 a0 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
4226#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4227 VEC_DATA_TYPE(DATA_TYPE, 2)
4228 a1 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
4229#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4230#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4231 VEC_DATA_TYPE(DATA_TYPE, 2)
4232 a2 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
4233#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4234#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4235 VEC_DATA_TYPE(DATA_TYPE, 2)
4236 a3 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
4237#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004238#endif // defined(REINTERPRET_INPUT_AS_3D)
4239
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004240 // Load values from matrix B
4241 VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
4242 VECTOR_TYPE b1 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1 + src1_stride_y));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004243
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004244 // Accumulate
4245 acc0 += b0 * (VECTOR_TYPE)a0.s0;
4246 acc0 += b1 * (VECTOR_TYPE)a0.s1;
4247#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4248 acc1 += b0 * (VECTOR_TYPE)a1.s0;
4249 acc1 += b1 * (VECTOR_TYPE)a1.s1;
4250#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4251#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4252 acc2 += b0 * (VECTOR_TYPE)a2.s0;
4253 acc2 += b1 * (VECTOR_TYPE)a2.s1;
4254#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4255#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4256 acc3 += b0 * (VECTOR_TYPE)a3.s0;
4257 acc3 += b1 * (VECTOR_TYPE)a3.s1;
4258#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004259 }
4260
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004261 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(DATA_TYPE), src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004262 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004263#if defined(REINTERPRET_INPUT_AS_3D)
4264 // Load values from matrix A
4265 DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
4266#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4267 DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
4268#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4269#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4270 DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
4271#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4272#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4273 DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
4274#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4275#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004276 // Load values from matrix A
4277 DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
4278#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4279 DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
4280#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4281#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4282 DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
4283#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4284#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4285 DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
4286#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004287#endif // defined(REINTERPRET_INPUT_AS_3D)
4288
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004289 // Load values from matrix B
4290 VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004291
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004292 // Accumulate
4293 acc0 += b0 * (VECTOR_TYPE)a0;
4294#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4295 acc1 += b0 * (VECTOR_TYPE)a1;
4296#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4297#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4298 acc2 += b0 * (VECTOR_TYPE)a2;
4299#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4300#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4301 acc3 += b0 * (VECTOR_TYPE)a3;
4302#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004303 }
4304
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004305 int z = get_global_id(2);
4306
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004307 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004308 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
4309
Gian Marcoae2af742018-02-15 12:35:44 +00004310 // Compute dst address
4311 __global uchar *dst_addr = offset(&dst, 0, 0);
4312
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004313 uint4 zout = 0;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004314
4315#if defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004316
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004317 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004318 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004319 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004320 // | |
4321 // | plane0 |
4322 // | |
4323 // |__________________|
4324 // |******************|
4325 // | cross_plane_pad |
4326 // |******************|
4327 // | |
4328 // | plane1 |
4329 // | |
4330 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004331
4332 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004333 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
4334 zout = min(DEPTH_GEMM3D - 1, zout);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004335
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004336 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004337 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004338
4339 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4340 // multiply dst_stride_z by DEPTH_GEMM3D
4341 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004342#else // defined(REINTERPRET_OUTPUT_AS_3D)
4343 // Add offset for batched GEMM
4344 dst_addr += z * dst_stride_z;
4345#endif // defined(REINTERPRET_OUTPUT_AS_3D)
4346
4347 // Multiply by the weight of matrix-matrix product and store the result
4348#if defined(ALPHA)
4349 SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, DATA_TYPE, acc, ALPHA);
4350#endif // defined(ALPHA)
4351
4352 // Add beta*bias
4353#if defined(BETA)
4354 REPEAT_VAR_INIT_TO_CONST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, uint, zero, 0);
4355
4356#if defined(BROADCAST_BIAS)
4357 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)NUM_ELEMS_PROCESSED_PER_THREAD_X * sizeof(DATA_TYPE));
4358
4359 LOAD_BLOCK(1, NUM_ELEMS_PROCESSED_PER_THREAD_X, DATA_TYPE, bias, src2_addr, 0, src2_stride_y, zero);
4360
4361#ifndef UNIT_BETA
4362 SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
4363#endif // UNIT_BIAS
4364
4365 // c = c + bias[broadcasted]
4366 ADD_BLOCK_BROADCAST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias0);
4367
4368#else // defined(BROADCAST_BIAS)
4369 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)NUM_ELEMS_PROCESSED_PER_THREAD_X * sizeof(DATA_TYPE)) + (get_global_id(1) *
4370 (uint)NUM_ELEMS_PROCESSED_PER_THREAD_Y * src2_stride_y) + get_global_id(2) * src2_stride_z;
4371
4372 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, NUM_ELEMS_PROCESSED_PER_THREAD_X, DATA_TYPE, bias, src2_addr, 0, src2_stride_y, zero);
4373
4374#ifndef UNIT_BETA
4375 SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, DATA_TYPE, bias, BETA);
4376#endif // UNIT_BIAS
4377
4378 // c = c + bias
4379 ADD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias);
4380
4381#endif // defined(BROADCAST_BIAS)
4382#endif // defined(BETA)
4383
4384#if defined(ACTIVATION_TYPE)
4385 ACTIVATION_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, ACTIVATION_TYPE, DATA_TYPE, acc, A_VAL, B_VAL);
4386#endif // defined(ACTIVATION_TYPE)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004387
4388 // Store output block
Usama Arif0681e3b2019-04-25 14:28:07 +01004389 STORE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, NUM_ELEMS_PROCESSED_PER_THREAD_X, DATA_TYPE, acc, dst_addr, dst_stride_y, zout.s);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004390}
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004391#endif // defined(DATA_TYPE)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01004392
Michele Di Giorgiof6f08da2018-04-26 10:24:30 +01004393/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004394 *
4395 * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
4396 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
4397 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
4398 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
4399 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004400 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
4401 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004402 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004403 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
4404 * The activation function is performed after the bias addition
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004405 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
4406 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004407 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
4408 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
4409 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
4410 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
4411 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004412 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004413 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
4414 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4415 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
4416 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4417 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
4418 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
4419 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
4420 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4421 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
4422 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4423 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004424 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
4425 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
4426 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
4427 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
4428 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
4429 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004430 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
4431 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
4432 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
4433 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
4434 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
4435 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004436 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
4437 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004438 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004439 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004440 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
4441 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004442 */
4443__kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
4444 IMAGE_DECLARATION(src1),
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004445#if defined(BETA)
4446 IMAGE_DECLARATION(src2),
4447#endif // defined(BETA)
Gian Marcoae2af742018-02-15 12:35:44 +00004448 IMAGE_DECLARATION(dst),
4449 uint src0_stride_z,
4450 uint src1_stride_z,
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004451#if defined(BETA)
4452 uint src2_stride_z,
4453#endif //defined(BETA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004454 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004455#if defined(REINTERPRET_INPUT_AS_3D)
4456 ,
4457 uint src_cross_plane_pad
4458#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004459#if defined(REINTERPRET_OUTPUT_AS_3D)
4460 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004461 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004462#endif // REINTERPRET_OUTPUT_AS_3D
4463 )
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004464{
4465 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
4466
4467 // Compute starting address for matrix A and matrix B
4468 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
4469
4470 // Update address for matrix A
4471 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
4472
4473 // Update address for matrix B
4474 src_addr.s1 += idx * sizeof(float);
4475
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004476#if defined(REINTERPRET_INPUT_AS_3D)
4477 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
4478 // in order to take into account the presence of possible cross plane paddings
4479 //
4480 // | |
4481 // | plane0 |
4482 // | |
4483 // |__________________|
4484 // |******************|
4485 // | cross_plane_pad |
4486 // |******************|
4487 // | |
4488 // | plane1 |
4489 // | |
4490 // |__________________|
4491
4492 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
4493 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
4494 zin = min(DEPTH_GEMM3D - 1, zin);
4495
4496 // Add offset due to the cross plane paddings
4497 zin *= (src_cross_plane_pad * src0_stride_y);
4498
4499 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4500 // multiply src0_stride_z by DEPTH_GEMM3D
4501 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
4502
4503#else // defined(REINTERPRET_INPUT_AS_3D)
4504
Gian Marcoae2af742018-02-15 12:35:44 +00004505 // Add offset for batched GEMM
4506 src_addr.s0 += get_global_id(2) * src0_stride_z;
4507
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004508#endif // defined(REINTERPRET_INPUT_AS_3D)
4509
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00004510#if defined(MATRIX_B_DEPTH)
4511 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
4512 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
4513#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00004514 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00004515#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00004516
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004517 // Initialize accumulators
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004518 float4 acc0 = 0.0f;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004519
4520#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004521 float4 acc1 = 0.0f;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004522#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4523
4524#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004525 float4 acc2 = 0.0f;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004526#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4527
4528#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004529 float4 acc3 = 0.0f;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004530#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4531
4532 // A and B src indices get incremented at the same time.
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004533 int i = 0;
4534 for(; i <= ((int)COLS_A - 4); i += 4)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004535 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004536#if defined(REINTERPRET_INPUT_AS_3D)
4537 // Load values from matrix A and matrix B
Usama Arif0681e3b2019-04-25 14:28:07 +01004538 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 4, float, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s);
4539#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004540 // Load values from matrix A and matrix B
4541 float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004542#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004543 float4 a1 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004544#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4545#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004546 float4 a2 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004547#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4548#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004549 float4 a3 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004550#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004551#endif // defined(REINTERPRET_INPUT_AS_3D)
4552
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004553 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
4554 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004555
4556 // Multiply and accumulate
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004557 acc0.s0 = fma(a0.s0, b0.s0, acc0.s0);
4558 acc0.s1 = fma(a0.s0, b0.s1, acc0.s1);
4559 acc0.s2 = fma(a0.s0, b0.s2, acc0.s2);
4560 acc0.s3 = fma(a0.s0, b0.s3, acc0.s3);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004561
4562#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004563
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004564 acc1.s0 = fma(a1.s0, b0.s0, acc1.s0);
4565 acc1.s1 = fma(a1.s0, b0.s1, acc1.s1);
4566 acc1.s2 = fma(a1.s0, b0.s2, acc1.s2);
4567 acc1.s3 = fma(a1.s0, b0.s3, acc1.s3);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004568
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004569#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4570#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004571
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004572 acc2.s0 = fma(a2.s0, b0.s0, acc2.s0);
4573 acc2.s1 = fma(a2.s0, b0.s1, acc2.s1);
4574 acc2.s2 = fma(a2.s0, b0.s2, acc2.s2);
4575 acc2.s3 = fma(a2.s0, b0.s3, acc2.s3);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004576
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004577#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4578#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004579
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004580 acc3.s0 = fma(a3.s0, b0.s0, acc3.s0);
4581 acc3.s1 = fma(a3.s0, b0.s1, acc3.s1);
4582 acc3.s2 = fma(a3.s0, b0.s2, acc3.s2);
4583 acc3.s3 = fma(a3.s0, b0.s3, acc3.s3);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004584#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004585
4586 // Load values from matrix A and matrix B
4587 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
4588 src_addr.s1 += src1_stride_y;
4589
4590 // Multiply and accumulate
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004591 acc0.s0 = fma(a0.s1, b0.s0, acc0.s0);
4592 acc0.s1 = fma(a0.s1, b0.s1, acc0.s1);
4593 acc0.s2 = fma(a0.s1, b0.s2, acc0.s2);
4594 acc0.s3 = fma(a0.s1, b0.s3, acc0.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004595
4596#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4597
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004598 acc1.s0 = fma(a1.s1, b0.s0, acc1.s0);
4599 acc1.s1 = fma(a1.s1, b0.s1, acc1.s1);
4600 acc1.s2 = fma(a1.s1, b0.s2, acc1.s2);
4601 acc1.s3 = fma(a1.s1, b0.s3, acc1.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004602
4603#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4604#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4605
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004606 acc2.s0 = fma(a2.s1, b0.s0, acc2.s0);
4607 acc2.s1 = fma(a2.s1, b0.s1, acc2.s1);
4608 acc2.s2 = fma(a2.s1, b0.s2, acc2.s2);
4609 acc2.s3 = fma(a2.s1, b0.s3, acc2.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004610
4611#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4612#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4613
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004614 acc3.s0 = fma(a3.s1, b0.s0, acc3.s0);
4615 acc3.s1 = fma(a3.s1, b0.s1, acc3.s1);
4616 acc3.s2 = fma(a3.s1, b0.s2, acc3.s2);
4617 acc3.s3 = fma(a3.s1, b0.s3, acc3.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004618#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4619
4620 // Load values from matrix A and matrix B
4621 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
4622 src_addr.s1 += src1_stride_y;
4623
4624 // Multiply and accumulate
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004625 acc0.s0 = fma(a0.s2, b0.s0, acc0.s0);
4626 acc0.s1 = fma(a0.s2, b0.s1, acc0.s1);
4627 acc0.s2 = fma(a0.s2, b0.s2, acc0.s2);
4628 acc0.s3 = fma(a0.s2, b0.s3, acc0.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004629
4630#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4631
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004632 acc1.s0 = fma(a1.s2, b0.s0, acc1.s0);
4633 acc1.s1 = fma(a1.s2, b0.s1, acc1.s1);
4634 acc1.s2 = fma(a1.s2, b0.s2, acc1.s2);
4635 acc1.s3 = fma(a1.s2, b0.s3, acc1.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004636
4637#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4638#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4639
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004640 acc2.s0 = fma(a2.s2, b0.s0, acc2.s0);
4641 acc2.s1 = fma(a2.s2, b0.s1, acc2.s1);
4642 acc2.s2 = fma(a2.s2, b0.s2, acc2.s2);
4643 acc2.s3 = fma(a2.s2, b0.s3, acc2.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004644
4645#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4646#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4647
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004648 acc3.s0 = fma(a3.s2, b0.s0, acc3.s0);
4649 acc3.s1 = fma(a3.s2, b0.s1, acc3.s1);
4650 acc3.s2 = fma(a3.s2, b0.s2, acc3.s2);
4651 acc3.s3 = fma(a3.s2, b0.s3, acc3.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004652#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4653
4654 // Load values from matrix A and matrix B
4655 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
4656 src_addr.s1 += src1_stride_y;
4657
4658 // Multiply and accumulate
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004659 acc0.s0 = fma(a0.s3, b0.s0, acc0.s0);
4660 acc0.s1 = fma(a0.s3, b0.s1, acc0.s1);
4661 acc0.s2 = fma(a0.s3, b0.s2, acc0.s2);
4662 acc0.s3 = fma(a0.s3, b0.s3, acc0.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004663
4664#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4665
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004666 acc1.s0 = fma(a1.s3, b0.s0, acc1.s0);
4667 acc1.s1 = fma(a1.s3, b0.s1, acc1.s1);
4668 acc1.s2 = fma(a1.s3, b0.s2, acc1.s2);
4669 acc1.s3 = fma(a1.s3, b0.s3, acc1.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004670
4671#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4672#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4673
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004674 acc2.s0 = fma(a2.s3, b0.s0, acc2.s0);
4675 acc2.s1 = fma(a2.s3, b0.s1, acc2.s1);
4676 acc2.s2 = fma(a2.s3, b0.s2, acc2.s2);
4677 acc2.s3 = fma(a2.s3, b0.s3, acc2.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004678
4679#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4680#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4681
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004682 acc3.s0 = fma(a3.s3, b0.s0, acc3.s0);
4683 acc3.s1 = fma(a3.s3, b0.s1, acc3.s1);
4684 acc3.s2 = fma(a3.s3, b0.s2, acc3.s2);
4685 acc3.s3 = fma(a3.s3, b0.s3, acc3.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004686#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4687
4688 src_addr.s0 += 4 * sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004689 }
4690
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004691 for(; i < (int)COLS_A; ++i)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004692 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004693#if defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004694 // Load values from matrix A
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004695 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
4696#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4697 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
4698#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4699#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4700 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
4701#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4702#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4703 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
4704#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4705#else // defined(REINTERPRET_INPUT_AS_3D)
4706 // Load values from matrix A
4707 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004708#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4709 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
4710#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4711#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4712 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
4713#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4714#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4715 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
4716#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004717#endif // defined(REINTERPRET_INPUT_AS_3D)
4718
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004719 // Load values from matrix B
4720 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004721 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004722
4723 // Multiply and accumulate
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004724 acc0.s0 = fma(a0, b0.s0, acc0.s0);
4725 acc0.s1 = fma(a0, b0.s1, acc0.s1);
4726 acc0.s2 = fma(a0, b0.s2, acc0.s2);
4727 acc0.s3 = fma(a0, b0.s3, acc0.s3);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004728#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004729 acc1.s0 = fma(a1, b0.s0, acc1.s0);
4730 acc1.s1 = fma(a1, b0.s1, acc1.s1);
4731 acc1.s2 = fma(a1, b0.s2, acc1.s2);
4732 acc1.s3 = fma(a1, b0.s3, acc1.s3);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004733#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4734#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004735 acc2.s0 = fma(a2, b0.s0, acc2.s0);
4736 acc2.s1 = fma(a2, b0.s1, acc2.s1);
4737 acc2.s2 = fma(a2, b0.s2, acc2.s2);
4738 acc2.s3 = fma(a2, b0.s3, acc2.s3);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004739#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4740#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004741 acc3.s0 = fma(a3, b0.s0, acc3.s0);
4742 acc3.s1 = fma(a3, b0.s1, acc3.s1);
4743 acc3.s2 = fma(a3, b0.s2, acc3.s2);
4744 acc3.s3 = fma(a3, b0.s3, acc3.s3);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004745#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004746
4747 src_addr.s0 += sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004748 }
4749
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004750 int z = get_global_id(2);
4751
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004752 // Compute destination address
4753 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
4754
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004755 // Compute dst address
4756 __global uchar *dst_addr = offset(&dst, 0, 0);
4757
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004758 uint4 zout = 0;
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004759
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004760#if defined(REINTERPRET_OUTPUT_AS_3D)
4761 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004762 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004763 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004764 // | |
4765 // | plane0 |
4766 // | |
4767 // |__________________|
4768 // |******************|
4769 // | cross_plane_pad |
4770 // |******************|
4771 // | |
4772 // | plane1 |
4773 // | |
4774 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004775
4776 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004777 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
4778 zout = min(DEPTH_GEMM3D - 1, zout);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004779
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004780 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004781 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004782
4783 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4784 // multiply dst_stride_z by DEPTH_GEMM3D
4785 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004786#else // defined(REINTERPRET_OUTPUT_AS_3D)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004787 // Add offset for batched GEMM
4788 dst_addr += z * dst_stride_z;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004789#endif // defined(REINTERPRET_OUTPUT_AS_3D)
4790
4791 // Multiply by the weight of matrix-matrix product and store the result
4792#if defined(ALPHA)
4793 SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, float, acc, ALPHA);
4794#endif // defined(ALPHA)
4795
4796 // Add beta*bias
4797#if defined(BETA)
4798 REPEAT_VAR_INIT_TO_CONST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, uint, zero, 0);
4799
4800#if defined(BROADCAST_BIAS)
4801 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float));
4802
4803 LOAD_BLOCK(1, 4, float, bias, src2_addr, 0, src2_stride_y, zero);
4804
4805#ifndef UNIT_BETA
4806 SCALE_BLOCK(1, float, bias, BETA);
4807#endif // UNIT_BIAS
4808
4809 // acc = acc + bias[broadcasted]
4810 ADD_BLOCK_BROADCAST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias0);
4811
4812#else // defined(BROADCAST_BIAS)
4813 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float)) + (get_global_id(1) *
4814 (uint)NUM_ELEMS_PROCESSED_PER_THREAD_Y * src2_stride_y) + get_global_id(2) * src2_stride_z;
4815
4816 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 4, float, bias, src2_addr, 0, src2_stride_y, zero);
4817
4818#ifndef UNIT_BETA
4819 SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, float, bias, BETA);
4820#endif // UNIT_BIAS
4821
4822 // acc = acc + bias
4823 ADD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias);
4824
4825#endif // defined(BROADCAST_BIAS)
4826#endif // defined(BETA)
4827
4828#if defined(ACTIVATION_TYPE)
4829 ACTIVATION_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, ACTIVATION_TYPE, float, acc, A_VAL, B_VAL);
4830#endif // defined(ACTIVATION_TYPE)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004831
4832 // Store the output block
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004833 vstore4(acc0, 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004834#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004835 vstore4(acc1, 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004836#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4837#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004838 vstore4(acc2, 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004839#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4840#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004841 vstore4(acc3, 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004842#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004843}
4844
4845/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
4846 *
4847 * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
4848 * This OpenCL kernel is optimized for Bifrost when the number of matrix B columns is less or equal to 1000.
4849 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
4850 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=2.
4851 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
4852 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha if alpha!=1.0f.
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004853 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
4854 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004855 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004856 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
4857 * The activation function is performed after the bias addition
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004858 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
4859 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004860 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
4861 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
4862 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
4863 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
4864 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004865 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004866 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
4867 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4868 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
4869 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4870 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
4871 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
4872 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
4873 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4874 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
4875 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4876 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004877 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
4878 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
4879 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
4880 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
4881 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
4882 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004883 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
4884 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
4885 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
4886 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
4887 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
4888 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004889 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
4890 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004891 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004892 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004893 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
4894 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004895 */
4896__kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
4897 IMAGE_DECLARATION(src1),
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004898#if defined(BETA)
4899 IMAGE_DECLARATION(src2),
4900#endif // defined(BETA)
Gian Marcoae2af742018-02-15 12:35:44 +00004901 IMAGE_DECLARATION(dst),
4902 uint src0_stride_z,
4903 uint src1_stride_z,
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004904#if defined(BETA)
4905 uint src2_stride_z,
4906#endif //defined(BETA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004907 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004908#if defined(REINTERPRET_INPUT_AS_3D)
4909 ,
4910 uint src_cross_plane_pad
4911#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004912#if defined(REINTERPRET_OUTPUT_AS_3D)
4913 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004914 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004915#endif // REINTERPRET_OUTPUT_AS_3D
4916 )
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004917{
4918 // Requires 2 NUM_ELEMS_PROCESSED_PER_THREAD_X, C vect2, A vect4, B (2 vload2) // to fix for NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4919 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
4920
4921 // Compute starting address for matrix A and Matrix B
4922 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
4923
4924 // Update address for the matrix A
4925 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
4926
4927 // Update address for the matrix B
4928 src_addr.s1 += idx * sizeof(float);
4929
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004930#if defined(REINTERPRET_INPUT_AS_3D)
4931 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
4932 // in order to take into account the presence of possible cross plane paddings
4933 //
4934 // | |
4935 // | plane0 |
4936 // | |
4937 // |__________________|
4938 // |******************|
4939 // | cross_plane_pad |
4940 // |******************|
4941 // | |
4942 // | plane1 |
4943 // | |
4944 // |__________________|
4945
4946 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
4947 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
4948 zin = min(DEPTH_GEMM3D - 1, zin);
4949
4950 // Add offset due to the cross plane paddings
4951 zin *= (src_cross_plane_pad * src0_stride_y);
4952
4953 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4954 // multiply src0_stride_z by DEPTH_GEMM3D
4955 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
4956
4957#else // defined(REINTERPRET_INPUT_AS_3D)
4958
Gian Marcoae2af742018-02-15 12:35:44 +00004959 // Add offset for batched GEMM
4960 src_addr.s0 += get_global_id(2) * src0_stride_z;
4961
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004962#endif // defined(REINTERPRET_INPUT_AS_3D)
4963
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00004964#if defined(MATRIX_B_DEPTH)
4965 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
4966 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
4967#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00004968 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00004969#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00004970
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004971 // Initialize accumulators
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004972 float2 acc0 = 0.0f;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004973#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004974 float2 acc1 = 0.0f;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004975#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4976#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004977 float2 acc2 = 0.0f;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004978#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4979#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004980 float2 acc3 = 0.0f;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004981#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4982
4983 // A and B src indices get incremented at the same time.
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004984 int i = 0;
4985 for(; i <= ((int)COLS_A - 8); i += 8)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004986 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004987#if defined(REINTERPRET_INPUT_AS_3D)
4988 // Load values from matrix A
4989 float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + zin.s0));
4990#else // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004991 // Load values from matrix A
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004992 float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0));
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004993#endif // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004994
4995 // Load values from matrix B
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004996 float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
4997 src_addr.s1 += src1_stride_y;
4998 float2 b1 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
4999 src_addr.s1 += src1_stride_y;
5000 float2 b2 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
5001 src_addr.s1 += src1_stride_y;
5002 float2 b3 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
5003 src_addr.s1 += src1_stride_y;
5004 float2 b4 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
5005 src_addr.s1 += src1_stride_y;
5006 float2 b5 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
5007 src_addr.s1 += src1_stride_y;
5008 float2 b6 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
5009 src_addr.s1 += src1_stride_y;
5010 float2 b7 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
5011 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005012
5013 // Multiply and accumulate
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005014 acc0.s0 = fma(a0.s0, b0.s0, acc0.s0);
5015 acc0.s0 = fma(a0.s1, b1.s0, acc0.s0);
5016 acc0.s0 = fma(a0.s2, b2.s0, acc0.s0);
5017 acc0.s0 = fma(a0.s3, b3.s0, acc0.s0);
5018 acc0.s0 = fma(a0.s4, b4.s0, acc0.s0);
5019 acc0.s0 = fma(a0.s5, b5.s0, acc0.s0);
5020 acc0.s0 = fma(a0.s6, b6.s0, acc0.s0);
5021 acc0.s0 = fma(a0.s7, b7.s0, acc0.s0);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005022
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005023 acc0.s1 = fma(a0.s0, b0.s1, acc0.s1);
5024 acc0.s1 = fma(a0.s1, b1.s1, acc0.s1);
5025 acc0.s1 = fma(a0.s2, b2.s1, acc0.s1);
5026 acc0.s1 = fma(a0.s3, b3.s1, acc0.s1);
5027 acc0.s1 = fma(a0.s4, b4.s1, acc0.s1);
5028 acc0.s1 = fma(a0.s5, b5.s1, acc0.s1);
5029 acc0.s1 = fma(a0.s6, b6.s1, acc0.s1);
5030 acc0.s1 = fma(a0.s7, b7.s1, acc0.s1);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005031
5032#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005033#if defined(REINTERPRET_INPUT_AS_3D)
5034 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
5035#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005036 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005037#endif // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005038 acc1.s0 = fma(a0.s0, b0.s0, acc1.s0);
5039 acc1.s0 = fma(a0.s1, b1.s0, acc1.s0);
5040 acc1.s0 = fma(a0.s2, b2.s0, acc1.s0);
5041 acc1.s0 = fma(a0.s3, b3.s0, acc1.s0);
5042 acc1.s0 = fma(a0.s4, b4.s0, acc1.s0);
5043 acc1.s0 = fma(a0.s5, b5.s0, acc1.s0);
5044 acc1.s0 = fma(a0.s6, b6.s0, acc1.s0);
5045 acc1.s0 = fma(a0.s7, b7.s0, acc1.s0);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005046
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005047 acc1.s1 = fma(a0.s0, b0.s1, acc1.s1);
5048 acc1.s1 = fma(a0.s1, b1.s1, acc1.s1);
5049 acc1.s1 = fma(a0.s2, b2.s1, acc1.s1);
5050 acc1.s1 = fma(a0.s3, b3.s1, acc1.s1);
5051 acc1.s1 = fma(a0.s4, b4.s1, acc1.s1);
5052 acc1.s1 = fma(a0.s5, b5.s1, acc1.s1);
5053 acc1.s1 = fma(a0.s6, b6.s1, acc1.s1);
5054 acc1.s1 = fma(a0.s7, b7.s1, acc1.s1);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005055#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5056#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005057#if defined(REINTERPRET_INPUT_AS_3D)
5058 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
5059#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005060 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005061#endif // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005062 acc2.s0 = fma(a0.s0, b0.s0, acc2.s0);
5063 acc2.s0 = fma(a0.s1, b1.s0, acc2.s0);
5064 acc2.s0 = fma(a0.s2, b2.s0, acc2.s0);
5065 acc2.s0 = fma(a0.s3, b3.s0, acc2.s0);
5066 acc2.s0 = fma(a0.s4, b4.s0, acc2.s0);
5067 acc2.s0 = fma(a0.s5, b5.s0, acc2.s0);
5068 acc2.s0 = fma(a0.s6, b6.s0, acc2.s0);
5069 acc2.s0 = fma(a0.s7, b7.s0, acc2.s0);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005070
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005071 acc2.s1 = fma(a0.s0, b0.s1, acc2.s1);
5072 acc2.s1 = fma(a0.s1, b1.s1, acc2.s1);
5073 acc2.s1 = fma(a0.s2, b2.s1, acc2.s1);
5074 acc2.s1 = fma(a0.s3, b3.s1, acc2.s1);
5075 acc2.s1 = fma(a0.s4, b4.s1, acc2.s1);
5076 acc2.s1 = fma(a0.s5, b5.s1, acc2.s1);
5077 acc2.s1 = fma(a0.s6, b6.s1, acc2.s1);
5078 acc2.s1 = fma(a0.s7, b7.s1, acc2.s1);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005079#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5080#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005081#if defined(REINTERPRET_INPUT_AS_3D)
5082 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
5083#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005084 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005085#endif // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005086 acc3.s0 = fma(a0.s0, b0.s0, acc3.s0);
5087 acc3.s0 = fma(a0.s1, b1.s0, acc3.s0);
5088 acc3.s0 = fma(a0.s2, b2.s0, acc3.s0);
5089 acc3.s0 = fma(a0.s3, b3.s0, acc3.s0);
5090 acc3.s0 = fma(a0.s4, b4.s0, acc3.s0);
5091 acc3.s0 = fma(a0.s5, b5.s0, acc3.s0);
5092 acc3.s0 = fma(a0.s6, b6.s0, acc3.s0);
5093 acc3.s0 = fma(a0.s7, b7.s0, acc3.s0);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005094
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005095 acc3.s1 = fma(a0.s0, b0.s1, acc3.s1);
5096 acc3.s1 = fma(a0.s1, b1.s1, acc3.s1);
5097 acc3.s1 = fma(a0.s2, b2.s1, acc3.s1);
5098 acc3.s1 = fma(a0.s3, b3.s1, acc3.s1);
5099 acc3.s1 = fma(a0.s4, b4.s1, acc3.s1);
5100 acc3.s1 = fma(a0.s5, b5.s1, acc3.s1);
5101 acc3.s1 = fma(a0.s6, b6.s1, acc3.s1);
5102 acc3.s1 = fma(a0.s7, b7.s1, acc3.s1);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005103#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005104
5105 src_addr.s0 += sizeof(float) * 8;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005106 }
5107 // float size increment
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005108 for(; i < (int)COLS_A; ++i)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005109 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005110#if defined(REINTERPRET_INPUT_AS_3D)
5111 // Load values from matrix A
5112 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
5113#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5114 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
5115#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5116#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5117 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
5118#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5119#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5120 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
5121#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5122#else // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005123 // Load values from matrix A
5124 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
5125#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5126 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
5127#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5128#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5129 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
5130#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5131#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5132 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
5133#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005134#endif // defined(REINTERPRET_INPUT_AS_3D)
5135
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005136 // Load values from matrix B
5137 float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005138 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005139
5140 // Multiply and accumulate
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005141 acc0.s0 = fma(a0, b0.s0, acc0.s0);
5142 acc0.s1 = fma(a0, b0.s1, acc0.s1);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005143#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005144 acc1.s0 = fma(a1, b0.s0, acc1.s0);
5145 acc1.s1 = fma(a1, b0.s1, acc1.s1);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005146#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5147#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005148 acc2.s0 = fma(a2, b0.s0, acc2.s0);
5149 acc2.s1 = fma(a2, b0.s1, acc2.s1);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005150#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5151#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005152 acc3.s0 = fma(a3, b0.s0, acc3.s0);
5153 acc3.s1 = fma(a3, b0.s1, acc3.s1);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005154#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005155
5156 src_addr.s0 += sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005157 }
5158
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005159 int z = get_global_id(2);
5160
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005161 // Compute destination address
5162 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
5163
Gian Marcoae2af742018-02-15 12:35:44 +00005164 // Compute dst address
5165 __global uchar *dst_addr = offset(&dst, 0, 0);
5166
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005167 uint4 zout = 0;
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00005168
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005169#if defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005170
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005171 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01005172 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005173 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01005174 // | |
5175 // | plane0 |
5176 // | |
5177 // |__________________|
5178 // |******************|
5179 // | cross_plane_pad |
5180 // |******************|
5181 // | |
5182 // | plane1 |
5183 // | |
5184 // |__________________|
Gian Marcoae2af742018-02-15 12:35:44 +00005185
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005186 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005187 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
5188 zout = min(DEPTH_GEMM3D - 1, zout);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005189
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01005190 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005191 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005192
5193 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
5194 // multiply dst_stride_z by DEPTH_GEMM3D
5195 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005196#else // defined(REINTERPRET_OUTPUT_AS_3D)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005197 // Add offset for batched GEMM
5198 dst_addr += z * dst_stride_z;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005199#endif // defined(REINTERPRET_OUTPUT_AS_3D)
5200
5201 // Multiply by the weight of matrix-matrix product and store the result
5202#if defined(ALPHA)
5203 SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, float, acc, ALPHA);
5204#endif // defined(ALPHA)
5205
5206 // Add beta*bias
5207#if defined(BETA)
5208 REPEAT_VAR_INIT_TO_CONST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, uint, zero, 0);
5209
5210#if defined(BROADCAST_BIAS)
5211 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)2 * sizeof(float));
5212
5213 LOAD_BLOCK(1, 2, float, bias, src2_addr, 0, src2_stride_y, zero);
5214
5215#ifndef UNIT_BETA
5216 SCALE_BLOCK(1, float, bias, BETA);
5217#endif // UNIT_BIAS
5218
5219 // acc = acc + bias[broadcasted]
5220 ADD_BLOCK_BROADCAST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias0);
5221
5222#else // defined(BROADCAST_BIAS)
5223 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)2 * sizeof(float)) + (get_global_id(1) *
5224 (uint)NUM_ELEMS_PROCESSED_PER_THREAD_Y * src2_stride_y) + get_global_id(2) * src2_stride_z;
5225
5226 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 2, float, bias, src2_addr, 0, src2_stride_y, zero);
5227
5228#ifndef UNIT_BETA
5229 SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, float, bias, BETA);
5230#endif // UNIT_BIAS
5231
5232 // acc = acc + bias
5233 ADD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias);
5234
5235#endif // defined(BROADCAST_BIAS)
5236#endif // defined(BETA)
5237
5238#if defined(ACTIVATION_TYPE)
5239 ACTIVATION_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, ACTIVATION_TYPE, float, acc, A_VAL, B_VAL);
5240#endif // defined(ACTIVATION_TYPE)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005241
5242 // Store the output block
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005243 vstore2(acc0, 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005244#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005245 vstore2(acc1, 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005246#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5247#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005248 vstore2(acc2, 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005249#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5250#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005251 vstore2(acc3, 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005252#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005253}
5254
Vidhya Sudhan Loganathanbdff4912018-05-22 15:03:09 +01005255#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005256/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
5257 *
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005258 * @note This OpenCL kernel works with the 16-bit floating point data type (half) and accumulating the result in a 32 floating point variable.
5259 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
5260 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
5261 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
5262 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005263 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
5264 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005265 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005266 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
5267 * The activation function is performed after the bias addition
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005268 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
5269 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
5270 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
5271 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
5272 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
5273 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
5274 *
5275 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
5276 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
5277 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5278 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
5279 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
5280 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
5281 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
5282 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
5283 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5284 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
5285 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
5286 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005287 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
5288 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
5289 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
5290 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
5291 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
5292 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005293 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
5294 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
5295 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
5296 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
5297 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
5298 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
5299 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
5300 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005301 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005302 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
5303 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
5304 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
5305 */
5306__kernel void gemm_mm_floating_point_f16_bifrost_acc32(IMAGE_DECLARATION(src0),
5307 IMAGE_DECLARATION(src1),
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005308#if defined(BETA)
5309 IMAGE_DECLARATION(src2),
5310#endif // defined(BETA)
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005311 IMAGE_DECLARATION(dst),
5312 uint src0_stride_z,
5313 uint src1_stride_z,
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005314#if defined(BETA)
5315 uint src2_stride_z,
5316#endif //defined(BETA)
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005317 uint dst_stride_z
5318#if defined(REINTERPRET_INPUT_AS_3D)
5319 ,
5320 uint src_cross_plane_pad
5321#endif // REINTERPRET_INPUT_AS_3D
5322#if defined(REINTERPRET_OUTPUT_AS_3D)
5323 ,
5324 uint dst_cross_plane_pad
5325#endif // REINTERPRET_OUTPUT_AS_3D
5326 )
5327{
5328 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
5329
5330 // Compute starting address for matrix A and Matrix B
5331 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
5332
5333 // Update address for the matrix A
5334 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
5335
5336 // Update address for the matrix B
5337 src_addr.s1 += idx * sizeof(half);
5338
5339#if defined(REINTERPRET_INPUT_AS_3D)
5340 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
5341 // in order to take into account the presence of possible cross plane paddings
5342 //
5343 // | |
5344 // | plane0 |
5345 // | |
5346 // |__________________|
5347 // |******************|
5348 // | cross_plane_pad |
5349 // |******************|
5350 // | |
5351 // | plane1 |
5352 // | |
5353 // |__________________|
5354
5355 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
5356 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
5357 zin = min(DEPTH_GEMM3D - 1, zin);
5358
5359 // Add offset due to the cross plane paddings
5360 zin *= (src_cross_plane_pad * src0_stride_y);
5361
5362 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
5363 // multiply src0_stride_z by DEPTH_GEMM3D
5364 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
5365
5366#else // defined(REINTERPRET_INPUT_AS_3D)
5367
5368 // Add offset for batched GEMM
5369 src_addr.s0 += get_global_id(2) * src0_stride_z;
5370
5371#endif // defined(REINTERPRET_INPUT_AS_3D)
5372
5373#if defined(MATRIX_B_DEPTH)
5374 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
5375 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
5376#else // defined(MATRIX_B_DEPTH)
5377 src_addr.s1 += get_global_id(2) * src1_stride_z;
5378#endif // defined(MATRIX_B_DEPTH)
5379
5380 float8 acc0 = 0.0h;
5381#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5382 float8 acc1 = 0.0h;
5383#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5384#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5385 float8 acc2 = 0.0h;
5386#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5387#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5388 float8 acc3 = 0.0h;
5389#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5390
5391 int i = 0;
5392 for(; i <= ((int)COLS_A - 4); i += 4)
5393 {
5394#if defined(REINTERPRET_INPUT_AS_3D)
5395 // Load values from matrix A
Usama Arif0681e3b2019-04-25 14:28:07 +01005396 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 4, half, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s);
5397#else // defined(REINTERPRET_INPUT_AS_3D)
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005398 // Load values from matrix A
5399 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
5400#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5401 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
5402#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5403#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5404 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
5405#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5406#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5407 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
5408#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5409#endif // defined(REINTERPRET_INPUT_AS_3D)
5410
5411 // Load values from matrix B
5412 float8 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
5413 src_addr.s1 += src1_stride_y;
5414
5415 // Accumulate
5416 acc0 = fma(b0, (float8)a0.s0, acc0);
5417#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5418 acc1 = fma(b0, (float8)a1.s0, acc1);
5419#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5420#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5421 acc2 = fma(b0, (float8)a2.s0, acc2);
5422#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5423#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5424 acc3 = fma(b0, (float8)a3.s0, acc3);
5425#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5426
5427 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
5428 src_addr.s1 += src1_stride_y;
5429 acc0 = fma(b0, (float8)a0.s1, acc0);
5430#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5431 acc1 = fma(b0, (float8)a1.s1, acc1);
5432#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5433#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5434 acc2 = fma(b0, (float8)a2.s1, acc2);
5435#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5436#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5437 acc3 = fma(b0, (float8)a3.s1, acc3);
5438#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5439
5440 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
5441 src_addr.s1 += src1_stride_y;
5442 acc0 = fma(b0, (float8)a0.s2, acc0);
5443#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5444 acc1 = fma(b0, (float8)a1.s2, acc1);
5445#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5446#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5447 acc2 = fma(b0, (float8)a2.s2, acc2);
5448#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5449#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5450 acc3 = fma(b0, (float8)a3.s2, acc3);
5451#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5452
5453 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
5454 src_addr.s1 += src1_stride_y;
5455 acc0 = fma(b0, (float8)a0.s3, acc0);
5456#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5457 acc1 = fma(b0, (float8)a1.s3, acc1);
5458#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5459#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5460 acc2 = fma(b0, (float8)a2.s3, acc2);
5461#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5462#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5463 acc3 = fma(b0, (float8)a3.s3, acc3);
5464#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5465
5466 src_addr.s0 += 4 * sizeof(half);
5467 }
5468
5469 for(; i < (int)COLS_A; ++i)
5470 {
5471#if defined(REINTERPRET_INPUT_AS_3D)
5472 // Load values from matrix A
5473 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
5474#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5475 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
5476#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5477#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5478 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
5479#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5480#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5481 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
5482#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5483#else // defined(REINTERPRET_INPUT_AS_3D)
5484 // Load values from matrix A
5485 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
5486#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5487 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
5488#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5489#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5490 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
5491#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5492#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5493 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
5494#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5495#endif // defined(REINTERPRET_INPUT_AS_3D)
5496
5497 // Load values from matrix B
5498 float8 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
5499
5500 src_addr += (int2)(sizeof(half), src1_stride_y);
5501
5502 // Accumulate
5503 acc0 = fma(b0, (float8)a0, acc0); // b0 * (half8)a0;
5504#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5505 acc1 = fma(b0, (float8)a1, acc1); // b0 * (half8)a1;
5506#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5507#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5508 acc2 = fma(b0, (float8)a2, acc2); // b0 * (half8)a2;
5509#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5510#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5511 acc3 = fma(b0, (float8)a3, acc3); // b0 * (half8)a3;
5512#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5513 }
5514
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005515 int z = get_global_id(2);
5516
5517 // Compute destination address
5518 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
5519
5520 // Compute dst address
5521 __global uchar *dst_addr = offset(&dst, 0, 0);
5522
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005523 uint4 zout = 0;
5524
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005525#if defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005526
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005527 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
5528 // in order to take into account the presence of possible cross plane paddings
5529 //
5530 // | |
5531 // | plane0 |
5532 // | |
5533 // |__________________|
5534 // |******************|
5535 // | cross_plane_pad |
5536 // |******************|
5537 // | |
5538 // | plane1 |
5539 // | |
5540 // |__________________|
5541
5542 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005543 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
5544 zout = min(DEPTH_GEMM3D - 1, zout);
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005545
5546 // Add offset due to the cross plane paddings
5547 zout *= (dst_cross_plane_pad * dst_stride_y);
5548
5549 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
5550 // multiply dst_stride_z by DEPTH_GEMM3D
5551 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005552#else // defined(REINTERPRET_OUTPUT_AS_3D)
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005553 // Add offset for batched GEMM
5554 dst_addr += z * dst_stride_z;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005555#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005556
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005557 // Multiply by the weight of matrix-matrix product and store the result
5558#if defined(ALPHA)
5559 SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, float, acc, ALPHA);
5560#endif // defined(ALPHA)
5561
5562#if defined(BETA)
5563 REPEAT_VAR_INIT_TO_CONST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, uint, zero, 0);
5564
5565#if defined(BROADCAST_BIAS)
5566 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half));
5567
5568 LOAD_BLOCK(1, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
5569
5570 float8 bias_f0 = convert_float8(bias0);
5571
5572#ifndef UNIT_BETA
5573 SCALE_BLOCK(1, float, bias_f, BETA);
5574#endif // UNIT_BIAS
5575
5576 // acc = acc + bias[broadcasted]
5577 ADD_BLOCK_BROADCAST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias_f0);
5578
5579#else // defined(BROADCAST_BIAS)
5580 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (get_global_id(1) *
5581 (uint)NUM_ELEMS_PROCESSED_PER_THREAD_Y * src2_stride_y) + get_global_id(2) * src2_stride_z;
5582
5583 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
5584
5585 float8 bias_f0 = convert_float8(bias0);
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005586#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005587 float8 bias_f1 = convert_float8(bias1);
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005588#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5589#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005590 float8 bias_f2 = convert_float8(bias2);
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005591#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5592#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005593 float8 bias_f3 = convert_float8(bias3);
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005594#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005595
5596#ifndef UNIT_BETA
5597 SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, float, bias_f, BETA);
5598#endif // UNIT_BIAS
5599
5600 // acc = acc + bias
5601 ADD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias_f);
5602
5603#endif // defined(BROADCAST_BIAS)
5604#endif // defined(BETA)
5605
5606 half8 acc_h0 = convert_half8(acc0);
5607#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5608 half8 acc_h1 = convert_half8(acc1);
5609#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5610#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5611 half8 acc_h2 = convert_half8(acc2);
5612#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5613#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5614 half8 acc_h3 = convert_half8(acc3);
5615#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5616
5617#if defined(ACTIVATION_TYPE)
5618 ACTIVATION_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, ACTIVATION_TYPE, half, acc_h, A_VAL, B_VAL);
5619#endif // defined(ACTIVATION_TYPE)
5620
5621 // Store the output block
5622 STORE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 8, half, acc_h, dst_addr, dst_stride_y, zout.s);
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005623}
5624
5625/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
5626 *
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005627 * @note This OpenCL kernel works with the 16-bit floating point data type (half) and uses the fma units.
5628 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
5629 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
5630 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
5631 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005632 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
5633 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005634 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005635 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
5636 * The activation function is performed after the bias addition
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005637 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
5638 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005639 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
5640 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
5641 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
5642 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
5643 *
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005644 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
5645 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
5646 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5647 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
5648 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
5649 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
5650 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
5651 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
5652 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5653 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
5654 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
5655 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005656 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
5657 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
5658 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
5659 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
5660 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
5661 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005662 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
5663 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
5664 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
5665 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
5666 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
5667 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005668 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
5669 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005670 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005671 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005672 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
5673 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005674 */
5675__kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0),
5676 IMAGE_DECLARATION(src1),
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005677#if defined(BETA)
5678 IMAGE_DECLARATION(src2),
5679#endif // defined(BETA)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005680 IMAGE_DECLARATION(dst),
5681 uint src0_stride_z,
5682 uint src1_stride_z,
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005683#if defined(BETA)
5684 uint src2_stride_z,
5685#endif //defined(BETA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005686 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005687#if defined(REINTERPRET_INPUT_AS_3D)
5688 ,
5689 uint src_cross_plane_pad
5690#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005691#if defined(REINTERPRET_OUTPUT_AS_3D)
5692 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005693 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005694#endif // REINTERPRET_OUTPUT_AS_3D
5695 )
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005696{
5697 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
5698
5699 // Compute starting address for matrix A and Matrix B
5700 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
5701
5702 // Update address for the matrix A
5703 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
5704
5705 // Update address for the matrix B
5706 src_addr.s1 += idx * sizeof(half);
5707
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005708#if defined(REINTERPRET_INPUT_AS_3D)
5709 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
5710 // in order to take into account the presence of possible cross plane paddings
5711 //
5712 // | |
5713 // | plane0 |
5714 // | |
5715 // |__________________|
5716 // |******************|
5717 // | cross_plane_pad |
5718 // |******************|
5719 // | |
5720 // | plane1 |
5721 // | |
5722 // |__________________|
5723
5724 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
5725 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
5726 zin = min(DEPTH_GEMM3D - 1, zin);
5727
5728 // Add offset due to the cross plane paddings
5729 zin *= (src_cross_plane_pad * src0_stride_y);
5730
5731 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
5732 // multiply src0_stride_z by DEPTH_GEMM3D
5733 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
5734
5735#else // defined(REINTERPRET_INPUT_AS_3D)
5736
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005737 // Add offset for batched GEMM
5738 src_addr.s0 += get_global_id(2) * src0_stride_z;
5739
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005740#endif // defined(REINTERPRET_INPUT_AS_3D)
5741
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005742#if defined(MATRIX_B_DEPTH)
5743 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
5744 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
5745#else // defined(MATRIX_B_DEPTH)
5746 src_addr.s1 += get_global_id(2) * src1_stride_z;
5747#endif // defined(MATRIX_B_DEPTH)
5748
5749 half8 acc0 = 0.0h;
5750#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5751 half8 acc1 = 0.0h;
5752#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5753#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5754 half8 acc2 = 0.0h;
5755#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5756#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5757 half8 acc3 = 0.0h;
5758#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5759
5760 int i = 0;
5761 for(; i <= ((int)COLS_A - 4); i += 4)
5762 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005763#if defined(REINTERPRET_INPUT_AS_3D)
5764 // Load values from matrix A
Usama Arif0681e3b2019-04-25 14:28:07 +01005765 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 4, half, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s);
5766#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005767 // Load values from matrix A
5768 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
5769#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5770 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
5771#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5772#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5773 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
5774#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5775#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5776 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
5777#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005778#endif // defined(REINTERPRET_INPUT_AS_3D)
5779
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005780 // Load values from matrix B
5781 half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
5782 src_addr.s1 += src1_stride_y;
5783
5784 // Accumulate
5785 acc0 = fma(b0, (half8)a0.s0, acc0);
5786#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5787 acc1 = fma(b0, (half8)a1.s0, acc1);
5788#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5789#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5790 acc2 = fma(b0, (half8)a2.s0, acc2);
5791#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5792#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5793 acc3 = fma(b0, (half8)a3.s0, acc3);
5794#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5795
5796 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
5797 src_addr.s1 += src1_stride_y;
5798 acc0 = fma(b0, (half8)a0.s1, acc0);
5799#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5800 acc1 = fma(b0, (half8)a1.s1, acc1);
5801#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5802#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5803 acc2 = fma(b0, (half8)a2.s1, acc2);
5804#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5805#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5806 acc3 = fma(b0, (half8)a3.s1, acc3);
5807#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5808
5809 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
5810 src_addr.s1 += src1_stride_y;
5811 acc0 = fma(b0, (half8)a0.s2, acc0);
5812#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5813 acc1 = fma(b0, (half8)a1.s2, acc1);
5814#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5815#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5816 acc2 = fma(b0, (half8)a2.s2, acc2);
5817#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5818#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5819 acc3 = fma(b0, (half8)a3.s2, acc3);
5820#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5821
5822 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
5823 src_addr.s1 += src1_stride_y;
5824 acc0 = fma(b0, (half8)a0.s3, acc0);
5825#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5826 acc1 = fma(b0, (half8)a1.s3, acc1);
5827#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5828#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5829 acc2 = fma(b0, (half8)a2.s3, acc2);
5830#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5831#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5832 acc3 = fma(b0, (half8)a3.s3, acc3);
5833#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5834
5835 src_addr.s0 += 4 * sizeof(half);
5836 }
5837
5838 for(; i < (int)COLS_A; ++i)
5839 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005840#if defined(REINTERPRET_INPUT_AS_3D)
5841 // Load values from matrix A
5842 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
5843#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5844 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
5845#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5846#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5847 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
5848#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5849#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5850 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
5851#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5852#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005853 // Load values from matrix A
5854 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
5855#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5856 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
5857#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5858#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5859 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
5860#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5861#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5862 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
5863#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005864#endif // defined(REINTERPRET_INPUT_AS_3D)
5865
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005866 // Load values from matrix B
5867 half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
5868
5869 src_addr += (int2)(sizeof(half), src1_stride_y);
5870
5871 // Accumulate
5872 acc0 = fma(b0, (half8)a0, acc0); // b0 * (half8)a0;
5873#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5874 acc1 = fma(b0, (half8)a1, acc1); // b0 * (half8)a1;
5875#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5876#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5877 acc2 = fma(b0, (half8)a2, acc2); // b0 * (half8)a2;
5878#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5879#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5880 acc3 = fma(b0, (half8)a3, acc3); // b0 * (half8)a3;
5881#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5882 }
5883
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005884 int z = get_global_id(2);
5885
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005886 // Compute destination address
5887 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
5888
5889 // Compute dst address
5890 __global uchar *dst_addr = offset(&dst, 0, 0);
5891
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005892 uint4 zout = 0;
5893
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005894#if defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005895
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005896 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01005897 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005898 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01005899 // | |
5900 // | plane0 |
5901 // | |
5902 // |__________________|
5903 // |******************|
5904 // | cross_plane_pad |
5905 // |******************|
5906 // | |
5907 // | plane1 |
5908 // | |
5909 // |__________________|
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005910
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005911 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005912 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
5913 zout = min(DEPTH_GEMM3D - 1, zout);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005914
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01005915 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005916 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005917
5918 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
5919 // multiply dst_stride_z by DEPTH_GEMM3D
5920 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005921#else // defined(REINTERPRET_OUTPUT_AS_3D)
5922 // Add offset for batched GEMM
5923 dst_addr += z * dst_stride_z;
5924#endif // defined(REINTERPRET_OUTPUT_AS_3D)
5925
5926 // Multiply by the weight of matrix-matrix product and store the result
5927#if defined(ALPHA)
5928 SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, half, acc, ALPHA);
5929#endif // defined(ALPHA)
5930
5931 // Add beta*bias
5932#if defined(BETA)
5933 REPEAT_VAR_INIT_TO_CONST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, uint, zero, 0);
5934
5935#if defined(BROADCAST_BIAS)
5936 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half));
5937
5938 LOAD_BLOCK(1, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
5939
5940#ifndef UNIT_BETA
5941 SCALE_BLOCK(1, half, bias, BETA);
5942#endif // UNIT_BIAS
5943
5944 // acc = acc + bias[broadcasted]
5945 ADD_BLOCK_BROADCAST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias0);
5946
5947#else // defined(BROADCAST_BIAS)
5948 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (get_global_id(1) *
5949 (uint)NUM_ELEMS_PROCESSED_PER_THREAD_Y * src2_stride_y) + get_global_id(2) * src2_stride_z;
5950
5951 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
5952
5953#ifndef UNIT_BETA
5954 SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, half, bias, BETA);
5955#endif // UNIT_BIAS
5956
5957 // acc = acc + bias
5958 ADD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias);
5959
5960#endif // defined(BROADCAST_BIAS)
5961#endif // defined(BETA)
5962
5963#if defined(ACTIVATION_TYPE)
5964 ACTIVATION_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, ACTIVATION_TYPE, half, acc, A_VAL, B_VAL);
5965#endif // defined(ACTIVATION_TYPE)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005966
5967 // Store the output block
Usama Arif0681e3b2019-04-25 14:28:07 +01005968 STORE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 8, half, acc, dst_addr, dst_stride_y, zout.s);
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005969}
Vidhya Sudhan Loganathanbdff4912018-05-22 15:03:09 +01005970#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005971
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01005972#endif // defined(COLS_A) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && (NUM_ELEMS_PROCESSED_PER_THREAD_Y)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005973
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005974#if defined(BETA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005975/** This OpenCL kernel performs the in-place matrix addition between 2 matrices taking into account that the second matrix might be weighted by a scalar value beta:
5976 *
Gian Marco19835e52018-01-30 13:35:54 +00005977 * @note The beta's value need to be passed at compile time using -DBETA
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005978 *
5979 * @param[in] src_ptr Pointer to the source matrix. Supported data types: F32
5980 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
5981 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5982 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
5983 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005984 * @param[in] src_stride_z Stride of the destination tensor in Z dimension (in bytes)
5985 * @param[in] src_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005986 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01005987 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005988 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
5989 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
5990 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
5991 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005992 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
5993 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005994 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
5995 */
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005996__kernel void gemm_ma_f32(TENSOR3D_DECLARATION(src),
5997 TENSOR3D_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005998{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005999 // Compute source and destination addresses
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006000 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
6001 Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006002
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006003 // Load values from A x B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006004 float4 alpha_ab = vload4(0, (__global float *)dst.ptr);
6005
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006006 // Load values from Matrix C
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006007 float4 c = vload4(0, (__global float *)src.ptr);
6008
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006009 // Computes alpha * axb + beta * c
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006010 float4 out = alpha_ab + (float4)BETA * c;
6011
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006012 // Store final result in axb matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006013 vstore4(out, 0, (__global float *)dst.ptr);
6014}
6015
Vidhya Sudhan Loganathan76c85642018-05-25 13:53:02 +01006016#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006017/** This OpenCL kernel performs the in-place matrix addition between 2 matrices taking into account that the second matrix might be weighted by a scalar value beta:
6018 *
Gian Marco19835e52018-01-30 13:35:54 +00006019 * @note The beta's value need to be passed at compile time using -DBETA
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01006020 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006021 * @param[in] src_ptr Pointer to the source matrix. Supported data types: F16
6022 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
6023 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
6024 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
6025 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006026 * @param[in] src_stride_z Stride of the destination tensor in Z dimension (in bytes)
6027 * @param[in] src_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006028 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01006029 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006030 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
6031 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
6032 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
6033 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006034 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
6035 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006036 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
6037 */
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006038__kernel void gemm_ma_f16(TENSOR3D_DECLARATION(src),
6039 TENSOR3D_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006040{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006041 // Compute source and destination addresses
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006042 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
6043 Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006044
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006045 // Load values from A x B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006046 half8 alpha_ab = vload8(0, (__global half *)dst.ptr);
6047
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006048 // Load values from Matrix C
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006049 half8 c = vload8(0, (__global half *)src.ptr);
6050
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006051 // Computes alpha * axb + beta * c
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006052 half8 out = alpha_ab + (half8)BETA * c;
6053
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006054 // Store final result in axb matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006055 vstore8(out, 0, (__global half *)dst.ptr);
6056}
Vidhya Sudhan Loganathan76c85642018-05-25 13:53:02 +01006057#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006058#endif // defined(BETA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006059
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006060#if defined(WIDTH_VECTOR_A)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006061/** This OpenCL kernel computes the vector by matrix multiplication between each row of A (src0) and matrix B (src1) used for locally connected layer
6062 *
Gian Marco19835e52018-01-30 13:35:54 +00006063 * @note The width of A need to be passed at compile time using -DWIDTH_VECTOR_A
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006064 *
Gian Marco19835e52018-01-30 13:35:54 +00006065 * @note The input A and matrix B must not be reshaped
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006066 *
6067 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
6068 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
6069 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
6070 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
6071 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
6072 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01006073 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006074 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
6075 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
6076 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
6077 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
6078 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
6079 * @param[in] src1_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
6080 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01006081 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006082 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
6083 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
6084 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
6085 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
6086 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
6087 */
6088__kernel void gemm_lc_vm_f32(IMAGE_DECLARATION(src0),
6089 TENSOR3D_DECLARATION(src1),
6090 IMAGE_DECLARATION(dst))
6091{
6092 int idx = get_global_id(0) * 4;
6093 int idy = get_global_id(1);
6094
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006095 // Compute the address for the vector A and matrix B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006096 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes + src0_stride_y * idy, src1_offset_first_element_in_bytes + src1_stride_z * idy));
6097 src_addr.s1 += idx * sizeof(float);
6098
6099 int end_row_vec_a = src_addr.s0 + (WIDTH_VECTOR_A * sizeof(float));
6100
6101 float4 acc = 0.0f;
6102
Georgios Pinitas96880cf2017-10-20 18:52:20 +01006103 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(float)); src_addr += (int2)(2 * sizeof(float), 2 * src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006104 {
6105 float2 a0 = vload2(0, (__global float *)(src0_ptr + src_addr.s0));
6106 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
6107 float4 b1 = vload4(0, (__global float *)(src1_ptr + src_addr.s1 + src1_stride_y));
6108
6109 acc += b0 * (float4)a0.s0;
6110 acc += b1 * (float4)a0.s1;
6111 }
6112
6113 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(float), src1_stride_y))
6114 {
6115 float a0 = *((__global float *)(src0_ptr + src_addr.s0));
6116 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
6117
6118 acc += b0 * (float4)a0;
6119 }
6120
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006121 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006122 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
6123
6124 vstore4(acc, 0, (__global float *)(offset(&dst, 0, 0)));
6125}
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006126#endif // defined(WIDTH_VECTOR_A)
6127
6128/** This kernel accumulates each row with the biases vector.
6129 *
6130 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=short.
6131 * @note The vector size must be passed at compile time using -DVECTOR_SIZE e.g. -DVECTOR_SIZE=16.
6132 *
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01006133 * @param[in, out] accum_ptr Pointer to the accumulate tensor. Supported data type: U8/S8/U16/S16/F16/U32/S32/F32
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006134 * @param[in] accum_stride_x Stride of the accmulate tensor in X dimension (in bytes)
6135 * @param[in] accum_step_x accum_stride_x * number of elements along X processed per workitem(in bytes)
6136 * @param[in] accum_stride_y Stride of the accumlulate tensor in Y dimension (in bytes)
6137 * @param[in] accum_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
6138 * @param[in] accum_offset_first_element_in_bytes The offset of the first element in the accumulate tensor
6139 * @param[in] biases_ptr Pointer to the biases vector. Same as @p accum_ptr
6140 * @param[in] biases_stride_x Stride of the destination tensor in X dimension (in bytes)
6141 * @param[in] biases_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
6142 * @param[in] biases_offset_first_element_in_bytes The offset of the first element in the destination tensor
6143 */
6144#if defined(DATA_TYPE) && defined(VECTOR_SIZE)
6145__kernel void gemm_accumulate_biases(
6146 IMAGE_DECLARATION(accum),
6147 VECTOR_DECLARATION(biases))
6148{
6149 Image accum = CONVERT_TO_IMAGE_STRUCT(accum);
6150 Vector biases = CONVERT_TO_VECTOR_STRUCT(biases);
6151
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006152 // Vector size, e.g. number of vector elements.
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006153 VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
6154 accum_value = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)accum.ptr);
6155 VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
6156 biases_value = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)biases.ptr);
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01006157 accum_value = biases_value + accum_value;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006158 // Store result in the accumulate buffer
6159 VSTORE(VECTOR_SIZE)
6160 (accum_value, 0, (__global DATA_TYPE *)accum.ptr);
6161}
6162#endif // defined(DATA_TYPE) && defined(VECTOR_SIZE)