blob: 213075df07c5cca550455d590c043beea446922c [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +00002 * Copyright (c) 2017-2019 ARM Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Usama Arif0681e3b2019-04-25 14:28:07 +010024#include "gemm_helpers.h"
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +000025#include "repeat.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010026
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +000027#if defined(M0) && defined(K0) && defined(V0) && defined(DATA_TYPE) && defined(SRC_WIDTH)
28#define INC2 (VEC_DATA_TYPE(uint, 2))(0, 1)
29#define INC3 (VEC_DATA_TYPE(uint, 3))(0, 1, 2)
30#define INC4 (VEC_DATA_TYPE(uint, 4))(0, 1, 2, 3)
31#define INC8 (VEC_DATA_TYPE(uint, 8))(0, 1, 2, 3, 4, 5, 6, 7)
32#define INC16 (VEC_DATA_TYPE(uint, 16))(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)
33#define CONCAT_INC(K0) INC##K0
34#define INC(K0) CONCAT_INC(K0)
35
36#if(SRC_WIDTH % K0)
37#define BOUNDARY_CONDITION_X(x, a) \
38 ({ \
39 a = select(0, a, CONVERT(((x * (VEC_DATA_TYPE(uint, K0))K0 + INC(K0)) < (VEC_DATA_TYPE(uint, K0))SRC_WIDTH), VEC_DATA_TYPE(DATA_TYPE, K0))); \
40 })
41#else // (SRC_WIDTH % K0)
42#define BOUNDARY_CONDITION_X(x, a) \
43 ({})
44#endif // (SRC_WIDTH % K0)
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000045
46/** This OpenCL kernel reshapes the lhs input matrix. The kernel splits the input matrix in blocks of size M0xK0 and stores each one (not transposed) in
47 * the output matrix unrolling the values.
48 *
49 * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float)
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +000050 * @note The width of the input tensor must be passed at compile time using -DSRC_WIDTH (i.e. -DSRC_WIDTH=16)
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000051 * @note The block's dimensions (M0 and K0) must be passed at compile time using -DM0 and -DK0 (i.e. -DM0=2, -DK0=2).
52 * @note The number of M0xK0 vertical blocks to store on the same output row must be passed at compile time using -DV0 (i.e. -DV0=2)
53 * @note Only the following values for M0, K0 and V0 are supported:
54 * M0: 2,3,4,5,6,7,8
Gian Marco Iodicebacfec52019-01-11 11:30:55 +000055 * K0: 2,3,4,8,16
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000056 * V0: greater than 0
57 * @note In case the input has to be reinterpreted as a 3D tensor (i.e. input of convolution layer 1x1), the following information must be passed at compile time:
58 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
59 * -# HEIGHT_GEMM3D: The height of the input in case it has to be reinterpreted as a 3D tensor.
60 * -# DEPTH_GEMM3D: The depth of the input in case it has to be reinterpreted as a 3D tensor
61 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
62 * @note If the M0xK0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
63 *
64 * @param[in] src_ptr Pointer to the source LHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
65 * @param[in] src_stride_x Stride of the source LHS tensor in X dimension (in bytes)
66 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
67 * @param[in] src_stride_y Stride of the source LHS tensor in Y dimension (in bytes)
68 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
69 * @param[in] src_stride_z Stride of the source LHS tensor in Z dimension (in bytes)
70 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
71 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source LHS tensor
72 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
73 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
74 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
75 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
76 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
77 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
78 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
79 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
80 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
81 */
82__kernel void gemm_reshape_lhs_matrix_nt(TENSOR3D_DECLARATION(src),
83 TENSOR3D_DECLARATION(dst)
84#if defined(REINTERPRET_INPUT_AS_3D)
85 ,
86 uint cross_plane_pad
87#endif // REINTERPRET_INPUT_AS_3D
88 )
89{
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +000090 // Block size
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000091#define BLOCK_SIZE ((M0) * (K0))
92
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +000093 // Output offset X
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000094#if defined(INTERLEAVE)
95#define OUTPUT_OFFSET_X (K0)
96#else // defined(INTERLEAVE)
97#define OUTPUT_OFFSET_X (BLOCK_SIZE)
98#endif // defined(INTERLEAVE)
99
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000100 // Output step X
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000101#if defined(INTERLEAVE)
102#define OUTPUT_STEP_X (K0) * (V0)
103#else // Do not interleave
104#define OUTPUT_STEP_X (K0)
105#endif // defined(INTERLEAVE)
106
107 // Compute source and destination addresses
108 uint x = get_global_id(0);
109 uint y = get_global_id(1);
110 uint z = get_global_id(2);
111
112 // ------------------ Compute input/output addresses ---------------------------
113
114 // Compute the input address
115 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)K0 * sizeof(DATA_TYPE) + y * (uint)M0 * src_stride_y;
116
117 // Compute the output address
118 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)BLOCK_SIZE * (uint)V0 * sizeof(DATA_TYPE)) + ((y / (uint)V0) * (uint)dst_stride_y) + ((y % V0) *
119 (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE));
120
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000121 // Create variables: uint zin0=0, zin1=0, zin2=0...zin(M0-1)=0;
122 REPEAT_VAR_INIT_TO_CONST(M0, uint, zin, 0);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000123
124#if defined(REINTERPRET_INPUT_AS_3D)
125 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
126 // multiply src_stride_z by DEPTH_GEMM3D
127
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000128 input_ptr += z * (uint)src_stride_z * DEPTH_GEMM3D;
129
130 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
Usama Arif0681e3b2019-04-25 14:28:07 +0100131 CALCULATE_Z_OFFSET(M0, uint, zin, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, cross_plane_pad, src_stride_y);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000132
133#else // defined(REINTERPRET_INPUT_AS_3D)
134
135 input_ptr += z * (uint)src_stride_z;
136
137#endif // defined(REINTERPRET_INPUT_AS_3D)
138
139 // Add offset for batched GEMM
140 output_ptr += z * (uint)dst_stride_z;
141
142 // ---------------------------Load input values --------------------------------
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000143 // Load values from the LHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +0100144 LOAD_BLOCK(M0, K0, DATA_TYPE, a, input_ptr, 0, src_stride_y, zin);
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000145 BOUNDARY_CONDITION_X(x, a0);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000146#if M0 > 1
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000147 BOUNDARY_CONDITION_X(x, a1);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000148#endif // M0 > 1
149#if M0 > 2
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000150 BOUNDARY_CONDITION_X(x, a2);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000151#endif // M0 > 2
152#if M0 > 3
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000153 BOUNDARY_CONDITION_X(x, a3);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000154#endif // M0 > 3
155#if M0 > 4
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000156 BOUNDARY_CONDITION_X(x, a4);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000157#endif // M0 > 4
158#if M0 > 5
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000159 BOUNDARY_CONDITION_X(x, a5);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000160#endif // M0 > 5
161#if M0 > 6
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000162 BOUNDARY_CONDITION_X(x, a6);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000163#endif // M0 > 6
164#if M0 > 7
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000165 BOUNDARY_CONDITION_X(x, a7);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000166#endif // M0 > 7
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000167 // ---------------------------Store output values ------------------------------
Usama Arif0681e3b2019-04-25 14:28:07 +0100168 REPEAT_VAR_INIT_TO_CONST(16, uint, zout, 0);
169 STORE_BLOCK(M0, K0, DATA_TYPE, a, output_ptr, OUTPUT_STEP_X * sizeof(DATA_TYPE), zout);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000170
171#undef BLOCK_SIZE
172#undef OUTPUT_OFFSET_X
173#undef OUTPUT_STEP_X
174}
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000175
176#if M0 == 2
177#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
178 ({ \
179 VEC_DATA_TYPE(DATA_TYPE, M0) \
180 res = (VEC_DATA_TYPE(DATA_TYPE, M0))(a0.s##i, a1.s##i); \
181 VSTORE(M0) \
182 (res, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
183 })
184#elif M0 == 3 // M0 == 3
185#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
186 ({ \
187 VEC_DATA_TYPE(DATA_TYPE, M0) \
188 res = (VEC_DATA_TYPE(DATA_TYPE, M0))(a0.s##i, a1.s##i, a2.s##i); \
189 VSTORE(M0) \
190 (res, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
191 })
192#elif M0 == 4 // M0 == 4
193#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
194 ({ \
195 VEC_DATA_TYPE(DATA_TYPE, M0) \
196 res = (VEC_DATA_TYPE(DATA_TYPE, M0))(a0.s##i, a1.s##i, a2.s##i, a3.s##i); \
197 VSTORE(M0) \
198 (res, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
199 })
200#elif M0 == 5 // M0 == 5
201#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
202 ({ \
203 VEC_DATA_TYPE(DATA_TYPE, 4) \
204 res0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s##i, a1.s##i, a2.s##i, a3.s##i); \
205 DATA_TYPE res1 = a4.s##i; \
206 VSTORE(4) \
207 (res0, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
208 *((__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE)) + 4) = res1; \
209 })
210#elif M0 == 6 // M0 == 6
211#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
212 ({ \
213 VEC_DATA_TYPE(DATA_TYPE, 4) \
214 res0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s##i, a1.s##i, a2.s##i, a3.s##i); \
215 VEC_DATA_TYPE(DATA_TYPE, 2) \
216 res1 = (VEC_DATA_TYPE(DATA_TYPE, 2))(a4.s##i, a5.s##i); \
217 VSTORE(4) \
218 (res0, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
219 VSTORE(2) \
220 (res1, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE)) + 4); \
221 })
222#elif M0 == 7 // M0 == 7
223#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
224 ({ \
225 VEC_DATA_TYPE(DATA_TYPE, 4) \
226 res0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s##i, a1.s##i, a2.s##i, a3.s##i); \
227 VEC_DATA_TYPE(DATA_TYPE, 3) \
228 res1 = (VEC_DATA_TYPE(DATA_TYPE, 3))(a4.s##i, a5.s##i, a6.s##i); \
229 VSTORE(4) \
230 (res0, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
231 VSTORE(3) \
232 (res1, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE)) + 4); \
233 })
234#elif M0 == 8 // M0 == 8
235#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
236 ({ \
237 VEC_DATA_TYPE(DATA_TYPE, M0) \
238 res = (VEC_DATA_TYPE(DATA_TYPE, M0))(a0.s##i, a1.s##i, a2.s##i, a3.s##i, a4.s##i, a5.s##i, a6.s##i, a7.s##i); \
239 VSTORE(M0) \
240 (res, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
241 })
242#else // M0 not supported
243#error "M0 value not supported"
244#endif // N0 conditions
245
246/** This OpenCL kernel reshapes the lhs input matrix. The kernel splits the input matrix in blocks of size M0xK0 and stores each one (transposed) in
247 * the output matrix unrolling the values.
248 *
249 * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float)
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000250 * @note The width of the input tensor must be passed at compile time using -DSRC_WIDTH (i.e. -DSRC_WIDTH=16)
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000251 * @note The block's dimensions (M0 and K0) must be passed at compile time using -DM0 and -DK0 (i.e. -DM0=2, -DK0=2).
252 * @note The number of M0xK0 vertical blocks to store on the same output row must be passed at compile time using -DV0 (i.e. -DV0=2)
253 * @note Only the following values for M0, K0 and V0 are supported:
254 * M0: 2,3,4,5,6,7,8
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000255 * K0: 2,3,4,8,16
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000256 * V0: greater than 0
257 * @note In case the input has to be reinterpreted as a 3D tensor (i.e. input of convolution layer 1x1), the following information must be passed at compile time:
258 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
259 * -# HEIGHT_GEMM3D: The height of the input in case it has to be reinterpreted as a 3D tensor.
260 * -# DEPTH_GEMM3D: The depth of the input in case it has to be reinterpreted as a 3D tensor
261 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
262 * @note If the M0xK0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
263 *
264 * @param[in] src_ptr Pointer to the source LHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
265 * @param[in] src_stride_x Stride of the source LHS tensor in X dimension (in bytes)
266 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
267 * @param[in] src_stride_y Stride of the source LHS tensor in Y dimension (in bytes)
268 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
269 * @param[in] src_stride_z Stride of the source LHS tensor in Z dimension (in bytes)
270 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
271 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source LHS tensor
272 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
273 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
274 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
275 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
276 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
277 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
278 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
279 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
280 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
281 */
282__kernel void gemm_reshape_lhs_matrix_t(TENSOR3D_DECLARATION(src),
283 TENSOR3D_DECLARATION(dst)
284#if defined(REINTERPRET_INPUT_AS_3D)
285 ,
286 uint cross_plane_pad
287#endif // REINTERPRET_INPUT_AS_3D
288 )
289{
290 // Block size
291#define BLOCK_SIZE ((M0) * (K0))
292
293 // Output offset X
294#if defined(INTERLEAVE)
295#define OUTPUT_OFFSET_X (M0)
296#else // defined(INTERLEAVE)
297#define OUTPUT_OFFSET_X (BLOCK_SIZE)
298#endif // defined(INTERLEAVE)
299
300 // Output step X
301#if defined(INTERLEAVE)
302#define OUTPUT_STEP_X (M0) * (V0)
303#else // Do not interleave
304#define OUTPUT_STEP_X (M0)
305#endif // defined(INTERLEAVE)
306
307 // Compute source and destination addresses
308 uint x = get_global_id(0);
309 uint y = get_global_id(1);
310 uint z = get_global_id(2);
311
312 // ------------------ Compute input/output addresses ---------------------------
313
314 // Compute the input address
315 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)K0 * sizeof(DATA_TYPE) + y * (uint)M0 * src_stride_y;
316
317 // Compute the output address
318 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)BLOCK_SIZE * (uint)V0 * sizeof(DATA_TYPE)) + ((y / (uint)V0) * (uint)dst_stride_y) + ((y % V0) *
319 (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE));
320
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000321 // Create variables: uint zin0=0, zin1=0, zin2=0...zin(M0-1)=0;
322 REPEAT_VAR_INIT_TO_CONST(M0, uint, zin, 0);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000323
324#if defined(REINTERPRET_INPUT_AS_3D)
325 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
326 // multiply src_stride_z by DEPTH_GEMM3D
327
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000328 input_ptr += z * (uint)src_stride_z * DEPTH_GEMM3D;
329
330 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
Usama Arif0681e3b2019-04-25 14:28:07 +0100331 CALCULATE_Z_OFFSET(M0, uint, zin, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, cross_plane_pad, src_stride_y);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000332
333#else // defined(REINTERPRET_INPUT_AS_3D)
334
335 input_ptr += z * (uint)src_stride_z;
336
337#endif // defined(REINTERPRET_INPUT_AS_3D)
338
339 // Add offset for batched GEMM
340 output_ptr += z * (uint)dst_stride_z;
341
342 // ---------------------------Load input values --------------------------------
343
344 // Load values from the LHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +0100345 LOAD_BLOCK(M0, K0, DATA_TYPE, a, input_ptr, 0, src_stride_y, zin);
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000346 BOUNDARY_CONDITION_X(x, a0);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000347#if M0 > 1
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000348 BOUNDARY_CONDITION_X(x, a1);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000349#endif // M0 > 1
350#if M0 > 2
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000351 BOUNDARY_CONDITION_X(x, a2);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000352#endif // M0 > 2
353#if M0 > 3
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000354 BOUNDARY_CONDITION_X(x, a3);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000355#endif // M0 > 3
356#if M0 > 4
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000357 BOUNDARY_CONDITION_X(x, a4);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000358#endif // M0 > 4
359#if M0 > 5
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000360 BOUNDARY_CONDITION_X(x, a5);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000361#endif // M0 > 5
362#if M0 > 6
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000363 BOUNDARY_CONDITION_X(x, a6);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000364#endif // M0 > 6
365#if M0 > 7
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000366 BOUNDARY_CONDITION_X(x, a7);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000367#endif // M0 > 7
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000368 // ---------------------------Transpose and store block -----------------------
369
370 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 0);
371 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 1);
372#if K0 > 2
373 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 2);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000374#endif // K0 > 2
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000375#if K0 > 3
376 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 3);
377#endif // K0 > 3
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000378#if K0 > 4
379 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 4);
380 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 5);
381 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 6);
382 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 7);
383#endif // K0 > 4
384#if K0 > 8
385 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 8);
386 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 9);
387 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, A);
388 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, B);
389 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, C);
390 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, D);
391 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, E);
392 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, F);
393#endif // K0 > 8
394
395#undef BLOCK_SIZE
396#undef OUTPUT_OFFSET_X
397#undef OUTPUT_STEP_X
398}
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000399#endif // defined(M0) && defined(K0) && defined(V0) && defined(DATA_TYPE) && defined(SRC_WIDTH)
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000400
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000401#if defined(K0) && defined(N0) && defined(H0) && defined(DATA_TYPE) && defined(SRC_HEIGHT)
402/** This OpenCL kernel reshapes the rhs input matrix. The kernel splits the input matrix in blocks of size K0xN0 and stores each one (not transposed) in
403 * the output matrix unrolling the values.
404 *
405 * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float)
406 * @note The height of the input tensor must be passed at compile time using -DSRC_HEIGHT (i.e. -DSRC_HEIGHT=16)
407 * @note The block's dimensions (K0 and N0) must be passed at compile time using -DK0 and -DN0 (i.e. -DK0=2, -DN0=2).
408 * @note The number of K0xN0 vertical blocks to store on the same output row must be passed at compile time using -DH0 (i.e. -DH0=2)
409 * @note If the K0xN0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
410 * @note Only the following values for K0, N0 and H0 are supported:
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000411 * N0: 2,3,4,8,16
412 * K0: 1,2,3,4,8,16
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000413 * H0: greater than 0
414 *
415 * @param[in] src_ptr Pointer to the source RHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
416 * @param[in] src_stride_x Stride of the source RHS tensor in X dimension (in bytes)
417 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
418 * @param[in] src_stride_y Stride of the source RHS tensor in Y dimension (in bytes)
419 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
420 * @param[in] src_stride_z Stride of the source RHS tensor in Z dimension (in bytes)
421 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
422 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source RHS tensor
423 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
424 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
425 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
426 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
427 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
428 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
429 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
430 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
431 */
432__kernel void gemm_reshape_rhs_matrix_nt(TENSOR3D_DECLARATION(src),
433 TENSOR3D_DECLARATION(dst))
434{
435 // Block size
436#define BLOCK_SIZE ((K0) * (N0))
437
438 // Output offset X
439#if defined(INTERLEAVE)
440#define OUTPUT_OFFSET_X (N0)
441#else // defined(INTERLEAVE)
442#define OUTPUT_OFFSET_X (BLOCK_SIZE)
443#endif // defined(INTERLEAVE)
444
445 // Output step X
446#if defined(INTERLEAVE)
447#define OUTPUT_STEP_X (N0) * (H0)
448#else // Do not interleave
449#define OUTPUT_STEP_X (N0)
450#endif // defined(INTERLEAVE)
451
452 // Compute source and destination addresses
453 uint x = get_global_id(0);
454 uint y = get_global_id(1);
455 uint z = get_global_id(2);
456
457 // ------------------ Compute input/output addresses ---------------------------
458
459 // Compute the input address
460 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)N0 * sizeof(DATA_TYPE) + y * (uint)K0 * src_stride_y + z * (uint)src_stride_z;
461
462 // Compute the output address
463 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (y * (uint)BLOCK_SIZE * (uint)H0 * sizeof(DATA_TYPE)) + ((x % (uint)H0) * (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE)) + ((
464 x / (uint)H0)
465 * (uint)dst_stride_y)
466 + z * (uint)dst_stride_z;
467
468 // ---------------------------Load input values --------------------------------
469
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +0000470 REPEAT_VAR_INIT_TO_CONST(K0, VEC_DATA_TYPE(DATA_TYPE, N0), a, 0); ////uint a0=0, a1=0, a2=0...a(M0-1)=0;
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000471
472 // Load values from the RHS matrix
473 a0 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y));
474#if K0 > 1
475 if(y * (uint)K0 + 1 < SRC_HEIGHT)
476 {
477 a1 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y));
478 }
479#endif // K0 > 1
480#if K0 > 2
481 if(y * (uint)K0 + 2 < SRC_HEIGHT)
482 {
483 a2 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y));
484 }
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000485#endif // K0 > 2
486#if K0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000487 if(y * (uint)K0 + 3 < SRC_HEIGHT)
488 {
489 a3 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y));
490 }
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000491#endif // K0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000492#if K0 > 4
493 if(y * (uint)K0 + 4 < SRC_HEIGHT)
494 {
495 a4 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 4 * src_stride_y));
496 }
497 if(y * (uint)K0 + 5 < SRC_HEIGHT)
498 {
499 a5 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 5 * src_stride_y));
500 }
501 if(y * (uint)K0 + 6 < SRC_HEIGHT)
502 {
503 a6 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 6 * src_stride_y));
504 }
505 if(y * (uint)K0 + 7 < SRC_HEIGHT)
506 {
507 a7 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 7 * src_stride_y));
508 }
509#endif // K0 > 4
510#if K0 > 8
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000511 if(y * (uint)K0 + 8 < SRC_HEIGHT)
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000512 {
513 a8 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 8 * src_stride_y));
514 }
515 if(y * (uint)K0 + 9 < SRC_HEIGHT)
516 {
517 a9 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 9 * src_stride_y));
518 }
519 if(y * (uint)K0 + 10 < SRC_HEIGHT)
520 {
521 aA = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 10 * src_stride_y));
522 }
523 if(y * (uint)K0 + 11 < SRC_HEIGHT)
524 {
525 aB = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 11 * src_stride_y));
526 }
527 if(y * (uint)K0 + 12 < SRC_HEIGHT)
528 {
529 aC = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 12 * src_stride_y));
530 }
531 if(y * (uint)K0 + 13 < SRC_HEIGHT)
532 {
533 aD = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 13 * src_stride_y));
534 }
535 if(y * (uint)K0 + 14 < SRC_HEIGHT)
536 {
537 aE = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 14 * src_stride_y));
538 }
539 if(y * (uint)K0 + 15 < SRC_HEIGHT)
540 {
541 aF = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 15 * src_stride_y));
542 }
543#endif // K0 > 8
544
545 // ---------------------------Store output values ------------------------------
Usama Arif0681e3b2019-04-25 14:28:07 +0100546 REPEAT_VAR_INIT_TO_CONST(16, uint, zout, 0);
547 STORE_BLOCK(K0, N0, DATA_TYPE, a, output_ptr, OUTPUT_STEP_X * sizeof(DATA_TYPE), zout);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000548
549#undef BLOCK_SIZE
550#undef OUTPUT_OFFSET_X
551#undef OUTPUT_STEP_X
552}
553
554#if defined(TRANSPOSE)
555/** This OpenCL kernel reshapes the rhs input matrix. The kernel splits the input matrix in blocks of size K0xN0 and stores each one (transposed) in
556 * the output matrix unrolling the values.
557 *
558 * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float)
559 * @note The height of the input tensor must be passed at compile time using -DSRC_HEIGHT (i.e. -DSRC_HEIGHT=16)
560 * @note The block's dimensions (K0 and N0) must be passed at compile time using -DK0 and -DN0 (i.e. -DK0=2, -DN0=2).
561 * @note The number of K0xN0 vertical blocks to store on the same output row must be passed at compile time using -DH0 (i.e. -DH0=2)
562 * @note If the K0xN0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
563 * @note The option -DTRANSPOSE must passed at compile time.
564 * @note Only the following values for K0, N0 and H0 are supported:
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000565 * N0: 2,3,4,8,16
566 * K0: 2,3,4,8,16
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000567 * H0: greater than 0
568 *
569 * @param[in] src_ptr Pointer to the source RHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
570 * @param[in] src_stride_x Stride of the source RHS tensor in X dimension (in bytes)
571 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
572 * @param[in] src_stride_y Stride of the source RHS tensor in Y dimension (in bytes)
573 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
574 * @param[in] src_stride_z Stride of the source RHS tensor in Z dimension (in bytes)
575 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
576 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source RHS tensor
577 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
578 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
579 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
580 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
581 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
582 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
583 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
584 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
585 */
586__kernel void gemm_reshape_rhs_matrix_t(TENSOR3D_DECLARATION(src),
587 TENSOR3D_DECLARATION(dst))
588{
589 // Block size
590#define BLOCK_SIZE ((K0) * (N0))
591
592 // Output offset X
593#if defined(INTERLEAVE)
594#define OUTPUT_OFFSET_X (K0)
595#else // defined(INTERLEAVE)
596#define OUTPUT_OFFSET_X (BLOCK_SIZE)
597#endif // defined(INTERLEAVE)
598
599 // Output step X
600#if defined(INTERLEAVE)
601#define OUTPUT_STEP_X (K0) * (H0)
602#else // Do not interleave
603#define OUTPUT_STEP_X (K0)
604#endif // defined(INTERLEAVE)
605
606 // Compute source and destination addresses
607 uint x = get_global_id(0);
608 uint y = get_global_id(1);
609 uint z = get_global_id(2);
610
611 // ------------------ Compute input/output addresses ---------------------------
612
613 // Compute the input address
614 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)N0 * sizeof(DATA_TYPE) + y * (uint)K0 * src_stride_y + z * (uint)src_stride_z;
615
616 // Compute the output address
617 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (y * (uint)BLOCK_SIZE * (uint)H0 * sizeof(DATA_TYPE)) + ((x % H0) * (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE)) + ((x /
618 (uint)H0) * (uint)dst_stride_y) + z * (uint)dst_stride_z;
619
620 // ---------------------------Load input values --------------------------------
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +0000621 REPEAT_VAR_INIT_TO_CONST(K0, VEC_DATA_TYPE(DATA_TYPE, N0), a, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) a0=0, a1=0, ... a(K0-1)=0;
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000622
623 // Load values from the RHS matrix
624 a0 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y));
625 if(y * (uint)K0 + 1 < SRC_HEIGHT)
626 {
627 a1 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y));
628 }
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000629#if K0 > 2
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000630 if(y * (uint)K0 + 2 < SRC_HEIGHT)
631 {
632 a2 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y));
633 }
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000634#endif // K0 > 2
635#if K0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000636 if(y * (uint)K0 + 3 < SRC_HEIGHT)
637 {
638 a3 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y));
639 }
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000640#endif // K0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000641#if K0 > 4
642 if(y * (uint)K0 + 4 < SRC_HEIGHT)
643 {
644 a4 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 4 * src_stride_y));
645 }
646 if(y * (uint)K0 + 5 < SRC_HEIGHT)
647 {
648 a5 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 5 * src_stride_y));
649 }
650 if(y * (uint)K0 + 6 < SRC_HEIGHT)
651 {
652 a6 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 6 * src_stride_y));
653 }
654 if(y * (uint)K0 + 7 < SRC_HEIGHT)
655 {
656 a7 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 7 * src_stride_y));
657 }
658#endif // K0 > 4
659#if K0 > 8
Gian Marco Iodice89124342018-12-19 14:17:22 +0000660 if(y * (uint)K0 + 8 < SRC_HEIGHT)
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000661 {
662 a8 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 8 * src_stride_y));
663 }
664 if(y * (uint)K0 + 9 < SRC_HEIGHT)
665 {
666 a9 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 9 * src_stride_y));
667 }
668 if(y * (uint)K0 + 10 < SRC_HEIGHT)
669 {
670 aA = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 10 * src_stride_y));
671 }
672 if(y * (uint)K0 + 11 < SRC_HEIGHT)
673 {
674 aB = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 11 * src_stride_y));
675 }
676 if(y * (uint)K0 + 12 < SRC_HEIGHT)
677 {
678 aC = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 12 * src_stride_y));
679 }
680 if(y * (uint)K0 + 13 < SRC_HEIGHT)
681 {
682 aD = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 13 * src_stride_y));
683 }
684 if(y * (uint)K0 + 14 < SRC_HEIGHT)
685 {
686 aE = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 14 * src_stride_y));
687 }
688 if(y * (uint)K0 + 15 < SRC_HEIGHT)
689 {
690 aF = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 15 * src_stride_y));
691 }
692#endif // K0 > 8
693
694 // ---------------------------Transpose the block ------------------------------
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +0000695 REPEAT_VAR_INIT_TO_CONST(N0, VEC_DATA_TYPE(DATA_TYPE, K0), res, 0); //VEC_DATA_TYPE(DATA_TYPE, K0) res0=0, res1=0, res2=0,... res(N0-1)=0;
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000696
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000697#if K0 == 2
698 // This part computes the following transpositions:
699 // 2x2 -> 2x2
700 // 2x4 -> 4x2
701 // 2x8 -> 8x2
702 // 2x16 -> 16x2
703 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0);
704 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1);
705#if N0 > 2
706 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2);
707#endif // N0 > 2
708#if N0 > 3
709 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3);
710#endif // N0 > 3
711#if N0 > 4
712 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4);
713 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5);
714 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6);
715 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7);
716#endif // N0 > 4
717#if N0 > 8
718 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8);
719 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9);
720 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA);
721 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB);
722 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC);
723 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD);
724 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE);
725 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF);
726#endif // N0 > 8
727
728#elif K0 == 3 // K0 == 2
729 // This part computes the following transpositions:
730 // 3x2 -> 2x3
731 // 3x4 -> 4x3
732 // 3x8 -> 8x3
733 // 3x16 -> 16x3
Georgios Pinitasb0f342e2019-05-21 13:32:43 +0100734 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0);
735 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000736#if N0 > 2
Georgios Pinitasb0f342e2019-05-21 13:32:43 +0100737 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000738#endif // N0 > 2
739#if N0 > 3
Georgios Pinitasb0f342e2019-05-21 13:32:43 +0100740 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000741#endif // N0 > 3
742#if N0 > 4
Georgios Pinitasb0f342e2019-05-21 13:32:43 +0100743 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4);
744 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5);
745 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6);
746 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000747#endif // N0 > 4
748#if N0 > 8
Georgios Pinitasb0f342e2019-05-21 13:32:43 +0100749 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8);
750 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9);
751 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA);
752 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB);
753 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC);
754 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD);
755 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE);
756 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000757#endif // N0 > 8
758
759#elif K0 == 4 // K0 == 4
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000760 // This part computes the following transpositions:
761 // 4x2 -> 2x4
762 // 4x4 -> 4x4
763 // 4x8 -> 8x4
764 // 4x16 -> 16x4
765 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0, a3.s0);
766 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1, a3.s1);
767#if N0 > 2
768 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2, a3.s2);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000769#endif // N0 > 2
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000770#if N0 > 3
771 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3, a3.s3);
772#endif // N0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000773#if N0 > 4
774 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4, a3.s4);
775 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5, a3.s5);
776 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6, a3.s6);
777 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7, a3.s7);
778#endif // N0 > 4
779#if N0 > 8
780 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8, a3.s8);
781 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9, a3.s9);
782 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA, a3.sA);
783 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB, a3.sB);
784 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC, a3.sC);
785 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD, a3.sD);
786 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE, a3.sE);
787 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF, a3.sF);
788#endif // N0 > 8
789
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000790#elif K0 == 8 // K0 == 8
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000791 // This part computes the following transpositions:
792 // 8x2 -> 2x8
793 // 8x4 -> 4x8
794 // 8x8 -> 8x8
795 // 8x16 -> 16x8
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000796 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0, a3.s0, a4.s0, a5.s0, a6.s0, a7.s0);
797 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1, a3.s1, a4.s1, a5.s1, a6.s1, a7.s1);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000798#if N0 > 2
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000799 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2, a3.s2, a4.s2, a5.s2, a6.s2, a7.s2);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000800#endif // N0 > 2
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000801#if N0 > 3
802 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3, a3.s3, a4.s3, a5.s3, a6.s3, a7.s3);
803#endif // N0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000804#if N0 > 4
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000805 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4, a3.s4, a4.s4, a5.s4, a6.s4, a7.s4);
806 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5, a3.s5, a4.s5, a5.s5, a6.s5, a7.s5);
807 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6, a3.s6, a4.s6, a5.s6, a6.s6, a7.s6);
808 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7, a3.s7, a4.s7, a5.s7, a6.s7, a7.s7);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000809#endif // N0 > 4
810#if N0 > 8
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000811 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8, a3.s8, a4.s8, a5.s8, a6.s8, a7.s8);
812 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9, a3.s9, a4.s9, a5.s9, a6.s9, a7.s9);
813 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA, a3.sA, a4.sA, a5.sA, a6.sA, a7.sA);
814 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB, a3.sB, a4.sB, a5.sB, a6.sB, a7.sB);
815 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC, a3.sC, a4.sC, a5.sC, a6.sC, a7.sC);
816 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD, a3.sD, a4.sD, a5.sD, a6.sD, a7.sD);
817 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE, a3.sE, a4.sE, a5.sE, a6.sE, a7.sE);
818 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF, a3.sF, a4.sF, a5.sF, a6.sF, a7.sF);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000819#endif // N0 > 8
820
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000821#elif K0 == 16 // K0 == 16
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000822
823 // This part computes the following transpositions:
824 // 16x2 -> 2x16
825 // 16x4 -> 4x16
826 // 16x8 -> 8x16
827 // 16x16 -> 16x16
828 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0, a3.s0, a4.s0, a5.s0, a6.s0, a7.s0,
829 a8.s0, a9.s0, aA.s0, aB.s0, aC.s0, aD.s0, aE.s0, aF.s0);
830 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1, a3.s1, a4.s1, a5.s1, a6.s1, a7.s1,
831 a8.s1, a9.s1, aA.s1, aB.s1, aC.s1, aD.s1, aE.s1, aF.s1);
832#if N0 > 2
833 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2, a3.s2, a4.s2, a5.s2, a6.s2, a7.s2,
834 a8.s2, a9.s2, aA.s2, aB.s2, aC.s2, aD.s2, aE.s2, aF.s2);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000835#endif // N0 > 2
836#if N0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000837 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3, a3.s3, a4.s3, a5.s3, a6.s3, a7.s3,
838 a8.s3, a9.s3, aA.s3, aB.s3, aC.s3, aD.s3, aE.s3, aF.s3);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000839#endif // N0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000840#if N0 > 4
841 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4, a3.s4, a4.s4, a5.s4, a6.s4, a7.s4,
842 a8.s4, a9.s4, aA.s4, aB.s4, aC.s4, aD.s4, aE.s4, aF.s4);
843 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5, a3.s5, a4.s5, a5.s5, a6.s5, a7.s5,
844 a8.s5, a9.s5, aA.s5, aB.s5, aC.s5, aD.s5, aE.s5, aF.s5);
845 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6, a3.s6, a4.s6, a5.s6, a6.s6, a7.s6,
846 a8.s6, a9.s6, aA.s6, aB.s6, aC.s6, aD.s6, aE.s6, aF.s6);
847 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7, a3.s7, a4.s7, a5.s7, a6.s7, a7.s7,
848 a8.s7, a9.s7, aA.s7, aB.s7, aC.s7, aD.s7, aE.s7, aF.s7);
849#endif // N0 > 4
850#if N0 > 8
851 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8, a3.s8, a4.s8, a5.s8, a6.s8, a7.s8,
852 a8.s8, a9.s8, aA.s8, aB.s8, aC.s8, aD.s8, aE.s8, aF.s8);
853 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9, a3.s9, a4.s9, a5.s9, a6.s9, a7.s9,
854 a8.s9, a9.s9, aA.s9, aB.s9, aC.s9, aD.s9, aE.s9, aF.s9);
855 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA, a3.sA, a4.sA, a5.sA, a6.sA, a7.sA,
856 a8.sA, a9.sA, aA.sA, aB.sA, aC.sA, aD.sA, aE.sA, aF.sA);
857 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB, a3.sB, a4.sB, a5.sB, a6.sB, a7.sB,
858 a8.sB, a9.sB, aA.sB, aB.sB, aC.sB, aD.sB, aE.sB, aF.sB);
859 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC, a3.sC, a4.sC, a5.sC, a6.sC, a7.sC,
860 a8.sC, a9.sC, aA.sC, aB.sC, aC.sC, aD.sC, aE.sC, aF.sC);
861 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD, a3.sD, a4.sD, a5.sD, a6.sD, a7.sD,
862 a8.sD, a9.sD, aA.sD, aB.sD, aC.sD, aD.sD, aE.sD, aF.sD);
863 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE, a3.sE, a4.sE, a5.sE, a6.sE, a7.sE,
864 a8.sE, a9.sE, aA.sE, aB.sE, aC.sE, aD.sE, aE.sE, aF.sE);
865 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF, a3.sF, a4.sF, a5.sF, a6.sF, a7.sF,
866 a8.sF, a9.sF, aA.sF, aB.sF, aC.sF, aD.sF, aE.sF, aF.sF);
867#endif // N0 > 8
868
869#else // N0 == 16
870#error "Not supported N0 value"
871#endif // N0 > 2
872
873 // ---------------------------Store the output values ------------------------------
Usama Arif0681e3b2019-04-25 14:28:07 +0100874 REPEAT_VAR_INIT_TO_CONST(16, uint, zout, 0);
875 STORE_BLOCK(N0, K0, DATA_TYPE, res, output_ptr, OUTPUT_STEP_X * sizeof(DATA_TYPE), zout);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000876
877#undef BLOCK_SIZE
878#undef OUTPUT_OFFSET_X
879#undef OUTPUT_STEP_X
880}
881#endif // defined(TRANSPOSE)
882#endif // defined(K0) && defined(N0) && defined(H0) && defined(DATA_TYPE) && defined(SRC_HEIGHT)
883
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +0000884#if defined(M0) && defined(N0) && defined(K0) && defined(H0) && defined(DATA_TYPE) && defined(M) && defined(N) && defined(K)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +0000885
886#define CONCAT(a, b) a##b
887
888#define ARM_DOT1(a, b, c) \
889 ({ \
890 c = fma(a, b, c); \
891 })
892#define ARM_DOT2(a, b, c) \
893 ({ \
894 c = fma(a.s0, b.s0, c); \
895 c = fma(a.s1, b.s1, c); \
896 })
897#define ARM_DOT3(a, b, c) \
898 ({ \
899 ARM_DOT2(a, b, c); \
900 c = fma((a.s2), (b.s2), c); \
901 })
902#define ARM_DOT4(a, b, c) \
903 ({ \
904 ARM_DOT3(a, b, c); \
905 c = fma((a.s3), (b.s3), c); \
906 })
907#define ARM_DOT8(a, b, c) \
908 ({ \
909 ARM_DOT4((a.lo), (b.lo), c); \
910 ARM_DOT4((a.hi), (b.hi), c); \
911 })
912#define ARM_DOT16(a, b, c) \
913 ({ \
914 ARM_DOT8((a.lo), (b.lo), c); \
915 ARM_DOT8((a.hi), (b.hi), c); \
916 })
917
918#if N0 == 2
919#define ARM_DOT_K0XN0(k0, a, b, c) \
920 ({ \
921 CONCAT(ARM_DOT, k0) \
922 ((a), (b##0), (c.s0)); \
923 CONCAT(ARM_DOT, k0) \
924 ((a), (b##1), (c.s1)); \
925 })
926#elif N0 == 3 // N0 == 3
927#define ARM_DOT_K0XN0(k0, a, b, c) \
928 ({ \
929 CONCAT(ARM_DOT, k0) \
930 ((a), (b##0), (c.s0)); \
931 CONCAT(ARM_DOT, k0) \
932 ((a), (b##1), (c.s1)); \
933 CONCAT(ARM_DOT, k0) \
934 ((a), (b##2), (c.s2)); \
935 })
936#elif N0 == 4 // N0 == 4
937#define ARM_DOT_K0XN0(k0, a, b, c) \
938 ({ \
939 CONCAT(ARM_DOT, k0) \
940 ((a), (b##0), (c.s0)); \
941 CONCAT(ARM_DOT, k0) \
942 ((a), (b##1), (c.s1)); \
943 CONCAT(ARM_DOT, k0) \
944 ((a), (b##2), (c.s2)); \
945 CONCAT(ARM_DOT, k0) \
946 ((a), (b##3), (c.s3)); \
947 })
948#elif N0 == 8 // N0 == 8
949#define ARM_DOT_K0XN0(k0, a, b, c) \
950 ({ \
951 CONCAT(ARM_DOT, k0) \
952 ((a), (b##0), (c.s0)); \
953 CONCAT(ARM_DOT, k0) \
954 ((a), (b##1), (c.s1)); \
955 CONCAT(ARM_DOT, k0) \
956 ((a), (b##2), (c.s2)); \
957 CONCAT(ARM_DOT, k0) \
958 ((a), (b##3), (c.s3)); \
959 CONCAT(ARM_DOT, k0) \
960 ((a), (b##4), (c.s4)); \
961 CONCAT(ARM_DOT, k0) \
962 ((a), (b##5), (c.s5)); \
963 CONCAT(ARM_DOT, k0) \
964 ((a), (b##6), (c.s6)); \
965 CONCAT(ARM_DOT, k0) \
966 ((a), (b##7), (c.s7)); \
967 })
968#elif N0 == 16 // N0 == 16
969#define ARM_DOT_K0XN0(k0, a, b, c) \
970 ({ \
971 CONCAT(ARM_DOT, k0) \
972 ((a), (b##0), (c.s0)); \
973 CONCAT(ARM_DOT, k0) \
974 ((a), (b##1), (c.s1)); \
975 CONCAT(ARM_DOT, k0) \
976 ((a), (b##2), (c.s2)); \
977 CONCAT(ARM_DOT, k0) \
978 ((a), (b##3), (c.s3)); \
979 CONCAT(ARM_DOT, k0) \
980 ((a), (b##4), (c.s4)); \
981 CONCAT(ARM_DOT, k0) \
982 ((a), (b##5), (c.s5)); \
983 CONCAT(ARM_DOT, k0) \
984 ((a), (b##6), (c.s6)); \
985 CONCAT(ARM_DOT, k0) \
986 ((a), (b##7), (c.s7)); \
987 CONCAT(ARM_DOT, k0) \
988 ((a), (b##8), (c.s8)); \
989 CONCAT(ARM_DOT, k0) \
990 ((a), (b##9), (c.s9)); \
991 CONCAT(ARM_DOT, k0) \
992 ((a), (b##A), (c.sA)); \
993 CONCAT(ARM_DOT, k0) \
994 ((a), (b##B), (c.sB)); \
995 CONCAT(ARM_DOT, k0) \
996 ((a), (b##C), (c.sC)); \
997 CONCAT(ARM_DOT, k0) \
998 ((a), (b##D), (c.sD)); \
999 CONCAT(ARM_DOT, k0) \
1000 ((a), (b##E), (c.sE)); \
1001 CONCAT(ARM_DOT, k0) \
1002 ((a), (b##F), (c.sF)); \
1003 })
1004#else // N0 not supported
1005#error "N0 value not supported"
1006#endif // N0 conditions
1007
1008/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
1009 * The LHS matrix is NOT reshaped
1010 * The RHS is reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the block K0xN0 is transposed
1011 *
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001012 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
1013 * @note The GEMM's dimensions (M,N and K) must be passed at compile time using -DM, -DN and and -DK (i.e. -DM=52, -DN=30 and -DK=90)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001014 * @note The number of columns of LHS matrix must be passed at compile time using -DK (i.e. -DK=64)
1015 * @note The block's dimensions used for reshaping the RHS matrix (N0 and K0) must be passed at compile time using -DN0 and -DK0 (i.e. -DN0=8, -DK0=4).
1016 * @note The number of M0 rows to process must be passed at compile time using -DM0 (i.e. -DM0=2)
1017 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (i.e. -DH0=2)
1018 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
1019 * @note Only the following configurations of M0, N0 and K0 are currently supported:
1020 * - M0 = 1, 2, 3, 4, 5, 6, 7, 8
1021 * - N0 = 2, 3, 4, 8, 16
1022 * - K0 = 2, 3, 4, 8, 16
Gian Marco Iodice62251f72019-03-11 16:07:12 +00001023 * - H0 >= 1
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001024 *
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +01001025 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (i.e. -DACTIVATION_TYPE=RELU), A, B variables required by some activation functions and should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
1026 * The activation function is performed after the bias addition
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001027 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
1028 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
1029 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1030 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1031 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1032 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix
1033 *
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001034 * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: F16/F32
1035 * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes)
1036 * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1037 * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes)
1038 * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1039 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
1040 * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
1041 * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes)
1042 * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1043 * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
1044 * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1045 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
1046 * @param[in] bias_ptr (Optional)Pointer to the bias reshaped matrix. Supported data type: same as @p lhs_ptr
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001047 * @param[in] bias_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
1048 * @param[in] bias_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
1049 * @param[in] bias_step_x (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes)
1050 * @param[in] bias_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
1051 * @param[in] bias_step_y (Optional) bias_stride_y * number of elements along Y processed per workitem(in bytes)
1052 * @param[in] bias_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001053 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
1054 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1055 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1056 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1057 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1058 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1059 * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes)
1060 * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001061 * @param[in] bias_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001062 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1063 * @param[in] lhs_cross_plane_pad (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
1064 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001065 */
1066__kernel void gemm_mm_reshaped_only_rhs_t(IMAGE_DECLARATION(lhs),
1067 IMAGE_DECLARATION(rhs),
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001068#if defined(BETA)
1069 IMAGE_DECLARATION(bias),
1070#endif // defined(BETA)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001071 IMAGE_DECLARATION(dst),
1072 uint lhs_stride_z,
1073 uint rhs_stride_z,
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001074#if defined(BETA)
1075 uint bias_stride_z,
1076#endif //defined(BETA)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001077 uint dst_stride_z
1078#if defined(REINTERPRET_INPUT_AS_3D)
1079 ,
1080 uint lhs_cross_plane_pad
1081#endif // REINTERPRET_INPUT_AS_3D
1082#if defined(REINTERPRET_OUTPUT_AS_3D)
1083 ,
1084 uint dst_cross_plane_pad
1085#endif // REINTERPRET_OUTPUT_AS_3D
1086 )
1087{
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001088 // Block size
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001089#define RHS_BLOCK_SIZE ((K0) * (N0))
1090
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001091 // RHS offset and step X
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001092#if defined(RHS_INTERLEAVE)
1093#define RHS_OFFSET_X (K0)
1094#define RHS_STEP_X ((K0) * (H0))
1095#define RHS_STEP_LOOP (1)
1096#else // defined(RHS_INTERLEAVE)
1097#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
1098#define RHS_STEP_X (K0)
1099#define RHS_STEP_LOOP (H0)
1100#endif // defined(RHS_INTERLEAVE)
1101
1102 uint x = get_global_id(0);
1103 uint y = get_global_id(1);
1104 uint z = get_global_id(2);
1105
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001106#if defined(DUMMY_WORK_ITEMS)
1107 if((x * N0 >= N) || (y * M0 >= M))
1108 {
1109 return;
1110 }
1111#endif // defined(DUMMY_WORK_ITEMS)
1112
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001113 // Compute LHS matrix address
1114 uint lhs_offset = lhs_offset_first_element_in_bytes + y * M0 * (uint)lhs_stride_y;
1115
1116 // Compute RHS matrix address
1117 uint rhs_offset = rhs_offset_first_element_in_bytes + (x % H0) * (uint)RHS_OFFSET_X * sizeof(DATA_TYPE) + (x / (uint)H0) * rhs_stride_y;
1118
1119#if defined(MATRIX_B_DEPTH)
1120 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1121 rhs_offset += (z % MATRIX_B_DEPTH) * rhs_stride_z;
1122#else // defined(MATRIX_B_DEPTH)
1123 rhs_offset += z * rhs_stride_z;
1124#endif // defined(MATRIX_B_DEPTH)
1125
Usama Arif0681e3b2019-04-25 14:28:07 +01001126 REPEAT_VAR_INIT_TO_CONST(8, uint, zlhs, 0); //uint zlhs0=0,zlhs1=0,zlhs2=0,... zlhs7=0;
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001127 REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001128
1129#if defined(REINTERPRET_INPUT_AS_3D)
Usama Arif0681e3b2019-04-25 14:28:07 +01001130 // The plane (zlhs) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
1131 CALCULATE_Z_OFFSET(M0, uint, zlhs, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001132
1133 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1134 // multiply lhs_stride_z by DEPTH_GEMM3D
1135 lhs_offset += z * lhs_stride_z * DEPTH_GEMM3D;
1136
1137#else // defined(REINTERPRET_INPUT_AS_3D)
1138
1139 // Add offset for batched GEMM
1140 lhs_offset += z * lhs_stride_z;
1141
1142#endif // defined(REINTERPRET_INPUT_AS_3D)
1143
1144 // Initialize the accumulators
1145 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(M0-1)=0;
1146
1147 int i = 0;
1148 for(; i <= (K - K0); i += K0)
1149 {
1150 // Supported cases (M0, K0):
1151 // 1,2 - 1,3 - 1,4 - 1,8 - 1,16
1152 // 2,2 - 2,3 - 2,4 - 2,8 - 2,16
1153 // 3,2 - 3,3 - 3,4 - 3,8 - 3,16
1154 // 4,2 - 4,3 - 4,4 - 4,8 - 4,16
1155 // 5,2 - 5,3 - 5,4 - 5,8 - 5,16
1156 // 6,2 - 6,3 - 6,4 - 6,8 - 6,16
1157 // 7,2 - 7,3 - 7,4 - 7,8 - 7,16
1158 // 8,2 - 8,3 - 8,4 - 8,8 - 8,16
1159 // Load values from LHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +01001160 LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001161
1162 // Load values from RHS matrix
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001163 LOAD_BLOCK(N0, K0, DATA_TYPE, b, rhs_ptr, rhs_offset, RHS_STEP_X * sizeof(DATA_TYPE), zero);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001164
1165 // Accumulate
1166 ARM_DOT_K0XN0(K0, a0, b, c0);
1167#if M0 > 1
1168 ARM_DOT_K0XN0(K0, a1, b, c1);
1169#endif // M0 > 1
1170#if M0 > 2
1171 ARM_DOT_K0XN0(K0, a2, b, c2);
1172#endif // M0 > 2
1173#if M0 > 3
1174 ARM_DOT_K0XN0(K0, a3, b, c3);
1175#endif // M0 > 3
1176#if M0 > 4
1177 ARM_DOT_K0XN0(K0, a4, b, c4);
1178#endif // M0 > 4
1179#if M0 > 5
1180 ARM_DOT_K0XN0(K0, a5, b, c5);
1181#endif // M0 > 5
1182#if M0 > 6
1183 ARM_DOT_K0XN0(K0, a6, b, c6);
1184#endif // M0 > 6
1185#if M0 > 7
1186 ARM_DOT_K0XN0(K0, a7, b, c7);
1187#endif // M0 > 7
1188
1189 lhs_offset += K0 * sizeof(DATA_TYPE);
1190 rhs_offset += (N0 * RHS_STEP_X * RHS_STEP_LOOP) * sizeof(DATA_TYPE);
1191 }
1192
1193 // Left-over accumulations
1194 for(; i < K; ++i)
1195 {
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001196 // Load values from LHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +01001197 LOAD_BLOCK(M0, 1, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001198
1199 // Load values from RHS matrix
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001200 LOAD_BLOCK(N0, 1, DATA_TYPE, b, rhs_ptr, rhs_offset, RHS_STEP_X * sizeof(DATA_TYPE), zero);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001201
1202 // Accumulate
1203 ARM_DOT_K0XN0(1, a0, b, c0);
1204#if M0 > 1
1205 ARM_DOT_K0XN0(1, a1, b, c1);
1206#endif // M0 > 1
1207#if M0 > 2
1208 ARM_DOT_K0XN0(1, a2, b, c2);
1209#endif // M0 > 2
1210#if M0 > 3
1211 ARM_DOT_K0XN0(1, a3, b, c3);
1212#endif // M0 > 3
1213#if M0 > 4
1214 ARM_DOT_K0XN0(1, a4, b, c4);
1215#endif // M0 > 4
1216#if M0 > 5
1217 ARM_DOT_K0XN0(1, a5, b, c5);
1218#endif // M0 > 5
1219#if M0 > 6
1220 ARM_DOT_K0XN0(1, a6, b, c6);
1221#endif // M0 > 6
1222#if M0 > 7
1223 ARM_DOT_K0XN0(1, a7, b, c7);
1224#endif // M0 > 7
1225
1226 lhs_offset += sizeof(DATA_TYPE);
1227 rhs_offset += sizeof(DATA_TYPE);
1228 }
1229
1230 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * dst_stride_y);
1231
1232 REPEAT_VAR_INIT_TO_CONST(8, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
1233
1234#if defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001235
1236 // The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
Usama Arif0681e3b2019-04-25 14:28:07 +01001237 CALCULATE_Z_OFFSET(M0, uint, zout, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001238
1239 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1240 // multiply dst_stride_z by DEPTH_GEMM3D
1241 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1242
1243#else // defined(REINTERPRET_OUTPUT_AS_3D)
1244
1245 // Add offset for batched GEMM
1246 dst_addr += z * dst_stride_z;
1247
1248#endif // defined(REINTERPRET_OUTPUT_AS_3D)
1249
1250 // Multiply by the weight of matrix-matrix product and store the result
1251#if defined(ALPHA)
Usama Arif0681e3b2019-04-25 14:28:07 +01001252 SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001253#endif // defined(ALPHA)
1254
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001255 // Add beta*bias
1256#if defined(BETA)
1257#if defined(BROADCAST_BIAS)
1258 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
1259
1260 LOAD_BLOCK(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
1261
1262#ifndef UNIT_BETA
1263 SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
1264#endif // UNIT_BIAS
1265
1266 // c = c + bias[broadcasted]
1267 ADD_BLOCK_BROADCAST(M0, c, bias0);
1268
1269#else // defined(BROADCAST_BIAS)
1270 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id(
1271 2) * bias_stride_z;
1272
1273 LOAD_BLOCK(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
1274
1275#ifndef UNIT_BETA
1276 SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
1277#endif // UNIT_BIAS
1278
1279 // c = c + bias
1280 ADD_BLOCK(M0, c, bias);
1281
1282#endif // defined(BROADCAST_BIAS)
1283#endif // defined(BETA)
1284
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +01001285#if defined(ACTIVATION_TYPE)
1286 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE, c, A_VAL, B_VAL);
1287#endif // defined(ACTIVATION_TYPE)
1288
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001289 // Store output block
Usama Arif0681e3b2019-04-25 14:28:07 +01001290 STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001291
1292#undef RHS_BLOCK_SIZE
1293#undef RHS_OFFSET_X
1294#undef RHS_STEP_X
1295}
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001296
1297#define VFMA(a, b, c) \
1298 ({ \
1299 c = fma(a, b, c); \
1300 })
1301
1302#if M0 == 1
1303#define LD_RHS_VFMA_M0xN0(i, a, c) \
1304 ({ \
1305 VEC_DATA_TYPE(DATA_TYPE, N0) \
1306 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1307 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1308 })
1309#elif M0 == 2 // M0 == 2
1310#define LD_RHS_VFMA_M0xN0(i, a, c) \
1311 ({ \
1312 VEC_DATA_TYPE(DATA_TYPE, N0) \
1313 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1314 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1315 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1316 })
1317#elif M0 == 3 // M0 == 3
1318#define LD_RHS_VFMA_M0xN0(i, a, c) \
1319 ({ \
1320 VEC_DATA_TYPE(DATA_TYPE, N0) \
1321 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1322 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1323 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1324 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1325 })
1326#elif M0 == 4 // M0 == 4
1327#define LD_RHS_VFMA_M0xN0(i, a, c) \
1328 ({ \
1329 VEC_DATA_TYPE(DATA_TYPE, N0) \
1330 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1331 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1332 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1333 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1334 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1335 })
1336#elif M0 == 5 // M0 == 5
1337#define LD_RHS_VFMA_M0xN0(i, a, c) \
1338 ({ \
1339 VEC_DATA_TYPE(DATA_TYPE, N0) \
1340 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1341 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1342 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1343 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1344 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1345 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
1346 })
1347#elif M0 == 6 // M0 == 6
1348#define LD_RHS_VFMA_M0xN0(i, a, c) \
1349 ({ \
1350 VEC_DATA_TYPE(DATA_TYPE, N0) \
1351 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1352 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1353 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1354 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1355 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1356 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
1357 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
1358 })
1359#elif M0 == 7 // M0 == 7
1360#define LD_RHS_VFMA_M0xN0(i, a, c) \
1361 ({ \
1362 VEC_DATA_TYPE(DATA_TYPE, N0) \
1363 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1364 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1365 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1366 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1367 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1368 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
1369 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
1370 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##6).s##i), b, (c##6)); \
1371 })
1372#elif M0 == 8 // M0 == 8
1373#define LD_RHS_VFMA_M0xN0(i, a, c) \
1374 ({ \
1375 VEC_DATA_TYPE(DATA_TYPE, N0) \
1376 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1377 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1378 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1379 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1380 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1381 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
1382 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
1383 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##6).s##i), b, (c##6)); \
1384 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##7).s##i), b, (c##7)); \
1385 })
1386#else // M0 not supported
1387#error "M0 not supported"
1388#endif // M0 not supported
1389
1390/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
1391 * The LHS matrix is NOT reshaped
1392 * The RHS is reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the block K0xN0 is NOT transposed
1393 *
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001394 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
1395 * @note The GEMM's dimensions (M,N and K) must be passed at compile time using -DM, -DN and and -DK (i.e. -DM=52, -DN=30 and -DK=90).
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001396 * @note The block's dimensions used for reshaping the RHS matrix (N0 and K0) must be passed at compile time using -DN0 and -DK0 (i.e. -DN0=8, -DK0=4).
1397 * @note The number of M0 rows to process must be passed at compile time using -DM0 (i.e. -DM0=2)
1398 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (i.e. -DH0=2)
1399 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
1400 * @note Only the following configurations of M0, N0 and K0 are currently supported:
1401 * - M0 = 1, 2, 3, 4, 5, 6, 7, 8
1402 * - N0 = 2, 3, 4, 8, 16
1403 * - K0 = 2, 3, 4, 8, 16
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001404 * - H0 >= 1
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001405 *
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +01001406 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (i.e. -DACTIVATION_TYPE=RELU), A, B variables required by some activation functions and should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
1407 * The activation function is performed after the bias addition
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001408 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
1409 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
1410 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1411 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1412 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1413 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix
1414 *
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001415 * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: F16/F32
1416 * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes)
1417 * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1418 * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes)
1419 * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1420 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
1421 * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
1422 * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes)
1423 * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1424 * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
1425 * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1426 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001427 * @param[in] bias_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
1428 * @param[in] bias_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001429 * @param[in] bias_step_x (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes)
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001430 * @param[in] bias_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001431 * @param[in] bias_step_y (Optional) bias_stride_y * number of elements along Y processed per workitem(in bytes)
1432 * @param[in] bias_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
1433 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
1434 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1435 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1436 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1437 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1438 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1439 * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes)
1440 * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001441 * @param[in] bias_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001442 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1443 * @param[in] lhs_cross_plane_pad (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
1444 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001445 */
1446__kernel void gemm_mm_reshaped_only_rhs_nt(IMAGE_DECLARATION(lhs),
1447 IMAGE_DECLARATION(rhs),
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001448#if defined(BETA)
1449 IMAGE_DECLARATION(bias),
1450#endif // defined(BETA)
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001451 IMAGE_DECLARATION(dst),
1452 uint lhs_stride_z,
1453 uint rhs_stride_z,
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001454#if defined(BETA)
1455 uint bias_stride_z,
1456#endif //defined(BETA)
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001457 uint dst_stride_z
1458#if defined(REINTERPRET_INPUT_AS_3D)
1459 ,
1460 uint lhs_cross_plane_pad
1461#endif // REINTERPRET_INPUT_AS_3D
1462#if defined(REINTERPRET_OUTPUT_AS_3D)
1463 ,
1464 uint dst_cross_plane_pad
1465#endif // REINTERPRET_OUTPUT_AS_3D
1466 )
1467{
1468 // Block size
1469#define RHS_BLOCK_SIZE ((K0) * (N0))
1470
1471 // RHS offset and step X
1472#if defined(RHS_INTERLEAVE)
1473#define RHS_OFFSET_X (N0)
1474#define RHS_STEP_X ((N0) * (H0))
1475#define RHS_STEP_LOOP (1)
1476#else // defined(RHS_INTERLEAVE)
1477#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
1478#define RHS_STEP_X (N0)
1479#define RHS_STEP_LOOP (H0)
1480#endif // defined(RHS_INTERLEAVE)
1481
1482 uint x = get_global_id(0);
1483 uint y = get_global_id(1);
1484 uint z = get_global_id(2);
1485
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001486#if defined(DUMMY_WORK_ITEMS)
1487 if((x * N0 >= N) || (y * M0 >= M))
1488 {
1489 return;
1490 }
1491#endif // defined(DUMMY_WORK_ITEMS)
1492
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001493 // Compute LHS matrix address
1494 uint lhs_offset = lhs_offset_first_element_in_bytes + y * M0 * (uint)lhs_stride_y;
1495
1496 // Compute RHS matrix address
1497 uint rhs_offset = rhs_offset_first_element_in_bytes + (x % H0) * (uint)RHS_OFFSET_X * sizeof(DATA_TYPE) + (x / (uint)H0) * rhs_stride_y;
1498
1499#if defined(MATRIX_B_DEPTH)
1500 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1501 rhs_offset += (z % MATRIX_B_DEPTH) * rhs_stride_z;
1502#else // defined(MATRIX_B_DEPTH)
1503 rhs_offset += z * rhs_stride_z;
1504#endif // defined(MATRIX_B_DEPTH)
1505
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001506 REPEAT_VAR_INIT_TO_CONST(8, uint, zin, 0); //uint zin0=0,zin1=0,zin2=0,... zin7=0;
1507 REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0); //uint zero0=0,zero1=0,zero2=0,... zero7=0;
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001508
1509#if defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001510
1511 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
Usama Arif0681e3b2019-04-25 14:28:07 +01001512 CALCULATE_Z_OFFSET(M0, uint, zin, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001513
1514 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1515 // multiply lhs_stride_z by DEPTH_GEMM3D
1516 lhs_offset += z * lhs_stride_z * DEPTH_GEMM3D;
1517
1518#else // defined(REINTERPRET_INPUT_AS_3D)
1519
1520 // Add offset for batched GEMM
1521 lhs_offset += z * lhs_stride_z;
1522
1523#endif // defined(REINTERPRET_INPUT_AS_3D)
1524
1525 // Initialize the accumulators
1526 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(N0-1)=0;
1527
1528 int i = 0;
1529 for(; i <= (K - K0); i += K0)
1530 {
1531 // Supported cases (M0, K0):
1532 // 1,2 - 1,3 - 1,4 - 1,8 - 1,16
1533 // 2,2 - 2,3 - 2,4 - 2,8 - 2,16
1534 // 3,2 - 3,3 - 3,4 - 3,8 - 3,16
1535 // 4,2 - 4,3 - 4,4 - 4,8 - 4,16
1536 // 5,2 - 5,3 - 5,4 - 5,8 - 5,16
1537 // 6,2 - 6,3 - 6,4 - 6,8 - 6,16
1538 // 7,2 - 7,3 - 7,4 - 7,8 - 7,16
1539 // 8,2 - 8,3 - 8,4 - 8,8 - 8,16
1540 // Load values from LHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +01001541 LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zin);
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001542
1543 LD_RHS_VFMA_M0xN0(0, a, c);
1544 LD_RHS_VFMA_M0xN0(1, a, c);
1545#if K0 > 2
1546 LD_RHS_VFMA_M0xN0(2, a, c);
1547#endif // K0 > 2
1548#if K0 > 3
1549 LD_RHS_VFMA_M0xN0(3, a, c);
1550#endif // K0 > 3
1551#if K0 > 4
1552 LD_RHS_VFMA_M0xN0(4, a, c);
1553 LD_RHS_VFMA_M0xN0(5, a, c);
1554 LD_RHS_VFMA_M0xN0(6, a, c);
1555 LD_RHS_VFMA_M0xN0(7, a, c);
1556#endif // K0 > 4
1557#if K0 > 8
1558 LD_RHS_VFMA_M0xN0(8, a, c);
1559 LD_RHS_VFMA_M0xN0(9, a, c);
1560 LD_RHS_VFMA_M0xN0(A, a, c);
1561 LD_RHS_VFMA_M0xN0(B, a, c);
1562 LD_RHS_VFMA_M0xN0(C, a, c);
1563 LD_RHS_VFMA_M0xN0(D, a, c);
1564 LD_RHS_VFMA_M0xN0(E, a, c);
1565 LD_RHS_VFMA_M0xN0(F, a, c);
1566#endif // K0 > 8
1567
1568 lhs_offset += K0 * sizeof(DATA_TYPE);
1569 rhs_offset += K0 * RHS_STEP_X * RHS_STEP_LOOP * sizeof(DATA_TYPE);
1570 }
1571
1572 // Left-over accumulations
1573 for(; i < K; ++i)
1574 {
1575 // Load values from LHS matrix
1576 VEC_DATA_TYPE(DATA_TYPE, 2)
1577 a0 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 0 * lhs_stride_y + zin0));
1578#if M0 > 1
1579 VEC_DATA_TYPE(DATA_TYPE, 2)
1580 a1 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 1 * lhs_stride_y + zin1));
1581#endif // M0 > 1
1582#if M0 > 2
1583 VEC_DATA_TYPE(DATA_TYPE, 2)
1584 a2 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 2 * lhs_stride_y + zin2));
1585#endif // M0 > 2
1586#if M0 > 3
1587 VEC_DATA_TYPE(DATA_TYPE, 2)
1588 a3 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 3 * lhs_stride_y + zin3));
1589#endif // M0 > 3
1590#if M0 > 4
1591 VEC_DATA_TYPE(DATA_TYPE, 2)
1592 a4 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 4 * lhs_stride_y + zin4));
1593#endif // M0 > 4
1594#if M0 > 5
1595 VEC_DATA_TYPE(DATA_TYPE, 2)
1596 a5 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 5 * lhs_stride_y + zin5));
1597#endif // M0 > 5
1598#if M0 > 6
1599 VEC_DATA_TYPE(DATA_TYPE, 2)
1600 a6 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 6 * lhs_stride_y + zin6));
1601#endif // M0 > 6
1602#if M0 > 7
1603 VEC_DATA_TYPE(DATA_TYPE, 2)
giuros01b3204e72019-04-01 13:50:22 +01001604 a7 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 7 * lhs_stride_y + zin7));
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001605#endif // M0 > 7
1606
1607 LD_RHS_VFMA_M0xN0(0, a, c);
1608
1609 lhs_offset += sizeof(DATA_TYPE);
1610 rhs_offset += RHS_STEP_X * sizeof(DATA_TYPE);
1611 }
1612
1613 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * dst_stride_y);
1614
1615 REPEAT_VAR_INIT_TO_CONST(8, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
1616
1617#if defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001618 // The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
Usama Arif0681e3b2019-04-25 14:28:07 +01001619 CALCULATE_Z_OFFSET(M0, uint, zout, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001620
1621 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1622 // multiply dst_stride_z by DEPTH_GEMM3D
1623 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1624
1625#else // defined(REINTERPRET_OUTPUT_AS_3D)
1626
1627 // Add offset for batched GEMM
1628 dst_addr += z * dst_stride_z;
1629
1630#endif // defined(REINTERPRET_OUTPUT_AS_3D)
1631
1632 // Multiply by the weight of matrix-matrix product and store the result
1633#if defined(ALPHA)
Usama Arif0681e3b2019-04-25 14:28:07 +01001634 SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001635#endif // defined(ALPHA)
1636
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001637 // Add beta*bias
1638#if defined(BETA)
1639#if defined(BROADCAST_BIAS)
1640 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
1641
1642 LOAD_BLOCK(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
1643
1644#ifndef UNIT_BETA
1645 SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
1646#endif // UNIT_BIAS
1647
1648 // c = c + bias[broadcasted]
1649 ADD_BLOCK_BROADCAST(M0, c, bias0);
1650
1651#else // defined(BROADCAST_BIAS)
1652 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id(
1653 2) * bias_stride_z;
1654
1655 LOAD_BLOCK(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
1656
1657#ifndef UNIT_BETA
1658 SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
1659#endif // UNIT_BIAS
1660
1661 // c = c + bias
1662 ADD_BLOCK(M0, c, bias);
1663
1664#endif // defined(BROADCAST_BIAS)
1665#endif // defined(BETA)
1666
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +01001667#if defined(ACTIVATION_TYPE)
1668 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE, c, A_VAL, B_VAL);
1669#endif // defined(ACTIVATION_TYPE)
1670
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001671 // Store output block
Usama Arif0681e3b2019-04-25 14:28:07 +01001672 STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001673
1674#undef RHS_BLOCK_SIZE
1675#undef RHS_OFFSET_X
1676#undef RHS_STEP_X
1677}
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001678#endif // defined(M0) && defined(N0) && defined(K0) && defined(H0) && defined(DATA_TYPE) && defined(M) && defined(N) && defined(K)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001679
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001680#if defined(M0) && defined(N0) && defined(K0) && defined(V0) && defined(H0) && defined(DATA_TYPE) && defined(M) && defined(N)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001681
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001682#if K0 == 2
1683#define ARM_DOT_K0(a, b, c) \
1684 ({ \
1685 c = fma(a.s0, b.s0, c); \
1686 c = fma(a.s1, b.s1, c); \
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001687 })
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001688#elif K0 == 3 // K0 == 3
1689#define ARM_DOT_K0(a, b, c) \
1690 ({ \
1691 c = fma(a.s0, b.s0, c); \
1692 c = fma(a.s1, b.s1, c); \
1693 c = fma(a.s2, b.s2, c); \
1694 })
1695#elif K0 == 4 // K0 == 4
1696#define ARM_DOT_K0(a, b, c) \
1697 ({ \
1698 c = fma(a.s0, b.s0, c); \
1699 c = fma(a.s1, b.s1, c); \
1700 c = fma(a.s2, b.s2, c); \
1701 c = fma(a.s3, b.s3, c); \
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001702 })
1703#elif K0 == 8 // K0 == 8
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001704#define ARM_DOT_K0(a, b, c) \
1705 ({ \
1706 c = fma(a.s0, b.s0, c); \
1707 c = fma(a.s1, b.s1, c); \
1708 c = fma(a.s2, b.s2, c); \
1709 c = fma(a.s3, b.s3, c); \
1710 c = fma(a.s4, b.s4, c); \
1711 c = fma(a.s5, b.s5, c); \
1712 c = fma(a.s6, b.s6, c); \
1713 c = fma(a.s7, b.s7, c); \
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001714 })
1715#elif K0 == 16 // K0 == 16
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001716#define ARM_DOT_K0(a, b, c) \
1717 ({ \
1718 c = fma(a.s0, b.s0, c); \
1719 c = fma(a.s1, b.s1, c); \
1720 c = fma(a.s2, b.s2, c); \
1721 c = fma(a.s3, b.s3, c); \
1722 c = fma(a.s4, b.s4, c); \
1723 c = fma(a.s5, b.s5, c); \
1724 c = fma(a.s6, b.s6, c); \
1725 c = fma(a.s7, b.s7, c); \
1726 c = fma(a.s8, b.s8, c); \
1727 c = fma(a.s9, b.s9, c); \
1728 c = fma(a.sA, b.sA, c); \
1729 c = fma(a.sB, b.sB, c); \
1730 c = fma(a.sC, b.sC, c); \
1731 c = fma(a.sD, b.sD, c); \
1732 c = fma(a.sE, b.sE, c); \
1733 c = fma(a.sF, b.sF, c); \
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001734 })
1735#else // K0 not supported
1736#error "K0 value not supported"
1737#endif // K0 conditions
1738
1739#if N0 == 2
1740#define ARM_DOT_K0XN0(a, b, c) \
1741 ({ \
1742 ARM_DOT_K0((a), (b##0), (c.s0)); \
1743 ARM_DOT_K0((a), (b##1), (c.s1)); \
1744 })
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001745#elif N0 == 3 // N0 == 3
1746#define ARM_DOT_K0XN0(a, b, c) \
1747 ({ \
1748 ARM_DOT_K0((a), (b##0), (c.s0)); \
1749 ARM_DOT_K0((a), (b##1), (c.s1)); \
1750 ARM_DOT_K0((a), (b##2), (c.s2)); \
1751 })
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001752#elif N0 == 4 // N0 == 4
1753#define ARM_DOT_K0XN0(a, b, c) \
1754 ({ \
1755 ARM_DOT_K0((a), (b##0), (c.s0)); \
1756 ARM_DOT_K0((a), (b##1), (c.s1)); \
1757 ARM_DOT_K0((a), (b##2), (c.s2)); \
1758 ARM_DOT_K0((a), (b##3), (c.s3)); \
1759 })
1760#elif N0 == 8 // N0 == 8
1761#define ARM_DOT_K0XN0(a, b, c) \
1762 ({ \
1763 ARM_DOT_K0((a), (b##0), (c.s0)); \
1764 ARM_DOT_K0((a), (b##1), (c.s1)); \
1765 ARM_DOT_K0((a), (b##2), (c.s2)); \
1766 ARM_DOT_K0((a), (b##3), (c.s3)); \
1767 ARM_DOT_K0((a), (b##4), (c.s4)); \
1768 ARM_DOT_K0((a), (b##5), (c.s5)); \
1769 ARM_DOT_K0((a), (b##6), (c.s6)); \
1770 ARM_DOT_K0((a), (b##7), (c.s7)); \
1771 })
1772#elif N0 == 16 // N0 == 16
1773#define ARM_DOT_K0XN0(a, b, c) \
1774 ({ \
1775 ARM_DOT_K0((a), (b##0), (c.s0)); \
1776 ARM_DOT_K0((a), (b##1), (c.s1)); \
1777 ARM_DOT_K0((a), (b##2), (c.s2)); \
1778 ARM_DOT_K0((a), (b##3), (c.s3)); \
1779 ARM_DOT_K0((a), (b##4), (c.s4)); \
1780 ARM_DOT_K0((a), (b##5), (c.s5)); \
1781 ARM_DOT_K0((a), (b##6), (c.s6)); \
1782 ARM_DOT_K0((a), (b##7), (c.s7)); \
1783 ARM_DOT_K0((a), (b##8), (c.s8)); \
1784 ARM_DOT_K0((a), (b##9), (c.s9)); \
1785 ARM_DOT_K0((a), (b##A), (c.sA)); \
1786 ARM_DOT_K0((a), (b##B), (c.sB)); \
1787 ARM_DOT_K0((a), (b##C), (c.sC)); \
1788 ARM_DOT_K0((a), (b##D), (c.sD)); \
1789 ARM_DOT_K0((a), (b##E), (c.sE)); \
1790 ARM_DOT_K0((a), (b##F), (c.sF)); \
1791 })
1792#else // N0 not supported
1793#error "N0 value not supported"
1794#endif // N0 conditions
1795
1796/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
1797 * The LHS matrix must be reshaped with @ref CLGEMMReshapeLHSMatrixKernel and the M0xK0 must be NOT transposed
1798 * The RHS matrix must be reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the K0xN0 must be transposed
1799 *
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001800 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
1801 * @note The GEMM's dimensions M and N must be passed at compile time using -DM and -DN (i.e. -DM=52 and -DN=90).
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001802 * @note The block's dimensions used for reshaping the LHS matrix and the RHS matrix (M0, N0 and K0) must be passed at compile time using -DM0, -DN0 and -DK0 (i.e. -DM0=4, -DN0=8, -DK0=4).
1803 * @note The number of M0xK0 vertical blocks stored on the same output row of the reshaped LHS matrix must be passed at compile time using -DV0 (i.e. -DV0=2)
1804 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (i.e. -DH0=2)
1805 * @note If the M0xK0 blocks in the reshaped LHS matrix have been interleaved, the option -DLHS_INTERLEAVE must passed at compile time.
1806 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
1807 * @note Only the following configurations of M0, N0 and K0 are currently supported:
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001808 * - M0 = 1, 2, 3, 4, 5, 6, 7, 8
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001809 * - N0 = 2, 3, 4, 8, 16
1810 * - K0 = 2, 3, 4, 8, 16
Gian Marco Iodice62251f72019-03-11 16:07:12 +00001811 * - V0 >= 1
1812 * - H0 >= 1
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001813 *
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +01001814 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (i.e. -DACTIVATION_TYPE=RELU), A, B variables required by some activation functions and should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
1815 * The activation function is performed after the bias addition
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001816 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
1817 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1818 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1819 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1820 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix NOT reshaped
1821 *
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001822 * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: F16/F32
1823 * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes)
1824 * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1825 * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes)
1826 * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1827 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
1828 * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
1829 * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes)
1830 * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1831 * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
1832 * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1833 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
1834 * @param[in] bias_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
1835 * @param[in] bias_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
1836 * @param[in] bias_step_x (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes)
1837 * @param[in] bias_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
1838 * @param[in] bias_step_y (Optional) bias_stride_y * number of elements along Y processed per workitem(in bytes)
1839 * @param[in] bias_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
1840 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
1841 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1842 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1843 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1844 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1845 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1846 * @param[in] k Number of columns in LHS matrix and rows in RHS matrix not reshaped.
1847 * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes)
1848 * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
1849 * @param[in] bias_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
1850 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1851 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001852 */
1853__kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs),
1854 IMAGE_DECLARATION(rhs),
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001855#if defined(BETA)
1856 IMAGE_DECLARATION(bias),
1857#endif // defined(BETA)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001858 IMAGE_DECLARATION(dst),
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001859 uint k,
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001860 uint lhs_stride_z,
1861 uint rhs_stride_z,
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001862#if defined(BETA)
1863 uint bias_stride_z,
1864#endif //defined(BETA)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001865 uint dst_stride_z
1866#if defined(REINTERPRET_OUTPUT_AS_3D)
1867 ,
1868 uint dst_cross_plane_pad
1869#endif // REINTERPRET_OUTPUT_AS_3D
1870 )
1871{
1872 // Block size
1873#define LHS_BLOCK_SIZE ((K0) * (M0))
1874
1875#if defined(LHS_INTERLEAVE)
1876#define LHS_OFFSET_X (K0)
1877#define LHS_STEP_X ((K0) * (V0))
1878#define LHS_STEP_LOOP (1)
1879#else // defined(INTERLEAVE)
1880#define LHS_OFFSET_X (LHS_BLOCK_SIZE)
1881#define LHS_STEP_X (K0)
1882#define LHS_STEP_LOOP (V0)
1883#endif // defined(INTERLEAVE)
1884
1885 // Block size
1886#define RHS_BLOCK_SIZE ((K0) * (N0))
1887
1888 // RHS offset and step X
1889#if defined(RHS_INTERLEAVE)
1890#define RHS_OFFSET_X (K0)
1891#define RHS_STEP_X ((K0) * (H0))
1892#define RHS_STEP_LOOP (1)
1893#else // defined(RHS_INTERLEAVE)
1894#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
1895#define RHS_STEP_X (K0)
1896#define RHS_STEP_LOOP (H0)
1897#endif // defined(RHS_INTERLEAVE)
1898
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001899#if defined(DUMMY_WORK_ITEMS)
1900 if((get_global_id(0) * N0 >= N) || (get_global_id(1) * M0 >= M))
1901 {
1902 return;
1903 }
1904#endif // defined(DUMMY_WORK_ITEMS)
1905
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001906 // Compute LHS matrix address
1907 __global uchar *lhs_addr = lhs_ptr + lhs_offset_first_element_in_bytes + (get_global_id(1) % V0) * (uint)LHS_OFFSET_X * sizeof(DATA_TYPE) + (get_global_id(1) / V0) * (uint)lhs_stride_y +
1908 (get_global_id(2) * lhs_stride_z);
1909
1910 // Compute RHS matrix address
1911 __global uchar *rhs_addr = rhs_ptr + rhs_offset_first_element_in_bytes + (get_global_id(0) % H0) * (uint)RHS_OFFSET_X * sizeof(DATA_TYPE) + (get_global_id(0) / (uint)H0) * rhs_stride_y;
1912
1913#if defined(MATRIX_B_DEPTH)
1914 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1915 rhs_addr += (get_global_id(2) % MATRIX_B_DEPTH) * rhs_stride_z;
1916#else // defined(MATRIX_B_DEPTH)
1917 rhs_addr += get_global_id(2) * rhs_stride_z;
1918#endif // defined(MATRIX_B_DEPTH)
1919
1920 // Initialize the accumulators
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +00001921 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(M0-1)=0;
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001922
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001923 REPEAT_VAR_INIT_TO_CONST(M0, uint, zlhs, 0); //uint zlhs0=0,zlhs1=0,zlhs2=0,... zlhs7=0;
1924 REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0);
Usama Arif0681e3b2019-04-25 14:28:07 +01001925
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001926 for(int i = 0; i < k; i += K0)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001927 {
1928 // Supported cases (M0, K0):
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001929 // 1,2 - 1,3 - 1,4 - 1,8 - 1,16
1930 // 2,2 - 2,3 - 2,4 - 2,8 - 2,16
1931 // 3,2 - 3,3 - 3,4 - 3,8 - 3,16
1932 // 4,2 - 4,3 - 4,4 - 4,8 - 4,16
1933 // 5,2 - 5,3 - 5,4 - 5,8 - 5,16
1934 // 6,2 - 6,3 - 6,4 - 6,8 - 6,16
1935 // 7,2 - 7,3 - 7,4 - 7,8 - 7,16
1936 // 8,2 - 8,3 - 8,4 - 8,8 - 8,16
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001937 // Load values from LHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +01001938 LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_addr, 0, LHS_STEP_X * sizeof(DATA_TYPE), zlhs);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001939
1940 // Load values from RHS matrix
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001941 LOAD_BLOCK(N0, K0, DATA_TYPE, b, rhs_addr, 0, RHS_STEP_X * sizeof(DATA_TYPE), zero);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001942
1943 // Accumulate
1944 ARM_DOT_K0XN0(a0, b, c0);
1945#if M0 > 1
1946 ARM_DOT_K0XN0(a1, b, c1);
1947#endif // M0 > 1
1948#if M0 > 2
1949 ARM_DOT_K0XN0(a2, b, c2);
1950#endif // M0 > 2
1951#if M0 > 3
1952 ARM_DOT_K0XN0(a3, b, c3);
1953#endif // M0 > 3
1954#if M0 > 4
1955 ARM_DOT_K0XN0(a4, b, c4);
1956#endif // M0 > 4
1957#if M0 > 5
1958 ARM_DOT_K0XN0(a5, b, c5);
1959#endif // M0 > 5
1960#if M0 > 6
1961 ARM_DOT_K0XN0(a6, b, c6);
1962#endif // M0 > 6
1963#if M0 > 7
1964 ARM_DOT_K0XN0(a7, b, c7);
1965#endif // M0 > 7
1966
1967 lhs_addr += (M0 * LHS_STEP_X * LHS_STEP_LOOP) * sizeof(DATA_TYPE);
1968 rhs_addr += (N0 * RHS_STEP_X * RHS_STEP_LOOP) * sizeof(DATA_TYPE);
1969 }
1970
1971 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * dst_stride_y);
1972
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001973 REPEAT_VAR_INIT_TO_CONST(M0, uint, zout, 0);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001974
1975#if defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001976
1977 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
Usama Arif0681e3b2019-04-25 14:28:07 +01001978 CALCULATE_Z_OFFSET(M0, uint, zout, get_global_id(1), HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001979 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1980 // multiply dst_stride_z by DEPTH_GEMM3D
1981 dst_addr += get_global_id(2) * dst_stride_z * DEPTH_GEMM3D;
1982
1983#else // defined(REINTERPRET_OUTPUT_AS_3D)
1984
1985 // Add offset for batched GEMM
1986 dst_addr += get_global_id(2) * dst_stride_z;
1987
1988#endif // defined(REINTERPRET_OUTPUT_AS_3D)
1989
1990 // Multiply by the weight of matrix-matrix product and store the result
1991#if defined(ALPHA)
Usama Arif0681e3b2019-04-25 14:28:07 +01001992 SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001993#endif // defined(ALPHA)
1994
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001995 // Add beta*bias
1996#if defined(BETA)
1997#if defined(BROADCAST_BIAS)
1998 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
1999
2000 LOAD_BLOCK(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
2001
2002#ifndef UNIT_BETA
2003 SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
2004#endif // UNIT_BIAS
2005
2006 // c = c + bias[broadcasted]
2007 ADD_BLOCK_BROADCAST(M0, c, bias0);
2008
2009#else // defined(BROADCAST_BIAS)
2010 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id(
2011 2) * bias_stride_z;
2012
2013 LOAD_BLOCK(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
2014
2015#ifndef UNIT_BETA
2016 SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
2017#endif // UNIT_BIAS
2018
2019 // c = c + bias
2020 ADD_BLOCK(M0, c, bias);
2021
2022#endif // defined(BROADCAST_BIAS)
2023#endif // defined(BETA)
2024
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +01002025#if defined(ACTIVATION_TYPE)
2026 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE, c, A_VAL, B_VAL);
2027#endif // defined(ACTIVATION_TYPE)
2028
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002029 // Store output block
Usama Arif0681e3b2019-04-25 14:28:07 +01002030 STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01002031
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002032#undef LHS_BLOCK_SIZE
2033#undef LHS_OFFSET_X
2034#undef LHS_STEP_X
2035#undef RHS_BLOCK_SIZE
2036#undef RHS_OFFSET_X
2037#undef RHS_STEP_X
2038}
giuros01b3204e72019-04-01 13:50:22 +01002039
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002040#endif // defined(M0) && defined(N0) && defined(K0) && defined(V0) && defined(H0) && defined(K) && defined(DATA_TYPE)
2041
giuros01b3204e72019-04-01 13:50:22 +01002042#if defined(M0) && defined(N0) && defined(K0) && defined(K) && defined(DATA_TYPE)
2043
2044#define VFMA(a, b, c) \
2045 ({ \
2046 c = fma(a, b, c); \
2047 })
2048
2049#if M0 == 1
2050#define RHS_VFMA_M0xN0(i, a, b, c) \
2051 ({ \
2052 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
2053 })
2054#elif M0 == 2 // M0 == 2
2055#define RHS_VFMA_M0xN0(i, a, b, c) \
2056 ({ \
2057 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
2058 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
2059 })
2060#elif M0 == 3 // M0 == 3
2061#define RHS_VFMA_M0xN0(i, a, b, c) \
2062 ({ \
2063 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
2064 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
2065 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
2066 })
2067#elif M0 == 4 // M0 == 4
2068#define RHS_VFMA_M0xN0(i, a, b, c) \
2069 ({ \
2070 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
2071 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
2072 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
2073 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
2074 })
2075#elif M0 == 5 // M0 == 5
2076#define RHS_VFMA_M0xN0(i, a, b, c) \
2077 ({ \
2078 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
2079 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
2080 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
2081 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
2082 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
2083 })
2084#elif M0 == 6 // M0 == 6
2085#define RHS_VFMA_M0xN0(i, a, b, c) \
2086 ({ \
2087 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
2088 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
2089 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
2090 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
2091 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
2092 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
2093 })
2094#elif M0 == 7 // M0 == 7
2095#define RHS_VFMA_M0xN0(i, a, b, c) \
2096 ({ \
2097 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
2098 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
2099 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
2100 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
2101 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
2102 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
2103 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##6).s##i), b, (c##6)); \
2104 })
2105#elif M0 == 8 // M0 == 8
2106#define RHS_VFMA_M0xN0(i, a, b, c) \
2107 ({ \
2108 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
2109 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
2110 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
2111 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
2112 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
2113 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
2114 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##6).s##i), b, (c##6)); \
2115 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##7).s##i), b, (c##7)); \
2116 })
2117#else // M0 not supported
2118#error "M0 not supported"
2119#endif // M0 not supported
2120
2121/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
2122 * The LHS matrix is NOT reshaped
2123 * The RHS matrix is NOT reshaped
2124 *
2125 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
2126 * @note The GEMM's dimensions (M,N and K) must be passed at compile time using -DM, -DN and and -DK (i.e. -DM=52, -DN=30 and -DK=90)
2127 * @note The number of columns of LHS matrix must be passed at compile time using -DK (i.e. -DK=64)
2128 * @note The number of M0 rows to process must be passed at compile time using -DM0 (i.e. -DM0=2)
2129 * @note The number of K0 partial accumulations must be passed at compile time using -DK0 (i.e., -DK0=2)
2130 * @note The number of N0 columns to process must be passed at compile time using -DN0 (i.e. -DN0=2)
2131 * @note Only the following configurations of M0, N0 and K0 are currently supported:
2132 * - M0 = 1, 2, 3, 4, 5, 6, 7, 8
2133 * - N0 = 2, 3, 4, 8, 16
2134 * - K0 = 2, 3, 4, 8, 16
2135 *
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +01002136 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (i.e. -DACTIVATION_TYPE=RELU), A, B variables required by some activation functions and should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
2137 * The activation function is performed after the bias addition
giuros01b3204e72019-04-01 13:50:22 +01002138 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
2139 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
2140 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2141 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2142 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2143 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix
2144 *
Gian Marco Iodice944170e2019-06-24 14:40:30 +01002145 * @param[in] lhs_ptr Pointer to the LHS matrix. Supported data type: F16/F32
2146 * @param[in] lhs_stride_x Stride of the LHS matrix in X dimension (in bytes)
2147 * @param[in] lhs_step_x lhs_stride_x * number of elements along X processed per workitem(in bytes)
2148 * @param[in] lhs_stride_y Stride of the LHS matrix in Y dimension (in bytes)
2149 * @param[in] lhs_step_y lhs_stride_y * number of elements along Y processed per workitem(in bytes)
2150 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS matrix
2151 * @param[in] rhs_ptr Pointer to the RHS matrix. Supported data type: same as @p lhs_ptr
2152 * @param[in] rhs_stride_x Stride of the RHS matrix in X dimension (in bytes)
2153 * @param[in] rhs_step_x rhs_stride_x * number of elements along X processed per workitem(in bytes)
2154 * @param[in] rhs_stride_y Stride of the RHS matrix in Y dimension (in bytes)
2155 * @param[in] rhs_step_y rhs_stride_y * number of elements along Y processed per workitem(in bytes)
2156 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS matrix
2157 * @param[in] bias_ptr (Optional)Pointer to the bias reshaped matrix. Supported data type: same as @p lhs_ptr
2158 * @param[in] bias_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
2159 * @param[in] bias_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
2160 * @param[in] bias_step_x (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes)
2161 * @param[in] bias_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
2162 * @param[in] bias_step_y (Optional) bias_stride_y * number of elements along Y processed per workitem(in bytes)
2163 * @param[in] bias_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
2164 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
2165 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2166 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2167 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2168 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2169 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2170 * @param[in] lhs_stride_z Stride of the LHS matrix in Z dimension (in bytes)
2171 * @param[in] rhs_stride_z Stride of the RHS matrix in Z dimension (in bytes)
2172 * @param[in] bias_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
2173 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2174 * @param[in] lhs_cross_plane_pad (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
2175 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
giuros01b3204e72019-04-01 13:50:22 +01002176 */
2177__kernel void gemm_mm_native(IMAGE_DECLARATION(lhs),
2178 IMAGE_DECLARATION(rhs),
Gian Marco Iodice944170e2019-06-24 14:40:30 +01002179#if defined(BETA)
2180 IMAGE_DECLARATION(bias),
2181#endif // defined(BETA)
giuros01b3204e72019-04-01 13:50:22 +01002182 IMAGE_DECLARATION(dst),
2183 uint lhs_stride_z,
2184 uint rhs_stride_z,
Gian Marco Iodice944170e2019-06-24 14:40:30 +01002185#if defined(BETA)
2186 uint bias_stride_z,
2187#endif //defined(BETA)
giuros01b3204e72019-04-01 13:50:22 +01002188 uint dst_stride_z
2189#if defined(REINTERPRET_INPUT_AS_3D)
2190 ,
2191 uint lhs_cross_plane_pad
2192#endif // REINTERPRET_INPUT_AS_3D
2193#if defined(REINTERPRET_OUTPUT_AS_3D)
2194 ,
2195 uint dst_cross_plane_pad
2196#endif // REINTERPRET_OUTPUT_AS_3D
2197 )
2198{
2199 // Block size
2200#define RHS_BLOCK_SIZE ((K0) * (N0))
2201
2202 // RHS offset and step X
2203#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
2204
2205 uint x = get_global_id(0);
2206 uint y = get_global_id(1);
2207 uint z = get_global_id(2);
2208
2209#if defined(DUMMY_WORK_ITEMS)
2210 if((x * N0 >= N) || (y * M0 >= M))
2211 {
2212 return;
2213 }
2214#endif // defined(DUMMY_WORK_ITEMS)
2215
2216 // Compute LHS matrix address
2217 uint lhs_offset = lhs_offset_first_element_in_bytes + y * M0 * (uint)lhs_stride_y;
2218
2219 // Compute RHS matrix address
2220 uint rhs_offset = rhs_offset_first_element_in_bytes + x * N0 * sizeof(DATA_TYPE);
2221
2222#if defined(MATRIX_B_DEPTH)
2223 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2224 rhs_offset += (z % MATRIX_B_DEPTH) * rhs_stride_z;
2225#else // defined(MATRIX_B_DEPTH)
2226 rhs_offset += z * rhs_stride_z;
2227#endif // defined(MATRIX_B_DEPTH)
2228
Gian Marco Iodice944170e2019-06-24 14:40:30 +01002229 REPEAT_VAR_INIT_TO_CONST(M0, uint, zlhs, 0);
2230 REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0);
giuros01b3204e72019-04-01 13:50:22 +01002231
2232#if defined(REINTERPRET_INPUT_AS_3D)
2233 // The plane (zlhs) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
2234 CALCULATE_Z_OFFSET(M0, uint, zlhs, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
2235
2236 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2237 // multiply lhs_stride_z by DEPTH_GEMM3D
2238 lhs_offset += z * lhs_stride_z * DEPTH_GEMM3D;
2239
2240#else // defined(REINTERPRET_INPUT_AS_3D)
2241
2242 // Add offset for batched GEMM
2243 lhs_offset += z * lhs_stride_z;
2244
2245#endif // defined(REINTERPRET_INPUT_AS_3D)
2246
2247 // Initialize the accumulators
Gian Marco Iodice944170e2019-06-24 14:40:30 +01002248 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(M0-1)=0;
giuros01b3204e72019-04-01 13:50:22 +01002249
2250 int i = 0;
2251 for(; i <= (K - K0); i += K0)
2252 {
2253 // Supported cases (M0, K0):
2254 // 1,2 - 1,3 - 1,4 - 1,8 - 1,16
2255 // 2,2 - 2,3 - 2,4 - 2,8 - 2,16
2256 // 3,2 - 3,3 - 3,4 - 3,8 - 3,16
2257 // 4,2 - 4,3 - 4,4 - 4,8 - 4,16
2258 // 5,2 - 5,3 - 5,4 - 5,8 - 5,16
2259 // 6,2 - 6,3 - 6,4 - 6,8 - 6,16
2260 // 7,2 - 7,3 - 7,4 - 7,8 - 7,16
2261 // 8,2 - 8,3 - 8,4 - 8,8 - 8,16
2262 // Load values from LHS matrix
2263 LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs);
2264
2265 // Load values from RHS matrix
Gian Marco Iodice944170e2019-06-24 14:40:30 +01002266 LOAD_BLOCK(K0, N0, DATA_TYPE, b, rhs_ptr, rhs_offset, rhs_stride_y, zero);
giuros01b3204e72019-04-01 13:50:22 +01002267
2268 RHS_VFMA_M0xN0(0, a, b0, c);
2269 RHS_VFMA_M0xN0(1, a, b1, c);
2270#if K0 > 2
2271 RHS_VFMA_M0xN0(2, a, b2, c);
2272#endif // K0 > 2
2273#if K0 > 3
2274 RHS_VFMA_M0xN0(3, a, b3, c);
2275#endif // K0 > 3
2276#if K0 > 4
2277 RHS_VFMA_M0xN0(4, a, b4, c);
2278 RHS_VFMA_M0xN0(5, a, b5, c);
2279 RHS_VFMA_M0xN0(6, a, b6, c);
2280 RHS_VFMA_M0xN0(7, a, b7, c);
2281#endif // K0 > 4
2282#if K0 > 8
2283 RHS_VFMA_M0xN0(8, a, b8, c);
2284 RHS_VFMA_M0xN0(9, a, b9, c);
2285 RHS_VFMA_M0xN0(A, a, b10, c);
2286 RHS_VFMA_M0xN0(B, a, b11, c);
2287 RHS_VFMA_M0xN0(C, a, b12, c);
2288 RHS_VFMA_M0xN0(D, a, b13, c);
2289 RHS_VFMA_M0xN0(E, a, b14, c);
2290 RHS_VFMA_M0xN0(F, a, b15, c);
2291#endif // K0 > 8
2292
2293 lhs_offset += K0 * sizeof(DATA_TYPE);
2294 rhs_offset += K0 * rhs_stride_y;
2295 }
2296
2297 // Left-over accumulations
2298 for(; i < K; ++i)
2299 {
2300 // Load values from LHS matrix
2301 VEC_DATA_TYPE(DATA_TYPE, 2)
2302 a0 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 0 * lhs_stride_y + zlhs0));
2303#if M0 > 1
2304 VEC_DATA_TYPE(DATA_TYPE, 2)
2305 a1 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 1 * lhs_stride_y + zlhs1));
2306#endif // M0 > 1
2307#if M0 > 2
2308 VEC_DATA_TYPE(DATA_TYPE, 2)
2309 a2 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 2 * lhs_stride_y + zlhs2));
2310#endif // M0 > 2
2311#if M0 > 3
2312 VEC_DATA_TYPE(DATA_TYPE, 2)
2313 a3 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 3 * lhs_stride_y + zlhs3));
2314#endif // M0 > 3
2315#if M0 > 4
2316 VEC_DATA_TYPE(DATA_TYPE, 2)
2317 a4 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 4 * lhs_stride_y + zlhs4));
2318#endif // M0 > 4
2319#if M0 > 5
2320 VEC_DATA_TYPE(DATA_TYPE, 2)
2321 a5 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 5 * lhs_stride_y + zlhs5));
2322#endif // M0 > 5
2323#if M0 > 6
2324 VEC_DATA_TYPE(DATA_TYPE, 2)
2325 a6 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 6 * lhs_stride_y + zlhs6));
2326#endif // M0 > 6
2327#if M0 > 7
2328 VEC_DATA_TYPE(DATA_TYPE, 2)
2329 a7 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 7 * lhs_stride_y + zlhs7));
2330#endif // M0 > 7
2331
2332 VEC_DATA_TYPE(DATA_TYPE, N0)
2333 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0 * rhs_stride_y));
2334 RHS_VFMA_M0xN0(0, a, b, c);
2335
2336 lhs_offset += sizeof(DATA_TYPE);
2337 rhs_offset += rhs_stride_y;
2338 }
2339
2340 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * dst_stride_y);
2341
Gian Marco Iodice944170e2019-06-24 14:40:30 +01002342 REPEAT_VAR_INIT_TO_CONST(M0, uint, zout, 0);
giuros01b3204e72019-04-01 13:50:22 +01002343
2344#if defined(REINTERPRET_OUTPUT_AS_3D)
2345 // The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
2346 CALCULATE_Z_OFFSET(M0, uint, zout, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
2347
2348 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2349 // multiply dst_stride_z by DEPTH_GEMM3D
2350 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
2351
2352#else // defined(REINTERPRET_OUTPUT_AS_3D)
2353
2354 // Add offset for batched GEMM
2355 dst_addr += z * dst_stride_z;
2356
2357#endif // defined(REINTERPRET_OUTPUT_AS_3D)
2358
2359 // Multiply by the weight of matrix-matrix product and store the result
giuros01b3204e72019-04-01 13:50:22 +01002360#if defined(ALPHA)
2361 SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
2362#endif // defined(ALPHA)
2363
Gian Marco Iodice944170e2019-06-24 14:40:30 +01002364 // Add beta*bias
2365#if defined(BETA)
2366#if defined(BROADCAST_BIAS)
2367 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
2368
2369 LOAD_BLOCK(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
2370
2371#ifndef UNIT_BETA
2372 SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
2373#endif // UNIT_BIAS
2374
2375 // c = c + bias[broadcasted]
2376 ADD_BLOCK_BROADCAST(M0, c, bias0);
2377
2378#else // defined(BROADCAST_BIAS)
2379 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id(
2380 2) * bias_stride_z;
2381
2382 LOAD_BLOCK(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
2383
2384#ifndef UNIT_BETA
2385 SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
2386#endif // UNIT_BIAS
2387
2388 // c = c + bias
2389 ADD_BLOCK(M0, c, bias);
2390
2391#endif // defined(BROADCAST_BIAS)
2392#endif // defined(BETA)
2393
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +01002394#if defined(ACTIVATION_TYPE)
2395 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE, c, A_VAL, B_VAL);
2396#endif // defined(ACTIVATION_TYPE)
2397
giuros01b3204e72019-04-01 13:50:22 +01002398 // Store output block
2399 STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
2400
2401#undef RHS_BLOCK_SIZE
2402#undef RHS_OFFSET_X
2403#undef RHS_STEP_X
2404}
2405#endif // defined(M0) && defined(N0) && defined(K0) && defined(K) && defined(DATA_TYPE)
2406
Gian Marco36a0a462018-01-12 10:21:40 +00002407#if defined(COLS_B) && defined(MULT_TRANSPOSE1XW_WIDTH) && defined(MULT_INTERLEAVE4X4_HEIGHT)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002408/** This OpenCL kernel is optimised for Midgard. It computes the matrix multiplication between matrix A (src0) and matrix B (src1)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002409 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_32bit and @ref gemm_transpose1x4 before running the matrix multiplication
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002410 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002411 * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
2412 *
Gian Marco19835e52018-01-30 13:35:54 +00002413 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
2414 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
2415 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002416 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
2417 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002418 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002419 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
2420 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2421 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2422 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2423 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
2424 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002425 * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
2426 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002427 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
2428 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
2429 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2430 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
2431 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2432 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002433 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002434 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
2435 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2436 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
2437 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2438 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002439 * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
2440 * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes)
2441 * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
2442 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002443 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002444 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00002445 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002446 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00002447 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002448 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002449 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
2450 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
2451 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002452 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002453 */
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01002454__kernel void gemm_mm_interleaved_transposed_f32(IMAGE_DECLARATION(src0),
2455 IMAGE_DECLARATION(src1),
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002456#if defined(ADD_VEC_C)
2457 VECTOR_DECLARATION(src2),
2458#endif /* defined(ADD_VEC_C) */
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01002459 IMAGE_DECLARATION(dst),
2460 uint src0_stride_z,
2461 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002462 uint dst_stride_z
2463#if defined(REINTERPRET_OUTPUT_AS_3D)
2464 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002465 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002466#endif // REINTERPRET_OUTPUT_AS_3D
2467 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002468{
Gian Marco36a0a462018-01-12 10:21:40 +00002469 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
2470 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +00002471 int z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002472
Gian Marco36a0a462018-01-12 10:21:40 +00002473 // Offset
2474 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
2475 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 4;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002476
Gian Marco36a0a462018-01-12 10:21:40 +00002477 // src_addr_a = address of matrix A
2478 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002479 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
2480 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
2481
2482#if defined(MATRIX_B_DEPTH)
2483 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2484 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
2485#else // defined(MATRIX_B_DEPTH)
2486 src1_addr_in_bytes += z * src1_stride_z;
2487#endif // defined(MATRIX_B_DEPTH)
2488
2489 __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes);
2490 __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002491
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002492 // Compute end row address for matrix B
Gian Marco36a0a462018-01-12 10:21:40 +00002493 __global float *src_end_addr_b = src_addr_b + COLS_B;
2494
2495 src_addr_a += offset_row_a;
2496 src_addr_b += offset_row_b;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002497
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002498 // Reset accumulators
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002499 float4 c00 = 0.0f;
2500 float4 c10 = 0.0f;
2501 float4 c20 = 0.0f;
2502 float4 c30 = 0.0f;
2503
Gian Marco36a0a462018-01-12 10:21:40 +00002504 for(; src_addr_b <= (src_end_addr_b - (int)(8 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002505 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002506 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00002507 float4 a0 = vload4(0, src_addr_a);
2508 float4 b0 = vload4(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002509
2510 c00 += (float4)a0.s0 * b0;
2511 c10 += (float4)a0.s1 * b0;
2512 c20 += (float4)a0.s2 * b0;
2513 c30 += (float4)a0.s3 * b0;
2514
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002515 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00002516 a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
2517 b0 = vload4(0, src_addr_b + 4 * MULT_TRANSPOSE1XW_WIDTH);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002518
2519 c00 += (float4)a0.s0 * b0;
2520 c10 += (float4)a0.s1 * b0;
2521 c20 += (float4)a0.s2 * b0;
2522 c30 += (float4)a0.s3 * b0;
2523 }
2524
Gian Marco36a0a462018-01-12 10:21:40 +00002525 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002526 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002527 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00002528 float4 a0 = vload4(0, src_addr_a);
2529 float4 b0 = vload4(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002530
2531 c00 += (float4)a0.s0 * b0;
2532 c10 += (float4)a0.s1 * b0;
2533 c20 += (float4)a0.s2 * b0;
2534 c30 += (float4)a0.s3 * b0;
2535 }
2536
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002537 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002538 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2539
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002540#if defined(ALPHA)
2541 // Multiply by the weight of matrix product
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002542 c00 = c00 * (float4)ALPHA;
2543 c10 = c10 * (float4)ALPHA;
2544 c20 = c20 * (float4)ALPHA;
2545 c30 = c30 * (float4)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002546#endif // defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002547
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002548#if defined(ADD_VEC_C)
2549 __global float *src2_addr = (__global float *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
2550 float4 c0 = vload4(0, src2_addr);
2551
2552 c00 += c0;
2553 c10 += c0;
2554 c20 += c0;
2555 c30 += c0;
2556#endif /* defined(ADD_VEC_C) */
2557
Gian Marcoae2af742018-02-15 12:35:44 +00002558 // Compute dst address
2559 __global uchar *dst_addr = offset(&dst, 0, 0);
2560
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002561#if defined(REINTERPRET_OUTPUT_AS_3D)
2562 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002563 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002564 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002565 // | |
2566 // | plane0 |
2567 // | |
2568 // |__________________|
2569 // |******************|
2570 // | cross_plane_pad |
2571 // |******************|
2572 // | |
2573 // | plane1 |
2574 // | |
2575 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002576
2577 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
2578 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
2579 zout = min(DEPTH_GEMM3D - 1, zout);
2580
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002581 // Add offset due to the cross plane paddings
2582 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002583
2584 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2585 // multiply dst_stride_z by DEPTH_GEMM3D
2586 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
2587
2588 // Store 4x4 block
2589 vstore4(c00, 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
2590 vstore4(c10, 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
2591 vstore4(c20, 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
2592 vstore4(c30, 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
2593
2594#else // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +00002595 // Add offset for batched GEMM
2596 dst_addr += z * dst_stride_z;
2597
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002598 // Store 4x4 block
Gian Marcoae2af742018-02-15 12:35:44 +00002599 vstore4(c00, 0, (__global float *)(dst_addr + 0 * dst_stride_y));
2600 vstore4(c10, 0, (__global float *)(dst_addr + 1 * dst_stride_y));
2601 vstore4(c20, 0, (__global float *)(dst_addr + 2 * dst_stride_y));
2602 vstore4(c30, 0, (__global float *)(dst_addr + 3 * dst_stride_y));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002603#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002604}
2605
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002606/** This OpenCL kernel is optimized for Bifrost. It computes the matrix multiplication between matrix A (src0) and matrix B (src1)
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002607 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_32bit and @ref gemm_transpose1x4 before running the matrix multiplication.
2608 *
2609 * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002610 *
Gian Marco19835e52018-01-30 13:35:54 +00002611 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
2612 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
2613 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002614 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
2615 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
2616 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002617 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002618 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
2619 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2620 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2621 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2622 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
2623 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002624 * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
2625 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002626 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
2627 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
2628 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2629 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
2630 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2631 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002632 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002633 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
2634 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2635 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
2636 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2637 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002638 * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
2639 * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes)
2640 * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
2641 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002642 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002643 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00002644 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002645 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00002646 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002647 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002648 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
2649 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
2650 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002651 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002652 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002653__kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0),
2654 IMAGE_DECLARATION(src1),
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002655#if defined(ADD_VEC_C)
2656 VECTOR_DECLARATION(src2),
2657#endif /* defined(ADD_VEC_C) */
Gian Marcoae2af742018-02-15 12:35:44 +00002658 IMAGE_DECLARATION(dst),
2659 uint src0_stride_z,
2660 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002661 uint dst_stride_z
2662#if defined(REINTERPRET_OUTPUT_AS_3D)
2663 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002664 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002665#endif // REINTERPRET_OUTPUT_AS_3D
2666 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002667{
Gian Marco36a0a462018-01-12 10:21:40 +00002668 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
2669 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +00002670 int z = get_global_id(2);
Gian Marco36a0a462018-01-12 10:21:40 +00002671
2672 // Offset
2673 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
2674 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 4;
2675
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002676 // src_addr_a = address of matrix A
2677 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002678 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
2679 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
2680
2681#if defined(MATRIX_B_DEPTH)
2682 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2683 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
2684#else // defined(MATRIX_B_DEPTH)
2685 src1_addr_in_bytes += z * src1_stride_z;
2686#endif // defined(MATRIX_B_DEPTH)
2687
2688 __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes);
2689 __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002690
Gian Marco36a0a462018-01-12 10:21:40 +00002691 src_addr_a += offset_row_a;
2692 src_addr_b += offset_row_b;
2693
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002694 // Reset accumulators
2695 float c00 = 0.0f;
2696 float c01 = 0.0f;
2697 float c02 = 0.0f;
2698 float c03 = 0.0f;
2699 float c10 = 0.0f;
2700 float c11 = 0.0f;
2701 float c12 = 0.0f;
2702 float c13 = 0.0f;
2703 float c20 = 0.0f;
2704 float c21 = 0.0f;
2705 float c22 = 0.0f;
2706 float c23 = 0.0f;
2707 float c30 = 0.0f;
2708 float c31 = 0.0f;
2709 float c32 = 0.0f;
2710 float c33 = 0.0f;
2711
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002712#define COLS_MTX_B (COLS_B / (4 * MULT_TRANSPOSE1XW_WIDTH))
2713
2714 int i = 0;
2715 for(; i <= (int)(COLS_MTX_B - 4); i += 4)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002716 {
2717 // Load values from matrix A (interleaved) and matrix B (transposed)
2718 float4 a0 = vload4(0, src_addr_a);
2719 float4 b0 = vload4(0, src_addr_b);
2720
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002721 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
2722 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002723
2724 c00 = fma(a0.s0, b0.s0, c00);
2725 c01 = fma(a0.s0, b0.s1, c01);
2726 c02 = fma(a0.s0, b0.s2, c02);
2727 c03 = fma(a0.s0, b0.s3, c03);
2728
2729 c10 = fma(a0.s1, b0.s0, c10);
2730 c11 = fma(a0.s1, b0.s1, c11);
2731 c12 = fma(a0.s1, b0.s2, c12);
2732 c13 = fma(a0.s1, b0.s3, c13);
2733
2734 c20 = fma(a0.s2, b0.s0, c20);
2735 c21 = fma(a0.s2, b0.s1, c21);
2736 c22 = fma(a0.s2, b0.s2, c22);
2737 c23 = fma(a0.s2, b0.s3, c23);
2738
2739 c30 = fma(a0.s3, b0.s0, c30);
2740 c31 = fma(a0.s3, b0.s1, c31);
2741 c32 = fma(a0.s3, b0.s2, c32);
2742 c33 = fma(a0.s3, b0.s3, c33);
2743
2744 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002745 a0 = vload4(0, src_addr_a);
2746 b0 = vload4(0, src_addr_b);
2747
2748 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
2749 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002750
2751 c00 = fma(a0.s0, b0.s0, c00);
2752 c01 = fma(a0.s0, b0.s1, c01);
2753 c02 = fma(a0.s0, b0.s2, c02);
2754 c03 = fma(a0.s0, b0.s3, c03);
2755
2756 c10 = fma(a0.s1, b0.s0, c10);
2757 c11 = fma(a0.s1, b0.s1, c11);
2758 c12 = fma(a0.s1, b0.s2, c12);
2759 c13 = fma(a0.s1, b0.s3, c13);
2760
2761 c20 = fma(a0.s2, b0.s0, c20);
2762 c21 = fma(a0.s2, b0.s1, c21);
2763 c22 = fma(a0.s2, b0.s2, c22);
2764 c23 = fma(a0.s2, b0.s3, c23);
2765
2766 c30 = fma(a0.s3, b0.s0, c30);
2767 c31 = fma(a0.s3, b0.s1, c31);
2768 c32 = fma(a0.s3, b0.s2, c32);
2769 c33 = fma(a0.s3, b0.s3, c33);
2770
2771 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002772 a0 = vload4(0, src_addr_a);
2773 b0 = vload4(0, src_addr_b);
2774
2775 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
2776 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
2777
2778 c00 = fma(a0.s0, b0.s0, c00);
2779 c01 = fma(a0.s0, b0.s1, c01);
2780 c02 = fma(a0.s0, b0.s2, c02);
2781 c03 = fma(a0.s0, b0.s3, c03);
2782
2783 c10 = fma(a0.s1, b0.s0, c10);
2784 c11 = fma(a0.s1, b0.s1, c11);
2785 c12 = fma(a0.s1, b0.s2, c12);
2786 c13 = fma(a0.s1, b0.s3, c13);
2787
2788 c20 = fma(a0.s2, b0.s0, c20);
2789 c21 = fma(a0.s2, b0.s1, c21);
2790 c22 = fma(a0.s2, b0.s2, c22);
2791 c23 = fma(a0.s2, b0.s3, c23);
2792
2793 c30 = fma(a0.s3, b0.s0, c30);
2794 c31 = fma(a0.s3, b0.s1, c31);
2795 c32 = fma(a0.s3, b0.s2, c32);
2796 c33 = fma(a0.s3, b0.s3, c33);
2797
2798 // Load values from matrix A (interleaved) and matrix B (transposed)
2799 a0 = vload4(0, src_addr_a);
2800 b0 = vload4(0, src_addr_b);
2801
2802 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
2803 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002804
2805 c00 = fma(a0.s0, b0.s0, c00);
2806 c01 = fma(a0.s0, b0.s1, c01);
2807 c02 = fma(a0.s0, b0.s2, c02);
2808 c03 = fma(a0.s0, b0.s3, c03);
2809
2810 c10 = fma(a0.s1, b0.s0, c10);
2811 c11 = fma(a0.s1, b0.s1, c11);
2812 c12 = fma(a0.s1, b0.s2, c12);
2813 c13 = fma(a0.s1, b0.s3, c13);
2814
2815 c20 = fma(a0.s2, b0.s0, c20);
2816 c21 = fma(a0.s2, b0.s1, c21);
2817 c22 = fma(a0.s2, b0.s2, c22);
2818 c23 = fma(a0.s2, b0.s3, c23);
2819
2820 c30 = fma(a0.s3, b0.s0, c30);
2821 c31 = fma(a0.s3, b0.s1, c31);
2822 c32 = fma(a0.s3, b0.s2, c32);
2823 c33 = fma(a0.s3, b0.s3, c33);
2824 }
2825
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002826 for(; i < (int)(COLS_MTX_B); ++i)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002827 {
2828 // Load values from matrix A (interleaved) and matrix B (transposed)
2829 float4 a0 = vload4(0, src_addr_a);
2830 float4 b0 = vload4(0, src_addr_b);
2831
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002832 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
2833 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
2834
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002835 c00 = fma(a0.s0, b0.s0, c00);
2836 c01 = fma(a0.s0, b0.s1, c01);
2837 c02 = fma(a0.s0, b0.s2, c02);
2838 c03 = fma(a0.s0, b0.s3, c03);
2839
2840 c10 = fma(a0.s1, b0.s0, c10);
2841 c11 = fma(a0.s1, b0.s1, c11);
2842 c12 = fma(a0.s1, b0.s2, c12);
2843 c13 = fma(a0.s1, b0.s3, c13);
2844
2845 c20 = fma(a0.s2, b0.s0, c20);
2846 c21 = fma(a0.s2, b0.s1, c21);
2847 c22 = fma(a0.s2, b0.s2, c22);
2848 c23 = fma(a0.s2, b0.s3, c23);
2849
2850 c30 = fma(a0.s3, b0.s0, c30);
2851 c31 = fma(a0.s3, b0.s1, c31);
2852 c32 = fma(a0.s3, b0.s2, c32);
2853 c33 = fma(a0.s3, b0.s3, c33);
2854 }
2855
2856 // Compute destination address
2857 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2858
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002859#if defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002860 // Multiply by the weight of matrix product
2861 c00 = c00 * ALPHA;
2862 c01 = c01 * ALPHA;
2863 c02 = c02 * ALPHA;
2864 c03 = c03 * ALPHA;
2865 c10 = c10 * ALPHA;
2866 c11 = c11 * ALPHA;
2867 c12 = c12 * ALPHA;
2868 c13 = c13 * ALPHA;
2869 c20 = c20 * ALPHA;
2870 c21 = c21 * ALPHA;
2871 c22 = c22 * ALPHA;
2872 c23 = c23 * ALPHA;
2873 c30 = c30 * ALPHA;
2874 c31 = c31 * ALPHA;
2875 c32 = c32 * ALPHA;
2876 c33 = c33 * ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002877#endif // defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002878
Gian Marcoae2af742018-02-15 12:35:44 +00002879 // Compute dst address
2880 __global uchar *dst_addr = offset(&dst, 0, 0);
2881
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002882#if defined(ADD_VEC_C)
2883 __global float *src2_addr = (__global float *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
2884 float4 c0 = vload4(0, src2_addr);
2885
2886 c00 += c0.s0;
2887 c01 += c0.s1;
2888 c02 += c0.s2;
2889 c03 += c0.s3;
2890 c10 += c0.s0;
2891 c11 += c0.s1;
2892 c12 += c0.s2;
2893 c13 += c0.s3;
2894 c20 += c0.s0;
2895 c21 += c0.s1;
2896 c22 += c0.s2;
2897 c23 += c0.s3;
2898 c30 += c0.s0;
2899 c31 += c0.s1;
2900 c32 += c0.s2;
2901 c33 += c0.s3;
2902#endif /* defined(ADD_VEC_C) */
2903
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002904#if defined(REINTERPRET_OUTPUT_AS_3D)
2905 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002906 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002907 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002908 // | |
2909 // | plane0 |
2910 // | |
2911 // |__________________|
2912 // |******************|
2913 // | cross_plane_pad |
2914 // |******************|
2915 // | |
2916 // | plane1 |
2917 // | |
2918 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002919
2920 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
2921 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
2922 zout = min(DEPTH_GEMM3D - 1, zout);
2923
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002924 // Add offset due to the cross plane paddings
2925 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002926
2927 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2928 // multiply dst_stride_z by DEPTH_GEMM3D
2929 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
2930
2931 // Store 4x4 block
2932 vstore4((float4)(c00, c01, c02, c03), 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
2933 vstore4((float4)(c10, c11, c12, c13), 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
2934 vstore4((float4)(c20, c21, c22, c23), 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
2935 vstore4((float4)(c30, c31, c32, c33), 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
2936
2937#else // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +00002938 // Add offset for batched GEMM
2939 dst_addr += z * dst_stride_z;
2940
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002941 // Store 4x4 block
Gian Marcoae2af742018-02-15 12:35:44 +00002942 vstore4((float4)(c00, c01, c02, c03), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
2943 vstore4((float4)(c10, c11, c12, c13), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
2944 vstore4((float4)(c20, c21, c22, c23), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
2945 vstore4((float4)(c30, c31, c32, c33), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002946#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002947}
2948
Georgios Pinitas84225582018-05-14 12:00:05 +01002949// Undefine local defines
2950#undef COLS_MTX_B
2951
Matthew Bentham6f31f8c2017-10-27 11:50:06 +01002952#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002953/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002954 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002955 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002956 * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
2957 *
Gian Marco19835e52018-01-30 13:35:54 +00002958 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
2959 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
2960 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002961 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
2962 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002963 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002964 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
2965 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2966 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2967 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2968 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
2969 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002970 * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
2971 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002972 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
2973 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
2974 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2975 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
2976 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2977 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002978 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002979 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
2980 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2981 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
2982 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2983 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002984 * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
2985 * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes)
2986 * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
2987 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002988 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002989 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00002990 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002991 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00002992 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002993 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002994 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
2995 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
2996 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002997 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002998 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002999__kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0),
3000 IMAGE_DECLARATION(src1),
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003001#if defined(ADD_VEC_C)
3002 VECTOR_DECLARATION(src2),
3003#endif /* defined(ADD_VEC_C) */
Gian Marcoae2af742018-02-15 12:35:44 +00003004 IMAGE_DECLARATION(dst),
3005 uint src0_stride_z,
3006 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003007 uint dst_stride_z
3008#if defined(REINTERPRET_OUTPUT_AS_3D)
3009 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003010 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003011#endif // REINTERPRET_OUTPUT_AS_3D
3012 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003013{
Gian Marco36a0a462018-01-12 10:21:40 +00003014 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
3015 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +00003016 int z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003017
Gian Marco36a0a462018-01-12 10:21:40 +00003018 // Offset
3019 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
3020 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003021
Gian Marco36a0a462018-01-12 10:21:40 +00003022 // src_addr_a = address of matrix A
3023 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003024 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
3025 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
3026
3027#if defined(MATRIX_B_DEPTH)
3028 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
3029 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
3030#else // defined(MATRIX_B_DEPTH)
3031 src1_addr_in_bytes += z * src1_stride_z;
3032#endif // defined(MATRIX_B_DEPTH)
3033
3034 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
3035 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003036
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003037 // Compute end row address for matrix B
Gian Marco36a0a462018-01-12 10:21:40 +00003038 __global half *src_end_addr_b = src_addr_b + COLS_B;
3039
3040 src_addr_a += offset_row_a;
3041 src_addr_b += offset_row_b;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003042
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003043 // Reset accumulators
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003044 half8 c00 = 0.0f;
3045 half8 c10 = 0.0f;
3046 half8 c20 = 0.0f;
3047 half8 c30 = 0.0f;
3048
Gian Marco36a0a462018-01-12 10:21:40 +00003049 for(; src_addr_b <= (src_end_addr_b - (int)(16 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 16 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003050 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003051 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00003052 half4 a0 = vload4(0, src_addr_a);
3053 half8 b0 = vload8(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003054
3055 c00 += (half8)a0.s0 * b0;
3056 c10 += (half8)a0.s1 * b0;
3057 c20 += (half8)a0.s2 * b0;
3058 c30 += (half8)a0.s3 * b0;
3059
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003060 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00003061 a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
3062 b0 = vload8(0, src_addr_b + 8 * MULT_TRANSPOSE1XW_WIDTH);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003063
3064 c00 += (half8)a0.s0 * b0;
3065 c10 += (half8)a0.s1 * b0;
3066 c20 += (half8)a0.s2 * b0;
3067 c30 += (half8)a0.s3 * b0;
3068 }
3069
Gian Marco36a0a462018-01-12 10:21:40 +00003070 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003071 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003072 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00003073 half4 a0 = vload4(0, src_addr_a);
3074 half8 b0 = vload8(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003075
3076 c00 += (half8)a0.s0 * b0;
3077 c10 += (half8)a0.s1 * b0;
3078 c20 += (half8)a0.s2 * b0;
3079 c30 += (half8)a0.s3 * b0;
3080 }
3081
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003082 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003083 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
3084
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003085#if defined(ALPHA)
3086 // Multiply by the weight of matrix product
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003087 c00 = c00 * (half8)ALPHA;
3088 c10 = c10 * (half8)ALPHA;
3089 c20 = c20 * (half8)ALPHA;
3090 c30 = c30 * (half8)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003091#endif // defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003092
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003093#if defined(ADD_VEC_C)
3094 // *INDENT-OFF*
3095 // clang-format off
3096 __global half *src2_addr = (__global half *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
3097 half8 c0 = vload8(0, src2_addr);
3098 // clang-format on
3099 // *INDENT-ON*
3100
3101 c00 += c0;
3102 c10 += c0;
3103 c20 += c0;
3104 c30 += c0;
3105#endif /* defined(ADD_VEC_C) */
3106
Gian Marcoae2af742018-02-15 12:35:44 +00003107 // Compute dst address
3108 __global uchar *dst_addr = offset(&dst, 0, 0);
3109
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003110#if defined(REINTERPRET_OUTPUT_AS_3D)
3111 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003112 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003113 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003114 // | |
3115 // | plane0 |
3116 // | |
3117 // |__________________|
3118 // |******************|
3119 // | cross_plane_pad |
3120 // |******************|
3121 // | |
3122 // | plane1 |
3123 // | |
3124 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003125
3126 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
3127 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
3128 zout = min(DEPTH_GEMM3D - 1, zout);
3129
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003130 // Add offset due to the cross plane paddings
3131 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003132
3133 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3134 // multiply dst_stride_z by DEPTH_GEMM3D
3135 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
3136
3137 // Store 4x8 block
3138 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
3139 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
3140 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
3141 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
3142
3143#else // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +00003144 // Add offset for batched GEMM
3145 dst_addr += z * dst_stride_z;
3146
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003147 // Store 4x8 block
Gian Marcoae2af742018-02-15 12:35:44 +00003148 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
3149 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
3150 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
3151 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003152#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003153}
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003154
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003155/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1) while accumulating the result in a 32 floating point variable.
3156 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
3157 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003158 * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
3159 *
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003160 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
3161 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
3162 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
3163 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
3164 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
3165 *
3166 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
3167 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
3168 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
3169 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
3170 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
3171 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003172 * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
3173 *
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003174 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
3175 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
3176 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3177 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
3178 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3179 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
3180 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
3181 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
3182 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3183 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
3184 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3185 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003186 * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
3187 * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes)
3188 * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
3189 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003190 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
3191 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
3192 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
3193 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
3194 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
3195 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
3196 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
3197 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
3198 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
3199 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
3200 */
3201__kernel void gemm_mm_interleaved_transposed_f16_acc32(IMAGE_DECLARATION(src0),
3202 IMAGE_DECLARATION(src1),
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003203#if defined(ADD_VEC_C)
3204 VECTOR_DECLARATION(src2),
3205#endif /* defined(ADD_VEC_C) */
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003206 IMAGE_DECLARATION(dst),
3207 uint src0_stride_z,
3208 uint src1_stride_z,
3209 uint dst_stride_z
3210#if defined(REINTERPRET_OUTPUT_AS_3D)
3211 ,
3212 uint cross_plane_pad
3213#endif // REINTERPRET_OUTPUT_AS_3D
3214 )
3215{
3216 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
3217 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
3218 int z = get_global_id(2);
3219
3220 // Offset
3221 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
3222 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
3223
3224 // src_addr_a = address of matrix A
3225 // src_addr_b = address of matrix B
3226 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
3227 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
3228
3229#if defined(MATRIX_B_DEPTH)
3230 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
3231 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
3232#else // defined(MATRIX_B_DEPTH)
3233 src1_addr_in_bytes += z * src1_stride_z;
3234#endif // defined(MATRIX_B_DEPTH)
3235
3236 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
3237 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
3238
3239 // Compute end row address for matrix B
3240 __global half *src_end_addr_b = src_addr_b + COLS_B;
3241
3242 src_addr_a += offset_row_a;
3243 src_addr_b += offset_row_b;
3244
3245 // Reset accumulators
3246 float8 c00 = 0.0f;
3247 float8 c10 = 0.0f;
3248 float8 c20 = 0.0f;
3249 float8 c30 = 0.0f;
3250
3251 for(; src_addr_b <= (src_end_addr_b - (int)(16 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 16 * MULT_TRANSPOSE1XW_WIDTH)
3252 {
3253 // Load values from matrix A (interleaved) and matrix B (transposed)
3254 float4 a0 = convert_float4(vload4(0, src_addr_a));
3255 float8 b0 = convert_float8(vload8(0, src_addr_b));
3256
3257 c00 += (float8)a0.s0 * b0;
3258 c10 += (float8)a0.s1 * b0;
3259 c20 += (float8)a0.s2 * b0;
3260 c30 += (float8)a0.s3 * b0;
3261
3262 // Load values from matrix A (interleaved) and matrix B (transposed)
3263 a0 = convert_float4(vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT));
3264 b0 = convert_float8(vload8(0, src_addr_b + 8 * MULT_TRANSPOSE1XW_WIDTH));
3265
3266 c00 += (float8)a0.s0 * b0;
3267 c10 += (float8)a0.s1 * b0;
3268 c20 += (float8)a0.s2 * b0;
3269 c30 += (float8)a0.s3 * b0;
3270 }
3271
3272 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
3273 {
3274 // Load values from matrix A (interleaved) and matrix B (transposed)
3275 float4 a0 = convert_float4(vload4(0, src_addr_a));
3276 float8 b0 = convert_float8(vload8(0, src_addr_b));
3277
3278 c00 += (float8)a0.s0 * b0;
3279 c10 += (float8)a0.s1 * b0;
3280 c20 += (float8)a0.s2 * b0;
3281 c30 += (float8)a0.s3 * b0;
3282 }
3283
3284 // Compute destination address
3285 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
3286
3287#if defined(ALPHA)
3288 // Multiply by the weight of matrix product
3289 c00 = c00 * (float8)ALPHA;
3290 c10 = c10 * (float8)ALPHA;
3291 c20 = c20 * (float8)ALPHA;
3292 c30 = c30 * (float8)ALPHA;
3293#endif // defined(ALPHA)
3294
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003295#if defined(ADD_VEC_C)
3296 // *INDENT-OFF*
3297 // clang-format off
3298 __global half *src2_addr = (__global half *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
3299 float8 c0 = convert_float8(vload8(0, src2_addr));
3300 // clang-format on
3301 // *INDENT-ON*
3302
3303 c00 += c0;
3304 c10 += c0;
3305 c20 += c0;
3306 c30 += c0;
3307#endif /* defined(ADD_VEC_C) */
3308
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003309 // Compute dst address
3310 __global uchar *dst_addr = offset(&dst, 0, 0);
3311
3312#if defined(REINTERPRET_OUTPUT_AS_3D)
3313 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
3314 // in order to take into account the presence of possible cross plane paddings
3315 //
3316 // | |
3317 // | plane0 |
3318 // | |
3319 // |__________________|
3320 // |******************|
3321 // | cross_plane_pad |
3322 // |******************|
3323 // | |
3324 // | plane1 |
3325 // | |
3326 // |__________________|
3327
3328 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
3329 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
3330 zout = min(DEPTH_GEMM3D - 1, zout);
3331
3332 // Add offset due to the cross plane paddings
3333 zout *= (cross_plane_pad * dst_stride_y);
3334
3335 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3336 // multiply dst_stride_z by DEPTH_GEMM3D
3337 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
3338
3339 // Store 4x8 block
3340 vstore8(convert_half8(c00), 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
3341 vstore8(convert_half8(c10), 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
3342 vstore8(convert_half8(c20), 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
3343 vstore8(convert_half8(c30), 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
3344
3345#else // defined(REINTERPRET_OUTPUT_AS_3D)
3346 // Add offset for batched GEMM
3347 dst_addr += z * dst_stride_z;
3348
3349 // Store 4x8 block
3350 vstore8(convert_half8(c00), 0, (__global half *)(dst_addr + 0 * dst_stride_y));
3351 vstore8(convert_half8(c10), 0, (__global half *)(dst_addr + 1 * dst_stride_y));
3352 vstore8(convert_half8(c20), 0, (__global half *)(dst_addr + 2 * dst_stride_y));
3353 vstore8(convert_half8(c30), 0, (__global half *)(dst_addr + 3 * dst_stride_y));
3354#endif // defined(REINTERPRET_OUTPUT_AS_3D)
3355}
3356
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003357/** This OpenCL kernel optimized for Bifrost architectures computes the matrix multiplication between matrix A (src0) and matrix B (src1)
3358 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
3359 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003360 * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
3361 *
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003362 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
3363 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
3364 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
3365 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
3366 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
3367 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003368 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
3369 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
3370 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
3371 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
3372 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
3373 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003374 * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
3375 *
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003376 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
3377 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
3378 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3379 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
3380 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3381 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
3382 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
3383 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
3384 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3385 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
3386 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3387 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003388 * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
3389 * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes)
3390 * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
3391 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003392 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
3393 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
3394 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
3395 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
3396 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
3397 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003398 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003399 */
3400__kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0),
3401 IMAGE_DECLARATION(src1),
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003402#if defined(ADD_VEC_C)
3403 VECTOR_DECLARATION(src2),
3404#endif /* defined(ADD_VEC_C) */
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003405 IMAGE_DECLARATION(dst),
3406 uint src0_stride_z,
3407 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003408 uint dst_stride_z
3409#if defined(REINTERPRET_OUTPUT_AS_3D)
3410 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003411 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003412#endif // REINTERPRET_OUTPUT_AS_3D
3413 )
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003414{
3415 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
3416 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
3417 int z = get_global_id(2);
3418
3419 // Offset
3420 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
3421 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
3422
3423 // src_addr_a = address of matrix A
3424 // src_addr_b = address of matrix B
3425 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
3426 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
3427
3428#if defined(MATRIX_B_DEPTH)
3429 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
3430 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
3431#else // defined(MATRIX_B_DEPTH)
3432 src1_addr_in_bytes += z * src1_stride_z;
3433#endif // defined(MATRIX_B_DEPTH)
3434
3435 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
3436 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
3437
3438 // Compute end row address for matrix B
3439 __global half *src_end_addr_b = src_addr_b + COLS_B;
3440
3441 src_addr_a += offset_row_a;
3442 src_addr_b += offset_row_b;
3443
3444 // Reset accumulators
3445 half8 c00 = 0.0f;
3446 half8 c10 = 0.0f;
3447 half8 c20 = 0.0f;
3448 half8 c30 = 0.0f;
3449
3450#define COLS_MTX_B (COLS_B / (8 * MULT_TRANSPOSE1XW_WIDTH))
3451
3452 int i = 0;
3453 for(; i <= (int)(COLS_MTX_B - 4); i += 4)
3454 {
3455#if MULT_INTERLEAVE4X4_HEIGHT == 1
3456 // Load values from matrix A (interleaved) and matrix B (transposed)
3457 half8 a0 = vload8(0, src_addr_a);
3458 half8 b0 = vload8(0, src_addr_b);
3459
3460 src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT;
3461 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3462
3463 c00 = fma((half8)a0.s0, b0, c00);
3464 c10 = fma((half8)a0.s1, b0, c10);
3465 c20 = fma((half8)a0.s2, b0, c20);
3466 c30 = fma((half8)a0.s3, b0, c30);
3467
3468 // Load values from matrix B (transposed)
3469 b0 = vload8(0, src_addr_b);
3470
3471 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3472
3473 c00 = fma((half8)a0.s4, b0, c00);
3474 c10 = fma((half8)a0.s5, b0, c10);
3475 c20 = fma((half8)a0.s6, b0, c20);
3476 c30 = fma((half8)a0.s7, b0, c30);
3477
3478 // Load values from matrix A (interleaved) and matrix B (transposed)
3479 a0 = vload8(0, src_addr_a);
3480 b0 = vload8(0, src_addr_b);
3481
3482 src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT;
3483 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3484
3485 c00 = fma((half8)a0.s0, b0, c00);
3486 c10 = fma((half8)a0.s1, b0, c10);
3487 c20 = fma((half8)a0.s2, b0, c20);
3488 c30 = fma((half8)a0.s3, b0, c30);
3489
3490 // Load values from matrix B (transposed)
3491 b0 = vload8(0, src_addr_b);
3492
3493 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3494
3495 c00 = fma((half8)a0.s4, b0, c00);
3496 c10 = fma((half8)a0.s5, b0, c10);
3497 c20 = fma((half8)a0.s6, b0, c20);
3498 c30 = fma((half8)a0.s7, b0, c30);
3499#else // MULT_INTERLEAVE4X4_HEIGHT == 1
3500 // Load values from matrix A (interleaved) and matrix B (transposed)
3501 half4 a0 = vload4(0, src_addr_a);
3502 half8 b0 = vload8(0, src_addr_b);
3503
3504 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3505 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3506
3507 c00 = fma((half8)a0.s0, b0, c00);
3508 c10 = fma((half8)a0.s1, b0, c10);
3509 c20 = fma((half8)a0.s2, b0, c20);
3510 c30 = fma((half8)a0.s3, b0, c30);
3511
3512 // Load values from matrix A (interleaved) and matrix B (transposed)
3513 a0 = vload4(0, src_addr_a);
3514 b0 = vload8(0, src_addr_b);
3515
3516 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3517 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3518
3519 c00 = fma((half8)a0.s0, b0, c00);
3520 c10 = fma((half8)a0.s1, b0, c10);
3521 c20 = fma((half8)a0.s2, b0, c20);
3522 c30 = fma((half8)a0.s3, b0, c30);
3523
3524 // Load values from matrix A (interleaved) and matrix B (transposed)
3525 a0 = vload4(0, src_addr_a);
3526 b0 = vload8(0, src_addr_b);
3527
3528 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3529 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3530
3531 c00 = fma((half8)a0.s0, b0, c00);
3532 c10 = fma((half8)a0.s1, b0, c10);
3533 c20 = fma((half8)a0.s2, b0, c20);
3534 c30 = fma((half8)a0.s3, b0, c30);
3535
3536 // Load values from matrix A (interleaved) and matrix B (transposed)
3537 a0 = vload4(0, src_addr_a);
3538 b0 = vload8(0, src_addr_b);
3539
3540 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3541 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3542
3543 c00 = fma((half8)a0.s0, b0, c00);
3544 c10 = fma((half8)a0.s1, b0, c10);
3545 c20 = fma((half8)a0.s2, b0, c20);
3546 c30 = fma((half8)a0.s3, b0, c30);
3547#endif // MULT_INTERLEAVE4X4_HEIGHT == 1
3548 }
3549
3550 for(; i < (int)(COLS_MTX_B); ++i)
3551 {
3552 // Load values from matrix A (interleaved) and matrix B (transposed)
3553 half4 a0 = vload4(0, src_addr_a);
3554 half8 b0 = vload8(0, src_addr_b);
3555
3556 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3557 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3558
3559 c00 = fma((half8)a0.s0, b0, c00);
3560 c10 = fma((half8)a0.s1, b0, c10);
3561 c20 = fma((half8)a0.s2, b0, c20);
3562 c30 = fma((half8)a0.s3, b0, c30);
3563 }
3564
3565 // Compute destination address
3566 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
3567
3568#if defined(ALPHA)
3569 // Multiply by the weight of matrix product
3570 c00 = c00 * (half8)ALPHA;
3571 c10 = c10 * (half8)ALPHA;
3572 c20 = c20 * (half8)ALPHA;
3573 c30 = c30 * (half8)ALPHA;
3574#endif // defined(ALPHA)
3575
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003576#if defined(ADD_VEC_C)
3577 // *INDENT-OFF*
3578 // clang-format off
3579 __global half *src2_addr = (__global half *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
3580 half8 c0 = vload8(0, src2_addr);
3581 // clang-format on
3582 // *INDENT-ON*
3583
3584 c00 += c0;
3585 c10 += c0;
3586 c20 += c0;
3587 c30 += c0;
3588#endif /* defined(ADD_VEC_C) */
3589
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003590 // Compute dst address
3591 __global uchar *dst_addr = offset(&dst, 0, 0);
3592
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003593#if defined(REINTERPRET_OUTPUT_AS_3D)
3594 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003595 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003596 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003597 // | |
3598 // | plane0 |
3599 // | |
3600 // |__________________|
3601 // |******************|
3602 // | cross_plane_pad |
3603 // |******************|
3604 // | |
3605 // | plane1 |
3606 // | |
3607 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003608
3609 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
3610 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
3611 zout = min(DEPTH_GEMM3D - 1, zout);
3612
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003613 // Add offset due to the cross plane paddings
3614 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003615
3616 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3617 // multiply dst_stride_z by DEPTH_GEMM3D
3618 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
3619
3620 // Store 4x8 block
3621 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
3622 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
3623 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
3624 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
3625
3626#else // defined(REINTERPRET_OUTPUT_AS_3D)
3627 // Add offset for batched GEMM
3628 dst_addr += z * dst_stride_z;
3629
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003630 // Store 4x8 block
3631 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
3632 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
3633 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
3634 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003635#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003636}
Georgios Pinitas84225582018-05-14 12:00:05 +01003637
3638// Undefine local defines
3639#undef COLS_MTX_B
3640
Matthew Bentham6f31f8c2017-10-27 11:50:06 +01003641#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003642
Gian Marco36a0a462018-01-12 10:21:40 +00003643#endif // defined(COLS_B) && defined(MULT_TRANSPOSE1XW_WIDTH) && defined(MULT_INTERLEAVE4X4_HEIGHT)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003644
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003645#if defined(COLS_A) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && (NUM_ELEMS_PROCESSED_PER_THREAD_Y)
3646#if defined(DATA_TYPE)
3647#define VECTOR_TYPE VEC_DATA_TYPE(DATA_TYPE, NUM_ELEMS_PROCESSED_PER_THREAD_X)
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003648/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped.
3649 *
3650 * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003651 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003652 * @note This OpenCL kernel works with floating point data types (F16/F32)
3653 * @note The floating point data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
3654 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003655 * @note The number of matrix A columns and the optional alpha's value need to be passed at compile time using -DCOLS_A and -DALPHA
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003656 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
3657 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003658 *
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003659 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
3660 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003661 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
3662 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
3663 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
3664 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
3665 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003666 * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
3667 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003668 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003669 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
3670 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3671 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
3672 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3673 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003674 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003675 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
3676 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3677 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
3678 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3679 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003680 * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
3681 * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes)
3682 * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
3683 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003684 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003685 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
3686 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
3687 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
3688 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
3689 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003690 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
3691 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
3692 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003693 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
3694 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements for the output tensor (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003695 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003696__kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
3697 IMAGE_DECLARATION(src1),
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003698#if defined(ADD_VEC_C)
3699 VECTOR_DECLARATION(src2),
3700#endif /* defined(ADD_VEC_C) */
Gian Marcoae2af742018-02-15 12:35:44 +00003701 IMAGE_DECLARATION(dst),
3702 uint src0_stride_z,
3703 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003704 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003705#if defined(REINTERPRET_INPUT_AS_3D)
3706 ,
3707 uint src_cross_plane_pad
3708#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003709#if defined(REINTERPRET_OUTPUT_AS_3D)
3710 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003711 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003712#endif // REINTERPRET_OUTPUT_AS_3D
3713 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003714{
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003715 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003716
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003717 // Compute starting address for matrix A and Matrix B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003718 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003719
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003720 // Update address for the matrix A
3721 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003722
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003723 // Update address for the matrix B
3724 src_addr.s1 += idx * sizeof(DATA_TYPE);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003725
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003726#if defined(REINTERPRET_INPUT_AS_3D)
3727 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
3728 // in order to take into account the presence of possible cross plane paddings
3729 //
3730 // | |
3731 // | plane0 |
3732 // | |
3733 // |__________________|
3734 // |******************|
3735 // | cross_plane_pad |
3736 // |******************|
3737 // | |
3738 // | plane1 |
3739 // | |
3740 // |__________________|
3741
3742 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
3743 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
3744 zin = min(DEPTH_GEMM3D - 1, zin);
3745
3746 // Add offset due to the cross plane paddings
3747 zin *= (src_cross_plane_pad * src0_stride_y);
3748
3749 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3750 // multiply src0_stride_z by DEPTH_GEMM3D
3751 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
3752
3753#else // defined(REINTERPRET_INPUT_AS_3D)
3754
Gian Marcoae2af742018-02-15 12:35:44 +00003755 // Add offset for batched GEMM
3756 src_addr.s0 += get_global_id(2) * src0_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003757
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003758#endif // defined(REINTERPRET_INPUT_AS_3D)
3759
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003760#if defined(MATRIX_B_DEPTH)
3761 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
3762 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
3763#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00003764 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003765#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00003766
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003767 int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(DATA_TYPE));
3768
3769 VECTOR_TYPE acc0 = 0.0f;
3770#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3771 VECTOR_TYPE acc1 = 0.0f;
3772#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3773#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3774 VECTOR_TYPE acc2 = 0.0f;
3775#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3776#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3777 VECTOR_TYPE acc3 = 0.0f;
3778#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3779
Georgios Pinitas96880cf2017-10-20 18:52:20 +01003780 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(DATA_TYPE)); src_addr += (int2)(2 * sizeof(DATA_TYPE), 2 * src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003781 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003782#if defined(REINTERPRET_INPUT_AS_3D)
3783 // Load values from matrix A
Usama Arif0681e3b2019-04-25 14:28:07 +01003784 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 2, DATA_TYPE, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s);
3785#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003786 // Load values from matrix A
3787 VEC_DATA_TYPE(DATA_TYPE, 2)
3788 a0 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
3789#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3790 VEC_DATA_TYPE(DATA_TYPE, 2)
3791 a1 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
3792#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3793#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3794 VEC_DATA_TYPE(DATA_TYPE, 2)
3795 a2 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
3796#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3797#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3798 VEC_DATA_TYPE(DATA_TYPE, 2)
3799 a3 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
3800#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003801#endif // defined(REINTERPRET_INPUT_AS_3D)
3802
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003803 // Load values from matrix B
3804 VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
3805 VECTOR_TYPE b1 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1 + src1_stride_y));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003806
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003807 // Accumulate
3808 acc0 += b0 * (VECTOR_TYPE)a0.s0;
3809 acc0 += b1 * (VECTOR_TYPE)a0.s1;
3810#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3811 acc1 += b0 * (VECTOR_TYPE)a1.s0;
3812 acc1 += b1 * (VECTOR_TYPE)a1.s1;
3813#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3814#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3815 acc2 += b0 * (VECTOR_TYPE)a2.s0;
3816 acc2 += b1 * (VECTOR_TYPE)a2.s1;
3817#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3818#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3819 acc3 += b0 * (VECTOR_TYPE)a3.s0;
3820 acc3 += b1 * (VECTOR_TYPE)a3.s1;
3821#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003822 }
3823
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003824 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(DATA_TYPE), src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003825 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003826#if defined(REINTERPRET_INPUT_AS_3D)
3827 // Load values from matrix A
3828 DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
3829#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3830 DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
3831#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3832#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3833 DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
3834#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3835#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3836 DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
3837#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3838#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003839 // Load values from matrix A
3840 DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
3841#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3842 DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
3843#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3844#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3845 DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
3846#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3847#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3848 DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
3849#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003850#endif // defined(REINTERPRET_INPUT_AS_3D)
3851
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003852 // Load values from matrix B
3853 VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003854
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003855 // Accumulate
3856 acc0 += b0 * (VECTOR_TYPE)a0;
3857#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3858 acc1 += b0 * (VECTOR_TYPE)a1;
3859#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3860#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3861 acc2 += b0 * (VECTOR_TYPE)a2;
3862#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3863#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3864 acc3 += b0 * (VECTOR_TYPE)a3;
3865#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003866 }
3867
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003868 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003869 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
3870
Gian Marcoae2af742018-02-15 12:35:44 +00003871 // Compute dst address
3872 __global uchar *dst_addr = offset(&dst, 0, 0);
3873
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003874 // Multiply by the weight of matrix-matrix product and store the result
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003875#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003876 acc0 = acc0 * (VECTOR_TYPE)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003877#endif // defined(ALPHA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003878#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
3879 acc1 = acc1 * (VECTOR_TYPE)ALPHA;
3880#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
3881#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
3882 acc2 = acc2 * (VECTOR_TYPE)ALPHA;
3883#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
3884#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
3885 acc3 = acc3 * (VECTOR_TYPE)ALPHA;
3886#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
3887
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003888#if defined(ADD_VEC_C)
3889 // *INDENT-OFF*
3890 // clang-format off
3891 __global DATA_TYPE *src2_addr = (__global DATA_TYPE *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
3892 VECTOR_TYPE c0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, src2_addr);
3893 // clang-format on
3894 // *INDENT-ON*
3895
3896 acc0 += c0;
3897#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3898 acc1 += c0;
3899#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3900#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3901 acc2 += c0;
3902#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3903#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3904 acc3 += c0;
3905#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3906#endif /* defined(ADD_VEC_C) */
3907
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003908 int z = get_global_id(2);
3909
3910#if defined(REINTERPRET_OUTPUT_AS_3D)
3911 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003912 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003913 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003914 // | |
3915 // | plane0 |
3916 // | |
3917 // |__________________|
3918 // |******************|
3919 // | cross_plane_pad |
3920 // |******************|
3921 // | |
3922 // | plane1 |
3923 // | |
3924 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003925
3926 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
3927 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
3928 zout = min(DEPTH_GEMM3D - 1, zout);
3929
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003930 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003931 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003932
3933 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3934 // multiply dst_stride_z by DEPTH_GEMM3D
3935 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
3936
3937 // Store output block
Usama Arif0681e3b2019-04-25 14:28:07 +01003938 STORE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, NUM_ELEMS_PROCESSED_PER_THREAD_X, DATA_TYPE, acc, dst_addr, dst_stride_y, zout.s);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003939#else // defined(REINTERPRET_OUTPUT_AS_3D)
3940 // Add offset for batched GEMM
3941 dst_addr += z * dst_stride_z;
3942
3943 // Store output block
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003944 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00003945 (acc0, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003946#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003947 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00003948 (acc1, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003949#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3950#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003951 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00003952 (acc2, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003953#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3954#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003955 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00003956 (acc3, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003957#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003958#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003959}
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003960#endif // defined(DATA_TYPE)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003961
Michele Di Giorgiof6f08da2018-04-26 10:24:30 +01003962/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003963 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003964 * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
3965 *
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003966 * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
3967 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
3968 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
3969 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
3970 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003971 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
3972 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003973 *
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003974 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
3975 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003976 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
3977 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
3978 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
3979 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
3980 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003981 * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
3982 *
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003983 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
3984 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
3985 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3986 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
3987 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3988 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
3989 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
3990 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
3991 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3992 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
3993 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3994 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003995 * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
3996 * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes)
3997 * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
3998 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003999 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
4000 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
4001 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
4002 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
4003 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
4004 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004005 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
4006 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
4007 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004008 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
4009 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004010 */
4011__kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
4012 IMAGE_DECLARATION(src1),
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004013#if defined(ADD_VEC_C)
4014 VECTOR_DECLARATION(src2),
4015#endif /* defined(ADD_VEC_C) */
Gian Marcoae2af742018-02-15 12:35:44 +00004016 IMAGE_DECLARATION(dst),
4017 uint src0_stride_z,
4018 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004019 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004020#if defined(REINTERPRET_INPUT_AS_3D)
4021 ,
4022 uint src_cross_plane_pad
4023#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004024#if defined(REINTERPRET_OUTPUT_AS_3D)
4025 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004026 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004027#endif // REINTERPRET_OUTPUT_AS_3D
4028 )
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004029{
4030 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
4031
4032 // Compute starting address for matrix A and matrix B
4033 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
4034
4035 // Update address for matrix A
4036 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
4037
4038 // Update address for matrix B
4039 src_addr.s1 += idx * sizeof(float);
4040
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004041#if defined(REINTERPRET_INPUT_AS_3D)
4042 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
4043 // in order to take into account the presence of possible cross plane paddings
4044 //
4045 // | |
4046 // | plane0 |
4047 // | |
4048 // |__________________|
4049 // |******************|
4050 // | cross_plane_pad |
4051 // |******************|
4052 // | |
4053 // | plane1 |
4054 // | |
4055 // |__________________|
4056
4057 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
4058 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
4059 zin = min(DEPTH_GEMM3D - 1, zin);
4060
4061 // Add offset due to the cross plane paddings
4062 zin *= (src_cross_plane_pad * src0_stride_y);
4063
4064 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4065 // multiply src0_stride_z by DEPTH_GEMM3D
4066 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
4067
4068#else // defined(REINTERPRET_INPUT_AS_3D)
4069
Gian Marcoae2af742018-02-15 12:35:44 +00004070 // Add offset for batched GEMM
4071 src_addr.s0 += get_global_id(2) * src0_stride_z;
4072
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004073#endif // defined(REINTERPRET_INPUT_AS_3D)
4074
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00004075#if defined(MATRIX_B_DEPTH)
4076 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
4077 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
4078#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00004079 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00004080#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00004081
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004082 // Initialize accumulators
4083 float acc00 = 0.0f;
4084 float acc01 = 0.0f;
4085 float acc02 = 0.0f;
4086 float acc03 = 0.0f;
4087
4088#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4089 float acc10 = 0.0f;
4090 float acc11 = 0.0f;
4091 float acc12 = 0.0f;
4092 float acc13 = 0.0f;
4093#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4094
4095#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4096 float acc20 = 0.0f;
4097 float acc21 = 0.0f;
4098 float acc22 = 0.0f;
4099 float acc23 = 0.0f;
4100#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4101
4102#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4103 float acc30 = 0.0f;
4104 float acc31 = 0.0f;
4105 float acc32 = 0.0f;
4106 float acc33 = 0.0f;
4107#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4108
4109 // A and B src indices get incremented at the same time.
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004110 int i = 0;
4111 for(; i <= ((int)COLS_A - 4); i += 4)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004112 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004113#if defined(REINTERPRET_INPUT_AS_3D)
4114 // Load values from matrix A and matrix B
Usama Arif0681e3b2019-04-25 14:28:07 +01004115 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 4, float, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s);
4116#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004117 // Load values from matrix A and matrix B
4118 float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004119#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004120 float4 a1 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004121#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4122#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004123 float4 a2 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004124#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4125#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004126 float4 a3 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004127#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004128#endif // defined(REINTERPRET_INPUT_AS_3D)
4129
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004130 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
4131 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004132
4133 // Multiply and accumulate
4134 acc00 = fma(a0.s0, b0.s0, acc00);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004135 acc01 = fma(a0.s0, b0.s1, acc01);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004136 acc02 = fma(a0.s0, b0.s2, acc02);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004137 acc03 = fma(a0.s0, b0.s3, acc03);
4138
4139#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004140
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004141 acc10 = fma(a1.s0, b0.s0, acc10);
4142 acc11 = fma(a1.s0, b0.s1, acc11);
4143 acc12 = fma(a1.s0, b0.s2, acc12);
4144 acc13 = fma(a1.s0, b0.s3, acc13);
4145
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004146#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4147#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004148
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004149 acc20 = fma(a2.s0, b0.s0, acc20);
4150 acc21 = fma(a2.s0, b0.s1, acc21);
4151 acc22 = fma(a2.s0, b0.s2, acc22);
4152 acc23 = fma(a2.s0, b0.s3, acc23);
4153
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004154#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4155#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004156
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004157 acc30 = fma(a3.s0, b0.s0, acc30);
4158 acc31 = fma(a3.s0, b0.s1, acc31);
4159 acc32 = fma(a3.s0, b0.s2, acc32);
4160 acc33 = fma(a3.s0, b0.s3, acc33);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004161#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004162
4163 // Load values from matrix A and matrix B
4164 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
4165 src_addr.s1 += src1_stride_y;
4166
4167 // Multiply and accumulate
4168 acc00 = fma(a0.s1, b0.s0, acc00);
4169 acc01 = fma(a0.s1, b0.s1, acc01);
4170 acc02 = fma(a0.s1, b0.s2, acc02);
4171 acc03 = fma(a0.s1, b0.s3, acc03);
4172
4173#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4174
4175 acc10 = fma(a1.s1, b0.s0, acc10);
4176 acc11 = fma(a1.s1, b0.s1, acc11);
4177 acc12 = fma(a1.s1, b0.s2, acc12);
4178 acc13 = fma(a1.s1, b0.s3, acc13);
4179
4180#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4181#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4182
4183 acc20 = fma(a2.s1, b0.s0, acc20);
4184 acc21 = fma(a2.s1, b0.s1, acc21);
4185 acc22 = fma(a2.s1, b0.s2, acc22);
4186 acc23 = fma(a2.s1, b0.s3, acc23);
4187
4188#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4189#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4190
4191 acc30 = fma(a3.s1, b0.s0, acc30);
4192 acc31 = fma(a3.s1, b0.s1, acc31);
4193 acc32 = fma(a3.s1, b0.s2, acc32);
4194 acc33 = fma(a3.s1, b0.s3, acc33);
4195#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4196
4197 // Load values from matrix A and matrix B
4198 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
4199 src_addr.s1 += src1_stride_y;
4200
4201 // Multiply and accumulate
4202 acc00 = fma(a0.s2, b0.s0, acc00);
4203 acc01 = fma(a0.s2, b0.s1, acc01);
4204 acc02 = fma(a0.s2, b0.s2, acc02);
4205 acc03 = fma(a0.s2, b0.s3, acc03);
4206
4207#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4208
4209 acc10 = fma(a1.s2, b0.s0, acc10);
4210 acc11 = fma(a1.s2, b0.s1, acc11);
4211 acc12 = fma(a1.s2, b0.s2, acc12);
4212 acc13 = fma(a1.s2, b0.s3, acc13);
4213
4214#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4215#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4216
4217 acc20 = fma(a2.s2, b0.s0, acc20);
4218 acc21 = fma(a2.s2, b0.s1, acc21);
4219 acc22 = fma(a2.s2, b0.s2, acc22);
4220 acc23 = fma(a2.s2, b0.s3, acc23);
4221
4222#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4223#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4224
4225 acc30 = fma(a3.s2, b0.s0, acc30);
4226 acc31 = fma(a3.s2, b0.s1, acc31);
4227 acc32 = fma(a3.s2, b0.s2, acc32);
4228 acc33 = fma(a3.s2, b0.s3, acc33);
4229#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4230
4231 // Load values from matrix A and matrix B
4232 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
4233 src_addr.s1 += src1_stride_y;
4234
4235 // Multiply and accumulate
4236 acc00 = fma(a0.s3, b0.s0, acc00);
4237 acc01 = fma(a0.s3, b0.s1, acc01);
4238 acc02 = fma(a0.s3, b0.s2, acc02);
4239 acc03 = fma(a0.s3, b0.s3, acc03);
4240
4241#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4242
4243 acc10 = fma(a1.s3, b0.s0, acc10);
4244 acc11 = fma(a1.s3, b0.s1, acc11);
4245 acc12 = fma(a1.s3, b0.s2, acc12);
4246 acc13 = fma(a1.s3, b0.s3, acc13);
4247
4248#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4249#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4250
4251 acc20 = fma(a2.s3, b0.s0, acc20);
4252 acc21 = fma(a2.s3, b0.s1, acc21);
4253 acc22 = fma(a2.s3, b0.s2, acc22);
4254 acc23 = fma(a2.s3, b0.s3, acc23);
4255
4256#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4257#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4258
4259 acc30 = fma(a3.s3, b0.s0, acc30);
4260 acc31 = fma(a3.s3, b0.s1, acc31);
4261 acc32 = fma(a3.s3, b0.s2, acc32);
4262 acc33 = fma(a3.s3, b0.s3, acc33);
4263#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4264
4265 src_addr.s0 += 4 * sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004266 }
4267
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004268 for(; i < (int)COLS_A; ++i)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004269 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004270#if defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004271 // Load values from matrix A
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004272 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
4273#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4274 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
4275#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4276#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4277 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
4278#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4279#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4280 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
4281#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4282#else // defined(REINTERPRET_INPUT_AS_3D)
4283 // Load values from matrix A
4284 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004285#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4286 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
4287#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4288#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4289 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
4290#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4291#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4292 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
4293#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004294#endif // defined(REINTERPRET_INPUT_AS_3D)
4295
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004296 // Load values from matrix B
4297 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004298 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004299
4300 // Multiply and accumulate
4301 acc00 = fma(a0, b0.s0, acc00);
4302 acc01 = fma(a0, b0.s1, acc01);
4303 acc02 = fma(a0, b0.s2, acc02);
4304 acc03 = fma(a0, b0.s3, acc03);
4305#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4306 acc10 = fma(a1, b0.s0, acc10);
4307 acc11 = fma(a1, b0.s1, acc11);
4308 acc12 = fma(a1, b0.s2, acc12);
4309 acc13 = fma(a1, b0.s3, acc13);
4310#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4311#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4312 acc20 = fma(a2, b0.s0, acc20);
4313 acc21 = fma(a2, b0.s1, acc21);
4314 acc22 = fma(a2, b0.s2, acc22);
4315 acc23 = fma(a2, b0.s3, acc23);
4316#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4317#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4318 acc30 = fma(a3, b0.s0, acc30);
4319 acc31 = fma(a3, b0.s1, acc31);
4320 acc32 = fma(a3, b0.s2, acc32);
4321 acc33 = fma(a3, b0.s3, acc33);
4322#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004323
4324 src_addr.s0 += sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004325 }
4326
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004327 int z = get_global_id(2);
4328
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004329 // Compute destination address
4330 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
4331
4332 // Multiply by the weight of matrix-matrix product and store the result
4333#if defined(ALPHA)
4334 acc00 = acc00 * ALPHA;
4335 acc01 = acc01 * ALPHA;
4336 acc02 = acc02 * ALPHA;
4337 acc03 = acc03 * ALPHA;
4338#endif // defined(ALPHA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004339#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004340 acc10 = acc10 * ALPHA;
4341 acc11 = acc11 * ALPHA;
4342 acc12 = acc12 * ALPHA;
4343 acc13 = acc13 * ALPHA;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004344#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
4345#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004346 acc20 = acc20 * ALPHA;
4347 acc21 = acc21 * ALPHA;
4348 acc22 = acc22 * ALPHA;
4349 acc23 = acc23 * ALPHA;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004350#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
4351#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004352 acc30 = acc30 * ALPHA;
4353 acc31 = acc31 * ALPHA;
4354 acc32 = acc32 * ALPHA;
4355 acc33 = acc33 * ALPHA;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004356#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
4357
4358 // Compute dst address
4359 __global uchar *dst_addr = offset(&dst, 0, 0);
4360
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004361#if defined(ADD_VEC_C)
4362 __global float *src2_addr = (__global float *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
4363 float4 c0 = vload4(0, src2_addr);
4364
4365 acc00 += c0.s0;
4366 acc01 += c0.s1;
4367 acc02 += c0.s2;
4368 acc03 += c0.s3;
4369#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4370 acc10 += c0.s0;
4371 acc11 += c0.s1;
4372 acc12 += c0.s2;
4373 acc13 += c0.s3;
4374#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4375#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4376 acc20 += c0.s0;
4377 acc21 += c0.s1;
4378 acc22 += c0.s2;
4379 acc23 += c0.s3;
4380#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4381#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4382 acc30 += c0.s0;
4383 acc31 += c0.s1;
4384 acc32 += c0.s2;
4385 acc33 += c0.s3;
4386#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4387#endif /* defined(ADD_VEC_C) */
4388
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004389#if defined(REINTERPRET_OUTPUT_AS_3D)
4390 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004391 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004392 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004393 // | |
4394 // | plane0 |
4395 // | |
4396 // |__________________|
4397 // |******************|
4398 // | cross_plane_pad |
4399 // |******************|
4400 // | |
4401 // | plane1 |
4402 // | |
4403 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004404
4405 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
4406 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
4407 zout = min(DEPTH_GEMM3D - 1, zout);
4408
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004409 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004410 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004411
4412 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4413 // multiply dst_stride_z by DEPTH_GEMM3D
4414 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
4415
4416 // Store the output block
4417 vstore4((float4)(acc00, acc01, acc02, acc03), 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
4418#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4419 vstore4((float4)(acc10, acc11, acc12, acc13), 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
4420#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4421#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4422 vstore4((float4)(acc20, acc21, acc22, acc23), 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
4423#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4424#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4425 vstore4((float4)(acc30, acc31, acc32, acc33), 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004426#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004427
4428#else // defined(REINTERPRET_OUTPUT_AS_3D)
4429 // Add offset for batched GEMM
4430 dst_addr += z * dst_stride_z;
4431
4432 // Store the output block
4433 vstore4((float4)(acc00, acc01, acc02, acc03), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
4434#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4435 vstore4((float4)(acc10, acc11, acc12, acc13), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
4436#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4437#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4438 vstore4((float4)(acc20, acc21, acc22, acc23), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
4439#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4440#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4441 vstore4((float4)(acc30, acc31, acc32, acc33), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
4442#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4443#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004444}
4445
4446/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
4447 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004448 * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
4449 *
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004450 * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
4451 * This OpenCL kernel is optimized for Bifrost when the number of matrix B columns is less or equal to 1000.
4452 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
4453 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=2.
4454 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
4455 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha if alpha!=1.0f.
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00004456 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
4457 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004458 *
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004459 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
4460 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004461 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
4462 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
4463 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
4464 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
4465 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004466 * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
4467 *
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004468 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
4469 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
4470 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4471 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
4472 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4473 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
4474 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
4475 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
4476 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4477 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
4478 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4479 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004480 * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
4481 * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes)
4482 * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
4483 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004484 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
4485 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
4486 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
4487 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
4488 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
4489 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004490 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
4491 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
4492 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004493 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
4494 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004495 */
4496__kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
4497 IMAGE_DECLARATION(src1),
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004498#if defined(ADD_VEC_C)
4499 VECTOR_DECLARATION(src2),
4500#endif /* defined(ADD_VEC_C) */
Gian Marcoae2af742018-02-15 12:35:44 +00004501 IMAGE_DECLARATION(dst),
4502 uint src0_stride_z,
4503 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004504 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004505#if defined(REINTERPRET_INPUT_AS_3D)
4506 ,
4507 uint src_cross_plane_pad
4508#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004509#if defined(REINTERPRET_OUTPUT_AS_3D)
4510 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004511 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004512#endif // REINTERPRET_OUTPUT_AS_3D
4513 )
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004514{
4515 // Requires 2 NUM_ELEMS_PROCESSED_PER_THREAD_X, C vect2, A vect4, B (2 vload2) // to fix for NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4516 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
4517
4518 // Compute starting address for matrix A and Matrix B
4519 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
4520
4521 // Update address for the matrix A
4522 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
4523
4524 // Update address for the matrix B
4525 src_addr.s1 += idx * sizeof(float);
4526
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004527#if defined(REINTERPRET_INPUT_AS_3D)
4528 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
4529 // in order to take into account the presence of possible cross plane paddings
4530 //
4531 // | |
4532 // | plane0 |
4533 // | |
4534 // |__________________|
4535 // |******************|
4536 // | cross_plane_pad |
4537 // |******************|
4538 // | |
4539 // | plane1 |
4540 // | |
4541 // |__________________|
4542
4543 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
4544 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
4545 zin = min(DEPTH_GEMM3D - 1, zin);
4546
4547 // Add offset due to the cross plane paddings
4548 zin *= (src_cross_plane_pad * src0_stride_y);
4549
4550 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4551 // multiply src0_stride_z by DEPTH_GEMM3D
4552 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
4553
4554#else // defined(REINTERPRET_INPUT_AS_3D)
4555
Gian Marcoae2af742018-02-15 12:35:44 +00004556 // Add offset for batched GEMM
4557 src_addr.s0 += get_global_id(2) * src0_stride_z;
4558
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004559#endif // defined(REINTERPRET_INPUT_AS_3D)
4560
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00004561#if defined(MATRIX_B_DEPTH)
4562 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
4563 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
4564#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00004565 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00004566#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00004567
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004568 // Initialize accumulators
4569 float acc00 = 0.0f;
4570 float acc01 = 0.0f;
4571
4572#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4573 float acc10 = 0.0f;
4574 float acc11 = 0.0f;
4575#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4576#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4577 float acc20 = 0.0f;
4578 float acc21 = 0.0f;
4579#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4580#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4581 float acc30 = 0.0f;
4582 float acc31 = 0.0f;
4583#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4584
4585 // A and B src indices get incremented at the same time.
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004586 int i = 0;
4587 for(; i <= ((int)COLS_A - 8); i += 8)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004588 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004589#if defined(REINTERPRET_INPUT_AS_3D)
4590 // Load values from matrix A
4591 float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + zin.s0));
4592#else // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004593 // Load values from matrix A
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004594 float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0));
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004595#endif // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004596
4597 // Load values from matrix B
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004598 float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
4599 src_addr.s1 += src1_stride_y;
4600 float2 b1 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
4601 src_addr.s1 += src1_stride_y;
4602 float2 b2 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
4603 src_addr.s1 += src1_stride_y;
4604 float2 b3 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
4605 src_addr.s1 += src1_stride_y;
4606 float2 b4 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
4607 src_addr.s1 += src1_stride_y;
4608 float2 b5 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
4609 src_addr.s1 += src1_stride_y;
4610 float2 b6 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
4611 src_addr.s1 += src1_stride_y;
4612 float2 b7 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
4613 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004614
4615 // Multiply and accumulate
4616 acc00 = fma(a0.s0, b0.s0, acc00);
4617 acc00 = fma(a0.s1, b1.s0, acc00);
4618 acc00 = fma(a0.s2, b2.s0, acc00);
4619 acc00 = fma(a0.s3, b3.s0, acc00);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004620 acc00 = fma(a0.s4, b4.s0, acc00);
4621 acc00 = fma(a0.s5, b5.s0, acc00);
4622 acc00 = fma(a0.s6, b6.s0, acc00);
4623 acc00 = fma(a0.s7, b7.s0, acc00);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004624
4625 acc01 = fma(a0.s0, b0.s1, acc01);
4626 acc01 = fma(a0.s1, b1.s1, acc01);
4627 acc01 = fma(a0.s2, b2.s1, acc01);
4628 acc01 = fma(a0.s3, b3.s1, acc01);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004629 acc01 = fma(a0.s4, b4.s1, acc01);
4630 acc01 = fma(a0.s5, b5.s1, acc01);
4631 acc01 = fma(a0.s6, b6.s1, acc01);
4632 acc01 = fma(a0.s7, b7.s1, acc01);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004633
4634#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004635#if defined(REINTERPRET_INPUT_AS_3D)
4636 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
4637#else // defined(REINTERPRET_INPUT_AS_3D)
4638 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
4639#endif // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004640 acc10 = fma(a0.s0, b0.s0, acc10);
4641 acc10 = fma(a0.s1, b1.s0, acc10);
4642 acc10 = fma(a0.s2, b2.s0, acc10);
4643 acc10 = fma(a0.s3, b3.s0, acc10);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004644 acc10 = fma(a0.s4, b4.s0, acc10);
4645 acc10 = fma(a0.s5, b5.s0, acc10);
4646 acc10 = fma(a0.s6, b6.s0, acc10);
4647 acc10 = fma(a0.s7, b7.s0, acc10);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004648
4649 acc11 = fma(a0.s0, b0.s1, acc11);
4650 acc11 = fma(a0.s1, b1.s1, acc11);
4651 acc11 = fma(a0.s2, b2.s1, acc11);
4652 acc11 = fma(a0.s3, b3.s1, acc11);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004653 acc11 = fma(a0.s4, b4.s1, acc11);
4654 acc11 = fma(a0.s5, b5.s1, acc11);
4655 acc11 = fma(a0.s6, b6.s1, acc11);
4656 acc11 = fma(a0.s7, b7.s1, acc11);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004657#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4658#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004659#if defined(REINTERPRET_INPUT_AS_3D)
4660 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
4661#else // defined(REINTERPRET_INPUT_AS_3D)
4662 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
4663#endif // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004664 acc20 = fma(a0.s0, b0.s0, acc20);
4665 acc20 = fma(a0.s1, b1.s0, acc20);
4666 acc20 = fma(a0.s2, b2.s0, acc20);
4667 acc20 = fma(a0.s3, b3.s0, acc20);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004668 acc20 = fma(a0.s4, b4.s0, acc20);
4669 acc20 = fma(a0.s5, b5.s0, acc20);
4670 acc20 = fma(a0.s6, b6.s0, acc20);
4671 acc20 = fma(a0.s7, b7.s0, acc20);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004672
4673 acc21 = fma(a0.s0, b0.s1, acc21);
4674 acc21 = fma(a0.s1, b1.s1, acc21);
4675 acc21 = fma(a0.s2, b2.s1, acc21);
4676 acc21 = fma(a0.s3, b3.s1, acc21);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004677 acc21 = fma(a0.s4, b4.s1, acc21);
4678 acc21 = fma(a0.s5, b5.s1, acc21);
4679 acc21 = fma(a0.s6, b6.s1, acc21);
4680 acc21 = fma(a0.s7, b7.s1, acc21);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004681#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4682#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004683#if defined(REINTERPRET_INPUT_AS_3D)
4684 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
4685#else // defined(REINTERPRET_INPUT_AS_3D)
4686 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
4687#endif // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004688 acc30 = fma(a0.s0, b0.s0, acc30);
4689 acc30 = fma(a0.s1, b1.s0, acc30);
4690 acc30 = fma(a0.s2, b2.s0, acc30);
4691 acc30 = fma(a0.s3, b3.s0, acc30);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004692 acc30 = fma(a0.s4, b4.s0, acc30);
4693 acc30 = fma(a0.s5, b5.s0, acc30);
4694 acc30 = fma(a0.s6, b6.s0, acc30);
4695 acc30 = fma(a0.s7, b7.s0, acc30);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004696
4697 acc31 = fma(a0.s0, b0.s1, acc31);
4698 acc31 = fma(a0.s1, b1.s1, acc31);
4699 acc31 = fma(a0.s2, b2.s1, acc31);
4700 acc31 = fma(a0.s3, b3.s1, acc31);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004701 acc31 = fma(a0.s4, b4.s1, acc31);
4702 acc31 = fma(a0.s5, b5.s1, acc31);
4703 acc31 = fma(a0.s6, b6.s1, acc31);
4704 acc31 = fma(a0.s7, b7.s1, acc31);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004705#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004706
4707 src_addr.s0 += sizeof(float) * 8;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004708 }
4709 // float size increment
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004710 for(; i < (int)COLS_A; ++i)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004711 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004712#if defined(REINTERPRET_INPUT_AS_3D)
4713 // Load values from matrix A
4714 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
4715#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4716 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
4717#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4718#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4719 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
4720#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4721#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4722 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
4723#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4724#else // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004725 // Load values from matrix A
4726 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
4727#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4728 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
4729#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4730#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4731 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
4732#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4733#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4734 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
4735#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004736#endif // defined(REINTERPRET_INPUT_AS_3D)
4737
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004738 // Load values from matrix B
4739 float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004740 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004741
4742 // Multiply and accumulate
4743 acc00 = fma(a0, b0.s0, acc00);
4744 acc01 = fma(a0, b0.s1, acc01);
4745#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4746 acc10 = fma(a1, b0.s0, acc10);
4747 acc11 = fma(a1, b0.s1, acc11);
4748#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4749#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4750 acc20 = fma(a2, b0.s0, acc20);
4751 acc21 = fma(a2, b0.s1, acc21);
4752#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4753#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4754 acc30 = fma(a3, b0.s0, acc30);
4755 acc31 = fma(a3, b0.s1, acc31);
4756#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004757
4758 src_addr.s0 += sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004759 }
4760
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004761 // Multiply by the weight of matrix-matrix product and store the result
4762#if defined(ALPHA)
4763 acc00 = acc00 * ALPHA;
4764 acc01 = acc01 * ALPHA;
4765#endif // defined(ALPHA)
4766#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
4767 acc10 = acc10 * ALPHA;
4768 acc11 = acc11 * ALPHA;
4769#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
4770#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
4771 acc20 = acc20 * ALPHA;
4772 acc21 = acc21 * ALPHA;
4773#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
4774#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
4775 acc30 = acc30 * ALPHA;
4776 acc31 = acc31 * ALPHA;
4777#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
4778
4779 int z = get_global_id(2);
4780
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004781 // Compute destination address
4782 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
4783
Gian Marcoae2af742018-02-15 12:35:44 +00004784 // Compute dst address
4785 __global uchar *dst_addr = offset(&dst, 0, 0);
4786
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004787#if defined(ADD_VEC_C)
4788 __global float *src2_addr = (__global float *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
4789 float2 c0 = vload2(0, src2_addr);
4790
4791 acc00 += c0.s0;
4792 acc01 += c0.s1;
4793#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4794 acc10 += c0.s0;
4795 acc11 += c0.s1;
4796#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4797#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4798 acc20 += c0.s0;
4799 acc21 += c0.s1;
4800#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4801#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4802 acc30 += c0.s0;
4803 acc31 += c0.s1;
4804#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4805#endif /* defined(ADD_VEC_C) */
4806
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004807#if defined(REINTERPRET_OUTPUT_AS_3D)
4808 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004809 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004810 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004811 // | |
4812 // | plane0 |
4813 // | |
4814 // |__________________|
4815 // |******************|
4816 // | cross_plane_pad |
4817 // |******************|
4818 // | |
4819 // | plane1 |
4820 // | |
4821 // |__________________|
Gian Marcoae2af742018-02-15 12:35:44 +00004822
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004823 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
4824 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
4825 zout = min(DEPTH_GEMM3D - 1, zout);
4826
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004827 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004828 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004829
4830 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4831 // multiply dst_stride_z by DEPTH_GEMM3D
4832 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
4833
4834 // Store the output block
4835 vstore2((float2)(acc00, acc01), 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004836#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004837 vstore2((float2)(acc10, acc11), 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004838#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4839#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004840 vstore2((float2)(acc20, acc21), 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004841#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4842#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004843 vstore2((float2)(acc30, acc31), 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004844#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004845
4846#else // defined(REINTERPRET_OUTPUT_AS_3D)
4847 // Add offset for batched GEMM
4848 dst_addr += z * dst_stride_z;
4849
4850 // Store the output block
4851 vstore2((float2)(acc00, acc01), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
4852#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4853 vstore2((float2)(acc10, acc11), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
4854#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4855#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4856 vstore2((float2)(acc20, acc21), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
4857#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4858#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4859 vstore2((float2)(acc30, acc31), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
4860#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4861#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004862}
4863
Vidhya Sudhan Loganathanbdff4912018-05-22 15:03:09 +01004864#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01004865/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
4866 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004867 * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
4868 *
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00004869 * @note This OpenCL kernel works with the 16-bit floating point data type (half) and accumulating the result in a 32 floating point variable.
4870 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
4871 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
4872 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
4873 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
4874 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
4875 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
4876 *
4877 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
4878 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
4879 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
4880 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
4881 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
4882 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
4883 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004884 * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
4885 *
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00004886 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
4887 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
4888 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4889 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
4890 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4891 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
4892 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
4893 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
4894 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4895 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
4896 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4897 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004898 * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
4899 * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes)
4900 * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
4901 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00004902 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
4903 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
4904 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
4905 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
4906 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
4907 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
4908 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
4909 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
4910 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
4911 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
4912 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
4913 */
4914__kernel void gemm_mm_floating_point_f16_bifrost_acc32(IMAGE_DECLARATION(src0),
4915 IMAGE_DECLARATION(src1),
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004916#if defined(ADD_VEC_C)
4917 VECTOR_DECLARATION(src2),
4918#endif /* defined(ADD_VEC_C) */
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00004919 IMAGE_DECLARATION(dst),
4920 uint src0_stride_z,
4921 uint src1_stride_z,
4922 uint dst_stride_z
4923#if defined(REINTERPRET_INPUT_AS_3D)
4924 ,
4925 uint src_cross_plane_pad
4926#endif // REINTERPRET_INPUT_AS_3D
4927#if defined(REINTERPRET_OUTPUT_AS_3D)
4928 ,
4929 uint dst_cross_plane_pad
4930#endif // REINTERPRET_OUTPUT_AS_3D
4931 )
4932{
4933 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
4934
4935 // Compute starting address for matrix A and Matrix B
4936 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
4937
4938 // Update address for the matrix A
4939 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
4940
4941 // Update address for the matrix B
4942 src_addr.s1 += idx * sizeof(half);
4943
4944#if defined(REINTERPRET_INPUT_AS_3D)
4945 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
4946 // in order to take into account the presence of possible cross plane paddings
4947 //
4948 // | |
4949 // | plane0 |
4950 // | |
4951 // |__________________|
4952 // |******************|
4953 // | cross_plane_pad |
4954 // |******************|
4955 // | |
4956 // | plane1 |
4957 // | |
4958 // |__________________|
4959
4960 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
4961 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
4962 zin = min(DEPTH_GEMM3D - 1, zin);
4963
4964 // Add offset due to the cross plane paddings
4965 zin *= (src_cross_plane_pad * src0_stride_y);
4966
4967 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4968 // multiply src0_stride_z by DEPTH_GEMM3D
4969 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
4970
4971#else // defined(REINTERPRET_INPUT_AS_3D)
4972
4973 // Add offset for batched GEMM
4974 src_addr.s0 += get_global_id(2) * src0_stride_z;
4975
4976#endif // defined(REINTERPRET_INPUT_AS_3D)
4977
4978#if defined(MATRIX_B_DEPTH)
4979 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
4980 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
4981#else // defined(MATRIX_B_DEPTH)
4982 src_addr.s1 += get_global_id(2) * src1_stride_z;
4983#endif // defined(MATRIX_B_DEPTH)
4984
4985 float8 acc0 = 0.0h;
4986#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4987 float8 acc1 = 0.0h;
4988#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4989#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4990 float8 acc2 = 0.0h;
4991#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4992#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4993 float8 acc3 = 0.0h;
4994#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4995
4996 int i = 0;
4997 for(; i <= ((int)COLS_A - 4); i += 4)
4998 {
4999#if defined(REINTERPRET_INPUT_AS_3D)
5000 // Load values from matrix A
Usama Arif0681e3b2019-04-25 14:28:07 +01005001 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 4, half, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s);
5002#else // defined(REINTERPRET_INPUT_AS_3D)
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005003 // Load values from matrix A
5004 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
5005#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5006 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
5007#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5008#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5009 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
5010#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5011#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5012 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
5013#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5014#endif // defined(REINTERPRET_INPUT_AS_3D)
5015
5016 // Load values from matrix B
5017 float8 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
5018 src_addr.s1 += src1_stride_y;
5019
5020 // Accumulate
5021 acc0 = fma(b0, (float8)a0.s0, acc0);
5022#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5023 acc1 = fma(b0, (float8)a1.s0, acc1);
5024#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5025#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5026 acc2 = fma(b0, (float8)a2.s0, acc2);
5027#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5028#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5029 acc3 = fma(b0, (float8)a3.s0, acc3);
5030#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5031
5032 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
5033 src_addr.s1 += src1_stride_y;
5034 acc0 = fma(b0, (float8)a0.s1, acc0);
5035#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5036 acc1 = fma(b0, (float8)a1.s1, acc1);
5037#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5038#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5039 acc2 = fma(b0, (float8)a2.s1, acc2);
5040#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5041#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5042 acc3 = fma(b0, (float8)a3.s1, acc3);
5043#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5044
5045 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
5046 src_addr.s1 += src1_stride_y;
5047 acc0 = fma(b0, (float8)a0.s2, acc0);
5048#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5049 acc1 = fma(b0, (float8)a1.s2, acc1);
5050#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5051#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5052 acc2 = fma(b0, (float8)a2.s2, acc2);
5053#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5054#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5055 acc3 = fma(b0, (float8)a3.s2, acc3);
5056#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5057
5058 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
5059 src_addr.s1 += src1_stride_y;
5060 acc0 = fma(b0, (float8)a0.s3, acc0);
5061#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5062 acc1 = fma(b0, (float8)a1.s3, acc1);
5063#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5064#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5065 acc2 = fma(b0, (float8)a2.s3, acc2);
5066#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5067#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5068 acc3 = fma(b0, (float8)a3.s3, acc3);
5069#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5070
5071 src_addr.s0 += 4 * sizeof(half);
5072 }
5073
5074 for(; i < (int)COLS_A; ++i)
5075 {
5076#if defined(REINTERPRET_INPUT_AS_3D)
5077 // Load values from matrix A
5078 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
5079#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5080 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
5081#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5082#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5083 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
5084#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5085#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5086 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
5087#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5088#else // defined(REINTERPRET_INPUT_AS_3D)
5089 // Load values from matrix A
5090 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
5091#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5092 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
5093#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5094#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5095 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
5096#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5097#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5098 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
5099#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5100#endif // defined(REINTERPRET_INPUT_AS_3D)
5101
5102 // Load values from matrix B
5103 float8 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
5104
5105 src_addr += (int2)(sizeof(half), src1_stride_y);
5106
5107 // Accumulate
5108 acc0 = fma(b0, (float8)a0, acc0); // b0 * (half8)a0;
5109#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5110 acc1 = fma(b0, (float8)a1, acc1); // b0 * (half8)a1;
5111#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5112#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5113 acc2 = fma(b0, (float8)a2, acc2); // b0 * (half8)a2;
5114#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5115#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5116 acc3 = fma(b0, (float8)a3, acc3); // b0 * (half8)a3;
5117#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5118 }
5119
5120 // Multiply by the weight of matrix-matrix product and store the result
5121#if defined(ALPHA)
5122 half8 hacc0 = convert_half8(acc0) * (half8)ALPHA;
5123#else //defined(ALPHA)
5124 half8 hacc0 = convert_half8(acc0);
5125#endif // defined(ALPHA)
5126#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5127#if defined(ALPHA)
5128 half8 hacc1 = convert_half8(acc1) * (half8)ALPHA;
5129#else //defined(ALPHA)
5130 half8 hacc1 = convert_half8(acc1);
5131#endif //defined(ALPHA)
5132#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y
5133
5134#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5135#if defined(ALPHA)
5136 half8 hacc2 = convert_half8(acc2) * (half8)ALPHA;
5137#else //defined(ALPHA)
5138 half8 hacc2 = convert_half8(acc2);
5139#endif //defined(ALPHA)
5140#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5141
5142#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5143#if defined(ALPHA)
5144 half8 hacc3 = convert_half8(acc3) * (half8)ALPHA;
5145#else //defined(ALPHA)
5146 half8 hacc3 = convert_half8(acc3);
5147#endif // defined(ALPHA)
5148#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5149
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00005150#if defined(ADD_VEC_C)
5151 // *INDENT-OFF*
5152 // clang-format off
5153 __global half *src2_addr = (__global half *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
5154 half8 c0 = vload8(0, src2_addr);
5155 // clang-format on
5156 // *INDENT-ON*
5157
5158 hacc0 += c0;
5159#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5160 hacc1 += c0;
5161#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5162#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5163 hacc2 += c0;
5164#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5165#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5166 hacc3 += c0;
5167#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5168#endif /* defined(ADD_VEC_C) */
5169
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005170 int z = get_global_id(2);
5171
5172 // Compute destination address
5173 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
5174
5175 // Compute dst address
5176 __global uchar *dst_addr = offset(&dst, 0, 0);
5177
5178#if defined(REINTERPRET_OUTPUT_AS_3D)
5179 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
5180 // in order to take into account the presence of possible cross plane paddings
5181 //
5182 // | |
5183 // | plane0 |
5184 // | |
5185 // |__________________|
5186 // |******************|
5187 // | cross_plane_pad |
5188 // |******************|
5189 // | |
5190 // | plane1 |
5191 // | |
5192 // |__________________|
5193
5194 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
5195 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
5196 zout = min(DEPTH_GEMM3D - 1, zout);
5197
5198 // Add offset due to the cross plane paddings
5199 zout *= (dst_cross_plane_pad * dst_stride_y);
5200
5201 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
5202 // multiply dst_stride_z by DEPTH_GEMM3D
5203 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005204 // Store the output block
Usama Arif0681e3b2019-04-25 14:28:07 +01005205 STORE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 8, half, hacc, dst_addr, dst_stride_y, zout.s);
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005206#else // defined(REINTERPRET_OUTPUT_AS_3D)
5207 // Add offset for batched GEMM
5208 dst_addr += z * dst_stride_z;
5209
5210 // Store the output block
5211 vstore8(hacc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
5212#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5213 vstore8(hacc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
5214#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5215#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5216 vstore8(hacc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
5217#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5218#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5219 vstore8(hacc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
5220#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5221#endif // REINTERPRET_OUTPUT_AS_3D
5222}
5223
5224/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
5225 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00005226 * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
5227 *
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005228 * @note This OpenCL kernel works with the 16-bit floating point data type (half) and uses the fma units.
5229 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
5230 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
5231 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
5232 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
5233 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
5234 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
5235 *
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005236 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
5237 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005238 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
5239 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
5240 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
5241 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
5242 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00005243 * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
5244 *
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005245 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
5246 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
5247 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5248 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
5249 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
5250 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
5251 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
5252 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
5253 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5254 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
5255 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
5256 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00005257 * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
5258 * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes)
5259 * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
5260 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005261 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
5262 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
5263 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
5264 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
5265 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
5266 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005267 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
5268 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
5269 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005270 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
5271 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005272 */
5273__kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0),
5274 IMAGE_DECLARATION(src1),
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00005275#if defined(ADD_VEC_C)
5276 VECTOR_DECLARATION(src2),
5277#endif /* defined(ADD_VEC_C) */
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005278 IMAGE_DECLARATION(dst),
5279 uint src0_stride_z,
5280 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005281 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005282#if defined(REINTERPRET_INPUT_AS_3D)
5283 ,
5284 uint src_cross_plane_pad
5285#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005286#if defined(REINTERPRET_OUTPUT_AS_3D)
5287 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005288 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005289#endif // REINTERPRET_OUTPUT_AS_3D
5290 )
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005291{
5292 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
5293
5294 // Compute starting address for matrix A and Matrix B
5295 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
5296
5297 // Update address for the matrix A
5298 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
5299
5300 // Update address for the matrix B
5301 src_addr.s1 += idx * sizeof(half);
5302
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005303#if defined(REINTERPRET_INPUT_AS_3D)
5304 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
5305 // in order to take into account the presence of possible cross plane paddings
5306 //
5307 // | |
5308 // | plane0 |
5309 // | |
5310 // |__________________|
5311 // |******************|
5312 // | cross_plane_pad |
5313 // |******************|
5314 // | |
5315 // | plane1 |
5316 // | |
5317 // |__________________|
5318
5319 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
5320 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
5321 zin = min(DEPTH_GEMM3D - 1, zin);
5322
5323 // Add offset due to the cross plane paddings
5324 zin *= (src_cross_plane_pad * src0_stride_y);
5325
5326 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
5327 // multiply src0_stride_z by DEPTH_GEMM3D
5328 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
5329
5330#else // defined(REINTERPRET_INPUT_AS_3D)
5331
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005332 // Add offset for batched GEMM
5333 src_addr.s0 += get_global_id(2) * src0_stride_z;
5334
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005335#endif // defined(REINTERPRET_INPUT_AS_3D)
5336
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005337#if defined(MATRIX_B_DEPTH)
5338 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
5339 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
5340#else // defined(MATRIX_B_DEPTH)
5341 src_addr.s1 += get_global_id(2) * src1_stride_z;
5342#endif // defined(MATRIX_B_DEPTH)
5343
5344 half8 acc0 = 0.0h;
5345#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5346 half8 acc1 = 0.0h;
5347#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5348#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5349 half8 acc2 = 0.0h;
5350#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5351#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5352 half8 acc3 = 0.0h;
5353#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5354
5355 int i = 0;
5356 for(; i <= ((int)COLS_A - 4); i += 4)
5357 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005358#if defined(REINTERPRET_INPUT_AS_3D)
5359 // Load values from matrix A
Usama Arif0681e3b2019-04-25 14:28:07 +01005360 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 4, half, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s);
5361#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005362 // Load values from matrix A
5363 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
5364#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5365 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
5366#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5367#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5368 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
5369#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5370#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5371 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
5372#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005373#endif // defined(REINTERPRET_INPUT_AS_3D)
5374
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005375 // Load values from matrix B
5376 half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
5377 src_addr.s1 += src1_stride_y;
5378
5379 // Accumulate
5380 acc0 = fma(b0, (half8)a0.s0, acc0);
5381#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5382 acc1 = fma(b0, (half8)a1.s0, acc1);
5383#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5384#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5385 acc2 = fma(b0, (half8)a2.s0, acc2);
5386#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5387#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5388 acc3 = fma(b0, (half8)a3.s0, acc3);
5389#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5390
5391 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
5392 src_addr.s1 += src1_stride_y;
5393 acc0 = fma(b0, (half8)a0.s1, acc0);
5394#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5395 acc1 = fma(b0, (half8)a1.s1, acc1);
5396#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5397#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5398 acc2 = fma(b0, (half8)a2.s1, acc2);
5399#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5400#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5401 acc3 = fma(b0, (half8)a3.s1, acc3);
5402#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5403
5404 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
5405 src_addr.s1 += src1_stride_y;
5406 acc0 = fma(b0, (half8)a0.s2, acc0);
5407#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5408 acc1 = fma(b0, (half8)a1.s2, acc1);
5409#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5410#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5411 acc2 = fma(b0, (half8)a2.s2, acc2);
5412#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5413#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5414 acc3 = fma(b0, (half8)a3.s2, acc3);
5415#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5416
5417 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
5418 src_addr.s1 += src1_stride_y;
5419 acc0 = fma(b0, (half8)a0.s3, acc0);
5420#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5421 acc1 = fma(b0, (half8)a1.s3, acc1);
5422#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5423#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5424 acc2 = fma(b0, (half8)a2.s3, acc2);
5425#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5426#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5427 acc3 = fma(b0, (half8)a3.s3, acc3);
5428#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5429
5430 src_addr.s0 += 4 * sizeof(half);
5431 }
5432
5433 for(; i < (int)COLS_A; ++i)
5434 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005435#if defined(REINTERPRET_INPUT_AS_3D)
5436 // Load values from matrix A
5437 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
5438#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5439 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
5440#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5441#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5442 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
5443#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5444#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5445 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
5446#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5447#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005448 // Load values from matrix A
5449 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
5450#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5451 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
5452#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5453#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5454 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
5455#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5456#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5457 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
5458#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005459#endif // defined(REINTERPRET_INPUT_AS_3D)
5460
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005461 // Load values from matrix B
5462 half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
5463
5464 src_addr += (int2)(sizeof(half), src1_stride_y);
5465
5466 // Accumulate
5467 acc0 = fma(b0, (half8)a0, acc0); // b0 * (half8)a0;
5468#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5469 acc1 = fma(b0, (half8)a1, acc1); // b0 * (half8)a1;
5470#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5471#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5472 acc2 = fma(b0, (half8)a2, acc2); // b0 * (half8)a2;
5473#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5474#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5475 acc3 = fma(b0, (half8)a3, acc3); // b0 * (half8)a3;
5476#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5477 }
5478
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005479 // Multiply by the weight of matrix-matrix product and store the result
5480#if defined(ALPHA)
5481 acc0 = acc0 * (half8)ALPHA;
5482#endif // defined(ALPHA)
5483#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
5484 acc1 = acc1 * (half8)ALPHA;
5485#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
5486#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
5487 acc2 = acc2 * (half8)ALPHA;
5488#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
5489#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
5490 acc3 = acc3 * (half8)ALPHA;
5491#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
5492
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00005493#if defined(ADD_VEC_C)
5494 // *INDENT-OFF*
5495 // clang-format off
5496 __global half *src2_addr = (__global half *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
5497 half8 c0 = vload8(0, src2_addr);
5498 // clang-format on
5499 // *INDENT-ON*
5500
5501 acc0 += c0;
5502#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5503 acc1 += c0;
5504#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5505#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5506 acc2 += c0;
5507#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5508#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5509 acc3 += c0;
5510#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5511#endif /* defined(ADD_VEC_C) */
5512
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005513 int z = get_global_id(2);
5514
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005515 // Compute destination address
5516 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
5517
5518 // Compute dst address
5519 __global uchar *dst_addr = offset(&dst, 0, 0);
5520
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005521#if defined(REINTERPRET_OUTPUT_AS_3D)
5522 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01005523 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005524 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01005525 // | |
5526 // | plane0 |
5527 // | |
5528 // |__________________|
5529 // |******************|
5530 // | cross_plane_pad |
5531 // |******************|
5532 // | |
5533 // | plane1 |
5534 // | |
5535 // |__________________|
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005536
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005537 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
5538 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
5539 zout = min(DEPTH_GEMM3D - 1, zout);
5540
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01005541 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005542 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005543
5544 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
5545 // multiply dst_stride_z by DEPTH_GEMM3D
5546 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
5547
5548 // Store the output block
Usama Arif0681e3b2019-04-25 14:28:07 +01005549 STORE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 8, half, acc, dst_addr, dst_stride_y, zout.s);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005550#else // defined(REINTERPRET_OUTPUT_AS_3D)
5551 // Add offset for batched GEMM
5552 dst_addr += z * dst_stride_z;
5553
5554 // Store the output block
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005555 vstore8(acc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
5556#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005557 vstore8(acc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
5558#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5559#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005560 vstore8(acc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
5561#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5562#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005563 vstore8(acc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
5564#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005565#endif // REINTERPRET_OUTPUT_AS_3D
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005566}
Vidhya Sudhan Loganathanbdff4912018-05-22 15:03:09 +01005567#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005568
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01005569#endif // defined(COLS_A) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && (NUM_ELEMS_PROCESSED_PER_THREAD_Y)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005570
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005571#if defined(BETA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005572/** This OpenCL kernel performs the in-place matrix addition between 2 matrices taking into account that the second matrix might be weighted by a scalar value beta:
5573 *
Gian Marco19835e52018-01-30 13:35:54 +00005574 * @note The beta's value need to be passed at compile time using -DBETA
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005575 *
5576 * @param[in] src_ptr Pointer to the source matrix. Supported data types: F32
5577 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
5578 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5579 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
5580 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005581 * @param[in] src_stride_z Stride of the destination tensor in Z dimension (in bytes)
5582 * @param[in] src_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005583 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01005584 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005585 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
5586 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
5587 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
5588 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005589 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
5590 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005591 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
5592 */
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005593__kernel void gemm_ma_f32(TENSOR3D_DECLARATION(src),
5594 TENSOR3D_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005595{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005596 // Compute source and destination addresses
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005597 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
5598 Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005599
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005600 // Load values from A x B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005601 float4 alpha_ab = vload4(0, (__global float *)dst.ptr);
5602
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005603 // Load values from Matrix C
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005604 float4 c = vload4(0, (__global float *)src.ptr);
5605
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005606 // Computes alpha * axb + beta * c
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005607 float4 out = alpha_ab + (float4)BETA * c;
5608
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005609 // Store final result in axb matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005610 vstore4(out, 0, (__global float *)dst.ptr);
5611}
5612
Vidhya Sudhan Loganathan76c85642018-05-25 13:53:02 +01005613#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005614/** This OpenCL kernel performs the in-place matrix addition between 2 matrices taking into account that the second matrix might be weighted by a scalar value beta:
5615 *
Gian Marco19835e52018-01-30 13:35:54 +00005616 * @note The beta's value need to be passed at compile time using -DBETA
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01005617 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005618 * @param[in] src_ptr Pointer to the source matrix. Supported data types: F16
5619 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
5620 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5621 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
5622 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005623 * @param[in] src_stride_z Stride of the destination tensor in Z dimension (in bytes)
5624 * @param[in] src_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005625 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01005626 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005627 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
5628 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
5629 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
5630 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005631 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
5632 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005633 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
5634 */
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005635__kernel void gemm_ma_f16(TENSOR3D_DECLARATION(src),
5636 TENSOR3D_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005637{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005638 // Compute source and destination addresses
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005639 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
5640 Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005641
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005642 // Load values from A x B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005643 half8 alpha_ab = vload8(0, (__global half *)dst.ptr);
5644
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005645 // Load values from Matrix C
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005646 half8 c = vload8(0, (__global half *)src.ptr);
5647
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005648 // Computes alpha * axb + beta * c
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005649 half8 out = alpha_ab + (half8)BETA * c;
5650
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005651 // Store final result in axb matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005652 vstore8(out, 0, (__global half *)dst.ptr);
5653}
Vidhya Sudhan Loganathan76c85642018-05-25 13:53:02 +01005654#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005655#endif // defined(BETA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005656
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005657#if defined(WIDTH_VECTOR_A)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005658/** This OpenCL kernel computes the vector by matrix multiplication between each row of A (src0) and matrix B (src1) used for locally connected layer
5659 *
Gian Marco19835e52018-01-30 13:35:54 +00005660 * @note The width of A need to be passed at compile time using -DWIDTH_VECTOR_A
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005661 *
Gian Marco19835e52018-01-30 13:35:54 +00005662 * @note The input A and matrix B must not be reshaped
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005663 *
5664 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
5665 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
5666 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5667 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
5668 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
5669 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01005670 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005671 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
5672 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5673 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
5674 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
5675 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
5676 * @param[in] src1_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
5677 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01005678 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005679 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
5680 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
5681 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
5682 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
5683 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
5684 */
5685__kernel void gemm_lc_vm_f32(IMAGE_DECLARATION(src0),
5686 TENSOR3D_DECLARATION(src1),
5687 IMAGE_DECLARATION(dst))
5688{
5689 int idx = get_global_id(0) * 4;
5690 int idy = get_global_id(1);
5691
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005692 // Compute the address for the vector A and matrix B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005693 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes + src0_stride_y * idy, src1_offset_first_element_in_bytes + src1_stride_z * idy));
5694 src_addr.s1 += idx * sizeof(float);
5695
5696 int end_row_vec_a = src_addr.s0 + (WIDTH_VECTOR_A * sizeof(float));
5697
5698 float4 acc = 0.0f;
5699
Georgios Pinitas96880cf2017-10-20 18:52:20 +01005700 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(float)); src_addr += (int2)(2 * sizeof(float), 2 * src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005701 {
5702 float2 a0 = vload2(0, (__global float *)(src0_ptr + src_addr.s0));
5703 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
5704 float4 b1 = vload4(0, (__global float *)(src1_ptr + src_addr.s1 + src1_stride_y));
5705
5706 acc += b0 * (float4)a0.s0;
5707 acc += b1 * (float4)a0.s1;
5708 }
5709
5710 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(float), src1_stride_y))
5711 {
5712 float a0 = *((__global float *)(src0_ptr + src_addr.s0));
5713 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
5714
5715 acc += b0 * (float4)a0;
5716 }
5717
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005718 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005719 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
5720
5721 vstore4(acc, 0, (__global float *)(offset(&dst, 0, 0)));
5722}
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005723#endif // defined(WIDTH_VECTOR_A)
5724
5725/** This kernel accumulates each row with the biases vector.
5726 *
5727 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=short.
5728 * @note The vector size must be passed at compile time using -DVECTOR_SIZE e.g. -DVECTOR_SIZE=16.
5729 *
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01005730 * @param[in, out] accum_ptr Pointer to the accumulate tensor. Supported data type: U8/S8/U16/S16/F16/U32/S32/F32
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005731 * @param[in] accum_stride_x Stride of the accmulate tensor in X dimension (in bytes)
5732 * @param[in] accum_step_x accum_stride_x * number of elements along X processed per workitem(in bytes)
5733 * @param[in] accum_stride_y Stride of the accumlulate tensor in Y dimension (in bytes)
5734 * @param[in] accum_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
5735 * @param[in] accum_offset_first_element_in_bytes The offset of the first element in the accumulate tensor
5736 * @param[in] biases_ptr Pointer to the biases vector. Same as @p accum_ptr
5737 * @param[in] biases_stride_x Stride of the destination tensor in X dimension (in bytes)
5738 * @param[in] biases_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
5739 * @param[in] biases_offset_first_element_in_bytes The offset of the first element in the destination tensor
5740 */
5741#if defined(DATA_TYPE) && defined(VECTOR_SIZE)
5742__kernel void gemm_accumulate_biases(
5743 IMAGE_DECLARATION(accum),
5744 VECTOR_DECLARATION(biases))
5745{
5746 Image accum = CONVERT_TO_IMAGE_STRUCT(accum);
5747 Vector biases = CONVERT_TO_VECTOR_STRUCT(biases);
5748
5749 // Vector size, i.e. number of vector elements.
5750 VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
5751 accum_value = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)accum.ptr);
5752 VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
5753 biases_value = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)biases.ptr);
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01005754 accum_value = biases_value + accum_value;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005755 // Store result in the accumulate buffer
5756 VSTORE(VECTOR_SIZE)
5757 (accum_value, 0, (__global DATA_TYPE *)accum.ptr);
5758}
5759#endif // defined(DATA_TYPE) && defined(VECTOR_SIZE)