blob: 41e5c338b309df5c86054fdb3457bb22757e0864 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +00002 * Copyright (c) 2017-2019 ARM Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Usama Arif0681e3b2019-04-25 14:28:07 +010024#include "gemm_helpers.h"
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +000025#include "repeat.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010026
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +000027#if defined(M0) && defined(K0) && defined(V0) && defined(DATA_TYPE) && defined(SRC_WIDTH)
28#define INC2 (VEC_DATA_TYPE(uint, 2))(0, 1)
29#define INC3 (VEC_DATA_TYPE(uint, 3))(0, 1, 2)
30#define INC4 (VEC_DATA_TYPE(uint, 4))(0, 1, 2, 3)
31#define INC8 (VEC_DATA_TYPE(uint, 8))(0, 1, 2, 3, 4, 5, 6, 7)
32#define INC16 (VEC_DATA_TYPE(uint, 16))(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)
33#define CONCAT_INC(K0) INC##K0
34#define INC(K0) CONCAT_INC(K0)
35
36#if(SRC_WIDTH % K0)
37#define BOUNDARY_CONDITION_X(x, a) \
38 ({ \
39 a = select(0, a, CONVERT(((x * (VEC_DATA_TYPE(uint, K0))K0 + INC(K0)) < (VEC_DATA_TYPE(uint, K0))SRC_WIDTH), VEC_DATA_TYPE(DATA_TYPE, K0))); \
40 })
41#else // (SRC_WIDTH % K0)
42#define BOUNDARY_CONDITION_X(x, a) \
43 ({})
44#endif // (SRC_WIDTH % K0)
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000045
46/** This OpenCL kernel reshapes the lhs input matrix. The kernel splits the input matrix in blocks of size M0xK0 and stores each one (not transposed) in
47 * the output matrix unrolling the values.
48 *
49 * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float)
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +000050 * @note The width of the input tensor must be passed at compile time using -DSRC_WIDTH (i.e. -DSRC_WIDTH=16)
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000051 * @note The block's dimensions (M0 and K0) must be passed at compile time using -DM0 and -DK0 (i.e. -DM0=2, -DK0=2).
52 * @note The number of M0xK0 vertical blocks to store on the same output row must be passed at compile time using -DV0 (i.e. -DV0=2)
53 * @note Only the following values for M0, K0 and V0 are supported:
54 * M0: 2,3,4,5,6,7,8
Gian Marco Iodicebacfec52019-01-11 11:30:55 +000055 * K0: 2,3,4,8,16
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000056 * V0: greater than 0
57 * @note In case the input has to be reinterpreted as a 3D tensor (i.e. input of convolution layer 1x1), the following information must be passed at compile time:
58 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
59 * -# HEIGHT_GEMM3D: The height of the input in case it has to be reinterpreted as a 3D tensor.
60 * -# DEPTH_GEMM3D: The depth of the input in case it has to be reinterpreted as a 3D tensor
61 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
62 * @note If the M0xK0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
63 *
64 * @param[in] src_ptr Pointer to the source LHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
65 * @param[in] src_stride_x Stride of the source LHS tensor in X dimension (in bytes)
66 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
67 * @param[in] src_stride_y Stride of the source LHS tensor in Y dimension (in bytes)
68 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
69 * @param[in] src_stride_z Stride of the source LHS tensor in Z dimension (in bytes)
70 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
71 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source LHS tensor
72 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
73 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
74 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
75 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
76 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
77 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
78 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
79 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
80 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
81 */
82__kernel void gemm_reshape_lhs_matrix_nt(TENSOR3D_DECLARATION(src),
83 TENSOR3D_DECLARATION(dst)
84#if defined(REINTERPRET_INPUT_AS_3D)
85 ,
86 uint cross_plane_pad
87#endif // REINTERPRET_INPUT_AS_3D
88 )
89{
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +000090 // Block size
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000091#define BLOCK_SIZE ((M0) * (K0))
92
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +000093 // Output offset X
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000094#if defined(INTERLEAVE)
95#define OUTPUT_OFFSET_X (K0)
96#else // defined(INTERLEAVE)
97#define OUTPUT_OFFSET_X (BLOCK_SIZE)
98#endif // defined(INTERLEAVE)
99
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000100 // Output step X
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000101#if defined(INTERLEAVE)
102#define OUTPUT_STEP_X (K0) * (V0)
103#else // Do not interleave
104#define OUTPUT_STEP_X (K0)
105#endif // defined(INTERLEAVE)
106
107 // Compute source and destination addresses
108 uint x = get_global_id(0);
109 uint y = get_global_id(1);
110 uint z = get_global_id(2);
111
112 // ------------------ Compute input/output addresses ---------------------------
113
114 // Compute the input address
115 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)K0 * sizeof(DATA_TYPE) + y * (uint)M0 * src_stride_y;
116
117 // Compute the output address
118 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)BLOCK_SIZE * (uint)V0 * sizeof(DATA_TYPE)) + ((y / (uint)V0) * (uint)dst_stride_y) + ((y % V0) *
119 (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE));
120
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000121 // Create variables: uint zin0=0, zin1=0, zin2=0...zin(M0-1)=0;
122 REPEAT_VAR_INIT_TO_CONST(M0, uint, zin, 0);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000123
124#if defined(REINTERPRET_INPUT_AS_3D)
125 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
126 // multiply src_stride_z by DEPTH_GEMM3D
127
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000128 input_ptr += z * (uint)src_stride_z * DEPTH_GEMM3D;
129
130 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
Usama Arif0681e3b2019-04-25 14:28:07 +0100131 CALCULATE_Z_OFFSET(M0, uint, zin, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, cross_plane_pad, src_stride_y);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000132
133#else // defined(REINTERPRET_INPUT_AS_3D)
134
135 input_ptr += z * (uint)src_stride_z;
136
137#endif // defined(REINTERPRET_INPUT_AS_3D)
138
139 // Add offset for batched GEMM
140 output_ptr += z * (uint)dst_stride_z;
141
142 // ---------------------------Load input values --------------------------------
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000143 // Load values from the LHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +0100144 LOAD_BLOCK(M0, K0, DATA_TYPE, a, input_ptr, 0, src_stride_y, zin);
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000145 BOUNDARY_CONDITION_X(x, a0);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000146#if M0 > 1
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000147 BOUNDARY_CONDITION_X(x, a1);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000148#endif // M0 > 1
149#if M0 > 2
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000150 BOUNDARY_CONDITION_X(x, a2);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000151#endif // M0 > 2
152#if M0 > 3
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000153 BOUNDARY_CONDITION_X(x, a3);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000154#endif // M0 > 3
155#if M0 > 4
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000156 BOUNDARY_CONDITION_X(x, a4);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000157#endif // M0 > 4
158#if M0 > 5
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000159 BOUNDARY_CONDITION_X(x, a5);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000160#endif // M0 > 5
161#if M0 > 6
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000162 BOUNDARY_CONDITION_X(x, a6);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000163#endif // M0 > 6
164#if M0 > 7
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000165 BOUNDARY_CONDITION_X(x, a7);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000166#endif // M0 > 7
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000167 // ---------------------------Store output values ------------------------------
Usama Arif0681e3b2019-04-25 14:28:07 +0100168 REPEAT_VAR_INIT_TO_CONST(16, uint, zout, 0);
169 STORE_BLOCK(M0, K0, DATA_TYPE, a, output_ptr, OUTPUT_STEP_X * sizeof(DATA_TYPE), zout);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000170
171#undef BLOCK_SIZE
172#undef OUTPUT_OFFSET_X
173#undef OUTPUT_STEP_X
174}
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000175
176#if M0 == 2
177#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
178 ({ \
179 VEC_DATA_TYPE(DATA_TYPE, M0) \
180 res = (VEC_DATA_TYPE(DATA_TYPE, M0))(a0.s##i, a1.s##i); \
181 VSTORE(M0) \
182 (res, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
183 })
184#elif M0 == 3 // M0 == 3
185#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
186 ({ \
187 VEC_DATA_TYPE(DATA_TYPE, M0) \
188 res = (VEC_DATA_TYPE(DATA_TYPE, M0))(a0.s##i, a1.s##i, a2.s##i); \
189 VSTORE(M0) \
190 (res, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
191 })
192#elif M0 == 4 // M0 == 4
193#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
194 ({ \
195 VEC_DATA_TYPE(DATA_TYPE, M0) \
196 res = (VEC_DATA_TYPE(DATA_TYPE, M0))(a0.s##i, a1.s##i, a2.s##i, a3.s##i); \
197 VSTORE(M0) \
198 (res, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
199 })
200#elif M0 == 5 // M0 == 5
201#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
202 ({ \
203 VEC_DATA_TYPE(DATA_TYPE, 4) \
204 res0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s##i, a1.s##i, a2.s##i, a3.s##i); \
205 DATA_TYPE res1 = a4.s##i; \
206 VSTORE(4) \
207 (res0, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
208 *((__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE)) + 4) = res1; \
209 })
210#elif M0 == 6 // M0 == 6
211#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
212 ({ \
213 VEC_DATA_TYPE(DATA_TYPE, 4) \
214 res0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s##i, a1.s##i, a2.s##i, a3.s##i); \
215 VEC_DATA_TYPE(DATA_TYPE, 2) \
216 res1 = (VEC_DATA_TYPE(DATA_TYPE, 2))(a4.s##i, a5.s##i); \
217 VSTORE(4) \
218 (res0, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
219 VSTORE(2) \
220 (res1, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE)) + 4); \
221 })
222#elif M0 == 7 // M0 == 7
223#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
224 ({ \
225 VEC_DATA_TYPE(DATA_TYPE, 4) \
226 res0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s##i, a1.s##i, a2.s##i, a3.s##i); \
227 VEC_DATA_TYPE(DATA_TYPE, 3) \
228 res1 = (VEC_DATA_TYPE(DATA_TYPE, 3))(a4.s##i, a5.s##i, a6.s##i); \
229 VSTORE(4) \
230 (res0, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
231 VSTORE(3) \
232 (res1, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE)) + 4); \
233 })
234#elif M0 == 8 // M0 == 8
235#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
236 ({ \
237 VEC_DATA_TYPE(DATA_TYPE, M0) \
238 res = (VEC_DATA_TYPE(DATA_TYPE, M0))(a0.s##i, a1.s##i, a2.s##i, a3.s##i, a4.s##i, a5.s##i, a6.s##i, a7.s##i); \
239 VSTORE(M0) \
240 (res, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
241 })
242#else // M0 not supported
243#error "M0 value not supported"
244#endif // N0 conditions
245
246/** This OpenCL kernel reshapes the lhs input matrix. The kernel splits the input matrix in blocks of size M0xK0 and stores each one (transposed) in
247 * the output matrix unrolling the values.
248 *
249 * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float)
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000250 * @note The width of the input tensor must be passed at compile time using -DSRC_WIDTH (i.e. -DSRC_WIDTH=16)
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000251 * @note The block's dimensions (M0 and K0) must be passed at compile time using -DM0 and -DK0 (i.e. -DM0=2, -DK0=2).
252 * @note The number of M0xK0 vertical blocks to store on the same output row must be passed at compile time using -DV0 (i.e. -DV0=2)
253 * @note Only the following values for M0, K0 and V0 are supported:
254 * M0: 2,3,4,5,6,7,8
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000255 * K0: 2,3,4,8,16
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000256 * V0: greater than 0
257 * @note In case the input has to be reinterpreted as a 3D tensor (i.e. input of convolution layer 1x1), the following information must be passed at compile time:
258 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
259 * -# HEIGHT_GEMM3D: The height of the input in case it has to be reinterpreted as a 3D tensor.
260 * -# DEPTH_GEMM3D: The depth of the input in case it has to be reinterpreted as a 3D tensor
261 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
262 * @note If the M0xK0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
263 *
264 * @param[in] src_ptr Pointer to the source LHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
265 * @param[in] src_stride_x Stride of the source LHS tensor in X dimension (in bytes)
266 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
267 * @param[in] src_stride_y Stride of the source LHS tensor in Y dimension (in bytes)
268 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
269 * @param[in] src_stride_z Stride of the source LHS tensor in Z dimension (in bytes)
270 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
271 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source LHS tensor
272 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
273 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
274 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
275 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
276 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
277 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
278 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
279 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
280 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
281 */
282__kernel void gemm_reshape_lhs_matrix_t(TENSOR3D_DECLARATION(src),
283 TENSOR3D_DECLARATION(dst)
284#if defined(REINTERPRET_INPUT_AS_3D)
285 ,
286 uint cross_plane_pad
287#endif // REINTERPRET_INPUT_AS_3D
288 )
289{
290 // Block size
291#define BLOCK_SIZE ((M0) * (K0))
292
293 // Output offset X
294#if defined(INTERLEAVE)
295#define OUTPUT_OFFSET_X (M0)
296#else // defined(INTERLEAVE)
297#define OUTPUT_OFFSET_X (BLOCK_SIZE)
298#endif // defined(INTERLEAVE)
299
300 // Output step X
301#if defined(INTERLEAVE)
302#define OUTPUT_STEP_X (M0) * (V0)
303#else // Do not interleave
304#define OUTPUT_STEP_X (M0)
305#endif // defined(INTERLEAVE)
306
307 // Compute source and destination addresses
308 uint x = get_global_id(0);
309 uint y = get_global_id(1);
310 uint z = get_global_id(2);
311
312 // ------------------ Compute input/output addresses ---------------------------
313
314 // Compute the input address
315 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)K0 * sizeof(DATA_TYPE) + y * (uint)M0 * src_stride_y;
316
317 // Compute the output address
318 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)BLOCK_SIZE * (uint)V0 * sizeof(DATA_TYPE)) + ((y / (uint)V0) * (uint)dst_stride_y) + ((y % V0) *
319 (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE));
320
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000321 // Create variables: uint zin0=0, zin1=0, zin2=0...zin(M0-1)=0;
322 REPEAT_VAR_INIT_TO_CONST(M0, uint, zin, 0);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000323
324#if defined(REINTERPRET_INPUT_AS_3D)
325 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
326 // multiply src_stride_z by DEPTH_GEMM3D
327
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000328 input_ptr += z * (uint)src_stride_z * DEPTH_GEMM3D;
329
330 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
Usama Arif0681e3b2019-04-25 14:28:07 +0100331 CALCULATE_Z_OFFSET(M0, uint, zin, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, cross_plane_pad, src_stride_y);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000332
333#else // defined(REINTERPRET_INPUT_AS_3D)
334
335 input_ptr += z * (uint)src_stride_z;
336
337#endif // defined(REINTERPRET_INPUT_AS_3D)
338
339 // Add offset for batched GEMM
340 output_ptr += z * (uint)dst_stride_z;
341
342 // ---------------------------Load input values --------------------------------
343
344 // Load values from the LHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +0100345 LOAD_BLOCK(M0, K0, DATA_TYPE, a, input_ptr, 0, src_stride_y, zin);
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000346 BOUNDARY_CONDITION_X(x, a0);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000347#if M0 > 1
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000348 BOUNDARY_CONDITION_X(x, a1);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000349#endif // M0 > 1
350#if M0 > 2
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000351 BOUNDARY_CONDITION_X(x, a2);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000352#endif // M0 > 2
353#if M0 > 3
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000354 BOUNDARY_CONDITION_X(x, a3);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000355#endif // M0 > 3
356#if M0 > 4
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000357 BOUNDARY_CONDITION_X(x, a4);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000358#endif // M0 > 4
359#if M0 > 5
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000360 BOUNDARY_CONDITION_X(x, a5);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000361#endif // M0 > 5
362#if M0 > 6
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000363 BOUNDARY_CONDITION_X(x, a6);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000364#endif // M0 > 6
365#if M0 > 7
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000366 BOUNDARY_CONDITION_X(x, a7);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000367#endif // M0 > 7
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000368 // ---------------------------Transpose and store block -----------------------
369
370 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 0);
371 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 1);
372#if K0 > 2
373 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 2);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000374#endif // K0 > 2
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000375#if K0 > 3
376 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 3);
377#endif // K0 > 3
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000378#if K0 > 4
379 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 4);
380 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 5);
381 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 6);
382 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 7);
383#endif // K0 > 4
384#if K0 > 8
385 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 8);
386 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 9);
387 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, A);
388 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, B);
389 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, C);
390 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, D);
391 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, E);
392 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, F);
393#endif // K0 > 8
394
395#undef BLOCK_SIZE
396#undef OUTPUT_OFFSET_X
397#undef OUTPUT_STEP_X
398}
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000399#endif // defined(M0) && defined(K0) && defined(V0) && defined(DATA_TYPE) && defined(SRC_WIDTH)
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000400
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000401#if defined(K0) && defined(N0) && defined(H0) && defined(DATA_TYPE) && defined(SRC_HEIGHT)
402/** This OpenCL kernel reshapes the rhs input matrix. The kernel splits the input matrix in blocks of size K0xN0 and stores each one (not transposed) in
403 * the output matrix unrolling the values.
404 *
405 * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float)
406 * @note The height of the input tensor must be passed at compile time using -DSRC_HEIGHT (i.e. -DSRC_HEIGHT=16)
407 * @note The block's dimensions (K0 and N0) must be passed at compile time using -DK0 and -DN0 (i.e. -DK0=2, -DN0=2).
408 * @note The number of K0xN0 vertical blocks to store on the same output row must be passed at compile time using -DH0 (i.e. -DH0=2)
409 * @note If the K0xN0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
410 * @note Only the following values for K0, N0 and H0 are supported:
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000411 * N0: 2,3,4,8,16
412 * K0: 1,2,3,4,8,16
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000413 * H0: greater than 0
414 *
415 * @param[in] src_ptr Pointer to the source RHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
416 * @param[in] src_stride_x Stride of the source RHS tensor in X dimension (in bytes)
417 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
418 * @param[in] src_stride_y Stride of the source RHS tensor in Y dimension (in bytes)
419 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
420 * @param[in] src_stride_z Stride of the source RHS tensor in Z dimension (in bytes)
421 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
422 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source RHS tensor
423 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
424 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
425 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
426 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
427 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
428 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
429 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
430 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
431 */
432__kernel void gemm_reshape_rhs_matrix_nt(TENSOR3D_DECLARATION(src),
433 TENSOR3D_DECLARATION(dst))
434{
435 // Block size
436#define BLOCK_SIZE ((K0) * (N0))
437
438 // Output offset X
439#if defined(INTERLEAVE)
440#define OUTPUT_OFFSET_X (N0)
441#else // defined(INTERLEAVE)
442#define OUTPUT_OFFSET_X (BLOCK_SIZE)
443#endif // defined(INTERLEAVE)
444
445 // Output step X
446#if defined(INTERLEAVE)
447#define OUTPUT_STEP_X (N0) * (H0)
448#else // Do not interleave
449#define OUTPUT_STEP_X (N0)
450#endif // defined(INTERLEAVE)
451
452 // Compute source and destination addresses
453 uint x = get_global_id(0);
454 uint y = get_global_id(1);
455 uint z = get_global_id(2);
456
457 // ------------------ Compute input/output addresses ---------------------------
458
459 // Compute the input address
460 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)N0 * sizeof(DATA_TYPE) + y * (uint)K0 * src_stride_y + z * (uint)src_stride_z;
461
462 // Compute the output address
463 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (y * (uint)BLOCK_SIZE * (uint)H0 * sizeof(DATA_TYPE)) + ((x % (uint)H0) * (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE)) + ((
464 x / (uint)H0)
465 * (uint)dst_stride_y)
466 + z * (uint)dst_stride_z;
467
468 // ---------------------------Load input values --------------------------------
469
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +0000470 REPEAT_VAR_INIT_TO_CONST(K0, VEC_DATA_TYPE(DATA_TYPE, N0), a, 0); ////uint a0=0, a1=0, a2=0...a(M0-1)=0;
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000471
472 // Load values from the RHS matrix
473 a0 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y));
474#if K0 > 1
475 if(y * (uint)K0 + 1 < SRC_HEIGHT)
476 {
477 a1 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y));
478 }
479#endif // K0 > 1
480#if K0 > 2
481 if(y * (uint)K0 + 2 < SRC_HEIGHT)
482 {
483 a2 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y));
484 }
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000485#endif // K0 > 2
486#if K0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000487 if(y * (uint)K0 + 3 < SRC_HEIGHT)
488 {
489 a3 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y));
490 }
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000491#endif // K0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000492#if K0 > 4
493 if(y * (uint)K0 + 4 < SRC_HEIGHT)
494 {
495 a4 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 4 * src_stride_y));
496 }
497 if(y * (uint)K0 + 5 < SRC_HEIGHT)
498 {
499 a5 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 5 * src_stride_y));
500 }
501 if(y * (uint)K0 + 6 < SRC_HEIGHT)
502 {
503 a6 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 6 * src_stride_y));
504 }
505 if(y * (uint)K0 + 7 < SRC_HEIGHT)
506 {
507 a7 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 7 * src_stride_y));
508 }
509#endif // K0 > 4
510#if K0 > 8
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000511 if(y * (uint)K0 + 8 < SRC_HEIGHT)
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000512 {
513 a8 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 8 * src_stride_y));
514 }
515 if(y * (uint)K0 + 9 < SRC_HEIGHT)
516 {
517 a9 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 9 * src_stride_y));
518 }
519 if(y * (uint)K0 + 10 < SRC_HEIGHT)
520 {
521 aA = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 10 * src_stride_y));
522 }
523 if(y * (uint)K0 + 11 < SRC_HEIGHT)
524 {
525 aB = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 11 * src_stride_y));
526 }
527 if(y * (uint)K0 + 12 < SRC_HEIGHT)
528 {
529 aC = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 12 * src_stride_y));
530 }
531 if(y * (uint)K0 + 13 < SRC_HEIGHT)
532 {
533 aD = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 13 * src_stride_y));
534 }
535 if(y * (uint)K0 + 14 < SRC_HEIGHT)
536 {
537 aE = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 14 * src_stride_y));
538 }
539 if(y * (uint)K0 + 15 < SRC_HEIGHT)
540 {
541 aF = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 15 * src_stride_y));
542 }
543#endif // K0 > 8
544
545 // ---------------------------Store output values ------------------------------
Usama Arif0681e3b2019-04-25 14:28:07 +0100546 REPEAT_VAR_INIT_TO_CONST(16, uint, zout, 0);
547 STORE_BLOCK(K0, N0, DATA_TYPE, a, output_ptr, OUTPUT_STEP_X * sizeof(DATA_TYPE), zout);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000548
549#undef BLOCK_SIZE
550#undef OUTPUT_OFFSET_X
551#undef OUTPUT_STEP_X
552}
553
554#if defined(TRANSPOSE)
555/** This OpenCL kernel reshapes the rhs input matrix. The kernel splits the input matrix in blocks of size K0xN0 and stores each one (transposed) in
556 * the output matrix unrolling the values.
557 *
558 * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float)
559 * @note The height of the input tensor must be passed at compile time using -DSRC_HEIGHT (i.e. -DSRC_HEIGHT=16)
560 * @note The block's dimensions (K0 and N0) must be passed at compile time using -DK0 and -DN0 (i.e. -DK0=2, -DN0=2).
561 * @note The number of K0xN0 vertical blocks to store on the same output row must be passed at compile time using -DH0 (i.e. -DH0=2)
562 * @note If the K0xN0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
563 * @note The option -DTRANSPOSE must passed at compile time.
564 * @note Only the following values for K0, N0 and H0 are supported:
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000565 * N0: 2,3,4,8,16
566 * K0: 2,3,4,8,16
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000567 * H0: greater than 0
568 *
569 * @param[in] src_ptr Pointer to the source RHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
570 * @param[in] src_stride_x Stride of the source RHS tensor in X dimension (in bytes)
571 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
572 * @param[in] src_stride_y Stride of the source RHS tensor in Y dimension (in bytes)
573 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
574 * @param[in] src_stride_z Stride of the source RHS tensor in Z dimension (in bytes)
575 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
576 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source RHS tensor
577 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
578 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
579 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
580 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
581 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
582 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
583 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
584 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
585 */
586__kernel void gemm_reshape_rhs_matrix_t(TENSOR3D_DECLARATION(src),
587 TENSOR3D_DECLARATION(dst))
588{
589 // Block size
590#define BLOCK_SIZE ((K0) * (N0))
591
592 // Output offset X
593#if defined(INTERLEAVE)
594#define OUTPUT_OFFSET_X (K0)
595#else // defined(INTERLEAVE)
596#define OUTPUT_OFFSET_X (BLOCK_SIZE)
597#endif // defined(INTERLEAVE)
598
599 // Output step X
600#if defined(INTERLEAVE)
601#define OUTPUT_STEP_X (K0) * (H0)
602#else // Do not interleave
603#define OUTPUT_STEP_X (K0)
604#endif // defined(INTERLEAVE)
605
606 // Compute source and destination addresses
607 uint x = get_global_id(0);
608 uint y = get_global_id(1);
609 uint z = get_global_id(2);
610
611 // ------------------ Compute input/output addresses ---------------------------
612
613 // Compute the input address
614 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)N0 * sizeof(DATA_TYPE) + y * (uint)K0 * src_stride_y + z * (uint)src_stride_z;
615
616 // Compute the output address
617 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (y * (uint)BLOCK_SIZE * (uint)H0 * sizeof(DATA_TYPE)) + ((x % H0) * (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE)) + ((x /
618 (uint)H0) * (uint)dst_stride_y) + z * (uint)dst_stride_z;
619
620 // ---------------------------Load input values --------------------------------
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +0000621 REPEAT_VAR_INIT_TO_CONST(K0, VEC_DATA_TYPE(DATA_TYPE, N0), a, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) a0=0, a1=0, ... a(K0-1)=0;
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000622
623 // Load values from the RHS matrix
624 a0 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y));
625 if(y * (uint)K0 + 1 < SRC_HEIGHT)
626 {
627 a1 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y));
628 }
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000629#if K0 > 2
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000630 if(y * (uint)K0 + 2 < SRC_HEIGHT)
631 {
632 a2 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y));
633 }
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000634#endif // K0 > 2
635#if K0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000636 if(y * (uint)K0 + 3 < SRC_HEIGHT)
637 {
638 a3 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y));
639 }
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000640#endif // K0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000641#if K0 > 4
642 if(y * (uint)K0 + 4 < SRC_HEIGHT)
643 {
644 a4 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 4 * src_stride_y));
645 }
646 if(y * (uint)K0 + 5 < SRC_HEIGHT)
647 {
648 a5 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 5 * src_stride_y));
649 }
650 if(y * (uint)K0 + 6 < SRC_HEIGHT)
651 {
652 a6 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 6 * src_stride_y));
653 }
654 if(y * (uint)K0 + 7 < SRC_HEIGHT)
655 {
656 a7 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 7 * src_stride_y));
657 }
658#endif // K0 > 4
659#if K0 > 8
Gian Marco Iodice89124342018-12-19 14:17:22 +0000660 if(y * (uint)K0 + 8 < SRC_HEIGHT)
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000661 {
662 a8 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 8 * src_stride_y));
663 }
664 if(y * (uint)K0 + 9 < SRC_HEIGHT)
665 {
666 a9 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 9 * src_stride_y));
667 }
668 if(y * (uint)K0 + 10 < SRC_HEIGHT)
669 {
670 aA = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 10 * src_stride_y));
671 }
672 if(y * (uint)K0 + 11 < SRC_HEIGHT)
673 {
674 aB = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 11 * src_stride_y));
675 }
676 if(y * (uint)K0 + 12 < SRC_HEIGHT)
677 {
678 aC = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 12 * src_stride_y));
679 }
680 if(y * (uint)K0 + 13 < SRC_HEIGHT)
681 {
682 aD = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 13 * src_stride_y));
683 }
684 if(y * (uint)K0 + 14 < SRC_HEIGHT)
685 {
686 aE = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 14 * src_stride_y));
687 }
688 if(y * (uint)K0 + 15 < SRC_HEIGHT)
689 {
690 aF = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 15 * src_stride_y));
691 }
692#endif // K0 > 8
693
694 // ---------------------------Transpose the block ------------------------------
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +0000695 REPEAT_VAR_INIT_TO_CONST(N0, VEC_DATA_TYPE(DATA_TYPE, K0), res, 0); //VEC_DATA_TYPE(DATA_TYPE, K0) res0=0, res1=0, res2=0,... res(N0-1)=0;
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000696
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000697#if K0 == 2
698 // This part computes the following transpositions:
699 // 2x2 -> 2x2
700 // 2x4 -> 4x2
701 // 2x8 -> 8x2
702 // 2x16 -> 16x2
703 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0);
704 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1);
705#if N0 > 2
706 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2);
707#endif // N0 > 2
708#if N0 > 3
709 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3);
710#endif // N0 > 3
711#if N0 > 4
712 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4);
713 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5);
714 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6);
715 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7);
716#endif // N0 > 4
717#if N0 > 8
718 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8);
719 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9);
720 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA);
721 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB);
722 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC);
723 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD);
724 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE);
725 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF);
726#endif // N0 > 8
727
728#elif K0 == 3 // K0 == 2
729 // This part computes the following transpositions:
730 // 3x2 -> 2x3
731 // 3x4 -> 4x3
732 // 3x8 -> 8x3
733 // 3x16 -> 16x3
734 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0);
735 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1);
736#if N0 > 2
737 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2);
738#endif // N0 > 2
739#if N0 > 3
740 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3);
741#endif // N0 > 3
742#if N0 > 4
743 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4);
744 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5);
745 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6);
746 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7);
747#endif // N0 > 4
748#if N0 > 8
749 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8);
750 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9);
751 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA);
752 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB);
753 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC);
754 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD);
755 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE);
756 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF);
757#endif // N0 > 8
758
759#elif K0 == 4 // K0 == 4
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000760 // This part computes the following transpositions:
761 // 4x2 -> 2x4
762 // 4x4 -> 4x4
763 // 4x8 -> 8x4
764 // 4x16 -> 16x4
765 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0, a3.s0);
766 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1, a3.s1);
767#if N0 > 2
768 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2, a3.s2);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000769#endif // N0 > 2
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000770#if N0 > 3
771 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3, a3.s3);
772#endif // N0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000773#if N0 > 4
774 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4, a3.s4);
775 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5, a3.s5);
776 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6, a3.s6);
777 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7, a3.s7);
778#endif // N0 > 4
779#if N0 > 8
780 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8, a3.s8);
781 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9, a3.s9);
782 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA, a3.sA);
783 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB, a3.sB);
784 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC, a3.sC);
785 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD, a3.sD);
786 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE, a3.sE);
787 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF, a3.sF);
788#endif // N0 > 8
789
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000790#elif K0 == 8 // K0 == 8
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000791 // This part computes the following transpositions:
792 // 8x2 -> 2x8
793 // 8x4 -> 4x8
794 // 8x8 -> 8x8
795 // 8x16 -> 16x8
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000796 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0, a3.s0, a4.s0, a5.s0, a6.s0, a7.s0);
797 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1, a3.s1, a4.s1, a5.s1, a6.s1, a7.s1);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000798#if N0 > 2
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000799 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2, a3.s2, a4.s2, a5.s2, a6.s2, a7.s2);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000800#endif // N0 > 2
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000801#if N0 > 3
802 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3, a3.s3, a4.s3, a5.s3, a6.s3, a7.s3);
803#endif // N0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000804#if N0 > 4
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000805 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4, a3.s4, a4.s4, a5.s4, a6.s4, a7.s4);
806 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5, a3.s5, a4.s5, a5.s5, a6.s5, a7.s5);
807 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6, a3.s6, a4.s6, a5.s6, a6.s6, a7.s6);
808 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7, a3.s7, a4.s7, a5.s7, a6.s7, a7.s7);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000809#endif // N0 > 4
810#if N0 > 8
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000811 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8, a3.s8, a4.s8, a5.s8, a6.s8, a7.s8);
812 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9, a3.s9, a4.s9, a5.s9, a6.s9, a7.s9);
813 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA, a3.sA, a4.sA, a5.sA, a6.sA, a7.sA);
814 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB, a3.sB, a4.sB, a5.sB, a6.sB, a7.sB);
815 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC, a3.sC, a4.sC, a5.sC, a6.sC, a7.sC);
816 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD, a3.sD, a4.sD, a5.sD, a6.sD, a7.sD);
817 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE, a3.sE, a4.sE, a5.sE, a6.sE, a7.sE);
818 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF, a3.sF, a4.sF, a5.sF, a6.sF, a7.sF);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000819#endif // N0 > 8
820
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000821#elif K0 == 16 // K0 == 16
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000822
823 // This part computes the following transpositions:
824 // 16x2 -> 2x16
825 // 16x4 -> 4x16
826 // 16x8 -> 8x16
827 // 16x16 -> 16x16
828 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0, a3.s0, a4.s0, a5.s0, a6.s0, a7.s0,
829 a8.s0, a9.s0, aA.s0, aB.s0, aC.s0, aD.s0, aE.s0, aF.s0);
830 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1, a3.s1, a4.s1, a5.s1, a6.s1, a7.s1,
831 a8.s1, a9.s1, aA.s1, aB.s1, aC.s1, aD.s1, aE.s1, aF.s1);
832#if N0 > 2
833 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2, a3.s2, a4.s2, a5.s2, a6.s2, a7.s2,
834 a8.s2, a9.s2, aA.s2, aB.s2, aC.s2, aD.s2, aE.s2, aF.s2);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000835#endif // N0 > 2
836#if N0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000837 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3, a3.s3, a4.s3, a5.s3, a6.s3, a7.s3,
838 a8.s3, a9.s3, aA.s3, aB.s3, aC.s3, aD.s3, aE.s3, aF.s3);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000839#endif // N0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000840#if N0 > 4
841 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4, a3.s4, a4.s4, a5.s4, a6.s4, a7.s4,
842 a8.s4, a9.s4, aA.s4, aB.s4, aC.s4, aD.s4, aE.s4, aF.s4);
843 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5, a3.s5, a4.s5, a5.s5, a6.s5, a7.s5,
844 a8.s5, a9.s5, aA.s5, aB.s5, aC.s5, aD.s5, aE.s5, aF.s5);
845 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6, a3.s6, a4.s6, a5.s6, a6.s6, a7.s6,
846 a8.s6, a9.s6, aA.s6, aB.s6, aC.s6, aD.s6, aE.s6, aF.s6);
847 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7, a3.s7, a4.s7, a5.s7, a6.s7, a7.s7,
848 a8.s7, a9.s7, aA.s7, aB.s7, aC.s7, aD.s7, aE.s7, aF.s7);
849#endif // N0 > 4
850#if N0 > 8
851 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8, a3.s8, a4.s8, a5.s8, a6.s8, a7.s8,
852 a8.s8, a9.s8, aA.s8, aB.s8, aC.s8, aD.s8, aE.s8, aF.s8);
853 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9, a3.s9, a4.s9, a5.s9, a6.s9, a7.s9,
854 a8.s9, a9.s9, aA.s9, aB.s9, aC.s9, aD.s9, aE.s9, aF.s9);
855 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA, a3.sA, a4.sA, a5.sA, a6.sA, a7.sA,
856 a8.sA, a9.sA, aA.sA, aB.sA, aC.sA, aD.sA, aE.sA, aF.sA);
857 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB, a3.sB, a4.sB, a5.sB, a6.sB, a7.sB,
858 a8.sB, a9.sB, aA.sB, aB.sB, aC.sB, aD.sB, aE.sB, aF.sB);
859 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC, a3.sC, a4.sC, a5.sC, a6.sC, a7.sC,
860 a8.sC, a9.sC, aA.sC, aB.sC, aC.sC, aD.sC, aE.sC, aF.sC);
861 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD, a3.sD, a4.sD, a5.sD, a6.sD, a7.sD,
862 a8.sD, a9.sD, aA.sD, aB.sD, aC.sD, aD.sD, aE.sD, aF.sD);
863 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE, a3.sE, a4.sE, a5.sE, a6.sE, a7.sE,
864 a8.sE, a9.sE, aA.sE, aB.sE, aC.sE, aD.sE, aE.sE, aF.sE);
865 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF, a3.sF, a4.sF, a5.sF, a6.sF, a7.sF,
866 a8.sF, a9.sF, aA.sF, aB.sF, aC.sF, aD.sF, aE.sF, aF.sF);
867#endif // N0 > 8
868
869#else // N0 == 16
870#error "Not supported N0 value"
871#endif // N0 > 2
872
873 // ---------------------------Store the output values ------------------------------
Usama Arif0681e3b2019-04-25 14:28:07 +0100874 REPEAT_VAR_INIT_TO_CONST(16, uint, zout, 0);
875 STORE_BLOCK(N0, K0, DATA_TYPE, res, output_ptr, OUTPUT_STEP_X * sizeof(DATA_TYPE), zout);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000876
877#undef BLOCK_SIZE
878#undef OUTPUT_OFFSET_X
879#undef OUTPUT_STEP_X
880}
881#endif // defined(TRANSPOSE)
882#endif // defined(K0) && defined(N0) && defined(H0) && defined(DATA_TYPE) && defined(SRC_HEIGHT)
883
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +0000884#if defined(M0) && defined(N0) && defined(K0) && defined(H0) && defined(DATA_TYPE) && defined(M) && defined(N) && defined(K)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +0000885
886#define CONCAT(a, b) a##b
887
888#define ARM_DOT1(a, b, c) \
889 ({ \
890 c = fma(a, b, c); \
891 })
892#define ARM_DOT2(a, b, c) \
893 ({ \
894 c = fma(a.s0, b.s0, c); \
895 c = fma(a.s1, b.s1, c); \
896 })
897#define ARM_DOT3(a, b, c) \
898 ({ \
899 ARM_DOT2(a, b, c); \
900 c = fma((a.s2), (b.s2), c); \
901 })
902#define ARM_DOT4(a, b, c) \
903 ({ \
904 ARM_DOT3(a, b, c); \
905 c = fma((a.s3), (b.s3), c); \
906 })
907#define ARM_DOT8(a, b, c) \
908 ({ \
909 ARM_DOT4((a.lo), (b.lo), c); \
910 ARM_DOT4((a.hi), (b.hi), c); \
911 })
912#define ARM_DOT16(a, b, c) \
913 ({ \
914 ARM_DOT8((a.lo), (b.lo), c); \
915 ARM_DOT8((a.hi), (b.hi), c); \
916 })
917
918#if N0 == 2
919#define ARM_DOT_K0XN0(k0, a, b, c) \
920 ({ \
921 CONCAT(ARM_DOT, k0) \
922 ((a), (b##0), (c.s0)); \
923 CONCAT(ARM_DOT, k0) \
924 ((a), (b##1), (c.s1)); \
925 })
926#elif N0 == 3 // N0 == 3
927#define ARM_DOT_K0XN0(k0, a, b, c) \
928 ({ \
929 CONCAT(ARM_DOT, k0) \
930 ((a), (b##0), (c.s0)); \
931 CONCAT(ARM_DOT, k0) \
932 ((a), (b##1), (c.s1)); \
933 CONCAT(ARM_DOT, k0) \
934 ((a), (b##2), (c.s2)); \
935 })
936#elif N0 == 4 // N0 == 4
937#define ARM_DOT_K0XN0(k0, a, b, c) \
938 ({ \
939 CONCAT(ARM_DOT, k0) \
940 ((a), (b##0), (c.s0)); \
941 CONCAT(ARM_DOT, k0) \
942 ((a), (b##1), (c.s1)); \
943 CONCAT(ARM_DOT, k0) \
944 ((a), (b##2), (c.s2)); \
945 CONCAT(ARM_DOT, k0) \
946 ((a), (b##3), (c.s3)); \
947 })
948#elif N0 == 8 // N0 == 8
949#define ARM_DOT_K0XN0(k0, a, b, c) \
950 ({ \
951 CONCAT(ARM_DOT, k0) \
952 ((a), (b##0), (c.s0)); \
953 CONCAT(ARM_DOT, k0) \
954 ((a), (b##1), (c.s1)); \
955 CONCAT(ARM_DOT, k0) \
956 ((a), (b##2), (c.s2)); \
957 CONCAT(ARM_DOT, k0) \
958 ((a), (b##3), (c.s3)); \
959 CONCAT(ARM_DOT, k0) \
960 ((a), (b##4), (c.s4)); \
961 CONCAT(ARM_DOT, k0) \
962 ((a), (b##5), (c.s5)); \
963 CONCAT(ARM_DOT, k0) \
964 ((a), (b##6), (c.s6)); \
965 CONCAT(ARM_DOT, k0) \
966 ((a), (b##7), (c.s7)); \
967 })
968#elif N0 == 16 // N0 == 16
969#define ARM_DOT_K0XN0(k0, a, b, c) \
970 ({ \
971 CONCAT(ARM_DOT, k0) \
972 ((a), (b##0), (c.s0)); \
973 CONCAT(ARM_DOT, k0) \
974 ((a), (b##1), (c.s1)); \
975 CONCAT(ARM_DOT, k0) \
976 ((a), (b##2), (c.s2)); \
977 CONCAT(ARM_DOT, k0) \
978 ((a), (b##3), (c.s3)); \
979 CONCAT(ARM_DOT, k0) \
980 ((a), (b##4), (c.s4)); \
981 CONCAT(ARM_DOT, k0) \
982 ((a), (b##5), (c.s5)); \
983 CONCAT(ARM_DOT, k0) \
984 ((a), (b##6), (c.s6)); \
985 CONCAT(ARM_DOT, k0) \
986 ((a), (b##7), (c.s7)); \
987 CONCAT(ARM_DOT, k0) \
988 ((a), (b##8), (c.s8)); \
989 CONCAT(ARM_DOT, k0) \
990 ((a), (b##9), (c.s9)); \
991 CONCAT(ARM_DOT, k0) \
992 ((a), (b##A), (c.sA)); \
993 CONCAT(ARM_DOT, k0) \
994 ((a), (b##B), (c.sB)); \
995 CONCAT(ARM_DOT, k0) \
996 ((a), (b##C), (c.sC)); \
997 CONCAT(ARM_DOT, k0) \
998 ((a), (b##D), (c.sD)); \
999 CONCAT(ARM_DOT, k0) \
1000 ((a), (b##E), (c.sE)); \
1001 CONCAT(ARM_DOT, k0) \
1002 ((a), (b##F), (c.sF)); \
1003 })
1004#else // N0 not supported
1005#error "N0 value not supported"
1006#endif // N0 conditions
1007
1008/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
1009 * The LHS matrix is NOT reshaped
1010 * The RHS is reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the block K0xN0 is transposed
1011 *
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001012 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
1013 * @note The GEMM's dimensions (M,N and K) must be passed at compile time using -DM, -DN and and -DK (i.e. -DM=52, -DN=30 and -DK=90)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001014 * @note The number of columns of LHS matrix must be passed at compile time using -DK (i.e. -DK=64)
1015 * @note The block's dimensions used for reshaping the RHS matrix (N0 and K0) must be passed at compile time using -DN0 and -DK0 (i.e. -DN0=8, -DK0=4).
1016 * @note The number of M0 rows to process must be passed at compile time using -DM0 (i.e. -DM0=2)
1017 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (i.e. -DH0=2)
1018 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
1019 * @note Only the following configurations of M0, N0 and K0 are currently supported:
1020 * - M0 = 1, 2, 3, 4, 5, 6, 7, 8
1021 * - N0 = 2, 3, 4, 8, 16
1022 * - K0 = 2, 3, 4, 8, 16
Gian Marco Iodice62251f72019-03-11 16:07:12 +00001023 * - H0 >= 1
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001024 *
1025 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
1026 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
1027 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1028 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1029 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1030 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix
1031 *
1032 * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: F16/F32
1033 * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes)
1034 * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1035 * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes)
1036 * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1037 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
1038 * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
1039 * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes)
1040 * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1041 * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
1042 * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1043 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
1044 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
1045 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1046 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1047 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1048 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1049 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1050 * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes)
1051 * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
1052 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1053 * @param[in] lhs_cross_plane_pad (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
1054 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
1055 */
1056__kernel void gemm_mm_reshaped_only_rhs_t(IMAGE_DECLARATION(lhs),
1057 IMAGE_DECLARATION(rhs),
1058 IMAGE_DECLARATION(dst),
1059 uint lhs_stride_z,
1060 uint rhs_stride_z,
1061 uint dst_stride_z
1062#if defined(REINTERPRET_INPUT_AS_3D)
1063 ,
1064 uint lhs_cross_plane_pad
1065#endif // REINTERPRET_INPUT_AS_3D
1066#if defined(REINTERPRET_OUTPUT_AS_3D)
1067 ,
1068 uint dst_cross_plane_pad
1069#endif // REINTERPRET_OUTPUT_AS_3D
1070 )
1071{
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001072 // Block size
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001073#define RHS_BLOCK_SIZE ((K0) * (N0))
1074
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001075 // RHS offset and step X
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001076#if defined(RHS_INTERLEAVE)
1077#define RHS_OFFSET_X (K0)
1078#define RHS_STEP_X ((K0) * (H0))
1079#define RHS_STEP_LOOP (1)
1080#else // defined(RHS_INTERLEAVE)
1081#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
1082#define RHS_STEP_X (K0)
1083#define RHS_STEP_LOOP (H0)
1084#endif // defined(RHS_INTERLEAVE)
1085
1086 uint x = get_global_id(0);
1087 uint y = get_global_id(1);
1088 uint z = get_global_id(2);
1089
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001090#if defined(DUMMY_WORK_ITEMS)
1091 if((x * N0 >= N) || (y * M0 >= M))
1092 {
1093 return;
1094 }
1095#endif // defined(DUMMY_WORK_ITEMS)
1096
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001097 // Compute LHS matrix address
1098 uint lhs_offset = lhs_offset_first_element_in_bytes + y * M0 * (uint)lhs_stride_y;
1099
1100 // Compute RHS matrix address
1101 uint rhs_offset = rhs_offset_first_element_in_bytes + (x % H0) * (uint)RHS_OFFSET_X * sizeof(DATA_TYPE) + (x / (uint)H0) * rhs_stride_y;
1102
1103#if defined(MATRIX_B_DEPTH)
1104 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1105 rhs_offset += (z % MATRIX_B_DEPTH) * rhs_stride_z;
1106#else // defined(MATRIX_B_DEPTH)
1107 rhs_offset += z * rhs_stride_z;
1108#endif // defined(MATRIX_B_DEPTH)
1109
Usama Arif0681e3b2019-04-25 14:28:07 +01001110 REPEAT_VAR_INIT_TO_CONST(8, uint, zlhs, 0); //uint zlhs0=0,zlhs1=0,zlhs2=0,... zlhs7=0;
1111 REPEAT_VAR_INIT_TO_CONST(16, uint, zrhs, 0);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001112
1113#if defined(REINTERPRET_INPUT_AS_3D)
Usama Arif0681e3b2019-04-25 14:28:07 +01001114 // The plane (zlhs) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
1115 CALCULATE_Z_OFFSET(M0, uint, zlhs, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001116
1117 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1118 // multiply lhs_stride_z by DEPTH_GEMM3D
1119 lhs_offset += z * lhs_stride_z * DEPTH_GEMM3D;
1120
1121#else // defined(REINTERPRET_INPUT_AS_3D)
1122
1123 // Add offset for batched GEMM
1124 lhs_offset += z * lhs_stride_z;
1125
1126#endif // defined(REINTERPRET_INPUT_AS_3D)
1127
1128 // Initialize the accumulators
1129 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(M0-1)=0;
1130
1131 int i = 0;
1132 for(; i <= (K - K0); i += K0)
1133 {
1134 // Supported cases (M0, K0):
1135 // 1,2 - 1,3 - 1,4 - 1,8 - 1,16
1136 // 2,2 - 2,3 - 2,4 - 2,8 - 2,16
1137 // 3,2 - 3,3 - 3,4 - 3,8 - 3,16
1138 // 4,2 - 4,3 - 4,4 - 4,8 - 4,16
1139 // 5,2 - 5,3 - 5,4 - 5,8 - 5,16
1140 // 6,2 - 6,3 - 6,4 - 6,8 - 6,16
1141 // 7,2 - 7,3 - 7,4 - 7,8 - 7,16
1142 // 8,2 - 8,3 - 8,4 - 8,8 - 8,16
1143 // Load values from LHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +01001144 LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001145
1146 // Load values from RHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +01001147 LOAD_BLOCK(N0, K0, DATA_TYPE, b, rhs_ptr, rhs_offset, RHS_STEP_X * sizeof(DATA_TYPE), zrhs);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001148
1149 // Accumulate
1150 ARM_DOT_K0XN0(K0, a0, b, c0);
1151#if M0 > 1
1152 ARM_DOT_K0XN0(K0, a1, b, c1);
1153#endif // M0 > 1
1154#if M0 > 2
1155 ARM_DOT_K0XN0(K0, a2, b, c2);
1156#endif // M0 > 2
1157#if M0 > 3
1158 ARM_DOT_K0XN0(K0, a3, b, c3);
1159#endif // M0 > 3
1160#if M0 > 4
1161 ARM_DOT_K0XN0(K0, a4, b, c4);
1162#endif // M0 > 4
1163#if M0 > 5
1164 ARM_DOT_K0XN0(K0, a5, b, c5);
1165#endif // M0 > 5
1166#if M0 > 6
1167 ARM_DOT_K0XN0(K0, a6, b, c6);
1168#endif // M0 > 6
1169#if M0 > 7
1170 ARM_DOT_K0XN0(K0, a7, b, c7);
1171#endif // M0 > 7
1172
1173 lhs_offset += K0 * sizeof(DATA_TYPE);
1174 rhs_offset += (N0 * RHS_STEP_X * RHS_STEP_LOOP) * sizeof(DATA_TYPE);
1175 }
1176
1177 // Left-over accumulations
1178 for(; i < K; ++i)
1179 {
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001180 // Load values from LHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +01001181 LOAD_BLOCK(M0, 1, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001182
1183 // Load values from RHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +01001184 LOAD_BLOCK(N0, 1, DATA_TYPE, b, rhs_ptr, rhs_offset, RHS_STEP_X * sizeof(DATA_TYPE), zrhs);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001185
1186 // Accumulate
1187 ARM_DOT_K0XN0(1, a0, b, c0);
1188#if M0 > 1
1189 ARM_DOT_K0XN0(1, a1, b, c1);
1190#endif // M0 > 1
1191#if M0 > 2
1192 ARM_DOT_K0XN0(1, a2, b, c2);
1193#endif // M0 > 2
1194#if M0 > 3
1195 ARM_DOT_K0XN0(1, a3, b, c3);
1196#endif // M0 > 3
1197#if M0 > 4
1198 ARM_DOT_K0XN0(1, a4, b, c4);
1199#endif // M0 > 4
1200#if M0 > 5
1201 ARM_DOT_K0XN0(1, a5, b, c5);
1202#endif // M0 > 5
1203#if M0 > 6
1204 ARM_DOT_K0XN0(1, a6, b, c6);
1205#endif // M0 > 6
1206#if M0 > 7
1207 ARM_DOT_K0XN0(1, a7, b, c7);
1208#endif // M0 > 7
1209
1210 lhs_offset += sizeof(DATA_TYPE);
1211 rhs_offset += sizeof(DATA_TYPE);
1212 }
1213
1214 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * dst_stride_y);
1215
1216 REPEAT_VAR_INIT_TO_CONST(8, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
1217
1218#if defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001219
1220 // The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
Usama Arif0681e3b2019-04-25 14:28:07 +01001221 CALCULATE_Z_OFFSET(M0, uint, zout, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001222
1223 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1224 // multiply dst_stride_z by DEPTH_GEMM3D
1225 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1226
1227#else // defined(REINTERPRET_OUTPUT_AS_3D)
1228
1229 // Add offset for batched GEMM
1230 dst_addr += z * dst_stride_z;
1231
1232#endif // defined(REINTERPRET_OUTPUT_AS_3D)
1233
1234 // Multiply by the weight of matrix-matrix product and store the result
1235#if defined(ALPHA)
Usama Arif0681e3b2019-04-25 14:28:07 +01001236 SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001237#endif // defined(ALPHA)
1238
1239 // Store output block
Usama Arif0681e3b2019-04-25 14:28:07 +01001240 STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001241
1242#undef RHS_BLOCK_SIZE
1243#undef RHS_OFFSET_X
1244#undef RHS_STEP_X
1245}
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001246
1247#define VFMA(a, b, c) \
1248 ({ \
1249 c = fma(a, b, c); \
1250 })
1251
1252#if M0 == 1
1253#define LD_RHS_VFMA_M0xN0(i, a, c) \
1254 ({ \
1255 VEC_DATA_TYPE(DATA_TYPE, N0) \
1256 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1257 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1258 })
1259#elif M0 == 2 // M0 == 2
1260#define LD_RHS_VFMA_M0xN0(i, a, c) \
1261 ({ \
1262 VEC_DATA_TYPE(DATA_TYPE, N0) \
1263 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1264 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1265 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1266 })
1267#elif M0 == 3 // M0 == 3
1268#define LD_RHS_VFMA_M0xN0(i, a, c) \
1269 ({ \
1270 VEC_DATA_TYPE(DATA_TYPE, N0) \
1271 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1272 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1273 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1274 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1275 })
1276#elif M0 == 4 // M0 == 4
1277#define LD_RHS_VFMA_M0xN0(i, a, c) \
1278 ({ \
1279 VEC_DATA_TYPE(DATA_TYPE, N0) \
1280 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1281 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1282 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1283 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1284 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1285 })
1286#elif M0 == 5 // M0 == 5
1287#define LD_RHS_VFMA_M0xN0(i, a, c) \
1288 ({ \
1289 VEC_DATA_TYPE(DATA_TYPE, N0) \
1290 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1291 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1292 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1293 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1294 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1295 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
1296 })
1297#elif M0 == 6 // M0 == 6
1298#define LD_RHS_VFMA_M0xN0(i, a, c) \
1299 ({ \
1300 VEC_DATA_TYPE(DATA_TYPE, N0) \
1301 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1302 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1303 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1304 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1305 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1306 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
1307 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
1308 })
1309#elif M0 == 7 // M0 == 7
1310#define LD_RHS_VFMA_M0xN0(i, a, c) \
1311 ({ \
1312 VEC_DATA_TYPE(DATA_TYPE, N0) \
1313 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1314 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1315 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1316 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1317 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1318 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
1319 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
1320 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##6).s##i), b, (c##6)); \
1321 })
1322#elif M0 == 8 // M0 == 8
1323#define LD_RHS_VFMA_M0xN0(i, a, c) \
1324 ({ \
1325 VEC_DATA_TYPE(DATA_TYPE, N0) \
1326 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1327 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1328 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1329 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1330 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1331 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
1332 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
1333 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##6).s##i), b, (c##6)); \
1334 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##7).s##i), b, (c##7)); \
1335 })
1336#else // M0 not supported
1337#error "M0 not supported"
1338#endif // M0 not supported
1339
1340/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
1341 * The LHS matrix is NOT reshaped
1342 * The RHS is reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the block K0xN0 is NOT transposed
1343 *
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001344 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
1345 * @note The GEMM's dimensions (M,N and K) must be passed at compile time using -DM, -DN and and -DK (i.e. -DM=52, -DN=30 and -DK=90).
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001346 * @note The block's dimensions used for reshaping the RHS matrix (N0 and K0) must be passed at compile time using -DN0 and -DK0 (i.e. -DN0=8, -DK0=4).
1347 * @note The number of M0 rows to process must be passed at compile time using -DM0 (i.e. -DM0=2)
1348 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (i.e. -DH0=2)
1349 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
1350 * @note Only the following configurations of M0, N0 and K0 are currently supported:
1351 * - M0 = 1, 2, 3, 4, 5, 6, 7, 8
1352 * - N0 = 2, 3, 4, 8, 16
1353 * - K0 = 2, 3, 4, 8, 16
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001354 * - H0 >= 1
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001355 *
1356 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
1357 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
1358 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1359 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1360 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1361 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix
1362 *
1363 * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: F16/F32
1364 * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes)
1365 * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1366 * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes)
1367 * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1368 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
1369 * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
1370 * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes)
1371 * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1372 * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
1373 * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1374 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
1375 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
1376 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1377 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1378 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1379 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1380 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1381 * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes)
1382 * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
1383 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1384 * @param[in] lhs_cross_plane_pad (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
1385 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
1386 */
1387__kernel void gemm_mm_reshaped_only_rhs_nt(IMAGE_DECLARATION(lhs),
1388 IMAGE_DECLARATION(rhs),
1389 IMAGE_DECLARATION(dst),
1390 uint lhs_stride_z,
1391 uint rhs_stride_z,
1392 uint dst_stride_z
1393#if defined(REINTERPRET_INPUT_AS_3D)
1394 ,
1395 uint lhs_cross_plane_pad
1396#endif // REINTERPRET_INPUT_AS_3D
1397#if defined(REINTERPRET_OUTPUT_AS_3D)
1398 ,
1399 uint dst_cross_plane_pad
1400#endif // REINTERPRET_OUTPUT_AS_3D
1401 )
1402{
1403 // Block size
1404#define RHS_BLOCK_SIZE ((K0) * (N0))
1405
1406 // RHS offset and step X
1407#if defined(RHS_INTERLEAVE)
1408#define RHS_OFFSET_X (N0)
1409#define RHS_STEP_X ((N0) * (H0))
1410#define RHS_STEP_LOOP (1)
1411#else // defined(RHS_INTERLEAVE)
1412#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
1413#define RHS_STEP_X (N0)
1414#define RHS_STEP_LOOP (H0)
1415#endif // defined(RHS_INTERLEAVE)
1416
1417 uint x = get_global_id(0);
1418 uint y = get_global_id(1);
1419 uint z = get_global_id(2);
1420
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001421#if defined(DUMMY_WORK_ITEMS)
1422 if((x * N0 >= N) || (y * M0 >= M))
1423 {
1424 return;
1425 }
1426#endif // defined(DUMMY_WORK_ITEMS)
1427
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001428 // Compute LHS matrix address
1429 uint lhs_offset = lhs_offset_first_element_in_bytes + y * M0 * (uint)lhs_stride_y;
1430
1431 // Compute RHS matrix address
1432 uint rhs_offset = rhs_offset_first_element_in_bytes + (x % H0) * (uint)RHS_OFFSET_X * sizeof(DATA_TYPE) + (x / (uint)H0) * rhs_stride_y;
1433
1434#if defined(MATRIX_B_DEPTH)
1435 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1436 rhs_offset += (z % MATRIX_B_DEPTH) * rhs_stride_z;
1437#else // defined(MATRIX_B_DEPTH)
1438 rhs_offset += z * rhs_stride_z;
1439#endif // defined(MATRIX_B_DEPTH)
1440
1441 REPEAT_VAR_INIT_TO_CONST(8, uint, zin, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
1442
1443#if defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001444
1445 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
Usama Arif0681e3b2019-04-25 14:28:07 +01001446 CALCULATE_Z_OFFSET(M0, uint, zin, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001447
1448 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1449 // multiply lhs_stride_z by DEPTH_GEMM3D
1450 lhs_offset += z * lhs_stride_z * DEPTH_GEMM3D;
1451
1452#else // defined(REINTERPRET_INPUT_AS_3D)
1453
1454 // Add offset for batched GEMM
1455 lhs_offset += z * lhs_stride_z;
1456
1457#endif // defined(REINTERPRET_INPUT_AS_3D)
1458
1459 // Initialize the accumulators
1460 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(N0-1)=0;
1461
1462 int i = 0;
1463 for(; i <= (K - K0); i += K0)
1464 {
1465 // Supported cases (M0, K0):
1466 // 1,2 - 1,3 - 1,4 - 1,8 - 1,16
1467 // 2,2 - 2,3 - 2,4 - 2,8 - 2,16
1468 // 3,2 - 3,3 - 3,4 - 3,8 - 3,16
1469 // 4,2 - 4,3 - 4,4 - 4,8 - 4,16
1470 // 5,2 - 5,3 - 5,4 - 5,8 - 5,16
1471 // 6,2 - 6,3 - 6,4 - 6,8 - 6,16
1472 // 7,2 - 7,3 - 7,4 - 7,8 - 7,16
1473 // 8,2 - 8,3 - 8,4 - 8,8 - 8,16
1474 // Load values from LHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +01001475 LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zin);
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001476
1477 LD_RHS_VFMA_M0xN0(0, a, c);
1478 LD_RHS_VFMA_M0xN0(1, a, c);
1479#if K0 > 2
1480 LD_RHS_VFMA_M0xN0(2, a, c);
1481#endif // K0 > 2
1482#if K0 > 3
1483 LD_RHS_VFMA_M0xN0(3, a, c);
1484#endif // K0 > 3
1485#if K0 > 4
1486 LD_RHS_VFMA_M0xN0(4, a, c);
1487 LD_RHS_VFMA_M0xN0(5, a, c);
1488 LD_RHS_VFMA_M0xN0(6, a, c);
1489 LD_RHS_VFMA_M0xN0(7, a, c);
1490#endif // K0 > 4
1491#if K0 > 8
1492 LD_RHS_VFMA_M0xN0(8, a, c);
1493 LD_RHS_VFMA_M0xN0(9, a, c);
1494 LD_RHS_VFMA_M0xN0(A, a, c);
1495 LD_RHS_VFMA_M0xN0(B, a, c);
1496 LD_RHS_VFMA_M0xN0(C, a, c);
1497 LD_RHS_VFMA_M0xN0(D, a, c);
1498 LD_RHS_VFMA_M0xN0(E, a, c);
1499 LD_RHS_VFMA_M0xN0(F, a, c);
1500#endif // K0 > 8
1501
1502 lhs_offset += K0 * sizeof(DATA_TYPE);
1503 rhs_offset += K0 * RHS_STEP_X * RHS_STEP_LOOP * sizeof(DATA_TYPE);
1504 }
1505
1506 // Left-over accumulations
1507 for(; i < K; ++i)
1508 {
1509 // Load values from LHS matrix
1510 VEC_DATA_TYPE(DATA_TYPE, 2)
1511 a0 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 0 * lhs_stride_y + zin0));
1512#if M0 > 1
1513 VEC_DATA_TYPE(DATA_TYPE, 2)
1514 a1 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 1 * lhs_stride_y + zin1));
1515#endif // M0 > 1
1516#if M0 > 2
1517 VEC_DATA_TYPE(DATA_TYPE, 2)
1518 a2 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 2 * lhs_stride_y + zin2));
1519#endif // M0 > 2
1520#if M0 > 3
1521 VEC_DATA_TYPE(DATA_TYPE, 2)
1522 a3 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 3 * lhs_stride_y + zin3));
1523#endif // M0 > 3
1524#if M0 > 4
1525 VEC_DATA_TYPE(DATA_TYPE, 2)
1526 a4 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 4 * lhs_stride_y + zin4));
1527#endif // M0 > 4
1528#if M0 > 5
1529 VEC_DATA_TYPE(DATA_TYPE, 2)
1530 a5 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 5 * lhs_stride_y + zin5));
1531#endif // M0 > 5
1532#if M0 > 6
1533 VEC_DATA_TYPE(DATA_TYPE, 2)
1534 a6 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 6 * lhs_stride_y + zin6));
1535#endif // M0 > 6
1536#if M0 > 7
1537 VEC_DATA_TYPE(DATA_TYPE, 2)
giuros01b3204e72019-04-01 13:50:22 +01001538 a7 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 7 * lhs_stride_y + zin7));
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001539#endif // M0 > 7
1540
1541 LD_RHS_VFMA_M0xN0(0, a, c);
1542
1543 lhs_offset += sizeof(DATA_TYPE);
1544 rhs_offset += RHS_STEP_X * sizeof(DATA_TYPE);
1545 }
1546
1547 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * dst_stride_y);
1548
1549 REPEAT_VAR_INIT_TO_CONST(8, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
1550
1551#if defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001552 // The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
Usama Arif0681e3b2019-04-25 14:28:07 +01001553 CALCULATE_Z_OFFSET(M0, uint, zout, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001554
1555 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1556 // multiply dst_stride_z by DEPTH_GEMM3D
1557 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1558
1559#else // defined(REINTERPRET_OUTPUT_AS_3D)
1560
1561 // Add offset for batched GEMM
1562 dst_addr += z * dst_stride_z;
1563
1564#endif // defined(REINTERPRET_OUTPUT_AS_3D)
1565
1566 // Multiply by the weight of matrix-matrix product and store the result
1567#if defined(ALPHA)
Usama Arif0681e3b2019-04-25 14:28:07 +01001568 SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001569#endif // defined(ALPHA)
1570
1571 // Store output block
Usama Arif0681e3b2019-04-25 14:28:07 +01001572 STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001573
1574#undef RHS_BLOCK_SIZE
1575#undef RHS_OFFSET_X
1576#undef RHS_STEP_X
1577}
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001578#endif // defined(M0) && defined(N0) && defined(K0) && defined(H0) && defined(DATA_TYPE) && defined(M) && defined(N) && defined(K)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001579
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001580#if defined(M0) && defined(N0) && defined(K0) && defined(V0) && defined(H0) && defined(DATA_TYPE) && defined(M) && defined(N)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001581
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001582#if K0 == 2
1583#define ARM_DOT_K0(a, b, c) \
1584 ({ \
1585 c = fma(a.s0, b.s0, c); \
1586 c = fma(a.s1, b.s1, c); \
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001587 })
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001588#elif K0 == 3 // K0 == 3
1589#define ARM_DOT_K0(a, b, c) \
1590 ({ \
1591 c = fma(a.s0, b.s0, c); \
1592 c = fma(a.s1, b.s1, c); \
1593 c = fma(a.s2, b.s2, c); \
1594 })
1595#elif K0 == 4 // K0 == 4
1596#define ARM_DOT_K0(a, b, c) \
1597 ({ \
1598 c = fma(a.s0, b.s0, c); \
1599 c = fma(a.s1, b.s1, c); \
1600 c = fma(a.s2, b.s2, c); \
1601 c = fma(a.s3, b.s3, c); \
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001602 })
1603#elif K0 == 8 // K0 == 8
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001604#define ARM_DOT_K0(a, b, c) \
1605 ({ \
1606 c = fma(a.s0, b.s0, c); \
1607 c = fma(a.s1, b.s1, c); \
1608 c = fma(a.s2, b.s2, c); \
1609 c = fma(a.s3, b.s3, c); \
1610 c = fma(a.s4, b.s4, c); \
1611 c = fma(a.s5, b.s5, c); \
1612 c = fma(a.s6, b.s6, c); \
1613 c = fma(a.s7, b.s7, c); \
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001614 })
1615#elif K0 == 16 // K0 == 16
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001616#define ARM_DOT_K0(a, b, c) \
1617 ({ \
1618 c = fma(a.s0, b.s0, c); \
1619 c = fma(a.s1, b.s1, c); \
1620 c = fma(a.s2, b.s2, c); \
1621 c = fma(a.s3, b.s3, c); \
1622 c = fma(a.s4, b.s4, c); \
1623 c = fma(a.s5, b.s5, c); \
1624 c = fma(a.s6, b.s6, c); \
1625 c = fma(a.s7, b.s7, c); \
1626 c = fma(a.s8, b.s8, c); \
1627 c = fma(a.s9, b.s9, c); \
1628 c = fma(a.sA, b.sA, c); \
1629 c = fma(a.sB, b.sB, c); \
1630 c = fma(a.sC, b.sC, c); \
1631 c = fma(a.sD, b.sD, c); \
1632 c = fma(a.sE, b.sE, c); \
1633 c = fma(a.sF, b.sF, c); \
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001634 })
1635#else // K0 not supported
1636#error "K0 value not supported"
1637#endif // K0 conditions
1638
1639#if N0 == 2
1640#define ARM_DOT_K0XN0(a, b, c) \
1641 ({ \
1642 ARM_DOT_K0((a), (b##0), (c.s0)); \
1643 ARM_DOT_K0((a), (b##1), (c.s1)); \
1644 })
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001645#elif N0 == 3 // N0 == 3
1646#define ARM_DOT_K0XN0(a, b, c) \
1647 ({ \
1648 ARM_DOT_K0((a), (b##0), (c.s0)); \
1649 ARM_DOT_K0((a), (b##1), (c.s1)); \
1650 ARM_DOT_K0((a), (b##2), (c.s2)); \
1651 })
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001652#elif N0 == 4 // N0 == 4
1653#define ARM_DOT_K0XN0(a, b, c) \
1654 ({ \
1655 ARM_DOT_K0((a), (b##0), (c.s0)); \
1656 ARM_DOT_K0((a), (b##1), (c.s1)); \
1657 ARM_DOT_K0((a), (b##2), (c.s2)); \
1658 ARM_DOT_K0((a), (b##3), (c.s3)); \
1659 })
1660#elif N0 == 8 // N0 == 8
1661#define ARM_DOT_K0XN0(a, b, c) \
1662 ({ \
1663 ARM_DOT_K0((a), (b##0), (c.s0)); \
1664 ARM_DOT_K0((a), (b##1), (c.s1)); \
1665 ARM_DOT_K0((a), (b##2), (c.s2)); \
1666 ARM_DOT_K0((a), (b##3), (c.s3)); \
1667 ARM_DOT_K0((a), (b##4), (c.s4)); \
1668 ARM_DOT_K0((a), (b##5), (c.s5)); \
1669 ARM_DOT_K0((a), (b##6), (c.s6)); \
1670 ARM_DOT_K0((a), (b##7), (c.s7)); \
1671 })
1672#elif N0 == 16 // N0 == 16
1673#define ARM_DOT_K0XN0(a, b, c) \
1674 ({ \
1675 ARM_DOT_K0((a), (b##0), (c.s0)); \
1676 ARM_DOT_K0((a), (b##1), (c.s1)); \
1677 ARM_DOT_K0((a), (b##2), (c.s2)); \
1678 ARM_DOT_K0((a), (b##3), (c.s3)); \
1679 ARM_DOT_K0((a), (b##4), (c.s4)); \
1680 ARM_DOT_K0((a), (b##5), (c.s5)); \
1681 ARM_DOT_K0((a), (b##6), (c.s6)); \
1682 ARM_DOT_K0((a), (b##7), (c.s7)); \
1683 ARM_DOT_K0((a), (b##8), (c.s8)); \
1684 ARM_DOT_K0((a), (b##9), (c.s9)); \
1685 ARM_DOT_K0((a), (b##A), (c.sA)); \
1686 ARM_DOT_K0((a), (b##B), (c.sB)); \
1687 ARM_DOT_K0((a), (b##C), (c.sC)); \
1688 ARM_DOT_K0((a), (b##D), (c.sD)); \
1689 ARM_DOT_K0((a), (b##E), (c.sE)); \
1690 ARM_DOT_K0((a), (b##F), (c.sF)); \
1691 })
1692#else // N0 not supported
1693#error "N0 value not supported"
1694#endif // N0 conditions
1695
1696/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
1697 * The LHS matrix must be reshaped with @ref CLGEMMReshapeLHSMatrixKernel and the M0xK0 must be NOT transposed
1698 * The RHS matrix must be reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the K0xN0 must be transposed
1699 *
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001700 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
1701 * @note The GEMM's dimensions M and N must be passed at compile time using -DM and -DN (i.e. -DM=52 and -DN=90).
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001702 * @note The block's dimensions used for reshaping the LHS matrix and the RHS matrix (M0, N0 and K0) must be passed at compile time using -DM0, -DN0 and -DK0 (i.e. -DM0=4, -DN0=8, -DK0=4).
1703 * @note The number of M0xK0 vertical blocks stored on the same output row of the reshaped LHS matrix must be passed at compile time using -DV0 (i.e. -DV0=2)
1704 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (i.e. -DH0=2)
1705 * @note If the M0xK0 blocks in the reshaped LHS matrix have been interleaved, the option -DLHS_INTERLEAVE must passed at compile time.
1706 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
1707 * @note Only the following configurations of M0, N0 and K0 are currently supported:
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001708 * - M0 = 1, 2, 3, 4, 5, 6, 7, 8
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001709 * - N0 = 2, 3, 4, 8, 16
1710 * - K0 = 2, 3, 4, 8, 16
Gian Marco Iodice62251f72019-03-11 16:07:12 +00001711 * - V0 >= 1
1712 * - H0 >= 1
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001713 *
1714 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
1715 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1716 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1717 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1718 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix NOT reshaped
1719 *
1720 * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: F16/F32
1721 * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes)
1722 * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1723 * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes)
1724 * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1725 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
Gian Marco Iodice49b10152018-12-14 17:13:34 +00001726 * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001727 * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes)
1728 * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1729 * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
1730 * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1731 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
Gian Marco Iodice49b10152018-12-14 17:13:34 +00001732 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001733 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1734 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1735 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1736 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1737 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001738 * @param[in] k Number of columns in LHS matrix and rows in RHS matrix not reshaped.
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001739 * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes)
1740 * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
1741 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1742 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
1743 */
1744__kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs),
1745 IMAGE_DECLARATION(rhs),
1746 IMAGE_DECLARATION(dst),
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001747 uint k,
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001748 uint lhs_stride_z,
1749 uint rhs_stride_z,
1750 uint dst_stride_z
1751#if defined(REINTERPRET_OUTPUT_AS_3D)
1752 ,
1753 uint dst_cross_plane_pad
1754#endif // REINTERPRET_OUTPUT_AS_3D
1755 )
1756{
1757 // Block size
1758#define LHS_BLOCK_SIZE ((K0) * (M0))
1759
1760#if defined(LHS_INTERLEAVE)
1761#define LHS_OFFSET_X (K0)
1762#define LHS_STEP_X ((K0) * (V0))
1763#define LHS_STEP_LOOP (1)
1764#else // defined(INTERLEAVE)
1765#define LHS_OFFSET_X (LHS_BLOCK_SIZE)
1766#define LHS_STEP_X (K0)
1767#define LHS_STEP_LOOP (V0)
1768#endif // defined(INTERLEAVE)
1769
1770 // Block size
1771#define RHS_BLOCK_SIZE ((K0) * (N0))
1772
1773 // RHS offset and step X
1774#if defined(RHS_INTERLEAVE)
1775#define RHS_OFFSET_X (K0)
1776#define RHS_STEP_X ((K0) * (H0))
1777#define RHS_STEP_LOOP (1)
1778#else // defined(RHS_INTERLEAVE)
1779#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
1780#define RHS_STEP_X (K0)
1781#define RHS_STEP_LOOP (H0)
1782#endif // defined(RHS_INTERLEAVE)
1783
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001784#if defined(DUMMY_WORK_ITEMS)
1785 if((get_global_id(0) * N0 >= N) || (get_global_id(1) * M0 >= M))
1786 {
1787 return;
1788 }
1789#endif // defined(DUMMY_WORK_ITEMS)
1790
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001791 // Compute LHS matrix address
1792 __global uchar *lhs_addr = lhs_ptr + lhs_offset_first_element_in_bytes + (get_global_id(1) % V0) * (uint)LHS_OFFSET_X * sizeof(DATA_TYPE) + (get_global_id(1) / V0) * (uint)lhs_stride_y +
1793 (get_global_id(2) * lhs_stride_z);
1794
1795 // Compute RHS matrix address
1796 __global uchar *rhs_addr = rhs_ptr + rhs_offset_first_element_in_bytes + (get_global_id(0) % H0) * (uint)RHS_OFFSET_X * sizeof(DATA_TYPE) + (get_global_id(0) / (uint)H0) * rhs_stride_y;
1797
1798#if defined(MATRIX_B_DEPTH)
1799 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1800 rhs_addr += (get_global_id(2) % MATRIX_B_DEPTH) * rhs_stride_z;
1801#else // defined(MATRIX_B_DEPTH)
1802 rhs_addr += get_global_id(2) * rhs_stride_z;
1803#endif // defined(MATRIX_B_DEPTH)
1804
1805 // Initialize the accumulators
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +00001806 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(M0-1)=0;
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001807
Usama Arif0681e3b2019-04-25 14:28:07 +01001808 REPEAT_VAR_INIT_TO_CONST(8, uint, zlhs, 0); //uint zlhs0=0,zlhs1=0,zlhs2=0,... zlhs7=0;
1809 REPEAT_VAR_INIT_TO_CONST(16, uint, zrhs, 0);
1810
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001811 for(int i = 0; i < k; i += K0)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001812 {
1813 // Supported cases (M0, K0):
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001814 // 1,2 - 1,3 - 1,4 - 1,8 - 1,16
1815 // 2,2 - 2,3 - 2,4 - 2,8 - 2,16
1816 // 3,2 - 3,3 - 3,4 - 3,8 - 3,16
1817 // 4,2 - 4,3 - 4,4 - 4,8 - 4,16
1818 // 5,2 - 5,3 - 5,4 - 5,8 - 5,16
1819 // 6,2 - 6,3 - 6,4 - 6,8 - 6,16
1820 // 7,2 - 7,3 - 7,4 - 7,8 - 7,16
1821 // 8,2 - 8,3 - 8,4 - 8,8 - 8,16
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001822 // Load values from LHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +01001823 LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_addr, 0, LHS_STEP_X * sizeof(DATA_TYPE), zlhs);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001824
1825 // Load values from RHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +01001826 LOAD_BLOCK(N0, K0, DATA_TYPE, b, rhs_addr, 0, RHS_STEP_X * sizeof(DATA_TYPE), zrhs);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001827
1828 // Accumulate
1829 ARM_DOT_K0XN0(a0, b, c0);
1830#if M0 > 1
1831 ARM_DOT_K0XN0(a1, b, c1);
1832#endif // M0 > 1
1833#if M0 > 2
1834 ARM_DOT_K0XN0(a2, b, c2);
1835#endif // M0 > 2
1836#if M0 > 3
1837 ARM_DOT_K0XN0(a3, b, c3);
1838#endif // M0 > 3
1839#if M0 > 4
1840 ARM_DOT_K0XN0(a4, b, c4);
1841#endif // M0 > 4
1842#if M0 > 5
1843 ARM_DOT_K0XN0(a5, b, c5);
1844#endif // M0 > 5
1845#if M0 > 6
1846 ARM_DOT_K0XN0(a6, b, c6);
1847#endif // M0 > 6
1848#if M0 > 7
1849 ARM_DOT_K0XN0(a7, b, c7);
1850#endif // M0 > 7
1851
1852 lhs_addr += (M0 * LHS_STEP_X * LHS_STEP_LOOP) * sizeof(DATA_TYPE);
1853 rhs_addr += (N0 * RHS_STEP_X * RHS_STEP_LOOP) * sizeof(DATA_TYPE);
1854 }
1855
1856 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * dst_stride_y);
1857
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +00001858 REPEAT_VAR_INIT_TO_CONST(8, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001859
1860#if defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001861
1862 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
Usama Arif0681e3b2019-04-25 14:28:07 +01001863 CALCULATE_Z_OFFSET(M0, uint, zout, get_global_id(1), HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001864 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1865 // multiply dst_stride_z by DEPTH_GEMM3D
1866 dst_addr += get_global_id(2) * dst_stride_z * DEPTH_GEMM3D;
1867
1868#else // defined(REINTERPRET_OUTPUT_AS_3D)
1869
1870 // Add offset for batched GEMM
1871 dst_addr += get_global_id(2) * dst_stride_z;
1872
1873#endif // defined(REINTERPRET_OUTPUT_AS_3D)
1874
1875 // Multiply by the weight of matrix-matrix product and store the result
1876#if defined(ALPHA)
Usama Arif0681e3b2019-04-25 14:28:07 +01001877 SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001878#endif // defined(ALPHA)
1879
1880 // Store output block
Usama Arif0681e3b2019-04-25 14:28:07 +01001881 STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001882#undef LHS_BLOCK_SIZE
1883#undef LHS_OFFSET_X
1884#undef LHS_STEP_X
1885#undef RHS_BLOCK_SIZE
1886#undef RHS_OFFSET_X
1887#undef RHS_STEP_X
1888}
giuros01b3204e72019-04-01 13:50:22 +01001889
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001890#endif // defined(M0) && defined(N0) && defined(K0) && defined(V0) && defined(H0) && defined(K) && defined(DATA_TYPE)
1891
giuros01b3204e72019-04-01 13:50:22 +01001892#if defined(M0) && defined(N0) && defined(K0) && defined(K) && defined(DATA_TYPE)
1893
1894#define VFMA(a, b, c) \
1895 ({ \
1896 c = fma(a, b, c); \
1897 })
1898
1899#if M0 == 1
1900#define RHS_VFMA_M0xN0(i, a, b, c) \
1901 ({ \
1902 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1903 })
1904#elif M0 == 2 // M0 == 2
1905#define RHS_VFMA_M0xN0(i, a, b, c) \
1906 ({ \
1907 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1908 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1909 })
1910#elif M0 == 3 // M0 == 3
1911#define RHS_VFMA_M0xN0(i, a, b, c) \
1912 ({ \
1913 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1914 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1915 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1916 })
1917#elif M0 == 4 // M0 == 4
1918#define RHS_VFMA_M0xN0(i, a, b, c) \
1919 ({ \
1920 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1921 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1922 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1923 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1924 })
1925#elif M0 == 5 // M0 == 5
1926#define RHS_VFMA_M0xN0(i, a, b, c) \
1927 ({ \
1928 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1929 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1930 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1931 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1932 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
1933 })
1934#elif M0 == 6 // M0 == 6
1935#define RHS_VFMA_M0xN0(i, a, b, c) \
1936 ({ \
1937 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1938 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1939 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1940 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1941 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
1942 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
1943 })
1944#elif M0 == 7 // M0 == 7
1945#define RHS_VFMA_M0xN0(i, a, b, c) \
1946 ({ \
1947 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1948 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1949 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1950 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1951 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
1952 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
1953 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##6).s##i), b, (c##6)); \
1954 })
1955#elif M0 == 8 // M0 == 8
1956#define RHS_VFMA_M0xN0(i, a, b, c) \
1957 ({ \
1958 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1959 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1960 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1961 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1962 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
1963 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
1964 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##6).s##i), b, (c##6)); \
1965 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##7).s##i), b, (c##7)); \
1966 })
1967#else // M0 not supported
1968#error "M0 not supported"
1969#endif // M0 not supported
1970
1971/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
1972 * The LHS matrix is NOT reshaped
1973 * The RHS matrix is NOT reshaped
1974 *
1975 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
1976 * @note The GEMM's dimensions (M,N and K) must be passed at compile time using -DM, -DN and and -DK (i.e. -DM=52, -DN=30 and -DK=90)
1977 * @note The number of columns of LHS matrix must be passed at compile time using -DK (i.e. -DK=64)
1978 * @note The number of M0 rows to process must be passed at compile time using -DM0 (i.e. -DM0=2)
1979 * @note The number of K0 partial accumulations must be passed at compile time using -DK0 (i.e., -DK0=2)
1980 * @note The number of N0 columns to process must be passed at compile time using -DN0 (i.e. -DN0=2)
1981 * @note Only the following configurations of M0, N0 and K0 are currently supported:
1982 * - M0 = 1, 2, 3, 4, 5, 6, 7, 8
1983 * - N0 = 2, 3, 4, 8, 16
1984 * - K0 = 2, 3, 4, 8, 16
1985 *
1986 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
1987 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
1988 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1989 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1990 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1991 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix
1992 *
1993 * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: F16/F32
1994 * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes)
1995 * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1996 * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes)
1997 * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1998 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
1999 * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
2000 * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes)
2001 * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2002 * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
2003 * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2004 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
2005 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
2006 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2007 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2008 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2009 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2010 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2011 * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes)
2012 * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
2013 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2014 * @param[in] lhs_cross_plane_pad (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
2015 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
2016 */
2017__kernel void gemm_mm_native(IMAGE_DECLARATION(lhs),
2018 IMAGE_DECLARATION(rhs),
2019 IMAGE_DECLARATION(dst),
2020 uint lhs_stride_z,
2021 uint rhs_stride_z,
2022 uint dst_stride_z
2023#if defined(REINTERPRET_INPUT_AS_3D)
2024 ,
2025 uint lhs_cross_plane_pad
2026#endif // REINTERPRET_INPUT_AS_3D
2027#if defined(REINTERPRET_OUTPUT_AS_3D)
2028 ,
2029 uint dst_cross_plane_pad
2030#endif // REINTERPRET_OUTPUT_AS_3D
2031 )
2032{
2033 // Block size
2034#define RHS_BLOCK_SIZE ((K0) * (N0))
2035
2036 // RHS offset and step X
2037#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
2038
2039 uint x = get_global_id(0);
2040 uint y = get_global_id(1);
2041 uint z = get_global_id(2);
2042
2043#if defined(DUMMY_WORK_ITEMS)
2044 if((x * N0 >= N) || (y * M0 >= M))
2045 {
2046 return;
2047 }
2048#endif // defined(DUMMY_WORK_ITEMS)
2049
2050 // Compute LHS matrix address
2051 uint lhs_offset = lhs_offset_first_element_in_bytes + y * M0 * (uint)lhs_stride_y;
2052
2053 // Compute RHS matrix address
2054 uint rhs_offset = rhs_offset_first_element_in_bytes + x * N0 * sizeof(DATA_TYPE);
2055
2056#if defined(MATRIX_B_DEPTH)
2057 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2058 rhs_offset += (z % MATRIX_B_DEPTH) * rhs_stride_z;
2059#else // defined(MATRIX_B_DEPTH)
2060 rhs_offset += z * rhs_stride_z;
2061#endif // defined(MATRIX_B_DEPTH)
2062
2063 REPEAT_VAR_INIT_TO_CONST(8, uint, zlhs, 0); //uint zlhs0=0,zlhs1=0,zlhs2=0,... zlhs7=0;
2064 REPEAT_VAR_INIT_TO_CONST(16, uint, zrhs, 0);
2065
2066#if defined(REINTERPRET_INPUT_AS_3D)
2067 // The plane (zlhs) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
2068 CALCULATE_Z_OFFSET(M0, uint, zlhs, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
2069
2070 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2071 // multiply lhs_stride_z by DEPTH_GEMM3D
2072 lhs_offset += z * lhs_stride_z * DEPTH_GEMM3D;
2073
2074#else // defined(REINTERPRET_INPUT_AS_3D)
2075
2076 // Add offset for batched GEMM
2077 lhs_offset += z * lhs_stride_z;
2078
2079#endif // defined(REINTERPRET_INPUT_AS_3D)
2080
2081 // Initialize the accumulators
2082 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(N0-1)=0;
2083
2084 int i = 0;
2085 for(; i <= (K - K0); i += K0)
2086 {
2087 // Supported cases (M0, K0):
2088 // 1,2 - 1,3 - 1,4 - 1,8 - 1,16
2089 // 2,2 - 2,3 - 2,4 - 2,8 - 2,16
2090 // 3,2 - 3,3 - 3,4 - 3,8 - 3,16
2091 // 4,2 - 4,3 - 4,4 - 4,8 - 4,16
2092 // 5,2 - 5,3 - 5,4 - 5,8 - 5,16
2093 // 6,2 - 6,3 - 6,4 - 6,8 - 6,16
2094 // 7,2 - 7,3 - 7,4 - 7,8 - 7,16
2095 // 8,2 - 8,3 - 8,4 - 8,8 - 8,16
2096 // Load values from LHS matrix
2097 LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs);
2098
2099 // Load values from RHS matrix
2100 LOAD_BLOCK(K0, N0, DATA_TYPE, b, rhs_ptr, rhs_offset, rhs_stride_y, zrhs);
2101
2102 RHS_VFMA_M0xN0(0, a, b0, c);
2103 RHS_VFMA_M0xN0(1, a, b1, c);
2104#if K0 > 2
2105 RHS_VFMA_M0xN0(2, a, b2, c);
2106#endif // K0 > 2
2107#if K0 > 3
2108 RHS_VFMA_M0xN0(3, a, b3, c);
2109#endif // K0 > 3
2110#if K0 > 4
2111 RHS_VFMA_M0xN0(4, a, b4, c);
2112 RHS_VFMA_M0xN0(5, a, b5, c);
2113 RHS_VFMA_M0xN0(6, a, b6, c);
2114 RHS_VFMA_M0xN0(7, a, b7, c);
2115#endif // K0 > 4
2116#if K0 > 8
2117 RHS_VFMA_M0xN0(8, a, b8, c);
2118 RHS_VFMA_M0xN0(9, a, b9, c);
2119 RHS_VFMA_M0xN0(A, a, b10, c);
2120 RHS_VFMA_M0xN0(B, a, b11, c);
2121 RHS_VFMA_M0xN0(C, a, b12, c);
2122 RHS_VFMA_M0xN0(D, a, b13, c);
2123 RHS_VFMA_M0xN0(E, a, b14, c);
2124 RHS_VFMA_M0xN0(F, a, b15, c);
2125#endif // K0 > 8
2126
2127 lhs_offset += K0 * sizeof(DATA_TYPE);
2128 rhs_offset += K0 * rhs_stride_y;
2129 }
2130
2131 // Left-over accumulations
2132 for(; i < K; ++i)
2133 {
2134 // Load values from LHS matrix
2135 VEC_DATA_TYPE(DATA_TYPE, 2)
2136 a0 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 0 * lhs_stride_y + zlhs0));
2137#if M0 > 1
2138 VEC_DATA_TYPE(DATA_TYPE, 2)
2139 a1 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 1 * lhs_stride_y + zlhs1));
2140#endif // M0 > 1
2141#if M0 > 2
2142 VEC_DATA_TYPE(DATA_TYPE, 2)
2143 a2 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 2 * lhs_stride_y + zlhs2));
2144#endif // M0 > 2
2145#if M0 > 3
2146 VEC_DATA_TYPE(DATA_TYPE, 2)
2147 a3 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 3 * lhs_stride_y + zlhs3));
2148#endif // M0 > 3
2149#if M0 > 4
2150 VEC_DATA_TYPE(DATA_TYPE, 2)
2151 a4 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 4 * lhs_stride_y + zlhs4));
2152#endif // M0 > 4
2153#if M0 > 5
2154 VEC_DATA_TYPE(DATA_TYPE, 2)
2155 a5 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 5 * lhs_stride_y + zlhs5));
2156#endif // M0 > 5
2157#if M0 > 6
2158 VEC_DATA_TYPE(DATA_TYPE, 2)
2159 a6 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 6 * lhs_stride_y + zlhs6));
2160#endif // M0 > 6
2161#if M0 > 7
2162 VEC_DATA_TYPE(DATA_TYPE, 2)
2163 a7 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 7 * lhs_stride_y + zlhs7));
2164#endif // M0 > 7
2165
2166 VEC_DATA_TYPE(DATA_TYPE, N0)
2167 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0 * rhs_stride_y));
2168 RHS_VFMA_M0xN0(0, a, b, c);
2169
2170 lhs_offset += sizeof(DATA_TYPE);
2171 rhs_offset += rhs_stride_y;
2172 }
2173
2174 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * dst_stride_y);
2175
2176 REPEAT_VAR_INIT_TO_CONST(8, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
2177
2178#if defined(REINTERPRET_OUTPUT_AS_3D)
2179 // The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
2180 CALCULATE_Z_OFFSET(M0, uint, zout, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
2181
2182 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2183 // multiply dst_stride_z by DEPTH_GEMM3D
2184 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
2185
2186#else // defined(REINTERPRET_OUTPUT_AS_3D)
2187
2188 // Add offset for batched GEMM
2189 dst_addr += z * dst_stride_z;
2190
2191#endif // defined(REINTERPRET_OUTPUT_AS_3D)
2192
2193 // Multiply by the weight of matrix-matrix product and store the result
2194 // Multiply by the weight of matrix-matrix product and store the result
2195#if defined(ALPHA)
2196 SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
2197#endif // defined(ALPHA)
2198
2199 // Store output block
2200 STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
2201
2202#undef RHS_BLOCK_SIZE
2203#undef RHS_OFFSET_X
2204#undef RHS_STEP_X
2205}
2206#endif // defined(M0) && defined(N0) && defined(K0) && defined(K) && defined(DATA_TYPE)
2207
Gian Marco36a0a462018-01-12 10:21:40 +00002208#if defined(COLS_B) && defined(MULT_TRANSPOSE1XW_WIDTH) && defined(MULT_INTERLEAVE4X4_HEIGHT)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002209/** This OpenCL kernel is optimised for Midgard. It computes the matrix multiplication between matrix A (src0) and matrix B (src1)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002210 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_32bit and @ref gemm_transpose1x4 before running the matrix multiplication
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002211 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002212 * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
2213 *
Gian Marco19835e52018-01-30 13:35:54 +00002214 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
2215 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
2216 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002217 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
2218 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002219 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002220 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
2221 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2222 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2223 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2224 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
2225 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002226 * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
2227 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002228 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
2229 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
2230 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2231 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
2232 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2233 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002234 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002235 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
2236 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2237 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
2238 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2239 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002240 * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
2241 * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes)
2242 * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
2243 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002244 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002245 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00002246 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002247 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00002248 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002249 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002250 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
2251 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
2252 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002253 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002254 */
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01002255__kernel void gemm_mm_interleaved_transposed_f32(IMAGE_DECLARATION(src0),
2256 IMAGE_DECLARATION(src1),
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002257#if defined(ADD_VEC_C)
2258 VECTOR_DECLARATION(src2),
2259#endif /* defined(ADD_VEC_C) */
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01002260 IMAGE_DECLARATION(dst),
2261 uint src0_stride_z,
2262 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002263 uint dst_stride_z
2264#if defined(REINTERPRET_OUTPUT_AS_3D)
2265 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002266 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002267#endif // REINTERPRET_OUTPUT_AS_3D
2268 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002269{
Gian Marco36a0a462018-01-12 10:21:40 +00002270 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
2271 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +00002272 int z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002273
Gian Marco36a0a462018-01-12 10:21:40 +00002274 // Offset
2275 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
2276 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 4;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002277
Gian Marco36a0a462018-01-12 10:21:40 +00002278 // src_addr_a = address of matrix A
2279 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002280 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
2281 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
2282
2283#if defined(MATRIX_B_DEPTH)
2284 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2285 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
2286#else // defined(MATRIX_B_DEPTH)
2287 src1_addr_in_bytes += z * src1_stride_z;
2288#endif // defined(MATRIX_B_DEPTH)
2289
2290 __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes);
2291 __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002292
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002293 // Compute end row address for matrix B
Gian Marco36a0a462018-01-12 10:21:40 +00002294 __global float *src_end_addr_b = src_addr_b + COLS_B;
2295
2296 src_addr_a += offset_row_a;
2297 src_addr_b += offset_row_b;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002298
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002299 // Reset accumulators
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002300 float4 c00 = 0.0f;
2301 float4 c10 = 0.0f;
2302 float4 c20 = 0.0f;
2303 float4 c30 = 0.0f;
2304
Gian Marco36a0a462018-01-12 10:21:40 +00002305 for(; src_addr_b <= (src_end_addr_b - (int)(8 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002306 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002307 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00002308 float4 a0 = vload4(0, src_addr_a);
2309 float4 b0 = vload4(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002310
2311 c00 += (float4)a0.s0 * b0;
2312 c10 += (float4)a0.s1 * b0;
2313 c20 += (float4)a0.s2 * b0;
2314 c30 += (float4)a0.s3 * b0;
2315
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002316 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00002317 a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
2318 b0 = vload4(0, src_addr_b + 4 * MULT_TRANSPOSE1XW_WIDTH);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002319
2320 c00 += (float4)a0.s0 * b0;
2321 c10 += (float4)a0.s1 * b0;
2322 c20 += (float4)a0.s2 * b0;
2323 c30 += (float4)a0.s3 * b0;
2324 }
2325
Gian Marco36a0a462018-01-12 10:21:40 +00002326 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002327 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002328 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00002329 float4 a0 = vload4(0, src_addr_a);
2330 float4 b0 = vload4(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002331
2332 c00 += (float4)a0.s0 * b0;
2333 c10 += (float4)a0.s1 * b0;
2334 c20 += (float4)a0.s2 * b0;
2335 c30 += (float4)a0.s3 * b0;
2336 }
2337
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002338 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002339 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2340
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002341#if defined(ALPHA)
2342 // Multiply by the weight of matrix product
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002343 c00 = c00 * (float4)ALPHA;
2344 c10 = c10 * (float4)ALPHA;
2345 c20 = c20 * (float4)ALPHA;
2346 c30 = c30 * (float4)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002347#endif // defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002348
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002349#if defined(ADD_VEC_C)
2350 __global float *src2_addr = (__global float *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
2351 float4 c0 = vload4(0, src2_addr);
2352
2353 c00 += c0;
2354 c10 += c0;
2355 c20 += c0;
2356 c30 += c0;
2357#endif /* defined(ADD_VEC_C) */
2358
Gian Marcoae2af742018-02-15 12:35:44 +00002359 // Compute dst address
2360 __global uchar *dst_addr = offset(&dst, 0, 0);
2361
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002362#if defined(REINTERPRET_OUTPUT_AS_3D)
2363 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002364 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002365 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002366 // | |
2367 // | plane0 |
2368 // | |
2369 // |__________________|
2370 // |******************|
2371 // | cross_plane_pad |
2372 // |******************|
2373 // | |
2374 // | plane1 |
2375 // | |
2376 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002377
2378 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
2379 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
2380 zout = min(DEPTH_GEMM3D - 1, zout);
2381
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002382 // Add offset due to the cross plane paddings
2383 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002384
2385 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2386 // multiply dst_stride_z by DEPTH_GEMM3D
2387 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
2388
2389 // Store 4x4 block
2390 vstore4(c00, 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
2391 vstore4(c10, 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
2392 vstore4(c20, 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
2393 vstore4(c30, 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
2394
2395#else // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +00002396 // Add offset for batched GEMM
2397 dst_addr += z * dst_stride_z;
2398
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002399 // Store 4x4 block
Gian Marcoae2af742018-02-15 12:35:44 +00002400 vstore4(c00, 0, (__global float *)(dst_addr + 0 * dst_stride_y));
2401 vstore4(c10, 0, (__global float *)(dst_addr + 1 * dst_stride_y));
2402 vstore4(c20, 0, (__global float *)(dst_addr + 2 * dst_stride_y));
2403 vstore4(c30, 0, (__global float *)(dst_addr + 3 * dst_stride_y));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002404#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002405}
2406
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002407/** This OpenCL kernel is optimized for Bifrost. It computes the matrix multiplication between matrix A (src0) and matrix B (src1)
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002408 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_32bit and @ref gemm_transpose1x4 before running the matrix multiplication.
2409 *
2410 * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002411 *
Gian Marco19835e52018-01-30 13:35:54 +00002412 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
2413 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
2414 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002415 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
2416 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
2417 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002418 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002419 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
2420 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2421 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2422 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2423 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
2424 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002425 * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
2426 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002427 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
2428 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
2429 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2430 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
2431 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2432 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002433 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002434 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
2435 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2436 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
2437 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2438 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002439 * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
2440 * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes)
2441 * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
2442 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002443 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002444 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00002445 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002446 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00002447 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002448 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002449 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
2450 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
2451 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002452 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002453 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002454__kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0),
2455 IMAGE_DECLARATION(src1),
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002456#if defined(ADD_VEC_C)
2457 VECTOR_DECLARATION(src2),
2458#endif /* defined(ADD_VEC_C) */
Gian Marcoae2af742018-02-15 12:35:44 +00002459 IMAGE_DECLARATION(dst),
2460 uint src0_stride_z,
2461 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002462 uint dst_stride_z
2463#if defined(REINTERPRET_OUTPUT_AS_3D)
2464 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002465 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002466#endif // REINTERPRET_OUTPUT_AS_3D
2467 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002468{
Gian Marco36a0a462018-01-12 10:21:40 +00002469 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
2470 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +00002471 int z = get_global_id(2);
Gian Marco36a0a462018-01-12 10:21:40 +00002472
2473 // Offset
2474 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
2475 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 4;
2476
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002477 // src_addr_a = address of matrix A
2478 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002479 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
2480 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
2481
2482#if defined(MATRIX_B_DEPTH)
2483 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2484 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
2485#else // defined(MATRIX_B_DEPTH)
2486 src1_addr_in_bytes += z * src1_stride_z;
2487#endif // defined(MATRIX_B_DEPTH)
2488
2489 __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes);
2490 __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002491
Gian Marco36a0a462018-01-12 10:21:40 +00002492 src_addr_a += offset_row_a;
2493 src_addr_b += offset_row_b;
2494
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002495 // Reset accumulators
2496 float c00 = 0.0f;
2497 float c01 = 0.0f;
2498 float c02 = 0.0f;
2499 float c03 = 0.0f;
2500 float c10 = 0.0f;
2501 float c11 = 0.0f;
2502 float c12 = 0.0f;
2503 float c13 = 0.0f;
2504 float c20 = 0.0f;
2505 float c21 = 0.0f;
2506 float c22 = 0.0f;
2507 float c23 = 0.0f;
2508 float c30 = 0.0f;
2509 float c31 = 0.0f;
2510 float c32 = 0.0f;
2511 float c33 = 0.0f;
2512
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002513#define COLS_MTX_B (COLS_B / (4 * MULT_TRANSPOSE1XW_WIDTH))
2514
2515 int i = 0;
2516 for(; i <= (int)(COLS_MTX_B - 4); i += 4)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002517 {
2518 // Load values from matrix A (interleaved) and matrix B (transposed)
2519 float4 a0 = vload4(0, src_addr_a);
2520 float4 b0 = vload4(0, src_addr_b);
2521
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002522 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
2523 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002524
2525 c00 = fma(a0.s0, b0.s0, c00);
2526 c01 = fma(a0.s0, b0.s1, c01);
2527 c02 = fma(a0.s0, b0.s2, c02);
2528 c03 = fma(a0.s0, b0.s3, c03);
2529
2530 c10 = fma(a0.s1, b0.s0, c10);
2531 c11 = fma(a0.s1, b0.s1, c11);
2532 c12 = fma(a0.s1, b0.s2, c12);
2533 c13 = fma(a0.s1, b0.s3, c13);
2534
2535 c20 = fma(a0.s2, b0.s0, c20);
2536 c21 = fma(a0.s2, b0.s1, c21);
2537 c22 = fma(a0.s2, b0.s2, c22);
2538 c23 = fma(a0.s2, b0.s3, c23);
2539
2540 c30 = fma(a0.s3, b0.s0, c30);
2541 c31 = fma(a0.s3, b0.s1, c31);
2542 c32 = fma(a0.s3, b0.s2, c32);
2543 c33 = fma(a0.s3, b0.s3, c33);
2544
2545 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002546 a0 = vload4(0, src_addr_a);
2547 b0 = vload4(0, src_addr_b);
2548
2549 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
2550 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002551
2552 c00 = fma(a0.s0, b0.s0, c00);
2553 c01 = fma(a0.s0, b0.s1, c01);
2554 c02 = fma(a0.s0, b0.s2, c02);
2555 c03 = fma(a0.s0, b0.s3, c03);
2556
2557 c10 = fma(a0.s1, b0.s0, c10);
2558 c11 = fma(a0.s1, b0.s1, c11);
2559 c12 = fma(a0.s1, b0.s2, c12);
2560 c13 = fma(a0.s1, b0.s3, c13);
2561
2562 c20 = fma(a0.s2, b0.s0, c20);
2563 c21 = fma(a0.s2, b0.s1, c21);
2564 c22 = fma(a0.s2, b0.s2, c22);
2565 c23 = fma(a0.s2, b0.s3, c23);
2566
2567 c30 = fma(a0.s3, b0.s0, c30);
2568 c31 = fma(a0.s3, b0.s1, c31);
2569 c32 = fma(a0.s3, b0.s2, c32);
2570 c33 = fma(a0.s3, b0.s3, c33);
2571
2572 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002573 a0 = vload4(0, src_addr_a);
2574 b0 = vload4(0, src_addr_b);
2575
2576 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
2577 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
2578
2579 c00 = fma(a0.s0, b0.s0, c00);
2580 c01 = fma(a0.s0, b0.s1, c01);
2581 c02 = fma(a0.s0, b0.s2, c02);
2582 c03 = fma(a0.s0, b0.s3, c03);
2583
2584 c10 = fma(a0.s1, b0.s0, c10);
2585 c11 = fma(a0.s1, b0.s1, c11);
2586 c12 = fma(a0.s1, b0.s2, c12);
2587 c13 = fma(a0.s1, b0.s3, c13);
2588
2589 c20 = fma(a0.s2, b0.s0, c20);
2590 c21 = fma(a0.s2, b0.s1, c21);
2591 c22 = fma(a0.s2, b0.s2, c22);
2592 c23 = fma(a0.s2, b0.s3, c23);
2593
2594 c30 = fma(a0.s3, b0.s0, c30);
2595 c31 = fma(a0.s3, b0.s1, c31);
2596 c32 = fma(a0.s3, b0.s2, c32);
2597 c33 = fma(a0.s3, b0.s3, c33);
2598
2599 // Load values from matrix A (interleaved) and matrix B (transposed)
2600 a0 = vload4(0, src_addr_a);
2601 b0 = vload4(0, src_addr_b);
2602
2603 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
2604 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002605
2606 c00 = fma(a0.s0, b0.s0, c00);
2607 c01 = fma(a0.s0, b0.s1, c01);
2608 c02 = fma(a0.s0, b0.s2, c02);
2609 c03 = fma(a0.s0, b0.s3, c03);
2610
2611 c10 = fma(a0.s1, b0.s0, c10);
2612 c11 = fma(a0.s1, b0.s1, c11);
2613 c12 = fma(a0.s1, b0.s2, c12);
2614 c13 = fma(a0.s1, b0.s3, c13);
2615
2616 c20 = fma(a0.s2, b0.s0, c20);
2617 c21 = fma(a0.s2, b0.s1, c21);
2618 c22 = fma(a0.s2, b0.s2, c22);
2619 c23 = fma(a0.s2, b0.s3, c23);
2620
2621 c30 = fma(a0.s3, b0.s0, c30);
2622 c31 = fma(a0.s3, b0.s1, c31);
2623 c32 = fma(a0.s3, b0.s2, c32);
2624 c33 = fma(a0.s3, b0.s3, c33);
2625 }
2626
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002627 for(; i < (int)(COLS_MTX_B); ++i)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002628 {
2629 // Load values from matrix A (interleaved) and matrix B (transposed)
2630 float4 a0 = vload4(0, src_addr_a);
2631 float4 b0 = vload4(0, src_addr_b);
2632
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002633 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
2634 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
2635
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002636 c00 = fma(a0.s0, b0.s0, c00);
2637 c01 = fma(a0.s0, b0.s1, c01);
2638 c02 = fma(a0.s0, b0.s2, c02);
2639 c03 = fma(a0.s0, b0.s3, c03);
2640
2641 c10 = fma(a0.s1, b0.s0, c10);
2642 c11 = fma(a0.s1, b0.s1, c11);
2643 c12 = fma(a0.s1, b0.s2, c12);
2644 c13 = fma(a0.s1, b0.s3, c13);
2645
2646 c20 = fma(a0.s2, b0.s0, c20);
2647 c21 = fma(a0.s2, b0.s1, c21);
2648 c22 = fma(a0.s2, b0.s2, c22);
2649 c23 = fma(a0.s2, b0.s3, c23);
2650
2651 c30 = fma(a0.s3, b0.s0, c30);
2652 c31 = fma(a0.s3, b0.s1, c31);
2653 c32 = fma(a0.s3, b0.s2, c32);
2654 c33 = fma(a0.s3, b0.s3, c33);
2655 }
2656
2657 // Compute destination address
2658 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2659
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002660#if defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002661 // Multiply by the weight of matrix product
2662 c00 = c00 * ALPHA;
2663 c01 = c01 * ALPHA;
2664 c02 = c02 * ALPHA;
2665 c03 = c03 * ALPHA;
2666 c10 = c10 * ALPHA;
2667 c11 = c11 * ALPHA;
2668 c12 = c12 * ALPHA;
2669 c13 = c13 * ALPHA;
2670 c20 = c20 * ALPHA;
2671 c21 = c21 * ALPHA;
2672 c22 = c22 * ALPHA;
2673 c23 = c23 * ALPHA;
2674 c30 = c30 * ALPHA;
2675 c31 = c31 * ALPHA;
2676 c32 = c32 * ALPHA;
2677 c33 = c33 * ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002678#endif // defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002679
Gian Marcoae2af742018-02-15 12:35:44 +00002680 // Compute dst address
2681 __global uchar *dst_addr = offset(&dst, 0, 0);
2682
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002683#if defined(ADD_VEC_C)
2684 __global float *src2_addr = (__global float *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
2685 float4 c0 = vload4(0, src2_addr);
2686
2687 c00 += c0.s0;
2688 c01 += c0.s1;
2689 c02 += c0.s2;
2690 c03 += c0.s3;
2691 c10 += c0.s0;
2692 c11 += c0.s1;
2693 c12 += c0.s2;
2694 c13 += c0.s3;
2695 c20 += c0.s0;
2696 c21 += c0.s1;
2697 c22 += c0.s2;
2698 c23 += c0.s3;
2699 c30 += c0.s0;
2700 c31 += c0.s1;
2701 c32 += c0.s2;
2702 c33 += c0.s3;
2703#endif /* defined(ADD_VEC_C) */
2704
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002705#if defined(REINTERPRET_OUTPUT_AS_3D)
2706 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002707 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002708 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002709 // | |
2710 // | plane0 |
2711 // | |
2712 // |__________________|
2713 // |******************|
2714 // | cross_plane_pad |
2715 // |******************|
2716 // | |
2717 // | plane1 |
2718 // | |
2719 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002720
2721 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
2722 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
2723 zout = min(DEPTH_GEMM3D - 1, zout);
2724
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002725 // Add offset due to the cross plane paddings
2726 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002727
2728 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2729 // multiply dst_stride_z by DEPTH_GEMM3D
2730 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
2731
2732 // Store 4x4 block
2733 vstore4((float4)(c00, c01, c02, c03), 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
2734 vstore4((float4)(c10, c11, c12, c13), 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
2735 vstore4((float4)(c20, c21, c22, c23), 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
2736 vstore4((float4)(c30, c31, c32, c33), 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
2737
2738#else // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +00002739 // Add offset for batched GEMM
2740 dst_addr += z * dst_stride_z;
2741
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002742 // Store 4x4 block
Gian Marcoae2af742018-02-15 12:35:44 +00002743 vstore4((float4)(c00, c01, c02, c03), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
2744 vstore4((float4)(c10, c11, c12, c13), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
2745 vstore4((float4)(c20, c21, c22, c23), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
2746 vstore4((float4)(c30, c31, c32, c33), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002747#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002748}
2749
Georgios Pinitas84225582018-05-14 12:00:05 +01002750// Undefine local defines
2751#undef COLS_MTX_B
2752
Matthew Bentham6f31f8c2017-10-27 11:50:06 +01002753#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002754/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002755 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002756 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002757 * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
2758 *
Gian Marco19835e52018-01-30 13:35:54 +00002759 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
2760 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
2761 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002762 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
2763 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002764 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002765 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
2766 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2767 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2768 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2769 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
2770 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002771 * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
2772 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002773 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
2774 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
2775 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2776 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
2777 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2778 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002779 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002780 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
2781 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2782 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
2783 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2784 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002785 * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
2786 * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes)
2787 * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
2788 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002789 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002790 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00002791 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002792 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00002793 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002794 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002795 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
2796 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
2797 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002798 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002799 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002800__kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0),
2801 IMAGE_DECLARATION(src1),
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002802#if defined(ADD_VEC_C)
2803 VECTOR_DECLARATION(src2),
2804#endif /* defined(ADD_VEC_C) */
Gian Marcoae2af742018-02-15 12:35:44 +00002805 IMAGE_DECLARATION(dst),
2806 uint src0_stride_z,
2807 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002808 uint dst_stride_z
2809#if defined(REINTERPRET_OUTPUT_AS_3D)
2810 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002811 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002812#endif // REINTERPRET_OUTPUT_AS_3D
2813 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002814{
Gian Marco36a0a462018-01-12 10:21:40 +00002815 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
2816 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +00002817 int z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002818
Gian Marco36a0a462018-01-12 10:21:40 +00002819 // Offset
2820 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
2821 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002822
Gian Marco36a0a462018-01-12 10:21:40 +00002823 // src_addr_a = address of matrix A
2824 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002825 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
2826 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
2827
2828#if defined(MATRIX_B_DEPTH)
2829 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2830 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
2831#else // defined(MATRIX_B_DEPTH)
2832 src1_addr_in_bytes += z * src1_stride_z;
2833#endif // defined(MATRIX_B_DEPTH)
2834
2835 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
2836 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002837
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002838 // Compute end row address for matrix B
Gian Marco36a0a462018-01-12 10:21:40 +00002839 __global half *src_end_addr_b = src_addr_b + COLS_B;
2840
2841 src_addr_a += offset_row_a;
2842 src_addr_b += offset_row_b;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002843
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002844 // Reset accumulators
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002845 half8 c00 = 0.0f;
2846 half8 c10 = 0.0f;
2847 half8 c20 = 0.0f;
2848 half8 c30 = 0.0f;
2849
Gian Marco36a0a462018-01-12 10:21:40 +00002850 for(; src_addr_b <= (src_end_addr_b - (int)(16 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 16 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002851 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002852 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00002853 half4 a0 = vload4(0, src_addr_a);
2854 half8 b0 = vload8(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002855
2856 c00 += (half8)a0.s0 * b0;
2857 c10 += (half8)a0.s1 * b0;
2858 c20 += (half8)a0.s2 * b0;
2859 c30 += (half8)a0.s3 * b0;
2860
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002861 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00002862 a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
2863 b0 = vload8(0, src_addr_b + 8 * MULT_TRANSPOSE1XW_WIDTH);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002864
2865 c00 += (half8)a0.s0 * b0;
2866 c10 += (half8)a0.s1 * b0;
2867 c20 += (half8)a0.s2 * b0;
2868 c30 += (half8)a0.s3 * b0;
2869 }
2870
Gian Marco36a0a462018-01-12 10:21:40 +00002871 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002872 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002873 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00002874 half4 a0 = vload4(0, src_addr_a);
2875 half8 b0 = vload8(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002876
2877 c00 += (half8)a0.s0 * b0;
2878 c10 += (half8)a0.s1 * b0;
2879 c20 += (half8)a0.s2 * b0;
2880 c30 += (half8)a0.s3 * b0;
2881 }
2882
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002883 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002884 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2885
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002886#if defined(ALPHA)
2887 // Multiply by the weight of matrix product
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002888 c00 = c00 * (half8)ALPHA;
2889 c10 = c10 * (half8)ALPHA;
2890 c20 = c20 * (half8)ALPHA;
2891 c30 = c30 * (half8)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002892#endif // defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002893
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002894#if defined(ADD_VEC_C)
2895 // *INDENT-OFF*
2896 // clang-format off
2897 __global half *src2_addr = (__global half *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
2898 half8 c0 = vload8(0, src2_addr);
2899 // clang-format on
2900 // *INDENT-ON*
2901
2902 c00 += c0;
2903 c10 += c0;
2904 c20 += c0;
2905 c30 += c0;
2906#endif /* defined(ADD_VEC_C) */
2907
Gian Marcoae2af742018-02-15 12:35:44 +00002908 // Compute dst address
2909 __global uchar *dst_addr = offset(&dst, 0, 0);
2910
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002911#if defined(REINTERPRET_OUTPUT_AS_3D)
2912 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002913 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002914 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002915 // | |
2916 // | plane0 |
2917 // | |
2918 // |__________________|
2919 // |******************|
2920 // | cross_plane_pad |
2921 // |******************|
2922 // | |
2923 // | plane1 |
2924 // | |
2925 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002926
2927 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
2928 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
2929 zout = min(DEPTH_GEMM3D - 1, zout);
2930
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002931 // Add offset due to the cross plane paddings
2932 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002933
2934 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2935 // multiply dst_stride_z by DEPTH_GEMM3D
2936 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
2937
2938 // Store 4x8 block
2939 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
2940 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
2941 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
2942 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
2943
2944#else // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +00002945 // Add offset for batched GEMM
2946 dst_addr += z * dst_stride_z;
2947
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002948 // Store 4x8 block
Gian Marcoae2af742018-02-15 12:35:44 +00002949 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
2950 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
2951 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
2952 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002953#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002954}
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01002955
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00002956/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1) while accumulating the result in a 32 floating point variable.
2957 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
2958 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002959 * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
2960 *
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00002961 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
2962 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
2963 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
2964 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
2965 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
2966 *
2967 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
2968 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2969 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2970 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2971 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
2972 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002973 * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
2974 *
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00002975 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
2976 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
2977 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2978 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
2979 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2980 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
2981 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
2982 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
2983 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2984 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
2985 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2986 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002987 * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
2988 * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes)
2989 * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
2990 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00002991 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
2992 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2993 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2994 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2995 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2996 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2997 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
2998 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
2999 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
3000 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
3001 */
3002__kernel void gemm_mm_interleaved_transposed_f16_acc32(IMAGE_DECLARATION(src0),
3003 IMAGE_DECLARATION(src1),
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003004#if defined(ADD_VEC_C)
3005 VECTOR_DECLARATION(src2),
3006#endif /* defined(ADD_VEC_C) */
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003007 IMAGE_DECLARATION(dst),
3008 uint src0_stride_z,
3009 uint src1_stride_z,
3010 uint dst_stride_z
3011#if defined(REINTERPRET_OUTPUT_AS_3D)
3012 ,
3013 uint cross_plane_pad
3014#endif // REINTERPRET_OUTPUT_AS_3D
3015 )
3016{
3017 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
3018 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
3019 int z = get_global_id(2);
3020
3021 // Offset
3022 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
3023 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
3024
3025 // src_addr_a = address of matrix A
3026 // src_addr_b = address of matrix B
3027 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
3028 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
3029
3030#if defined(MATRIX_B_DEPTH)
3031 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
3032 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
3033#else // defined(MATRIX_B_DEPTH)
3034 src1_addr_in_bytes += z * src1_stride_z;
3035#endif // defined(MATRIX_B_DEPTH)
3036
3037 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
3038 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
3039
3040 // Compute end row address for matrix B
3041 __global half *src_end_addr_b = src_addr_b + COLS_B;
3042
3043 src_addr_a += offset_row_a;
3044 src_addr_b += offset_row_b;
3045
3046 // Reset accumulators
3047 float8 c00 = 0.0f;
3048 float8 c10 = 0.0f;
3049 float8 c20 = 0.0f;
3050 float8 c30 = 0.0f;
3051
3052 for(; src_addr_b <= (src_end_addr_b - (int)(16 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 16 * MULT_TRANSPOSE1XW_WIDTH)
3053 {
3054 // Load values from matrix A (interleaved) and matrix B (transposed)
3055 float4 a0 = convert_float4(vload4(0, src_addr_a));
3056 float8 b0 = convert_float8(vload8(0, src_addr_b));
3057
3058 c00 += (float8)a0.s0 * b0;
3059 c10 += (float8)a0.s1 * b0;
3060 c20 += (float8)a0.s2 * b0;
3061 c30 += (float8)a0.s3 * b0;
3062
3063 // Load values from matrix A (interleaved) and matrix B (transposed)
3064 a0 = convert_float4(vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT));
3065 b0 = convert_float8(vload8(0, src_addr_b + 8 * MULT_TRANSPOSE1XW_WIDTH));
3066
3067 c00 += (float8)a0.s0 * b0;
3068 c10 += (float8)a0.s1 * b0;
3069 c20 += (float8)a0.s2 * b0;
3070 c30 += (float8)a0.s3 * b0;
3071 }
3072
3073 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
3074 {
3075 // Load values from matrix A (interleaved) and matrix B (transposed)
3076 float4 a0 = convert_float4(vload4(0, src_addr_a));
3077 float8 b0 = convert_float8(vload8(0, src_addr_b));
3078
3079 c00 += (float8)a0.s0 * b0;
3080 c10 += (float8)a0.s1 * b0;
3081 c20 += (float8)a0.s2 * b0;
3082 c30 += (float8)a0.s3 * b0;
3083 }
3084
3085 // Compute destination address
3086 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
3087
3088#if defined(ALPHA)
3089 // Multiply by the weight of matrix product
3090 c00 = c00 * (float8)ALPHA;
3091 c10 = c10 * (float8)ALPHA;
3092 c20 = c20 * (float8)ALPHA;
3093 c30 = c30 * (float8)ALPHA;
3094#endif // defined(ALPHA)
3095
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003096#if defined(ADD_VEC_C)
3097 // *INDENT-OFF*
3098 // clang-format off
3099 __global half *src2_addr = (__global half *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
3100 float8 c0 = convert_float8(vload8(0, src2_addr));
3101 // clang-format on
3102 // *INDENT-ON*
3103
3104 c00 += c0;
3105 c10 += c0;
3106 c20 += c0;
3107 c30 += c0;
3108#endif /* defined(ADD_VEC_C) */
3109
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003110 // Compute dst address
3111 __global uchar *dst_addr = offset(&dst, 0, 0);
3112
3113#if defined(REINTERPRET_OUTPUT_AS_3D)
3114 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
3115 // in order to take into account the presence of possible cross plane paddings
3116 //
3117 // | |
3118 // | plane0 |
3119 // | |
3120 // |__________________|
3121 // |******************|
3122 // | cross_plane_pad |
3123 // |******************|
3124 // | |
3125 // | plane1 |
3126 // | |
3127 // |__________________|
3128
3129 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
3130 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
3131 zout = min(DEPTH_GEMM3D - 1, zout);
3132
3133 // Add offset due to the cross plane paddings
3134 zout *= (cross_plane_pad * dst_stride_y);
3135
3136 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3137 // multiply dst_stride_z by DEPTH_GEMM3D
3138 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
3139
3140 // Store 4x8 block
3141 vstore8(convert_half8(c00), 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
3142 vstore8(convert_half8(c10), 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
3143 vstore8(convert_half8(c20), 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
3144 vstore8(convert_half8(c30), 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
3145
3146#else // defined(REINTERPRET_OUTPUT_AS_3D)
3147 // Add offset for batched GEMM
3148 dst_addr += z * dst_stride_z;
3149
3150 // Store 4x8 block
3151 vstore8(convert_half8(c00), 0, (__global half *)(dst_addr + 0 * dst_stride_y));
3152 vstore8(convert_half8(c10), 0, (__global half *)(dst_addr + 1 * dst_stride_y));
3153 vstore8(convert_half8(c20), 0, (__global half *)(dst_addr + 2 * dst_stride_y));
3154 vstore8(convert_half8(c30), 0, (__global half *)(dst_addr + 3 * dst_stride_y));
3155#endif // defined(REINTERPRET_OUTPUT_AS_3D)
3156}
3157
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003158/** This OpenCL kernel optimized for Bifrost architectures computes the matrix multiplication between matrix A (src0) and matrix B (src1)
3159 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
3160 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003161 * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
3162 *
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003163 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
3164 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
3165 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
3166 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
3167 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
3168 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003169 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
3170 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
3171 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
3172 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
3173 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
3174 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003175 * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
3176 *
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003177 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
3178 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
3179 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3180 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
3181 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3182 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
3183 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
3184 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
3185 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3186 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
3187 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3188 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003189 * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
3190 * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes)
3191 * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
3192 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003193 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
3194 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
3195 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
3196 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
3197 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
3198 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003199 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003200 */
3201__kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0),
3202 IMAGE_DECLARATION(src1),
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003203#if defined(ADD_VEC_C)
3204 VECTOR_DECLARATION(src2),
3205#endif /* defined(ADD_VEC_C) */
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003206 IMAGE_DECLARATION(dst),
3207 uint src0_stride_z,
3208 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003209 uint dst_stride_z
3210#if defined(REINTERPRET_OUTPUT_AS_3D)
3211 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003212 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003213#endif // REINTERPRET_OUTPUT_AS_3D
3214 )
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003215{
3216 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
3217 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
3218 int z = get_global_id(2);
3219
3220 // Offset
3221 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
3222 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
3223
3224 // src_addr_a = address of matrix A
3225 // src_addr_b = address of matrix B
3226 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
3227 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
3228
3229#if defined(MATRIX_B_DEPTH)
3230 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
3231 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
3232#else // defined(MATRIX_B_DEPTH)
3233 src1_addr_in_bytes += z * src1_stride_z;
3234#endif // defined(MATRIX_B_DEPTH)
3235
3236 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
3237 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
3238
3239 // Compute end row address for matrix B
3240 __global half *src_end_addr_b = src_addr_b + COLS_B;
3241
3242 src_addr_a += offset_row_a;
3243 src_addr_b += offset_row_b;
3244
3245 // Reset accumulators
3246 half8 c00 = 0.0f;
3247 half8 c10 = 0.0f;
3248 half8 c20 = 0.0f;
3249 half8 c30 = 0.0f;
3250
3251#define COLS_MTX_B (COLS_B / (8 * MULT_TRANSPOSE1XW_WIDTH))
3252
3253 int i = 0;
3254 for(; i <= (int)(COLS_MTX_B - 4); i += 4)
3255 {
3256#if MULT_INTERLEAVE4X4_HEIGHT == 1
3257 // Load values from matrix A (interleaved) and matrix B (transposed)
3258 half8 a0 = vload8(0, src_addr_a);
3259 half8 b0 = vload8(0, src_addr_b);
3260
3261 src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT;
3262 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3263
3264 c00 = fma((half8)a0.s0, b0, c00);
3265 c10 = fma((half8)a0.s1, b0, c10);
3266 c20 = fma((half8)a0.s2, b0, c20);
3267 c30 = fma((half8)a0.s3, b0, c30);
3268
3269 // Load values from matrix B (transposed)
3270 b0 = vload8(0, src_addr_b);
3271
3272 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3273
3274 c00 = fma((half8)a0.s4, b0, c00);
3275 c10 = fma((half8)a0.s5, b0, c10);
3276 c20 = fma((half8)a0.s6, b0, c20);
3277 c30 = fma((half8)a0.s7, b0, c30);
3278
3279 // Load values from matrix A (interleaved) and matrix B (transposed)
3280 a0 = vload8(0, src_addr_a);
3281 b0 = vload8(0, src_addr_b);
3282
3283 src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT;
3284 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3285
3286 c00 = fma((half8)a0.s0, b0, c00);
3287 c10 = fma((half8)a0.s1, b0, c10);
3288 c20 = fma((half8)a0.s2, b0, c20);
3289 c30 = fma((half8)a0.s3, b0, c30);
3290
3291 // Load values from matrix B (transposed)
3292 b0 = vload8(0, src_addr_b);
3293
3294 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3295
3296 c00 = fma((half8)a0.s4, b0, c00);
3297 c10 = fma((half8)a0.s5, b0, c10);
3298 c20 = fma((half8)a0.s6, b0, c20);
3299 c30 = fma((half8)a0.s7, b0, c30);
3300#else // MULT_INTERLEAVE4X4_HEIGHT == 1
3301 // Load values from matrix A (interleaved) and matrix B (transposed)
3302 half4 a0 = vload4(0, src_addr_a);
3303 half8 b0 = vload8(0, src_addr_b);
3304
3305 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3306 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3307
3308 c00 = fma((half8)a0.s0, b0, c00);
3309 c10 = fma((half8)a0.s1, b0, c10);
3310 c20 = fma((half8)a0.s2, b0, c20);
3311 c30 = fma((half8)a0.s3, b0, c30);
3312
3313 // Load values from matrix A (interleaved) and matrix B (transposed)
3314 a0 = vload4(0, src_addr_a);
3315 b0 = vload8(0, src_addr_b);
3316
3317 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3318 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3319
3320 c00 = fma((half8)a0.s0, b0, c00);
3321 c10 = fma((half8)a0.s1, b0, c10);
3322 c20 = fma((half8)a0.s2, b0, c20);
3323 c30 = fma((half8)a0.s3, b0, c30);
3324
3325 // Load values from matrix A (interleaved) and matrix B (transposed)
3326 a0 = vload4(0, src_addr_a);
3327 b0 = vload8(0, src_addr_b);
3328
3329 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3330 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3331
3332 c00 = fma((half8)a0.s0, b0, c00);
3333 c10 = fma((half8)a0.s1, b0, c10);
3334 c20 = fma((half8)a0.s2, b0, c20);
3335 c30 = fma((half8)a0.s3, b0, c30);
3336
3337 // Load values from matrix A (interleaved) and matrix B (transposed)
3338 a0 = vload4(0, src_addr_a);
3339 b0 = vload8(0, src_addr_b);
3340
3341 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3342 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3343
3344 c00 = fma((half8)a0.s0, b0, c00);
3345 c10 = fma((half8)a0.s1, b0, c10);
3346 c20 = fma((half8)a0.s2, b0, c20);
3347 c30 = fma((half8)a0.s3, b0, c30);
3348#endif // MULT_INTERLEAVE4X4_HEIGHT == 1
3349 }
3350
3351 for(; i < (int)(COLS_MTX_B); ++i)
3352 {
3353 // Load values from matrix A (interleaved) and matrix B (transposed)
3354 half4 a0 = vload4(0, src_addr_a);
3355 half8 b0 = vload8(0, src_addr_b);
3356
3357 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3358 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3359
3360 c00 = fma((half8)a0.s0, b0, c00);
3361 c10 = fma((half8)a0.s1, b0, c10);
3362 c20 = fma((half8)a0.s2, b0, c20);
3363 c30 = fma((half8)a0.s3, b0, c30);
3364 }
3365
3366 // Compute destination address
3367 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
3368
3369#if defined(ALPHA)
3370 // Multiply by the weight of matrix product
3371 c00 = c00 * (half8)ALPHA;
3372 c10 = c10 * (half8)ALPHA;
3373 c20 = c20 * (half8)ALPHA;
3374 c30 = c30 * (half8)ALPHA;
3375#endif // defined(ALPHA)
3376
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003377#if defined(ADD_VEC_C)
3378 // *INDENT-OFF*
3379 // clang-format off
3380 __global half *src2_addr = (__global half *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
3381 half8 c0 = vload8(0, src2_addr);
3382 // clang-format on
3383 // *INDENT-ON*
3384
3385 c00 += c0;
3386 c10 += c0;
3387 c20 += c0;
3388 c30 += c0;
3389#endif /* defined(ADD_VEC_C) */
3390
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003391 // Compute dst address
3392 __global uchar *dst_addr = offset(&dst, 0, 0);
3393
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003394#if defined(REINTERPRET_OUTPUT_AS_3D)
3395 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003396 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003397 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003398 // | |
3399 // | plane0 |
3400 // | |
3401 // |__________________|
3402 // |******************|
3403 // | cross_plane_pad |
3404 // |******************|
3405 // | |
3406 // | plane1 |
3407 // | |
3408 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003409
3410 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
3411 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
3412 zout = min(DEPTH_GEMM3D - 1, zout);
3413
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003414 // Add offset due to the cross plane paddings
3415 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003416
3417 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3418 // multiply dst_stride_z by DEPTH_GEMM3D
3419 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
3420
3421 // Store 4x8 block
3422 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
3423 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
3424 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
3425 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
3426
3427#else // defined(REINTERPRET_OUTPUT_AS_3D)
3428 // Add offset for batched GEMM
3429 dst_addr += z * dst_stride_z;
3430
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003431 // Store 4x8 block
3432 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
3433 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
3434 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
3435 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003436#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003437}
Georgios Pinitas84225582018-05-14 12:00:05 +01003438
3439// Undefine local defines
3440#undef COLS_MTX_B
3441
Matthew Bentham6f31f8c2017-10-27 11:50:06 +01003442#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003443
Gian Marco36a0a462018-01-12 10:21:40 +00003444#endif // defined(COLS_B) && defined(MULT_TRANSPOSE1XW_WIDTH) && defined(MULT_INTERLEAVE4X4_HEIGHT)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003445
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003446#if defined(COLS_A) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && (NUM_ELEMS_PROCESSED_PER_THREAD_Y)
3447#if defined(DATA_TYPE)
3448#define VECTOR_TYPE VEC_DATA_TYPE(DATA_TYPE, NUM_ELEMS_PROCESSED_PER_THREAD_X)
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003449/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped.
3450 *
3451 * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003452 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003453 * @note This OpenCL kernel works with floating point data types (F16/F32)
3454 * @note The floating point data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
3455 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003456 * @note The number of matrix A columns and the optional alpha's value need to be passed at compile time using -DCOLS_A and -DALPHA
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003457 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
3458 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003459 *
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003460 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
3461 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003462 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
3463 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
3464 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
3465 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
3466 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003467 * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
3468 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003469 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003470 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
3471 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3472 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
3473 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3474 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003475 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003476 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
3477 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3478 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
3479 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3480 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003481 * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
3482 * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes)
3483 * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
3484 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003485 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003486 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
3487 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
3488 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
3489 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
3490 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003491 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
3492 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
3493 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003494 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
3495 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements for the output tensor (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003496 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003497__kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
3498 IMAGE_DECLARATION(src1),
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003499#if defined(ADD_VEC_C)
3500 VECTOR_DECLARATION(src2),
3501#endif /* defined(ADD_VEC_C) */
Gian Marcoae2af742018-02-15 12:35:44 +00003502 IMAGE_DECLARATION(dst),
3503 uint src0_stride_z,
3504 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003505 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003506#if defined(REINTERPRET_INPUT_AS_3D)
3507 ,
3508 uint src_cross_plane_pad
3509#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003510#if defined(REINTERPRET_OUTPUT_AS_3D)
3511 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003512 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003513#endif // REINTERPRET_OUTPUT_AS_3D
3514 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003515{
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003516 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003517
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003518 // Compute starting address for matrix A and Matrix B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003519 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003520
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003521 // Update address for the matrix A
3522 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003523
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003524 // Update address for the matrix B
3525 src_addr.s1 += idx * sizeof(DATA_TYPE);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003526
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003527#if defined(REINTERPRET_INPUT_AS_3D)
3528 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
3529 // in order to take into account the presence of possible cross plane paddings
3530 //
3531 // | |
3532 // | plane0 |
3533 // | |
3534 // |__________________|
3535 // |******************|
3536 // | cross_plane_pad |
3537 // |******************|
3538 // | |
3539 // | plane1 |
3540 // | |
3541 // |__________________|
3542
3543 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
3544 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
3545 zin = min(DEPTH_GEMM3D - 1, zin);
3546
3547 // Add offset due to the cross plane paddings
3548 zin *= (src_cross_plane_pad * src0_stride_y);
3549
3550 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3551 // multiply src0_stride_z by DEPTH_GEMM3D
3552 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
3553
3554#else // defined(REINTERPRET_INPUT_AS_3D)
3555
Gian Marcoae2af742018-02-15 12:35:44 +00003556 // Add offset for batched GEMM
3557 src_addr.s0 += get_global_id(2) * src0_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003558
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003559#endif // defined(REINTERPRET_INPUT_AS_3D)
3560
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003561#if defined(MATRIX_B_DEPTH)
3562 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
3563 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
3564#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00003565 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003566#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00003567
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003568 int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(DATA_TYPE));
3569
3570 VECTOR_TYPE acc0 = 0.0f;
3571#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3572 VECTOR_TYPE acc1 = 0.0f;
3573#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3574#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3575 VECTOR_TYPE acc2 = 0.0f;
3576#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3577#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3578 VECTOR_TYPE acc3 = 0.0f;
3579#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3580
Georgios Pinitas96880cf2017-10-20 18:52:20 +01003581 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(DATA_TYPE)); src_addr += (int2)(2 * sizeof(DATA_TYPE), 2 * src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003582 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003583#if defined(REINTERPRET_INPUT_AS_3D)
3584 // Load values from matrix A
Usama Arif0681e3b2019-04-25 14:28:07 +01003585 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 2, DATA_TYPE, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s);
3586#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003587 // Load values from matrix A
3588 VEC_DATA_TYPE(DATA_TYPE, 2)
3589 a0 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
3590#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3591 VEC_DATA_TYPE(DATA_TYPE, 2)
3592 a1 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
3593#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3594#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3595 VEC_DATA_TYPE(DATA_TYPE, 2)
3596 a2 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
3597#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3598#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3599 VEC_DATA_TYPE(DATA_TYPE, 2)
3600 a3 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
3601#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003602#endif // defined(REINTERPRET_INPUT_AS_3D)
3603
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003604 // Load values from matrix B
3605 VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
3606 VECTOR_TYPE b1 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1 + src1_stride_y));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003607
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003608 // Accumulate
3609 acc0 += b0 * (VECTOR_TYPE)a0.s0;
3610 acc0 += b1 * (VECTOR_TYPE)a0.s1;
3611#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3612 acc1 += b0 * (VECTOR_TYPE)a1.s0;
3613 acc1 += b1 * (VECTOR_TYPE)a1.s1;
3614#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3615#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3616 acc2 += b0 * (VECTOR_TYPE)a2.s0;
3617 acc2 += b1 * (VECTOR_TYPE)a2.s1;
3618#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3619#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3620 acc3 += b0 * (VECTOR_TYPE)a3.s0;
3621 acc3 += b1 * (VECTOR_TYPE)a3.s1;
3622#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003623 }
3624
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003625 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(DATA_TYPE), src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003626 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003627#if defined(REINTERPRET_INPUT_AS_3D)
3628 // Load values from matrix A
3629 DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
3630#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3631 DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
3632#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3633#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3634 DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
3635#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3636#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3637 DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
3638#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3639#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003640 // Load values from matrix A
3641 DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
3642#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3643 DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
3644#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3645#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3646 DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
3647#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3648#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3649 DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
3650#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003651#endif // defined(REINTERPRET_INPUT_AS_3D)
3652
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003653 // Load values from matrix B
3654 VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003655
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003656 // Accumulate
3657 acc0 += b0 * (VECTOR_TYPE)a0;
3658#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3659 acc1 += b0 * (VECTOR_TYPE)a1;
3660#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3661#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3662 acc2 += b0 * (VECTOR_TYPE)a2;
3663#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3664#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3665 acc3 += b0 * (VECTOR_TYPE)a3;
3666#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003667 }
3668
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003669 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003670 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
3671
Gian Marcoae2af742018-02-15 12:35:44 +00003672 // Compute dst address
3673 __global uchar *dst_addr = offset(&dst, 0, 0);
3674
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003675 // Multiply by the weight of matrix-matrix product and store the result
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003676#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003677 acc0 = acc0 * (VECTOR_TYPE)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003678#endif // defined(ALPHA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003679#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
3680 acc1 = acc1 * (VECTOR_TYPE)ALPHA;
3681#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
3682#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
3683 acc2 = acc2 * (VECTOR_TYPE)ALPHA;
3684#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
3685#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
3686 acc3 = acc3 * (VECTOR_TYPE)ALPHA;
3687#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
3688
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003689#if defined(ADD_VEC_C)
3690 // *INDENT-OFF*
3691 // clang-format off
3692 __global DATA_TYPE *src2_addr = (__global DATA_TYPE *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
3693 VECTOR_TYPE c0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, src2_addr);
3694 // clang-format on
3695 // *INDENT-ON*
3696
3697 acc0 += c0;
3698#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3699 acc1 += c0;
3700#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3701#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3702 acc2 += c0;
3703#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3704#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3705 acc3 += c0;
3706#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3707#endif /* defined(ADD_VEC_C) */
3708
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003709 int z = get_global_id(2);
3710
3711#if defined(REINTERPRET_OUTPUT_AS_3D)
3712 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003713 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003714 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003715 // | |
3716 // | plane0 |
3717 // | |
3718 // |__________________|
3719 // |******************|
3720 // | cross_plane_pad |
3721 // |******************|
3722 // | |
3723 // | plane1 |
3724 // | |
3725 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003726
3727 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
3728 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
3729 zout = min(DEPTH_GEMM3D - 1, zout);
3730
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003731 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003732 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003733
3734 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3735 // multiply dst_stride_z by DEPTH_GEMM3D
3736 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
3737
3738 // Store output block
Usama Arif0681e3b2019-04-25 14:28:07 +01003739 STORE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, NUM_ELEMS_PROCESSED_PER_THREAD_X, DATA_TYPE, acc, dst_addr, dst_stride_y, zout.s);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003740#else // defined(REINTERPRET_OUTPUT_AS_3D)
3741 // Add offset for batched GEMM
3742 dst_addr += z * dst_stride_z;
3743
3744 // Store output block
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003745 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00003746 (acc0, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003747#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003748 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00003749 (acc1, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003750#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3751#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003752 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00003753 (acc2, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003754#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3755#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003756 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00003757 (acc3, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003758#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003759#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003760}
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003761#endif // defined(DATA_TYPE)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003762
Michele Di Giorgiof6f08da2018-04-26 10:24:30 +01003763/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003764 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003765 * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
3766 *
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003767 * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
3768 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
3769 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
3770 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
3771 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003772 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
3773 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003774 *
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003775 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
3776 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003777 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
3778 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
3779 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
3780 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
3781 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003782 * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
3783 *
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003784 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
3785 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
3786 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3787 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
3788 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3789 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
3790 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
3791 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
3792 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3793 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
3794 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3795 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003796 * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
3797 * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes)
3798 * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
3799 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003800 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
3801 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
3802 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
3803 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
3804 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
3805 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003806 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
3807 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
3808 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003809 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
3810 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003811 */
3812__kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
3813 IMAGE_DECLARATION(src1),
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003814#if defined(ADD_VEC_C)
3815 VECTOR_DECLARATION(src2),
3816#endif /* defined(ADD_VEC_C) */
Gian Marcoae2af742018-02-15 12:35:44 +00003817 IMAGE_DECLARATION(dst),
3818 uint src0_stride_z,
3819 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003820 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003821#if defined(REINTERPRET_INPUT_AS_3D)
3822 ,
3823 uint src_cross_plane_pad
3824#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003825#if defined(REINTERPRET_OUTPUT_AS_3D)
3826 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003827 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003828#endif // REINTERPRET_OUTPUT_AS_3D
3829 )
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003830{
3831 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
3832
3833 // Compute starting address for matrix A and matrix B
3834 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
3835
3836 // Update address for matrix A
3837 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
3838
3839 // Update address for matrix B
3840 src_addr.s1 += idx * sizeof(float);
3841
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003842#if defined(REINTERPRET_INPUT_AS_3D)
3843 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
3844 // in order to take into account the presence of possible cross plane paddings
3845 //
3846 // | |
3847 // | plane0 |
3848 // | |
3849 // |__________________|
3850 // |******************|
3851 // | cross_plane_pad |
3852 // |******************|
3853 // | |
3854 // | plane1 |
3855 // | |
3856 // |__________________|
3857
3858 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
3859 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
3860 zin = min(DEPTH_GEMM3D - 1, zin);
3861
3862 // Add offset due to the cross plane paddings
3863 zin *= (src_cross_plane_pad * src0_stride_y);
3864
3865 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3866 // multiply src0_stride_z by DEPTH_GEMM3D
3867 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
3868
3869#else // defined(REINTERPRET_INPUT_AS_3D)
3870
Gian Marcoae2af742018-02-15 12:35:44 +00003871 // Add offset for batched GEMM
3872 src_addr.s0 += get_global_id(2) * src0_stride_z;
3873
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003874#endif // defined(REINTERPRET_INPUT_AS_3D)
3875
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003876#if defined(MATRIX_B_DEPTH)
3877 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
3878 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
3879#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00003880 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003881#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00003882
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003883 // Initialize accumulators
3884 float acc00 = 0.0f;
3885 float acc01 = 0.0f;
3886 float acc02 = 0.0f;
3887 float acc03 = 0.0f;
3888
3889#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3890 float acc10 = 0.0f;
3891 float acc11 = 0.0f;
3892 float acc12 = 0.0f;
3893 float acc13 = 0.0f;
3894#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3895
3896#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3897 float acc20 = 0.0f;
3898 float acc21 = 0.0f;
3899 float acc22 = 0.0f;
3900 float acc23 = 0.0f;
3901#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3902
3903#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3904 float acc30 = 0.0f;
3905 float acc31 = 0.0f;
3906 float acc32 = 0.0f;
3907 float acc33 = 0.0f;
3908#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3909
3910 // A and B src indices get incremented at the same time.
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003911 int i = 0;
3912 for(; i <= ((int)COLS_A - 4); i += 4)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003913 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003914#if defined(REINTERPRET_INPUT_AS_3D)
3915 // Load values from matrix A and matrix B
Usama Arif0681e3b2019-04-25 14:28:07 +01003916 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 4, float, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s);
3917#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003918 // Load values from matrix A and matrix B
3919 float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003920#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003921 float4 a1 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003922#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3923#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003924 float4 a2 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003925#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3926#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003927 float4 a3 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003928#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003929#endif // defined(REINTERPRET_INPUT_AS_3D)
3930
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003931 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
3932 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003933
3934 // Multiply and accumulate
3935 acc00 = fma(a0.s0, b0.s0, acc00);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003936 acc01 = fma(a0.s0, b0.s1, acc01);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003937 acc02 = fma(a0.s0, b0.s2, acc02);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003938 acc03 = fma(a0.s0, b0.s3, acc03);
3939
3940#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003941
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003942 acc10 = fma(a1.s0, b0.s0, acc10);
3943 acc11 = fma(a1.s0, b0.s1, acc11);
3944 acc12 = fma(a1.s0, b0.s2, acc12);
3945 acc13 = fma(a1.s0, b0.s3, acc13);
3946
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003947#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3948#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003949
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003950 acc20 = fma(a2.s0, b0.s0, acc20);
3951 acc21 = fma(a2.s0, b0.s1, acc21);
3952 acc22 = fma(a2.s0, b0.s2, acc22);
3953 acc23 = fma(a2.s0, b0.s3, acc23);
3954
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003955#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3956#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003957
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003958 acc30 = fma(a3.s0, b0.s0, acc30);
3959 acc31 = fma(a3.s0, b0.s1, acc31);
3960 acc32 = fma(a3.s0, b0.s2, acc32);
3961 acc33 = fma(a3.s0, b0.s3, acc33);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003962#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003963
3964 // Load values from matrix A and matrix B
3965 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
3966 src_addr.s1 += src1_stride_y;
3967
3968 // Multiply and accumulate
3969 acc00 = fma(a0.s1, b0.s0, acc00);
3970 acc01 = fma(a0.s1, b0.s1, acc01);
3971 acc02 = fma(a0.s1, b0.s2, acc02);
3972 acc03 = fma(a0.s1, b0.s3, acc03);
3973
3974#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3975
3976 acc10 = fma(a1.s1, b0.s0, acc10);
3977 acc11 = fma(a1.s1, b0.s1, acc11);
3978 acc12 = fma(a1.s1, b0.s2, acc12);
3979 acc13 = fma(a1.s1, b0.s3, acc13);
3980
3981#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3982#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3983
3984 acc20 = fma(a2.s1, b0.s0, acc20);
3985 acc21 = fma(a2.s1, b0.s1, acc21);
3986 acc22 = fma(a2.s1, b0.s2, acc22);
3987 acc23 = fma(a2.s1, b0.s3, acc23);
3988
3989#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3990#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3991
3992 acc30 = fma(a3.s1, b0.s0, acc30);
3993 acc31 = fma(a3.s1, b0.s1, acc31);
3994 acc32 = fma(a3.s1, b0.s2, acc32);
3995 acc33 = fma(a3.s1, b0.s3, acc33);
3996#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3997
3998 // Load values from matrix A and matrix B
3999 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
4000 src_addr.s1 += src1_stride_y;
4001
4002 // Multiply and accumulate
4003 acc00 = fma(a0.s2, b0.s0, acc00);
4004 acc01 = fma(a0.s2, b0.s1, acc01);
4005 acc02 = fma(a0.s2, b0.s2, acc02);
4006 acc03 = fma(a0.s2, b0.s3, acc03);
4007
4008#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4009
4010 acc10 = fma(a1.s2, b0.s0, acc10);
4011 acc11 = fma(a1.s2, b0.s1, acc11);
4012 acc12 = fma(a1.s2, b0.s2, acc12);
4013 acc13 = fma(a1.s2, b0.s3, acc13);
4014
4015#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4016#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4017
4018 acc20 = fma(a2.s2, b0.s0, acc20);
4019 acc21 = fma(a2.s2, b0.s1, acc21);
4020 acc22 = fma(a2.s2, b0.s2, acc22);
4021 acc23 = fma(a2.s2, b0.s3, acc23);
4022
4023#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4024#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4025
4026 acc30 = fma(a3.s2, b0.s0, acc30);
4027 acc31 = fma(a3.s2, b0.s1, acc31);
4028 acc32 = fma(a3.s2, b0.s2, acc32);
4029 acc33 = fma(a3.s2, b0.s3, acc33);
4030#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4031
4032 // Load values from matrix A and matrix B
4033 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
4034 src_addr.s1 += src1_stride_y;
4035
4036 // Multiply and accumulate
4037 acc00 = fma(a0.s3, b0.s0, acc00);
4038 acc01 = fma(a0.s3, b0.s1, acc01);
4039 acc02 = fma(a0.s3, b0.s2, acc02);
4040 acc03 = fma(a0.s3, b0.s3, acc03);
4041
4042#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4043
4044 acc10 = fma(a1.s3, b0.s0, acc10);
4045 acc11 = fma(a1.s3, b0.s1, acc11);
4046 acc12 = fma(a1.s3, b0.s2, acc12);
4047 acc13 = fma(a1.s3, b0.s3, acc13);
4048
4049#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4050#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4051
4052 acc20 = fma(a2.s3, b0.s0, acc20);
4053 acc21 = fma(a2.s3, b0.s1, acc21);
4054 acc22 = fma(a2.s3, b0.s2, acc22);
4055 acc23 = fma(a2.s3, b0.s3, acc23);
4056
4057#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4058#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4059
4060 acc30 = fma(a3.s3, b0.s0, acc30);
4061 acc31 = fma(a3.s3, b0.s1, acc31);
4062 acc32 = fma(a3.s3, b0.s2, acc32);
4063 acc33 = fma(a3.s3, b0.s3, acc33);
4064#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4065
4066 src_addr.s0 += 4 * sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004067 }
4068
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004069 for(; i < (int)COLS_A; ++i)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004070 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004071#if defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004072 // Load values from matrix A
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004073 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
4074#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4075 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
4076#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4077#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4078 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
4079#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4080#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4081 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
4082#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4083#else // defined(REINTERPRET_INPUT_AS_3D)
4084 // Load values from matrix A
4085 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004086#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4087 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
4088#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4089#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4090 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
4091#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4092#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4093 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
4094#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004095#endif // defined(REINTERPRET_INPUT_AS_3D)
4096
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004097 // Load values from matrix B
4098 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004099 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004100
4101 // Multiply and accumulate
4102 acc00 = fma(a0, b0.s0, acc00);
4103 acc01 = fma(a0, b0.s1, acc01);
4104 acc02 = fma(a0, b0.s2, acc02);
4105 acc03 = fma(a0, b0.s3, acc03);
4106#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4107 acc10 = fma(a1, b0.s0, acc10);
4108 acc11 = fma(a1, b0.s1, acc11);
4109 acc12 = fma(a1, b0.s2, acc12);
4110 acc13 = fma(a1, b0.s3, acc13);
4111#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4112#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4113 acc20 = fma(a2, b0.s0, acc20);
4114 acc21 = fma(a2, b0.s1, acc21);
4115 acc22 = fma(a2, b0.s2, acc22);
4116 acc23 = fma(a2, b0.s3, acc23);
4117#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4118#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4119 acc30 = fma(a3, b0.s0, acc30);
4120 acc31 = fma(a3, b0.s1, acc31);
4121 acc32 = fma(a3, b0.s2, acc32);
4122 acc33 = fma(a3, b0.s3, acc33);
4123#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004124
4125 src_addr.s0 += sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004126 }
4127
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004128 int z = get_global_id(2);
4129
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004130 // Compute destination address
4131 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
4132
4133 // Multiply by the weight of matrix-matrix product and store the result
4134#if defined(ALPHA)
4135 acc00 = acc00 * ALPHA;
4136 acc01 = acc01 * ALPHA;
4137 acc02 = acc02 * ALPHA;
4138 acc03 = acc03 * ALPHA;
4139#endif // defined(ALPHA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004140#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004141 acc10 = acc10 * ALPHA;
4142 acc11 = acc11 * ALPHA;
4143 acc12 = acc12 * ALPHA;
4144 acc13 = acc13 * ALPHA;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004145#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
4146#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004147 acc20 = acc20 * ALPHA;
4148 acc21 = acc21 * ALPHA;
4149 acc22 = acc22 * ALPHA;
4150 acc23 = acc23 * ALPHA;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004151#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
4152#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004153 acc30 = acc30 * ALPHA;
4154 acc31 = acc31 * ALPHA;
4155 acc32 = acc32 * ALPHA;
4156 acc33 = acc33 * ALPHA;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004157#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
4158
4159 // Compute dst address
4160 __global uchar *dst_addr = offset(&dst, 0, 0);
4161
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004162#if defined(ADD_VEC_C)
4163 __global float *src2_addr = (__global float *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
4164 float4 c0 = vload4(0, src2_addr);
4165
4166 acc00 += c0.s0;
4167 acc01 += c0.s1;
4168 acc02 += c0.s2;
4169 acc03 += c0.s3;
4170#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4171 acc10 += c0.s0;
4172 acc11 += c0.s1;
4173 acc12 += c0.s2;
4174 acc13 += c0.s3;
4175#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4176#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4177 acc20 += c0.s0;
4178 acc21 += c0.s1;
4179 acc22 += c0.s2;
4180 acc23 += c0.s3;
4181#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4182#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4183 acc30 += c0.s0;
4184 acc31 += c0.s1;
4185 acc32 += c0.s2;
4186 acc33 += c0.s3;
4187#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4188#endif /* defined(ADD_VEC_C) */
4189
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004190#if defined(REINTERPRET_OUTPUT_AS_3D)
4191 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004192 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004193 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004194 // | |
4195 // | plane0 |
4196 // | |
4197 // |__________________|
4198 // |******************|
4199 // | cross_plane_pad |
4200 // |******************|
4201 // | |
4202 // | plane1 |
4203 // | |
4204 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004205
4206 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
4207 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
4208 zout = min(DEPTH_GEMM3D - 1, zout);
4209
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004210 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004211 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004212
4213 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4214 // multiply dst_stride_z by DEPTH_GEMM3D
4215 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
4216
4217 // Store the output block
4218 vstore4((float4)(acc00, acc01, acc02, acc03), 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
4219#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4220 vstore4((float4)(acc10, acc11, acc12, acc13), 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
4221#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4222#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4223 vstore4((float4)(acc20, acc21, acc22, acc23), 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
4224#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4225#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4226 vstore4((float4)(acc30, acc31, acc32, acc33), 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004227#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004228
4229#else // defined(REINTERPRET_OUTPUT_AS_3D)
4230 // Add offset for batched GEMM
4231 dst_addr += z * dst_stride_z;
4232
4233 // Store the output block
4234 vstore4((float4)(acc00, acc01, acc02, acc03), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
4235#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4236 vstore4((float4)(acc10, acc11, acc12, acc13), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
4237#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4238#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4239 vstore4((float4)(acc20, acc21, acc22, acc23), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
4240#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4241#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4242 vstore4((float4)(acc30, acc31, acc32, acc33), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
4243#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4244#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004245}
4246
4247/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
4248 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004249 * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
4250 *
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004251 * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
4252 * This OpenCL kernel is optimized for Bifrost when the number of matrix B columns is less or equal to 1000.
4253 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
4254 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=2.
4255 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
4256 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha if alpha!=1.0f.
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00004257 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
4258 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004259 *
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004260 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
4261 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004262 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
4263 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
4264 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
4265 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
4266 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004267 * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
4268 *
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004269 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
4270 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
4271 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4272 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
4273 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4274 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
4275 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
4276 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
4277 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4278 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
4279 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4280 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004281 * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
4282 * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes)
4283 * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
4284 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004285 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
4286 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
4287 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
4288 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
4289 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
4290 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004291 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
4292 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
4293 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004294 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
4295 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004296 */
4297__kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
4298 IMAGE_DECLARATION(src1),
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004299#if defined(ADD_VEC_C)
4300 VECTOR_DECLARATION(src2),
4301#endif /* defined(ADD_VEC_C) */
Gian Marcoae2af742018-02-15 12:35:44 +00004302 IMAGE_DECLARATION(dst),
4303 uint src0_stride_z,
4304 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004305 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004306#if defined(REINTERPRET_INPUT_AS_3D)
4307 ,
4308 uint src_cross_plane_pad
4309#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004310#if defined(REINTERPRET_OUTPUT_AS_3D)
4311 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004312 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004313#endif // REINTERPRET_OUTPUT_AS_3D
4314 )
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004315{
4316 // Requires 2 NUM_ELEMS_PROCESSED_PER_THREAD_X, C vect2, A vect4, B (2 vload2) // to fix for NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4317 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
4318
4319 // Compute starting address for matrix A and Matrix B
4320 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
4321
4322 // Update address for the matrix A
4323 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
4324
4325 // Update address for the matrix B
4326 src_addr.s1 += idx * sizeof(float);
4327
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004328#if defined(REINTERPRET_INPUT_AS_3D)
4329 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
4330 // in order to take into account the presence of possible cross plane paddings
4331 //
4332 // | |
4333 // | plane0 |
4334 // | |
4335 // |__________________|
4336 // |******************|
4337 // | cross_plane_pad |
4338 // |******************|
4339 // | |
4340 // | plane1 |
4341 // | |
4342 // |__________________|
4343
4344 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
4345 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
4346 zin = min(DEPTH_GEMM3D - 1, zin);
4347
4348 // Add offset due to the cross plane paddings
4349 zin *= (src_cross_plane_pad * src0_stride_y);
4350
4351 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4352 // multiply src0_stride_z by DEPTH_GEMM3D
4353 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
4354
4355#else // defined(REINTERPRET_INPUT_AS_3D)
4356
Gian Marcoae2af742018-02-15 12:35:44 +00004357 // Add offset for batched GEMM
4358 src_addr.s0 += get_global_id(2) * src0_stride_z;
4359
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004360#endif // defined(REINTERPRET_INPUT_AS_3D)
4361
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00004362#if defined(MATRIX_B_DEPTH)
4363 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
4364 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
4365#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00004366 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00004367#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00004368
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004369 // Initialize accumulators
4370 float acc00 = 0.0f;
4371 float acc01 = 0.0f;
4372
4373#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4374 float acc10 = 0.0f;
4375 float acc11 = 0.0f;
4376#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4377#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4378 float acc20 = 0.0f;
4379 float acc21 = 0.0f;
4380#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4381#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4382 float acc30 = 0.0f;
4383 float acc31 = 0.0f;
4384#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4385
4386 // A and B src indices get incremented at the same time.
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004387 int i = 0;
4388 for(; i <= ((int)COLS_A - 8); i += 8)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004389 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004390#if defined(REINTERPRET_INPUT_AS_3D)
4391 // Load values from matrix A
4392 float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + zin.s0));
4393#else // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004394 // Load values from matrix A
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004395 float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0));
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004396#endif // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004397
4398 // Load values from matrix B
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004399 float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
4400 src_addr.s1 += src1_stride_y;
4401 float2 b1 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
4402 src_addr.s1 += src1_stride_y;
4403 float2 b2 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
4404 src_addr.s1 += src1_stride_y;
4405 float2 b3 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
4406 src_addr.s1 += src1_stride_y;
4407 float2 b4 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
4408 src_addr.s1 += src1_stride_y;
4409 float2 b5 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
4410 src_addr.s1 += src1_stride_y;
4411 float2 b6 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
4412 src_addr.s1 += src1_stride_y;
4413 float2 b7 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
4414 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004415
4416 // Multiply and accumulate
4417 acc00 = fma(a0.s0, b0.s0, acc00);
4418 acc00 = fma(a0.s1, b1.s0, acc00);
4419 acc00 = fma(a0.s2, b2.s0, acc00);
4420 acc00 = fma(a0.s3, b3.s0, acc00);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004421 acc00 = fma(a0.s4, b4.s0, acc00);
4422 acc00 = fma(a0.s5, b5.s0, acc00);
4423 acc00 = fma(a0.s6, b6.s0, acc00);
4424 acc00 = fma(a0.s7, b7.s0, acc00);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004425
4426 acc01 = fma(a0.s0, b0.s1, acc01);
4427 acc01 = fma(a0.s1, b1.s1, acc01);
4428 acc01 = fma(a0.s2, b2.s1, acc01);
4429 acc01 = fma(a0.s3, b3.s1, acc01);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004430 acc01 = fma(a0.s4, b4.s1, acc01);
4431 acc01 = fma(a0.s5, b5.s1, acc01);
4432 acc01 = fma(a0.s6, b6.s1, acc01);
4433 acc01 = fma(a0.s7, b7.s1, acc01);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004434
4435#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004436#if defined(REINTERPRET_INPUT_AS_3D)
4437 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
4438#else // defined(REINTERPRET_INPUT_AS_3D)
4439 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
4440#endif // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004441 acc10 = fma(a0.s0, b0.s0, acc10);
4442 acc10 = fma(a0.s1, b1.s0, acc10);
4443 acc10 = fma(a0.s2, b2.s0, acc10);
4444 acc10 = fma(a0.s3, b3.s0, acc10);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004445 acc10 = fma(a0.s4, b4.s0, acc10);
4446 acc10 = fma(a0.s5, b5.s0, acc10);
4447 acc10 = fma(a0.s6, b6.s0, acc10);
4448 acc10 = fma(a0.s7, b7.s0, acc10);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004449
4450 acc11 = fma(a0.s0, b0.s1, acc11);
4451 acc11 = fma(a0.s1, b1.s1, acc11);
4452 acc11 = fma(a0.s2, b2.s1, acc11);
4453 acc11 = fma(a0.s3, b3.s1, acc11);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004454 acc11 = fma(a0.s4, b4.s1, acc11);
4455 acc11 = fma(a0.s5, b5.s1, acc11);
4456 acc11 = fma(a0.s6, b6.s1, acc11);
4457 acc11 = fma(a0.s7, b7.s1, acc11);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004458#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4459#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004460#if defined(REINTERPRET_INPUT_AS_3D)
4461 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
4462#else // defined(REINTERPRET_INPUT_AS_3D)
4463 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
4464#endif // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004465 acc20 = fma(a0.s0, b0.s0, acc20);
4466 acc20 = fma(a0.s1, b1.s0, acc20);
4467 acc20 = fma(a0.s2, b2.s0, acc20);
4468 acc20 = fma(a0.s3, b3.s0, acc20);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004469 acc20 = fma(a0.s4, b4.s0, acc20);
4470 acc20 = fma(a0.s5, b5.s0, acc20);
4471 acc20 = fma(a0.s6, b6.s0, acc20);
4472 acc20 = fma(a0.s7, b7.s0, acc20);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004473
4474 acc21 = fma(a0.s0, b0.s1, acc21);
4475 acc21 = fma(a0.s1, b1.s1, acc21);
4476 acc21 = fma(a0.s2, b2.s1, acc21);
4477 acc21 = fma(a0.s3, b3.s1, acc21);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004478 acc21 = fma(a0.s4, b4.s1, acc21);
4479 acc21 = fma(a0.s5, b5.s1, acc21);
4480 acc21 = fma(a0.s6, b6.s1, acc21);
4481 acc21 = fma(a0.s7, b7.s1, acc21);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004482#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4483#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004484#if defined(REINTERPRET_INPUT_AS_3D)
4485 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
4486#else // defined(REINTERPRET_INPUT_AS_3D)
4487 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
4488#endif // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004489 acc30 = fma(a0.s0, b0.s0, acc30);
4490 acc30 = fma(a0.s1, b1.s0, acc30);
4491 acc30 = fma(a0.s2, b2.s0, acc30);
4492 acc30 = fma(a0.s3, b3.s0, acc30);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004493 acc30 = fma(a0.s4, b4.s0, acc30);
4494 acc30 = fma(a0.s5, b5.s0, acc30);
4495 acc30 = fma(a0.s6, b6.s0, acc30);
4496 acc30 = fma(a0.s7, b7.s0, acc30);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004497
4498 acc31 = fma(a0.s0, b0.s1, acc31);
4499 acc31 = fma(a0.s1, b1.s1, acc31);
4500 acc31 = fma(a0.s2, b2.s1, acc31);
4501 acc31 = fma(a0.s3, b3.s1, acc31);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004502 acc31 = fma(a0.s4, b4.s1, acc31);
4503 acc31 = fma(a0.s5, b5.s1, acc31);
4504 acc31 = fma(a0.s6, b6.s1, acc31);
4505 acc31 = fma(a0.s7, b7.s1, acc31);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004506#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004507
4508 src_addr.s0 += sizeof(float) * 8;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004509 }
4510 // float size increment
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004511 for(; i < (int)COLS_A; ++i)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004512 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004513#if defined(REINTERPRET_INPUT_AS_3D)
4514 // Load values from matrix A
4515 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
4516#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4517 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
4518#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4519#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4520 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
4521#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4522#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4523 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
4524#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4525#else // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004526 // Load values from matrix A
4527 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
4528#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4529 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
4530#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4531#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4532 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
4533#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4534#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4535 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
4536#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004537#endif // defined(REINTERPRET_INPUT_AS_3D)
4538
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004539 // Load values from matrix B
4540 float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004541 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004542
4543 // Multiply and accumulate
4544 acc00 = fma(a0, b0.s0, acc00);
4545 acc01 = fma(a0, b0.s1, acc01);
4546#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4547 acc10 = fma(a1, b0.s0, acc10);
4548 acc11 = fma(a1, b0.s1, acc11);
4549#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4550#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4551 acc20 = fma(a2, b0.s0, acc20);
4552 acc21 = fma(a2, b0.s1, acc21);
4553#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4554#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4555 acc30 = fma(a3, b0.s0, acc30);
4556 acc31 = fma(a3, b0.s1, acc31);
4557#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004558
4559 src_addr.s0 += sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004560 }
4561
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004562 // Multiply by the weight of matrix-matrix product and store the result
4563#if defined(ALPHA)
4564 acc00 = acc00 * ALPHA;
4565 acc01 = acc01 * ALPHA;
4566#endif // defined(ALPHA)
4567#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
4568 acc10 = acc10 * ALPHA;
4569 acc11 = acc11 * ALPHA;
4570#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
4571#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
4572 acc20 = acc20 * ALPHA;
4573 acc21 = acc21 * ALPHA;
4574#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
4575#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
4576 acc30 = acc30 * ALPHA;
4577 acc31 = acc31 * ALPHA;
4578#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
4579
4580 int z = get_global_id(2);
4581
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004582 // Compute destination address
4583 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
4584
Gian Marcoae2af742018-02-15 12:35:44 +00004585 // Compute dst address
4586 __global uchar *dst_addr = offset(&dst, 0, 0);
4587
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004588#if defined(ADD_VEC_C)
4589 __global float *src2_addr = (__global float *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
4590 float2 c0 = vload2(0, src2_addr);
4591
4592 acc00 += c0.s0;
4593 acc01 += c0.s1;
4594#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4595 acc10 += c0.s0;
4596 acc11 += c0.s1;
4597#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4598#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4599 acc20 += c0.s0;
4600 acc21 += c0.s1;
4601#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4602#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4603 acc30 += c0.s0;
4604 acc31 += c0.s1;
4605#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4606#endif /* defined(ADD_VEC_C) */
4607
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004608#if defined(REINTERPRET_OUTPUT_AS_3D)
4609 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004610 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004611 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004612 // | |
4613 // | plane0 |
4614 // | |
4615 // |__________________|
4616 // |******************|
4617 // | cross_plane_pad |
4618 // |******************|
4619 // | |
4620 // | plane1 |
4621 // | |
4622 // |__________________|
Gian Marcoae2af742018-02-15 12:35:44 +00004623
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004624 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
4625 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
4626 zout = min(DEPTH_GEMM3D - 1, zout);
4627
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004628 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004629 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004630
4631 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4632 // multiply dst_stride_z by DEPTH_GEMM3D
4633 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
4634
4635 // Store the output block
4636 vstore2((float2)(acc00, acc01), 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004637#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004638 vstore2((float2)(acc10, acc11), 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004639#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4640#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004641 vstore2((float2)(acc20, acc21), 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004642#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4643#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004644 vstore2((float2)(acc30, acc31), 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004645#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004646
4647#else // defined(REINTERPRET_OUTPUT_AS_3D)
4648 // Add offset for batched GEMM
4649 dst_addr += z * dst_stride_z;
4650
4651 // Store the output block
4652 vstore2((float2)(acc00, acc01), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
4653#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4654 vstore2((float2)(acc10, acc11), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
4655#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4656#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4657 vstore2((float2)(acc20, acc21), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
4658#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4659#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4660 vstore2((float2)(acc30, acc31), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
4661#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4662#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004663}
4664
Vidhya Sudhan Loganathanbdff4912018-05-22 15:03:09 +01004665#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01004666/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
4667 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004668 * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
4669 *
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00004670 * @note This OpenCL kernel works with the 16-bit floating point data type (half) and accumulating the result in a 32 floating point variable.
4671 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
4672 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
4673 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
4674 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
4675 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
4676 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
4677 *
4678 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
4679 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
4680 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
4681 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
4682 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
4683 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
4684 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004685 * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
4686 *
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00004687 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
4688 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
4689 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4690 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
4691 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4692 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
4693 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
4694 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
4695 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4696 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
4697 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4698 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004699 * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
4700 * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes)
4701 * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
4702 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00004703 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
4704 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
4705 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
4706 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
4707 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
4708 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
4709 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
4710 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
4711 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
4712 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
4713 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
4714 */
4715__kernel void gemm_mm_floating_point_f16_bifrost_acc32(IMAGE_DECLARATION(src0),
4716 IMAGE_DECLARATION(src1),
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004717#if defined(ADD_VEC_C)
4718 VECTOR_DECLARATION(src2),
4719#endif /* defined(ADD_VEC_C) */
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00004720 IMAGE_DECLARATION(dst),
4721 uint src0_stride_z,
4722 uint src1_stride_z,
4723 uint dst_stride_z
4724#if defined(REINTERPRET_INPUT_AS_3D)
4725 ,
4726 uint src_cross_plane_pad
4727#endif // REINTERPRET_INPUT_AS_3D
4728#if defined(REINTERPRET_OUTPUT_AS_3D)
4729 ,
4730 uint dst_cross_plane_pad
4731#endif // REINTERPRET_OUTPUT_AS_3D
4732 )
4733{
4734 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
4735
4736 // Compute starting address for matrix A and Matrix B
4737 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
4738
4739 // Update address for the matrix A
4740 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
4741
4742 // Update address for the matrix B
4743 src_addr.s1 += idx * sizeof(half);
4744
4745#if defined(REINTERPRET_INPUT_AS_3D)
4746 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
4747 // in order to take into account the presence of possible cross plane paddings
4748 //
4749 // | |
4750 // | plane0 |
4751 // | |
4752 // |__________________|
4753 // |******************|
4754 // | cross_plane_pad |
4755 // |******************|
4756 // | |
4757 // | plane1 |
4758 // | |
4759 // |__________________|
4760
4761 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
4762 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
4763 zin = min(DEPTH_GEMM3D - 1, zin);
4764
4765 // Add offset due to the cross plane paddings
4766 zin *= (src_cross_plane_pad * src0_stride_y);
4767
4768 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4769 // multiply src0_stride_z by DEPTH_GEMM3D
4770 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
4771
4772#else // defined(REINTERPRET_INPUT_AS_3D)
4773
4774 // Add offset for batched GEMM
4775 src_addr.s0 += get_global_id(2) * src0_stride_z;
4776
4777#endif // defined(REINTERPRET_INPUT_AS_3D)
4778
4779#if defined(MATRIX_B_DEPTH)
4780 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
4781 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
4782#else // defined(MATRIX_B_DEPTH)
4783 src_addr.s1 += get_global_id(2) * src1_stride_z;
4784#endif // defined(MATRIX_B_DEPTH)
4785
4786 float8 acc0 = 0.0h;
4787#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4788 float8 acc1 = 0.0h;
4789#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4790#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4791 float8 acc2 = 0.0h;
4792#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4793#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4794 float8 acc3 = 0.0h;
4795#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4796
4797 int i = 0;
4798 for(; i <= ((int)COLS_A - 4); i += 4)
4799 {
4800#if defined(REINTERPRET_INPUT_AS_3D)
4801 // Load values from matrix A
Usama Arif0681e3b2019-04-25 14:28:07 +01004802 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 4, half, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s);
4803#else // defined(REINTERPRET_INPUT_AS_3D)
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00004804 // Load values from matrix A
4805 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
4806#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4807 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
4808#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4809#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4810 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
4811#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4812#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4813 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
4814#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4815#endif // defined(REINTERPRET_INPUT_AS_3D)
4816
4817 // Load values from matrix B
4818 float8 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
4819 src_addr.s1 += src1_stride_y;
4820
4821 // Accumulate
4822 acc0 = fma(b0, (float8)a0.s0, acc0);
4823#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4824 acc1 = fma(b0, (float8)a1.s0, acc1);
4825#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4826#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4827 acc2 = fma(b0, (float8)a2.s0, acc2);
4828#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4829#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4830 acc3 = fma(b0, (float8)a3.s0, acc3);
4831#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4832
4833 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
4834 src_addr.s1 += src1_stride_y;
4835 acc0 = fma(b0, (float8)a0.s1, acc0);
4836#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4837 acc1 = fma(b0, (float8)a1.s1, acc1);
4838#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4839#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4840 acc2 = fma(b0, (float8)a2.s1, acc2);
4841#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4842#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4843 acc3 = fma(b0, (float8)a3.s1, acc3);
4844#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4845
4846 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
4847 src_addr.s1 += src1_stride_y;
4848 acc0 = fma(b0, (float8)a0.s2, acc0);
4849#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4850 acc1 = fma(b0, (float8)a1.s2, acc1);
4851#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4852#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4853 acc2 = fma(b0, (float8)a2.s2, acc2);
4854#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4855#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4856 acc3 = fma(b0, (float8)a3.s2, acc3);
4857#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4858
4859 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
4860 src_addr.s1 += src1_stride_y;
4861 acc0 = fma(b0, (float8)a0.s3, acc0);
4862#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4863 acc1 = fma(b0, (float8)a1.s3, acc1);
4864#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4865#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4866 acc2 = fma(b0, (float8)a2.s3, acc2);
4867#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4868#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4869 acc3 = fma(b0, (float8)a3.s3, acc3);
4870#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4871
4872 src_addr.s0 += 4 * sizeof(half);
4873 }
4874
4875 for(; i < (int)COLS_A; ++i)
4876 {
4877#if defined(REINTERPRET_INPUT_AS_3D)
4878 // Load values from matrix A
4879 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
4880#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4881 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
4882#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4883#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4884 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
4885#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4886#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4887 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
4888#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4889#else // defined(REINTERPRET_INPUT_AS_3D)
4890 // Load values from matrix A
4891 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
4892#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4893 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
4894#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4895#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4896 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
4897#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4898#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4899 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
4900#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4901#endif // defined(REINTERPRET_INPUT_AS_3D)
4902
4903 // Load values from matrix B
4904 float8 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
4905
4906 src_addr += (int2)(sizeof(half), src1_stride_y);
4907
4908 // Accumulate
4909 acc0 = fma(b0, (float8)a0, acc0); // b0 * (half8)a0;
4910#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4911 acc1 = fma(b0, (float8)a1, acc1); // b0 * (half8)a1;
4912#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4913#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4914 acc2 = fma(b0, (float8)a2, acc2); // b0 * (half8)a2;
4915#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4916#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4917 acc3 = fma(b0, (float8)a3, acc3); // b0 * (half8)a3;
4918#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4919 }
4920
4921 // Multiply by the weight of matrix-matrix product and store the result
4922#if defined(ALPHA)
4923 half8 hacc0 = convert_half8(acc0) * (half8)ALPHA;
4924#else //defined(ALPHA)
4925 half8 hacc0 = convert_half8(acc0);
4926#endif // defined(ALPHA)
4927#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4928#if defined(ALPHA)
4929 half8 hacc1 = convert_half8(acc1) * (half8)ALPHA;
4930#else //defined(ALPHA)
4931 half8 hacc1 = convert_half8(acc1);
4932#endif //defined(ALPHA)
4933#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y
4934
4935#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4936#if defined(ALPHA)
4937 half8 hacc2 = convert_half8(acc2) * (half8)ALPHA;
4938#else //defined(ALPHA)
4939 half8 hacc2 = convert_half8(acc2);
4940#endif //defined(ALPHA)
4941#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4942
4943#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4944#if defined(ALPHA)
4945 half8 hacc3 = convert_half8(acc3) * (half8)ALPHA;
4946#else //defined(ALPHA)
4947 half8 hacc3 = convert_half8(acc3);
4948#endif // defined(ALPHA)
4949#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4950
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004951#if defined(ADD_VEC_C)
4952 // *INDENT-OFF*
4953 // clang-format off
4954 __global half *src2_addr = (__global half *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
4955 half8 c0 = vload8(0, src2_addr);
4956 // clang-format on
4957 // *INDENT-ON*
4958
4959 hacc0 += c0;
4960#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4961 hacc1 += c0;
4962#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4963#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4964 hacc2 += c0;
4965#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4966#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4967 hacc3 += c0;
4968#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4969#endif /* defined(ADD_VEC_C) */
4970
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00004971 int z = get_global_id(2);
4972
4973 // Compute destination address
4974 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
4975
4976 // Compute dst address
4977 __global uchar *dst_addr = offset(&dst, 0, 0);
4978
4979#if defined(REINTERPRET_OUTPUT_AS_3D)
4980 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
4981 // in order to take into account the presence of possible cross plane paddings
4982 //
4983 // | |
4984 // | plane0 |
4985 // | |
4986 // |__________________|
4987 // |******************|
4988 // | cross_plane_pad |
4989 // |******************|
4990 // | |
4991 // | plane1 |
4992 // | |
4993 // |__________________|
4994
4995 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
4996 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
4997 zout = min(DEPTH_GEMM3D - 1, zout);
4998
4999 // Add offset due to the cross plane paddings
5000 zout *= (dst_cross_plane_pad * dst_stride_y);
5001
5002 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
5003 // multiply dst_stride_z by DEPTH_GEMM3D
5004 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005005 // Store the output block
Usama Arif0681e3b2019-04-25 14:28:07 +01005006 STORE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 8, half, hacc, dst_addr, dst_stride_y, zout.s);
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005007#else // defined(REINTERPRET_OUTPUT_AS_3D)
5008 // Add offset for batched GEMM
5009 dst_addr += z * dst_stride_z;
5010
5011 // Store the output block
5012 vstore8(hacc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
5013#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5014 vstore8(hacc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
5015#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5016#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5017 vstore8(hacc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
5018#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5019#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5020 vstore8(hacc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
5021#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5022#endif // REINTERPRET_OUTPUT_AS_3D
5023}
5024
5025/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
5026 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00005027 * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
5028 *
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005029 * @note This OpenCL kernel works with the 16-bit floating point data type (half) and uses the fma units.
5030 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
5031 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
5032 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
5033 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
5034 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
5035 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
5036 *
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005037 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
5038 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005039 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
5040 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
5041 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
5042 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
5043 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00005044 * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
5045 *
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005046 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
5047 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
5048 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5049 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
5050 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
5051 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
5052 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
5053 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
5054 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5055 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
5056 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
5057 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00005058 * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
5059 * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes)
5060 * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
5061 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005062 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
5063 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
5064 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
5065 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
5066 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
5067 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005068 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
5069 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
5070 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005071 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
5072 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005073 */
5074__kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0),
5075 IMAGE_DECLARATION(src1),
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00005076#if defined(ADD_VEC_C)
5077 VECTOR_DECLARATION(src2),
5078#endif /* defined(ADD_VEC_C) */
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005079 IMAGE_DECLARATION(dst),
5080 uint src0_stride_z,
5081 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005082 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005083#if defined(REINTERPRET_INPUT_AS_3D)
5084 ,
5085 uint src_cross_plane_pad
5086#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005087#if defined(REINTERPRET_OUTPUT_AS_3D)
5088 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005089 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005090#endif // REINTERPRET_OUTPUT_AS_3D
5091 )
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005092{
5093 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
5094
5095 // Compute starting address for matrix A and Matrix B
5096 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
5097
5098 // Update address for the matrix A
5099 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
5100
5101 // Update address for the matrix B
5102 src_addr.s1 += idx * sizeof(half);
5103
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005104#if defined(REINTERPRET_INPUT_AS_3D)
5105 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
5106 // in order to take into account the presence of possible cross plane paddings
5107 //
5108 // | |
5109 // | plane0 |
5110 // | |
5111 // |__________________|
5112 // |******************|
5113 // | cross_plane_pad |
5114 // |******************|
5115 // | |
5116 // | plane1 |
5117 // | |
5118 // |__________________|
5119
5120 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
5121 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
5122 zin = min(DEPTH_GEMM3D - 1, zin);
5123
5124 // Add offset due to the cross plane paddings
5125 zin *= (src_cross_plane_pad * src0_stride_y);
5126
5127 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
5128 // multiply src0_stride_z by DEPTH_GEMM3D
5129 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
5130
5131#else // defined(REINTERPRET_INPUT_AS_3D)
5132
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005133 // Add offset for batched GEMM
5134 src_addr.s0 += get_global_id(2) * src0_stride_z;
5135
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005136#endif // defined(REINTERPRET_INPUT_AS_3D)
5137
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005138#if defined(MATRIX_B_DEPTH)
5139 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
5140 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
5141#else // defined(MATRIX_B_DEPTH)
5142 src_addr.s1 += get_global_id(2) * src1_stride_z;
5143#endif // defined(MATRIX_B_DEPTH)
5144
5145 half8 acc0 = 0.0h;
5146#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5147 half8 acc1 = 0.0h;
5148#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5149#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5150 half8 acc2 = 0.0h;
5151#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5152#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5153 half8 acc3 = 0.0h;
5154#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5155
5156 int i = 0;
5157 for(; i <= ((int)COLS_A - 4); i += 4)
5158 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005159#if defined(REINTERPRET_INPUT_AS_3D)
5160 // Load values from matrix A
Usama Arif0681e3b2019-04-25 14:28:07 +01005161 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 4, half, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s);
5162#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005163 // Load values from matrix A
5164 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
5165#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5166 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
5167#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5168#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5169 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
5170#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5171#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5172 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
5173#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005174#endif // defined(REINTERPRET_INPUT_AS_3D)
5175
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005176 // Load values from matrix B
5177 half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
5178 src_addr.s1 += src1_stride_y;
5179
5180 // Accumulate
5181 acc0 = fma(b0, (half8)a0.s0, acc0);
5182#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5183 acc1 = fma(b0, (half8)a1.s0, acc1);
5184#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5185#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5186 acc2 = fma(b0, (half8)a2.s0, acc2);
5187#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5188#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5189 acc3 = fma(b0, (half8)a3.s0, acc3);
5190#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5191
5192 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
5193 src_addr.s1 += src1_stride_y;
5194 acc0 = fma(b0, (half8)a0.s1, acc0);
5195#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5196 acc1 = fma(b0, (half8)a1.s1, acc1);
5197#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5198#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5199 acc2 = fma(b0, (half8)a2.s1, acc2);
5200#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5201#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5202 acc3 = fma(b0, (half8)a3.s1, acc3);
5203#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5204
5205 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
5206 src_addr.s1 += src1_stride_y;
5207 acc0 = fma(b0, (half8)a0.s2, acc0);
5208#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5209 acc1 = fma(b0, (half8)a1.s2, acc1);
5210#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5211#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5212 acc2 = fma(b0, (half8)a2.s2, acc2);
5213#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5214#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5215 acc3 = fma(b0, (half8)a3.s2, acc3);
5216#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5217
5218 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
5219 src_addr.s1 += src1_stride_y;
5220 acc0 = fma(b0, (half8)a0.s3, acc0);
5221#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5222 acc1 = fma(b0, (half8)a1.s3, acc1);
5223#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5224#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5225 acc2 = fma(b0, (half8)a2.s3, acc2);
5226#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5227#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5228 acc3 = fma(b0, (half8)a3.s3, acc3);
5229#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5230
5231 src_addr.s0 += 4 * sizeof(half);
5232 }
5233
5234 for(; i < (int)COLS_A; ++i)
5235 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005236#if defined(REINTERPRET_INPUT_AS_3D)
5237 // Load values from matrix A
5238 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
5239#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5240 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
5241#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5242#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5243 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
5244#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5245#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5246 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
5247#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5248#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005249 // Load values from matrix A
5250 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
5251#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5252 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
5253#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5254#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5255 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
5256#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5257#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5258 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
5259#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005260#endif // defined(REINTERPRET_INPUT_AS_3D)
5261
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005262 // Load values from matrix B
5263 half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
5264
5265 src_addr += (int2)(sizeof(half), src1_stride_y);
5266
5267 // Accumulate
5268 acc0 = fma(b0, (half8)a0, acc0); // b0 * (half8)a0;
5269#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5270 acc1 = fma(b0, (half8)a1, acc1); // b0 * (half8)a1;
5271#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5272#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5273 acc2 = fma(b0, (half8)a2, acc2); // b0 * (half8)a2;
5274#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5275#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5276 acc3 = fma(b0, (half8)a3, acc3); // b0 * (half8)a3;
5277#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5278 }
5279
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005280 // Multiply by the weight of matrix-matrix product and store the result
5281#if defined(ALPHA)
5282 acc0 = acc0 * (half8)ALPHA;
5283#endif // defined(ALPHA)
5284#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
5285 acc1 = acc1 * (half8)ALPHA;
5286#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
5287#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
5288 acc2 = acc2 * (half8)ALPHA;
5289#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
5290#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
5291 acc3 = acc3 * (half8)ALPHA;
5292#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
5293
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00005294#if defined(ADD_VEC_C)
5295 // *INDENT-OFF*
5296 // clang-format off
5297 __global half *src2_addr = (__global half *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
5298 half8 c0 = vload8(0, src2_addr);
5299 // clang-format on
5300 // *INDENT-ON*
5301
5302 acc0 += c0;
5303#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5304 acc1 += c0;
5305#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5306#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5307 acc2 += c0;
5308#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5309#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5310 acc3 += c0;
5311#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5312#endif /* defined(ADD_VEC_C) */
5313
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005314 int z = get_global_id(2);
5315
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005316 // Compute destination address
5317 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
5318
5319 // Compute dst address
5320 __global uchar *dst_addr = offset(&dst, 0, 0);
5321
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005322#if defined(REINTERPRET_OUTPUT_AS_3D)
5323 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01005324 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005325 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01005326 // | |
5327 // | plane0 |
5328 // | |
5329 // |__________________|
5330 // |******************|
5331 // | cross_plane_pad |
5332 // |******************|
5333 // | |
5334 // | plane1 |
5335 // | |
5336 // |__________________|
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005337
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005338 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
5339 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
5340 zout = min(DEPTH_GEMM3D - 1, zout);
5341
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01005342 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005343 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005344
5345 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
5346 // multiply dst_stride_z by DEPTH_GEMM3D
5347 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
5348
5349 // Store the output block
Usama Arif0681e3b2019-04-25 14:28:07 +01005350 STORE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 8, half, acc, dst_addr, dst_stride_y, zout.s);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005351#else // defined(REINTERPRET_OUTPUT_AS_3D)
5352 // Add offset for batched GEMM
5353 dst_addr += z * dst_stride_z;
5354
5355 // Store the output block
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005356 vstore8(acc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
5357#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005358 vstore8(acc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
5359#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5360#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005361 vstore8(acc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
5362#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5363#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005364 vstore8(acc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
5365#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005366#endif // REINTERPRET_OUTPUT_AS_3D
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005367}
Vidhya Sudhan Loganathanbdff4912018-05-22 15:03:09 +01005368#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005369
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01005370#endif // defined(COLS_A) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && (NUM_ELEMS_PROCESSED_PER_THREAD_Y)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005371
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005372#if defined(BETA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005373/** This OpenCL kernel performs the in-place matrix addition between 2 matrices taking into account that the second matrix might be weighted by a scalar value beta:
5374 *
Gian Marco19835e52018-01-30 13:35:54 +00005375 * @note The beta's value need to be passed at compile time using -DBETA
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005376 *
5377 * @param[in] src_ptr Pointer to the source matrix. Supported data types: F32
5378 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
5379 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5380 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
5381 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005382 * @param[in] src_stride_z Stride of the destination tensor in Z dimension (in bytes)
5383 * @param[in] src_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005384 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01005385 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005386 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
5387 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
5388 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
5389 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005390 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
5391 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005392 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
5393 */
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005394__kernel void gemm_ma_f32(TENSOR3D_DECLARATION(src),
5395 TENSOR3D_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005396{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005397 // Compute source and destination addresses
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005398 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
5399 Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005400
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005401 // Load values from A x B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005402 float4 alpha_ab = vload4(0, (__global float *)dst.ptr);
5403
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005404 // Load values from Matrix C
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005405 float4 c = vload4(0, (__global float *)src.ptr);
5406
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005407 // Computes alpha * axb + beta * c
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005408 float4 out = alpha_ab + (float4)BETA * c;
5409
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005410 // Store final result in axb matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005411 vstore4(out, 0, (__global float *)dst.ptr);
5412}
5413
Vidhya Sudhan Loganathan76c85642018-05-25 13:53:02 +01005414#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005415/** This OpenCL kernel performs the in-place matrix addition between 2 matrices taking into account that the second matrix might be weighted by a scalar value beta:
5416 *
Gian Marco19835e52018-01-30 13:35:54 +00005417 * @note The beta's value need to be passed at compile time using -DBETA
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01005418 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005419 * @param[in] src_ptr Pointer to the source matrix. Supported data types: F16
5420 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
5421 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5422 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
5423 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005424 * @param[in] src_stride_z Stride of the destination tensor in Z dimension (in bytes)
5425 * @param[in] src_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005426 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01005427 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005428 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
5429 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
5430 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
5431 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005432 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
5433 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005434 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
5435 */
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005436__kernel void gemm_ma_f16(TENSOR3D_DECLARATION(src),
5437 TENSOR3D_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005438{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005439 // Compute source and destination addresses
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005440 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
5441 Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005442
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005443 // Load values from A x B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005444 half8 alpha_ab = vload8(0, (__global half *)dst.ptr);
5445
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005446 // Load values from Matrix C
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005447 half8 c = vload8(0, (__global half *)src.ptr);
5448
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005449 // Computes alpha * axb + beta * c
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005450 half8 out = alpha_ab + (half8)BETA * c;
5451
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005452 // Store final result in axb matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005453 vstore8(out, 0, (__global half *)dst.ptr);
5454}
Vidhya Sudhan Loganathan76c85642018-05-25 13:53:02 +01005455#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005456#endif // defined(BETA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005457
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005458#if defined(WIDTH_VECTOR_A)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005459/** This OpenCL kernel computes the vector by matrix multiplication between each row of A (src0) and matrix B (src1) used for locally connected layer
5460 *
Gian Marco19835e52018-01-30 13:35:54 +00005461 * @note The width of A need to be passed at compile time using -DWIDTH_VECTOR_A
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005462 *
Gian Marco19835e52018-01-30 13:35:54 +00005463 * @note The input A and matrix B must not be reshaped
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005464 *
5465 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
5466 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
5467 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5468 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
5469 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
5470 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01005471 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005472 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
5473 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5474 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
5475 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
5476 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
5477 * @param[in] src1_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
5478 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01005479 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005480 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
5481 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
5482 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
5483 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
5484 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
5485 */
5486__kernel void gemm_lc_vm_f32(IMAGE_DECLARATION(src0),
5487 TENSOR3D_DECLARATION(src1),
5488 IMAGE_DECLARATION(dst))
5489{
5490 int idx = get_global_id(0) * 4;
5491 int idy = get_global_id(1);
5492
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005493 // Compute the address for the vector A and matrix B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005494 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes + src0_stride_y * idy, src1_offset_first_element_in_bytes + src1_stride_z * idy));
5495 src_addr.s1 += idx * sizeof(float);
5496
5497 int end_row_vec_a = src_addr.s0 + (WIDTH_VECTOR_A * sizeof(float));
5498
5499 float4 acc = 0.0f;
5500
Georgios Pinitas96880cf2017-10-20 18:52:20 +01005501 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(float)); src_addr += (int2)(2 * sizeof(float), 2 * src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005502 {
5503 float2 a0 = vload2(0, (__global float *)(src0_ptr + src_addr.s0));
5504 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
5505 float4 b1 = vload4(0, (__global float *)(src1_ptr + src_addr.s1 + src1_stride_y));
5506
5507 acc += b0 * (float4)a0.s0;
5508 acc += b1 * (float4)a0.s1;
5509 }
5510
5511 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(float), src1_stride_y))
5512 {
5513 float a0 = *((__global float *)(src0_ptr + src_addr.s0));
5514 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
5515
5516 acc += b0 * (float4)a0;
5517 }
5518
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005519 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005520 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
5521
5522 vstore4(acc, 0, (__global float *)(offset(&dst, 0, 0)));
5523}
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005524#endif // defined(WIDTH_VECTOR_A)
5525
5526/** This kernel accumulates each row with the biases vector.
5527 *
5528 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=short.
5529 * @note The vector size must be passed at compile time using -DVECTOR_SIZE e.g. -DVECTOR_SIZE=16.
5530 *
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01005531 * @param[in, out] accum_ptr Pointer to the accumulate tensor. Supported data type: U8/S8/U16/S16/F16/U32/S32/F32
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005532 * @param[in] accum_stride_x Stride of the accmulate tensor in X dimension (in bytes)
5533 * @param[in] accum_step_x accum_stride_x * number of elements along X processed per workitem(in bytes)
5534 * @param[in] accum_stride_y Stride of the accumlulate tensor in Y dimension (in bytes)
5535 * @param[in] accum_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
5536 * @param[in] accum_offset_first_element_in_bytes The offset of the first element in the accumulate tensor
5537 * @param[in] biases_ptr Pointer to the biases vector. Same as @p accum_ptr
5538 * @param[in] biases_stride_x Stride of the destination tensor in X dimension (in bytes)
5539 * @param[in] biases_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
5540 * @param[in] biases_offset_first_element_in_bytes The offset of the first element in the destination tensor
5541 */
5542#if defined(DATA_TYPE) && defined(VECTOR_SIZE)
5543__kernel void gemm_accumulate_biases(
5544 IMAGE_DECLARATION(accum),
5545 VECTOR_DECLARATION(biases))
5546{
5547 Image accum = CONVERT_TO_IMAGE_STRUCT(accum);
5548 Vector biases = CONVERT_TO_VECTOR_STRUCT(biases);
5549
5550 // Vector size, i.e. number of vector elements.
5551 VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
5552 accum_value = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)accum.ptr);
5553 VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
5554 biases_value = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)biases.ptr);
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01005555 accum_value = biases_value + accum_value;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005556 // Store result in the accumulate buffer
5557 VSTORE(VECTOR_SIZE)
5558 (accum_value, 0, (__global DATA_TYPE *)accum.ptr);
5559}
5560#endif // defined(DATA_TYPE) && defined(VECTOR_SIZE)