blob: 2ac2eb7c323f33f54f394f94ab33a932f4ad8017 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +00002 * Copyright (c) 2017-2019 ARM Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Usama Arif0681e3b2019-04-25 14:28:07 +010024#include "gemm_helpers.h"
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +000025#include "repeat.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010026
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +000027#if defined(M0) && defined(K0) && defined(V0) && defined(DATA_TYPE) && defined(SRC_WIDTH)
28#define INC2 (VEC_DATA_TYPE(uint, 2))(0, 1)
29#define INC3 (VEC_DATA_TYPE(uint, 3))(0, 1, 2)
30#define INC4 (VEC_DATA_TYPE(uint, 4))(0, 1, 2, 3)
31#define INC8 (VEC_DATA_TYPE(uint, 8))(0, 1, 2, 3, 4, 5, 6, 7)
32#define INC16 (VEC_DATA_TYPE(uint, 16))(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)
33#define CONCAT_INC(K0) INC##K0
34#define INC(K0) CONCAT_INC(K0)
35
36#if(SRC_WIDTH % K0)
37#define BOUNDARY_CONDITION_X(x, a) \
38 ({ \
39 a = select(0, a, CONVERT(((x * (VEC_DATA_TYPE(uint, K0))K0 + INC(K0)) < (VEC_DATA_TYPE(uint, K0))SRC_WIDTH), VEC_DATA_TYPE(DATA_TYPE, K0))); \
40 })
41#else // (SRC_WIDTH % K0)
42#define BOUNDARY_CONDITION_X(x, a) \
43 ({})
44#endif // (SRC_WIDTH % K0)
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000045
46/** This OpenCL kernel reshapes the lhs input matrix. The kernel splits the input matrix in blocks of size M0xK0 and stores each one (not transposed) in
47 * the output matrix unrolling the values.
48 *
49 * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float)
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +000050 * @note The width of the input tensor must be passed at compile time using -DSRC_WIDTH (i.e. -DSRC_WIDTH=16)
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000051 * @note The block's dimensions (M0 and K0) must be passed at compile time using -DM0 and -DK0 (i.e. -DM0=2, -DK0=2).
52 * @note The number of M0xK0 vertical blocks to store on the same output row must be passed at compile time using -DV0 (i.e. -DV0=2)
53 * @note Only the following values for M0, K0 and V0 are supported:
54 * M0: 2,3,4,5,6,7,8
Gian Marco Iodicebacfec52019-01-11 11:30:55 +000055 * K0: 2,3,4,8,16
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000056 * V0: greater than 0
57 * @note In case the input has to be reinterpreted as a 3D tensor (i.e. input of convolution layer 1x1), the following information must be passed at compile time:
58 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
59 * -# HEIGHT_GEMM3D: The height of the input in case it has to be reinterpreted as a 3D tensor.
60 * -# DEPTH_GEMM3D: The depth of the input in case it has to be reinterpreted as a 3D tensor
61 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
62 * @note If the M0xK0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
63 *
64 * @param[in] src_ptr Pointer to the source LHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
65 * @param[in] src_stride_x Stride of the source LHS tensor in X dimension (in bytes)
66 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
67 * @param[in] src_stride_y Stride of the source LHS tensor in Y dimension (in bytes)
68 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
69 * @param[in] src_stride_z Stride of the source LHS tensor in Z dimension (in bytes)
70 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
71 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source LHS tensor
72 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
73 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
74 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
75 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
76 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
77 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
78 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
79 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
80 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
81 */
82__kernel void gemm_reshape_lhs_matrix_nt(TENSOR3D_DECLARATION(src),
83 TENSOR3D_DECLARATION(dst)
84#if defined(REINTERPRET_INPUT_AS_3D)
85 ,
86 uint cross_plane_pad
87#endif // REINTERPRET_INPUT_AS_3D
88 )
89{
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +000090 // Block size
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000091#define BLOCK_SIZE ((M0) * (K0))
92
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +000093 // Output offset X
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000094#if defined(INTERLEAVE)
95#define OUTPUT_OFFSET_X (K0)
96#else // defined(INTERLEAVE)
97#define OUTPUT_OFFSET_X (BLOCK_SIZE)
98#endif // defined(INTERLEAVE)
99
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000100 // Output step X
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000101#if defined(INTERLEAVE)
102#define OUTPUT_STEP_X (K0) * (V0)
103#else // Do not interleave
104#define OUTPUT_STEP_X (K0)
105#endif // defined(INTERLEAVE)
106
107 // Compute source and destination addresses
108 uint x = get_global_id(0);
109 uint y = get_global_id(1);
110 uint z = get_global_id(2);
111
112 // ------------------ Compute input/output addresses ---------------------------
113
114 // Compute the input address
115 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)K0 * sizeof(DATA_TYPE) + y * (uint)M0 * src_stride_y;
116
117 // Compute the output address
118 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)BLOCK_SIZE * (uint)V0 * sizeof(DATA_TYPE)) + ((y / (uint)V0) * (uint)dst_stride_y) + ((y % V0) *
119 (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE));
120
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000121 // Create variables: uint zin0=0, zin1=0, zin2=0...zin(M0-1)=0;
122 REPEAT_VAR_INIT_TO_CONST(M0, uint, zin, 0);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000123
124#if defined(REINTERPRET_INPUT_AS_3D)
125 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
126 // multiply src_stride_z by DEPTH_GEMM3D
127
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000128 input_ptr += z * (uint)src_stride_z * DEPTH_GEMM3D;
129
130 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
Usama Arif0681e3b2019-04-25 14:28:07 +0100131 CALCULATE_Z_OFFSET(M0, uint, zin, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, cross_plane_pad, src_stride_y);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000132
133#else // defined(REINTERPRET_INPUT_AS_3D)
134
135 input_ptr += z * (uint)src_stride_z;
136
137#endif // defined(REINTERPRET_INPUT_AS_3D)
138
139 // Add offset for batched GEMM
140 output_ptr += z * (uint)dst_stride_z;
141
142 // ---------------------------Load input values --------------------------------
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000143 // Load values from the LHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +0100144 LOAD_BLOCK(M0, K0, DATA_TYPE, a, input_ptr, 0, src_stride_y, zin);
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000145 BOUNDARY_CONDITION_X(x, a0);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000146#if M0 > 1
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000147 BOUNDARY_CONDITION_X(x, a1);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000148#endif // M0 > 1
149#if M0 > 2
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000150 BOUNDARY_CONDITION_X(x, a2);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000151#endif // M0 > 2
152#if M0 > 3
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000153 BOUNDARY_CONDITION_X(x, a3);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000154#endif // M0 > 3
155#if M0 > 4
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000156 BOUNDARY_CONDITION_X(x, a4);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000157#endif // M0 > 4
158#if M0 > 5
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000159 BOUNDARY_CONDITION_X(x, a5);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000160#endif // M0 > 5
161#if M0 > 6
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000162 BOUNDARY_CONDITION_X(x, a6);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000163#endif // M0 > 6
164#if M0 > 7
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000165 BOUNDARY_CONDITION_X(x, a7);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000166#endif // M0 > 7
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000167 // ---------------------------Store output values ------------------------------
Usama Arif0681e3b2019-04-25 14:28:07 +0100168 REPEAT_VAR_INIT_TO_CONST(16, uint, zout, 0);
169 STORE_BLOCK(M0, K0, DATA_TYPE, a, output_ptr, OUTPUT_STEP_X * sizeof(DATA_TYPE), zout);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000170
171#undef BLOCK_SIZE
172#undef OUTPUT_OFFSET_X
173#undef OUTPUT_STEP_X
174}
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000175
176#if M0 == 2
177#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
178 ({ \
179 VEC_DATA_TYPE(DATA_TYPE, M0) \
180 res = (VEC_DATA_TYPE(DATA_TYPE, M0))(a0.s##i, a1.s##i); \
181 VSTORE(M0) \
182 (res, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
183 })
184#elif M0 == 3 // M0 == 3
185#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
186 ({ \
187 VEC_DATA_TYPE(DATA_TYPE, M0) \
188 res = (VEC_DATA_TYPE(DATA_TYPE, M0))(a0.s##i, a1.s##i, a2.s##i); \
189 VSTORE(M0) \
190 (res, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
191 })
192#elif M0 == 4 // M0 == 4
193#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
194 ({ \
195 VEC_DATA_TYPE(DATA_TYPE, M0) \
196 res = (VEC_DATA_TYPE(DATA_TYPE, M0))(a0.s##i, a1.s##i, a2.s##i, a3.s##i); \
197 VSTORE(M0) \
198 (res, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
199 })
200#elif M0 == 5 // M0 == 5
201#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
202 ({ \
203 VEC_DATA_TYPE(DATA_TYPE, 4) \
204 res0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s##i, a1.s##i, a2.s##i, a3.s##i); \
205 DATA_TYPE res1 = a4.s##i; \
206 VSTORE(4) \
207 (res0, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
208 *((__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE)) + 4) = res1; \
209 })
210#elif M0 == 6 // M0 == 6
211#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
212 ({ \
213 VEC_DATA_TYPE(DATA_TYPE, 4) \
214 res0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s##i, a1.s##i, a2.s##i, a3.s##i); \
215 VEC_DATA_TYPE(DATA_TYPE, 2) \
216 res1 = (VEC_DATA_TYPE(DATA_TYPE, 2))(a4.s##i, a5.s##i); \
217 VSTORE(4) \
218 (res0, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
219 VSTORE(2) \
220 (res1, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE)) + 4); \
221 })
222#elif M0 == 7 // M0 == 7
223#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
224 ({ \
225 VEC_DATA_TYPE(DATA_TYPE, 4) \
226 res0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s##i, a1.s##i, a2.s##i, a3.s##i); \
227 VEC_DATA_TYPE(DATA_TYPE, 3) \
228 res1 = (VEC_DATA_TYPE(DATA_TYPE, 3))(a4.s##i, a5.s##i, a6.s##i); \
229 VSTORE(4) \
230 (res0, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
231 VSTORE(3) \
232 (res1, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE)) + 4); \
233 })
234#elif M0 == 8 // M0 == 8
235#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
236 ({ \
237 VEC_DATA_TYPE(DATA_TYPE, M0) \
238 res = (VEC_DATA_TYPE(DATA_TYPE, M0))(a0.s##i, a1.s##i, a2.s##i, a3.s##i, a4.s##i, a5.s##i, a6.s##i, a7.s##i); \
239 VSTORE(M0) \
240 (res, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
241 })
242#else // M0 not supported
243#error "M0 value not supported"
244#endif // N0 conditions
245
246/** This OpenCL kernel reshapes the lhs input matrix. The kernel splits the input matrix in blocks of size M0xK0 and stores each one (transposed) in
247 * the output matrix unrolling the values.
248 *
249 * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float)
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000250 * @note The width of the input tensor must be passed at compile time using -DSRC_WIDTH (i.e. -DSRC_WIDTH=16)
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000251 * @note The block's dimensions (M0 and K0) must be passed at compile time using -DM0 and -DK0 (i.e. -DM0=2, -DK0=2).
252 * @note The number of M0xK0 vertical blocks to store on the same output row must be passed at compile time using -DV0 (i.e. -DV0=2)
253 * @note Only the following values for M0, K0 and V0 are supported:
254 * M0: 2,3,4,5,6,7,8
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000255 * K0: 2,3,4,8,16
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000256 * V0: greater than 0
257 * @note In case the input has to be reinterpreted as a 3D tensor (i.e. input of convolution layer 1x1), the following information must be passed at compile time:
258 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
259 * -# HEIGHT_GEMM3D: The height of the input in case it has to be reinterpreted as a 3D tensor.
260 * -# DEPTH_GEMM3D: The depth of the input in case it has to be reinterpreted as a 3D tensor
261 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
262 * @note If the M0xK0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
263 *
264 * @param[in] src_ptr Pointer to the source LHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
265 * @param[in] src_stride_x Stride of the source LHS tensor in X dimension (in bytes)
266 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
267 * @param[in] src_stride_y Stride of the source LHS tensor in Y dimension (in bytes)
268 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
269 * @param[in] src_stride_z Stride of the source LHS tensor in Z dimension (in bytes)
270 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
271 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source LHS tensor
272 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
273 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
274 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
275 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
276 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
277 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
278 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
279 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
280 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
281 */
282__kernel void gemm_reshape_lhs_matrix_t(TENSOR3D_DECLARATION(src),
283 TENSOR3D_DECLARATION(dst)
284#if defined(REINTERPRET_INPUT_AS_3D)
285 ,
286 uint cross_plane_pad
287#endif // REINTERPRET_INPUT_AS_3D
288 )
289{
290 // Block size
291#define BLOCK_SIZE ((M0) * (K0))
292
293 // Output offset X
294#if defined(INTERLEAVE)
295#define OUTPUT_OFFSET_X (M0)
296#else // defined(INTERLEAVE)
297#define OUTPUT_OFFSET_X (BLOCK_SIZE)
298#endif // defined(INTERLEAVE)
299
300 // Output step X
301#if defined(INTERLEAVE)
302#define OUTPUT_STEP_X (M0) * (V0)
303#else // Do not interleave
304#define OUTPUT_STEP_X (M0)
305#endif // defined(INTERLEAVE)
306
307 // Compute source and destination addresses
308 uint x = get_global_id(0);
309 uint y = get_global_id(1);
310 uint z = get_global_id(2);
311
312 // ------------------ Compute input/output addresses ---------------------------
313
314 // Compute the input address
315 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)K0 * sizeof(DATA_TYPE) + y * (uint)M0 * src_stride_y;
316
317 // Compute the output address
318 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)BLOCK_SIZE * (uint)V0 * sizeof(DATA_TYPE)) + ((y / (uint)V0) * (uint)dst_stride_y) + ((y % V0) *
319 (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE));
320
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000321 // Create variables: uint zin0=0, zin1=0, zin2=0...zin(M0-1)=0;
322 REPEAT_VAR_INIT_TO_CONST(M0, uint, zin, 0);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000323
324#if defined(REINTERPRET_INPUT_AS_3D)
325 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
326 // multiply src_stride_z by DEPTH_GEMM3D
327
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000328 input_ptr += z * (uint)src_stride_z * DEPTH_GEMM3D;
329
330 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
Usama Arif0681e3b2019-04-25 14:28:07 +0100331 CALCULATE_Z_OFFSET(M0, uint, zin, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, cross_plane_pad, src_stride_y);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000332
333#else // defined(REINTERPRET_INPUT_AS_3D)
334
335 input_ptr += z * (uint)src_stride_z;
336
337#endif // defined(REINTERPRET_INPUT_AS_3D)
338
339 // Add offset for batched GEMM
340 output_ptr += z * (uint)dst_stride_z;
341
342 // ---------------------------Load input values --------------------------------
343
344 // Load values from the LHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +0100345 LOAD_BLOCK(M0, K0, DATA_TYPE, a, input_ptr, 0, src_stride_y, zin);
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000346 BOUNDARY_CONDITION_X(x, a0);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000347#if M0 > 1
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000348 BOUNDARY_CONDITION_X(x, a1);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000349#endif // M0 > 1
350#if M0 > 2
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000351 BOUNDARY_CONDITION_X(x, a2);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000352#endif // M0 > 2
353#if M0 > 3
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000354 BOUNDARY_CONDITION_X(x, a3);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000355#endif // M0 > 3
356#if M0 > 4
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000357 BOUNDARY_CONDITION_X(x, a4);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000358#endif // M0 > 4
359#if M0 > 5
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000360 BOUNDARY_CONDITION_X(x, a5);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000361#endif // M0 > 5
362#if M0 > 6
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000363 BOUNDARY_CONDITION_X(x, a6);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000364#endif // M0 > 6
365#if M0 > 7
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000366 BOUNDARY_CONDITION_X(x, a7);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000367#endif // M0 > 7
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000368 // ---------------------------Transpose and store block -----------------------
369
370 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 0);
371 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 1);
372#if K0 > 2
373 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 2);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000374#endif // K0 > 2
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000375#if K0 > 3
376 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 3);
377#endif // K0 > 3
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000378#if K0 > 4
379 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 4);
380 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 5);
381 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 6);
382 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 7);
383#endif // K0 > 4
384#if K0 > 8
385 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 8);
386 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 9);
387 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, A);
388 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, B);
389 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, C);
390 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, D);
391 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, E);
392 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, F);
393#endif // K0 > 8
394
395#undef BLOCK_SIZE
396#undef OUTPUT_OFFSET_X
397#undef OUTPUT_STEP_X
398}
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000399#endif // defined(M0) && defined(K0) && defined(V0) && defined(DATA_TYPE) && defined(SRC_WIDTH)
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000400
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000401#if defined(K0) && defined(N0) && defined(H0) && defined(DATA_TYPE) && defined(SRC_HEIGHT)
402/** This OpenCL kernel reshapes the rhs input matrix. The kernel splits the input matrix in blocks of size K0xN0 and stores each one (not transposed) in
403 * the output matrix unrolling the values.
404 *
405 * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float)
406 * @note The height of the input tensor must be passed at compile time using -DSRC_HEIGHT (i.e. -DSRC_HEIGHT=16)
407 * @note The block's dimensions (K0 and N0) must be passed at compile time using -DK0 and -DN0 (i.e. -DK0=2, -DN0=2).
408 * @note The number of K0xN0 vertical blocks to store on the same output row must be passed at compile time using -DH0 (i.e. -DH0=2)
409 * @note If the K0xN0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
410 * @note Only the following values for K0, N0 and H0 are supported:
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000411 * N0: 2,3,4,8,16
412 * K0: 1,2,3,4,8,16
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000413 * H0: greater than 0
414 *
415 * @param[in] src_ptr Pointer to the source RHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
416 * @param[in] src_stride_x Stride of the source RHS tensor in X dimension (in bytes)
417 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
418 * @param[in] src_stride_y Stride of the source RHS tensor in Y dimension (in bytes)
419 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
420 * @param[in] src_stride_z Stride of the source RHS tensor in Z dimension (in bytes)
421 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
422 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source RHS tensor
423 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
424 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
425 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
426 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
427 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
428 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
429 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
430 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
431 */
432__kernel void gemm_reshape_rhs_matrix_nt(TENSOR3D_DECLARATION(src),
433 TENSOR3D_DECLARATION(dst))
434{
435 // Block size
436#define BLOCK_SIZE ((K0) * (N0))
437
438 // Output offset X
439#if defined(INTERLEAVE)
440#define OUTPUT_OFFSET_X (N0)
441#else // defined(INTERLEAVE)
442#define OUTPUT_OFFSET_X (BLOCK_SIZE)
443#endif // defined(INTERLEAVE)
444
445 // Output step X
446#if defined(INTERLEAVE)
447#define OUTPUT_STEP_X (N0) * (H0)
448#else // Do not interleave
449#define OUTPUT_STEP_X (N0)
450#endif // defined(INTERLEAVE)
451
452 // Compute source and destination addresses
453 uint x = get_global_id(0);
454 uint y = get_global_id(1);
455 uint z = get_global_id(2);
456
457 // ------------------ Compute input/output addresses ---------------------------
458
459 // Compute the input address
460 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)N0 * sizeof(DATA_TYPE) + y * (uint)K0 * src_stride_y + z * (uint)src_stride_z;
461
462 // Compute the output address
463 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (y * (uint)BLOCK_SIZE * (uint)H0 * sizeof(DATA_TYPE)) + ((x % (uint)H0) * (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE)) + ((
464 x / (uint)H0)
465 * (uint)dst_stride_y)
466 + z * (uint)dst_stride_z;
467
468 // ---------------------------Load input values --------------------------------
469
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +0000470 REPEAT_VAR_INIT_TO_CONST(K0, VEC_DATA_TYPE(DATA_TYPE, N0), a, 0); ////uint a0=0, a1=0, a2=0...a(M0-1)=0;
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000471
472 // Load values from the RHS matrix
473 a0 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y));
474#if K0 > 1
475 if(y * (uint)K0 + 1 < SRC_HEIGHT)
476 {
477 a1 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y));
478 }
479#endif // K0 > 1
480#if K0 > 2
481 if(y * (uint)K0 + 2 < SRC_HEIGHT)
482 {
483 a2 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y));
484 }
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000485#endif // K0 > 2
486#if K0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000487 if(y * (uint)K0 + 3 < SRC_HEIGHT)
488 {
489 a3 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y));
490 }
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000491#endif // K0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000492#if K0 > 4
493 if(y * (uint)K0 + 4 < SRC_HEIGHT)
494 {
495 a4 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 4 * src_stride_y));
496 }
497 if(y * (uint)K0 + 5 < SRC_HEIGHT)
498 {
499 a5 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 5 * src_stride_y));
500 }
501 if(y * (uint)K0 + 6 < SRC_HEIGHT)
502 {
503 a6 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 6 * src_stride_y));
504 }
505 if(y * (uint)K0 + 7 < SRC_HEIGHT)
506 {
507 a7 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 7 * src_stride_y));
508 }
509#endif // K0 > 4
510#if K0 > 8
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000511 if(y * (uint)K0 + 8 < SRC_HEIGHT)
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000512 {
513 a8 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 8 * src_stride_y));
514 }
515 if(y * (uint)K0 + 9 < SRC_HEIGHT)
516 {
517 a9 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 9 * src_stride_y));
518 }
519 if(y * (uint)K0 + 10 < SRC_HEIGHT)
520 {
521 aA = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 10 * src_stride_y));
522 }
523 if(y * (uint)K0 + 11 < SRC_HEIGHT)
524 {
525 aB = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 11 * src_stride_y));
526 }
527 if(y * (uint)K0 + 12 < SRC_HEIGHT)
528 {
529 aC = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 12 * src_stride_y));
530 }
531 if(y * (uint)K0 + 13 < SRC_HEIGHT)
532 {
533 aD = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 13 * src_stride_y));
534 }
535 if(y * (uint)K0 + 14 < SRC_HEIGHT)
536 {
537 aE = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 14 * src_stride_y));
538 }
539 if(y * (uint)K0 + 15 < SRC_HEIGHT)
540 {
541 aF = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 15 * src_stride_y));
542 }
543#endif // K0 > 8
544
545 // ---------------------------Store output values ------------------------------
Usama Arif0681e3b2019-04-25 14:28:07 +0100546 REPEAT_VAR_INIT_TO_CONST(16, uint, zout, 0);
547 STORE_BLOCK(K0, N0, DATA_TYPE, a, output_ptr, OUTPUT_STEP_X * sizeof(DATA_TYPE), zout);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000548
549#undef BLOCK_SIZE
550#undef OUTPUT_OFFSET_X
551#undef OUTPUT_STEP_X
552}
553
554#if defined(TRANSPOSE)
555/** This OpenCL kernel reshapes the rhs input matrix. The kernel splits the input matrix in blocks of size K0xN0 and stores each one (transposed) in
556 * the output matrix unrolling the values.
557 *
558 * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float)
559 * @note The height of the input tensor must be passed at compile time using -DSRC_HEIGHT (i.e. -DSRC_HEIGHT=16)
560 * @note The block's dimensions (K0 and N0) must be passed at compile time using -DK0 and -DN0 (i.e. -DK0=2, -DN0=2).
561 * @note The number of K0xN0 vertical blocks to store on the same output row must be passed at compile time using -DH0 (i.e. -DH0=2)
562 * @note If the K0xN0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
563 * @note The option -DTRANSPOSE must passed at compile time.
564 * @note Only the following values for K0, N0 and H0 are supported:
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000565 * N0: 2,3,4,8,16
566 * K0: 2,3,4,8,16
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000567 * H0: greater than 0
568 *
569 * @param[in] src_ptr Pointer to the source RHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
570 * @param[in] src_stride_x Stride of the source RHS tensor in X dimension (in bytes)
571 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
572 * @param[in] src_stride_y Stride of the source RHS tensor in Y dimension (in bytes)
573 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
574 * @param[in] src_stride_z Stride of the source RHS tensor in Z dimension (in bytes)
575 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
576 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source RHS tensor
577 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
578 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
579 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
580 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
581 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
582 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
583 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
584 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
585 */
586__kernel void gemm_reshape_rhs_matrix_t(TENSOR3D_DECLARATION(src),
587 TENSOR3D_DECLARATION(dst))
588{
589 // Block size
590#define BLOCK_SIZE ((K0) * (N0))
591
592 // Output offset X
593#if defined(INTERLEAVE)
594#define OUTPUT_OFFSET_X (K0)
595#else // defined(INTERLEAVE)
596#define OUTPUT_OFFSET_X (BLOCK_SIZE)
597#endif // defined(INTERLEAVE)
598
599 // Output step X
600#if defined(INTERLEAVE)
601#define OUTPUT_STEP_X (K0) * (H0)
602#else // Do not interleave
603#define OUTPUT_STEP_X (K0)
604#endif // defined(INTERLEAVE)
605
606 // Compute source and destination addresses
607 uint x = get_global_id(0);
608 uint y = get_global_id(1);
609 uint z = get_global_id(2);
610
611 // ------------------ Compute input/output addresses ---------------------------
612
613 // Compute the input address
614 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)N0 * sizeof(DATA_TYPE) + y * (uint)K0 * src_stride_y + z * (uint)src_stride_z;
615
616 // Compute the output address
617 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (y * (uint)BLOCK_SIZE * (uint)H0 * sizeof(DATA_TYPE)) + ((x % H0) * (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE)) + ((x /
618 (uint)H0) * (uint)dst_stride_y) + z * (uint)dst_stride_z;
619
620 // ---------------------------Load input values --------------------------------
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +0000621 REPEAT_VAR_INIT_TO_CONST(K0, VEC_DATA_TYPE(DATA_TYPE, N0), a, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) a0=0, a1=0, ... a(K0-1)=0;
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000622
623 // Load values from the RHS matrix
624 a0 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y));
625 if(y * (uint)K0 + 1 < SRC_HEIGHT)
626 {
627 a1 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y));
628 }
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000629#if K0 > 2
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000630 if(y * (uint)K0 + 2 < SRC_HEIGHT)
631 {
632 a2 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y));
633 }
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000634#endif // K0 > 2
635#if K0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000636 if(y * (uint)K0 + 3 < SRC_HEIGHT)
637 {
638 a3 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y));
639 }
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000640#endif // K0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000641#if K0 > 4
642 if(y * (uint)K0 + 4 < SRC_HEIGHT)
643 {
644 a4 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 4 * src_stride_y));
645 }
646 if(y * (uint)K0 + 5 < SRC_HEIGHT)
647 {
648 a5 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 5 * src_stride_y));
649 }
650 if(y * (uint)K0 + 6 < SRC_HEIGHT)
651 {
652 a6 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 6 * src_stride_y));
653 }
654 if(y * (uint)K0 + 7 < SRC_HEIGHT)
655 {
656 a7 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 7 * src_stride_y));
657 }
658#endif // K0 > 4
659#if K0 > 8
Gian Marco Iodice89124342018-12-19 14:17:22 +0000660 if(y * (uint)K0 + 8 < SRC_HEIGHT)
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000661 {
662 a8 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 8 * src_stride_y));
663 }
664 if(y * (uint)K0 + 9 < SRC_HEIGHT)
665 {
666 a9 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 9 * src_stride_y));
667 }
668 if(y * (uint)K0 + 10 < SRC_HEIGHT)
669 {
670 aA = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 10 * src_stride_y));
671 }
672 if(y * (uint)K0 + 11 < SRC_HEIGHT)
673 {
674 aB = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 11 * src_stride_y));
675 }
676 if(y * (uint)K0 + 12 < SRC_HEIGHT)
677 {
678 aC = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 12 * src_stride_y));
679 }
680 if(y * (uint)K0 + 13 < SRC_HEIGHT)
681 {
682 aD = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 13 * src_stride_y));
683 }
684 if(y * (uint)K0 + 14 < SRC_HEIGHT)
685 {
686 aE = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 14 * src_stride_y));
687 }
688 if(y * (uint)K0 + 15 < SRC_HEIGHT)
689 {
690 aF = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 15 * src_stride_y));
691 }
692#endif // K0 > 8
693
694 // ---------------------------Transpose the block ------------------------------
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +0000695 REPEAT_VAR_INIT_TO_CONST(N0, VEC_DATA_TYPE(DATA_TYPE, K0), res, 0); //VEC_DATA_TYPE(DATA_TYPE, K0) res0=0, res1=0, res2=0,... res(N0-1)=0;
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000696
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000697#if K0 == 2
698 // This part computes the following transpositions:
699 // 2x2 -> 2x2
700 // 2x4 -> 4x2
701 // 2x8 -> 8x2
702 // 2x16 -> 16x2
703 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0);
704 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1);
705#if N0 > 2
706 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2);
707#endif // N0 > 2
708#if N0 > 3
709 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3);
710#endif // N0 > 3
711#if N0 > 4
712 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4);
713 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5);
714 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6);
715 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7);
716#endif // N0 > 4
717#if N0 > 8
718 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8);
719 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9);
720 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA);
721 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB);
722 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC);
723 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD);
724 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE);
725 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF);
726#endif // N0 > 8
727
728#elif K0 == 3 // K0 == 2
729 // This part computes the following transpositions:
730 // 3x2 -> 2x3
731 // 3x4 -> 4x3
732 // 3x8 -> 8x3
733 // 3x16 -> 16x3
Georgios Pinitasb0f342e2019-05-21 13:32:43 +0100734 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0);
735 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000736#if N0 > 2
Georgios Pinitasb0f342e2019-05-21 13:32:43 +0100737 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000738#endif // N0 > 2
739#if N0 > 3
Georgios Pinitasb0f342e2019-05-21 13:32:43 +0100740 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000741#endif // N0 > 3
742#if N0 > 4
Georgios Pinitasb0f342e2019-05-21 13:32:43 +0100743 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4);
744 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5);
745 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6);
746 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000747#endif // N0 > 4
748#if N0 > 8
Georgios Pinitasb0f342e2019-05-21 13:32:43 +0100749 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8);
750 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9);
751 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA);
752 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB);
753 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC);
754 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD);
755 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE);
756 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000757#endif // N0 > 8
758
759#elif K0 == 4 // K0 == 4
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000760 // This part computes the following transpositions:
761 // 4x2 -> 2x4
762 // 4x4 -> 4x4
763 // 4x8 -> 8x4
764 // 4x16 -> 16x4
765 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0, a3.s0);
766 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1, a3.s1);
767#if N0 > 2
768 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2, a3.s2);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000769#endif // N0 > 2
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000770#if N0 > 3
771 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3, a3.s3);
772#endif // N0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000773#if N0 > 4
774 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4, a3.s4);
775 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5, a3.s5);
776 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6, a3.s6);
777 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7, a3.s7);
778#endif // N0 > 4
779#if N0 > 8
780 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8, a3.s8);
781 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9, a3.s9);
782 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA, a3.sA);
783 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB, a3.sB);
784 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC, a3.sC);
785 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD, a3.sD);
786 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE, a3.sE);
787 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF, a3.sF);
788#endif // N0 > 8
789
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000790#elif K0 == 8 // K0 == 8
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000791 // This part computes the following transpositions:
792 // 8x2 -> 2x8
793 // 8x4 -> 4x8
794 // 8x8 -> 8x8
795 // 8x16 -> 16x8
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000796 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0, a3.s0, a4.s0, a5.s0, a6.s0, a7.s0);
797 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1, a3.s1, a4.s1, a5.s1, a6.s1, a7.s1);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000798#if N0 > 2
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000799 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2, a3.s2, a4.s2, a5.s2, a6.s2, a7.s2);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000800#endif // N0 > 2
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000801#if N0 > 3
802 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3, a3.s3, a4.s3, a5.s3, a6.s3, a7.s3);
803#endif // N0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000804#if N0 > 4
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000805 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4, a3.s4, a4.s4, a5.s4, a6.s4, a7.s4);
806 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5, a3.s5, a4.s5, a5.s5, a6.s5, a7.s5);
807 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6, a3.s6, a4.s6, a5.s6, a6.s6, a7.s6);
808 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7, a3.s7, a4.s7, a5.s7, a6.s7, a7.s7);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000809#endif // N0 > 4
810#if N0 > 8
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000811 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8, a3.s8, a4.s8, a5.s8, a6.s8, a7.s8);
812 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9, a3.s9, a4.s9, a5.s9, a6.s9, a7.s9);
813 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA, a3.sA, a4.sA, a5.sA, a6.sA, a7.sA);
814 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB, a3.sB, a4.sB, a5.sB, a6.sB, a7.sB);
815 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC, a3.sC, a4.sC, a5.sC, a6.sC, a7.sC);
816 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD, a3.sD, a4.sD, a5.sD, a6.sD, a7.sD);
817 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE, a3.sE, a4.sE, a5.sE, a6.sE, a7.sE);
818 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF, a3.sF, a4.sF, a5.sF, a6.sF, a7.sF);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000819#endif // N0 > 8
820
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000821#elif K0 == 16 // K0 == 16
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000822
823 // This part computes the following transpositions:
824 // 16x2 -> 2x16
825 // 16x4 -> 4x16
826 // 16x8 -> 8x16
827 // 16x16 -> 16x16
828 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0, a3.s0, a4.s0, a5.s0, a6.s0, a7.s0,
829 a8.s0, a9.s0, aA.s0, aB.s0, aC.s0, aD.s0, aE.s0, aF.s0);
830 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1, a3.s1, a4.s1, a5.s1, a6.s1, a7.s1,
831 a8.s1, a9.s1, aA.s1, aB.s1, aC.s1, aD.s1, aE.s1, aF.s1);
832#if N0 > 2
833 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2, a3.s2, a4.s2, a5.s2, a6.s2, a7.s2,
834 a8.s2, a9.s2, aA.s2, aB.s2, aC.s2, aD.s2, aE.s2, aF.s2);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000835#endif // N0 > 2
836#if N0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000837 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3, a3.s3, a4.s3, a5.s3, a6.s3, a7.s3,
838 a8.s3, a9.s3, aA.s3, aB.s3, aC.s3, aD.s3, aE.s3, aF.s3);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000839#endif // N0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000840#if N0 > 4
841 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4, a3.s4, a4.s4, a5.s4, a6.s4, a7.s4,
842 a8.s4, a9.s4, aA.s4, aB.s4, aC.s4, aD.s4, aE.s4, aF.s4);
843 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5, a3.s5, a4.s5, a5.s5, a6.s5, a7.s5,
844 a8.s5, a9.s5, aA.s5, aB.s5, aC.s5, aD.s5, aE.s5, aF.s5);
845 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6, a3.s6, a4.s6, a5.s6, a6.s6, a7.s6,
846 a8.s6, a9.s6, aA.s6, aB.s6, aC.s6, aD.s6, aE.s6, aF.s6);
847 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7, a3.s7, a4.s7, a5.s7, a6.s7, a7.s7,
848 a8.s7, a9.s7, aA.s7, aB.s7, aC.s7, aD.s7, aE.s7, aF.s7);
849#endif // N0 > 4
850#if N0 > 8
851 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8, a3.s8, a4.s8, a5.s8, a6.s8, a7.s8,
852 a8.s8, a9.s8, aA.s8, aB.s8, aC.s8, aD.s8, aE.s8, aF.s8);
853 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9, a3.s9, a4.s9, a5.s9, a6.s9, a7.s9,
854 a8.s9, a9.s9, aA.s9, aB.s9, aC.s9, aD.s9, aE.s9, aF.s9);
855 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA, a3.sA, a4.sA, a5.sA, a6.sA, a7.sA,
856 a8.sA, a9.sA, aA.sA, aB.sA, aC.sA, aD.sA, aE.sA, aF.sA);
857 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB, a3.sB, a4.sB, a5.sB, a6.sB, a7.sB,
858 a8.sB, a9.sB, aA.sB, aB.sB, aC.sB, aD.sB, aE.sB, aF.sB);
859 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC, a3.sC, a4.sC, a5.sC, a6.sC, a7.sC,
860 a8.sC, a9.sC, aA.sC, aB.sC, aC.sC, aD.sC, aE.sC, aF.sC);
861 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD, a3.sD, a4.sD, a5.sD, a6.sD, a7.sD,
862 a8.sD, a9.sD, aA.sD, aB.sD, aC.sD, aD.sD, aE.sD, aF.sD);
863 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE, a3.sE, a4.sE, a5.sE, a6.sE, a7.sE,
864 a8.sE, a9.sE, aA.sE, aB.sE, aC.sE, aD.sE, aE.sE, aF.sE);
865 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF, a3.sF, a4.sF, a5.sF, a6.sF, a7.sF,
866 a8.sF, a9.sF, aA.sF, aB.sF, aC.sF, aD.sF, aE.sF, aF.sF);
867#endif // N0 > 8
868
869#else // N0 == 16
870#error "Not supported N0 value"
871#endif // N0 > 2
872
873 // ---------------------------Store the output values ------------------------------
Usama Arif0681e3b2019-04-25 14:28:07 +0100874 REPEAT_VAR_INIT_TO_CONST(16, uint, zout, 0);
875 STORE_BLOCK(N0, K0, DATA_TYPE, res, output_ptr, OUTPUT_STEP_X * sizeof(DATA_TYPE), zout);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000876
877#undef BLOCK_SIZE
878#undef OUTPUT_OFFSET_X
879#undef OUTPUT_STEP_X
880}
881#endif // defined(TRANSPOSE)
882#endif // defined(K0) && defined(N0) && defined(H0) && defined(DATA_TYPE) && defined(SRC_HEIGHT)
883
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +0000884#if defined(M0) && defined(N0) && defined(K0) && defined(H0) && defined(DATA_TYPE) && defined(M) && defined(N) && defined(K)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +0000885
886#define CONCAT(a, b) a##b
887
888#define ARM_DOT1(a, b, c) \
889 ({ \
890 c = fma(a, b, c); \
891 })
892#define ARM_DOT2(a, b, c) \
893 ({ \
894 c = fma(a.s0, b.s0, c); \
895 c = fma(a.s1, b.s1, c); \
896 })
897#define ARM_DOT3(a, b, c) \
898 ({ \
899 ARM_DOT2(a, b, c); \
900 c = fma((a.s2), (b.s2), c); \
901 })
902#define ARM_DOT4(a, b, c) \
903 ({ \
904 ARM_DOT3(a, b, c); \
905 c = fma((a.s3), (b.s3), c); \
906 })
907#define ARM_DOT8(a, b, c) \
908 ({ \
909 ARM_DOT4((a.lo), (b.lo), c); \
910 ARM_DOT4((a.hi), (b.hi), c); \
911 })
912#define ARM_DOT16(a, b, c) \
913 ({ \
914 ARM_DOT8((a.lo), (b.lo), c); \
915 ARM_DOT8((a.hi), (b.hi), c); \
916 })
917
918#if N0 == 2
919#define ARM_DOT_K0XN0(k0, a, b, c) \
920 ({ \
921 CONCAT(ARM_DOT, k0) \
922 ((a), (b##0), (c.s0)); \
923 CONCAT(ARM_DOT, k0) \
924 ((a), (b##1), (c.s1)); \
925 })
926#elif N0 == 3 // N0 == 3
927#define ARM_DOT_K0XN0(k0, a, b, c) \
928 ({ \
929 CONCAT(ARM_DOT, k0) \
930 ((a), (b##0), (c.s0)); \
931 CONCAT(ARM_DOT, k0) \
932 ((a), (b##1), (c.s1)); \
933 CONCAT(ARM_DOT, k0) \
934 ((a), (b##2), (c.s2)); \
935 })
936#elif N0 == 4 // N0 == 4
937#define ARM_DOT_K0XN0(k0, a, b, c) \
938 ({ \
939 CONCAT(ARM_DOT, k0) \
940 ((a), (b##0), (c.s0)); \
941 CONCAT(ARM_DOT, k0) \
942 ((a), (b##1), (c.s1)); \
943 CONCAT(ARM_DOT, k0) \
944 ((a), (b##2), (c.s2)); \
945 CONCAT(ARM_DOT, k0) \
946 ((a), (b##3), (c.s3)); \
947 })
948#elif N0 == 8 // N0 == 8
949#define ARM_DOT_K0XN0(k0, a, b, c) \
950 ({ \
951 CONCAT(ARM_DOT, k0) \
952 ((a), (b##0), (c.s0)); \
953 CONCAT(ARM_DOT, k0) \
954 ((a), (b##1), (c.s1)); \
955 CONCAT(ARM_DOT, k0) \
956 ((a), (b##2), (c.s2)); \
957 CONCAT(ARM_DOT, k0) \
958 ((a), (b##3), (c.s3)); \
959 CONCAT(ARM_DOT, k0) \
960 ((a), (b##4), (c.s4)); \
961 CONCAT(ARM_DOT, k0) \
962 ((a), (b##5), (c.s5)); \
963 CONCAT(ARM_DOT, k0) \
964 ((a), (b##6), (c.s6)); \
965 CONCAT(ARM_DOT, k0) \
966 ((a), (b##7), (c.s7)); \
967 })
968#elif N0 == 16 // N0 == 16
969#define ARM_DOT_K0XN0(k0, a, b, c) \
970 ({ \
971 CONCAT(ARM_DOT, k0) \
972 ((a), (b##0), (c.s0)); \
973 CONCAT(ARM_DOT, k0) \
974 ((a), (b##1), (c.s1)); \
975 CONCAT(ARM_DOT, k0) \
976 ((a), (b##2), (c.s2)); \
977 CONCAT(ARM_DOT, k0) \
978 ((a), (b##3), (c.s3)); \
979 CONCAT(ARM_DOT, k0) \
980 ((a), (b##4), (c.s4)); \
981 CONCAT(ARM_DOT, k0) \
982 ((a), (b##5), (c.s5)); \
983 CONCAT(ARM_DOT, k0) \
984 ((a), (b##6), (c.s6)); \
985 CONCAT(ARM_DOT, k0) \
986 ((a), (b##7), (c.s7)); \
987 CONCAT(ARM_DOT, k0) \
988 ((a), (b##8), (c.s8)); \
989 CONCAT(ARM_DOT, k0) \
990 ((a), (b##9), (c.s9)); \
991 CONCAT(ARM_DOT, k0) \
992 ((a), (b##A), (c.sA)); \
993 CONCAT(ARM_DOT, k0) \
994 ((a), (b##B), (c.sB)); \
995 CONCAT(ARM_DOT, k0) \
996 ((a), (b##C), (c.sC)); \
997 CONCAT(ARM_DOT, k0) \
998 ((a), (b##D), (c.sD)); \
999 CONCAT(ARM_DOT, k0) \
1000 ((a), (b##E), (c.sE)); \
1001 CONCAT(ARM_DOT, k0) \
1002 ((a), (b##F), (c.sF)); \
1003 })
1004#else // N0 not supported
1005#error "N0 value not supported"
1006#endif // N0 conditions
1007
1008/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
1009 * The LHS matrix is NOT reshaped
1010 * The RHS is reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the block K0xN0 is transposed
1011 *
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001012 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
1013 * @note The GEMM's dimensions (M,N and K) must be passed at compile time using -DM, -DN and and -DK (i.e. -DM=52, -DN=30 and -DK=90)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001014 * @note The number of columns of LHS matrix must be passed at compile time using -DK (i.e. -DK=64)
1015 * @note The block's dimensions used for reshaping the RHS matrix (N0 and K0) must be passed at compile time using -DN0 and -DK0 (i.e. -DN0=8, -DK0=4).
1016 * @note The number of M0 rows to process must be passed at compile time using -DM0 (i.e. -DM0=2)
1017 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (i.e. -DH0=2)
1018 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
1019 * @note Only the following configurations of M0, N0 and K0 are currently supported:
1020 * - M0 = 1, 2, 3, 4, 5, 6, 7, 8
1021 * - N0 = 2, 3, 4, 8, 16
1022 * - K0 = 2, 3, 4, 8, 16
Gian Marco Iodice62251f72019-03-11 16:07:12 +00001023 * - H0 >= 1
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001024 *
1025 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
1026 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
1027 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1028 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1029 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1030 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix
1031 *
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001032 * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: F16/F32
1033 * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes)
1034 * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1035 * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes)
1036 * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1037 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
1038 * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
1039 * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes)
1040 * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1041 * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
1042 * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1043 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
1044 * @param[in] bias_ptr (Optional)Pointer to the bias reshaped matrix. Supported data type: same as @p lhs_ptr
1045 * @param[in] bias_stride_x (Optional)Stride of the bias reshaped matrix in X dimension (in bytes)
1046 * @param[in] bias_step_x (Optional)bias_stride_x * number of elements along X processed per workitem(in bytes)
1047 * @param[in] bias_stride_y (Optional)Stride of the bias reshaped matrix in Y dimension (in bytes)
1048 * @param[in] bias_step_y (Optional)bias_stride_y * number of elements along Y processed per workitem(in bytes)
1049 * @param[in] bias_offset_first_element_in_bytes (Optional)The offset of the first element in the bias reshaped matrix
1050 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
1051 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1052 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1053 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1054 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1055 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1056 * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes)
1057 * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
1058 * @param[in] bias_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
1059 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1060 * @param[in] lhs_cross_plane_pad (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
1061 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001062 */
1063__kernel void gemm_mm_reshaped_only_rhs_t(IMAGE_DECLARATION(lhs),
1064 IMAGE_DECLARATION(rhs),
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001065#if defined(BETA)
1066 IMAGE_DECLARATION(bias),
1067#endif // defined(BETA)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001068 IMAGE_DECLARATION(dst),
1069 uint lhs_stride_z,
1070 uint rhs_stride_z,
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001071#if defined(BETA)
1072 uint bias_stride_z,
1073#endif //defined(BETA)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001074 uint dst_stride_z
1075#if defined(REINTERPRET_INPUT_AS_3D)
1076 ,
1077 uint lhs_cross_plane_pad
1078#endif // REINTERPRET_INPUT_AS_3D
1079#if defined(REINTERPRET_OUTPUT_AS_3D)
1080 ,
1081 uint dst_cross_plane_pad
1082#endif // REINTERPRET_OUTPUT_AS_3D
1083 )
1084{
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001085 // Block size
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001086#define RHS_BLOCK_SIZE ((K0) * (N0))
1087
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001088 // RHS offset and step X
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001089#if defined(RHS_INTERLEAVE)
1090#define RHS_OFFSET_X (K0)
1091#define RHS_STEP_X ((K0) * (H0))
1092#define RHS_STEP_LOOP (1)
1093#else // defined(RHS_INTERLEAVE)
1094#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
1095#define RHS_STEP_X (K0)
1096#define RHS_STEP_LOOP (H0)
1097#endif // defined(RHS_INTERLEAVE)
1098
1099 uint x = get_global_id(0);
1100 uint y = get_global_id(1);
1101 uint z = get_global_id(2);
1102
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001103#if defined(DUMMY_WORK_ITEMS)
1104 if((x * N0 >= N) || (y * M0 >= M))
1105 {
1106 return;
1107 }
1108#endif // defined(DUMMY_WORK_ITEMS)
1109
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001110 // Compute LHS matrix address
1111 uint lhs_offset = lhs_offset_first_element_in_bytes + y * M0 * (uint)lhs_stride_y;
1112
1113 // Compute RHS matrix address
1114 uint rhs_offset = rhs_offset_first_element_in_bytes + (x % H0) * (uint)RHS_OFFSET_X * sizeof(DATA_TYPE) + (x / (uint)H0) * rhs_stride_y;
1115
1116#if defined(MATRIX_B_DEPTH)
1117 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1118 rhs_offset += (z % MATRIX_B_DEPTH) * rhs_stride_z;
1119#else // defined(MATRIX_B_DEPTH)
1120 rhs_offset += z * rhs_stride_z;
1121#endif // defined(MATRIX_B_DEPTH)
1122
Usama Arif0681e3b2019-04-25 14:28:07 +01001123 REPEAT_VAR_INIT_TO_CONST(8, uint, zlhs, 0); //uint zlhs0=0,zlhs1=0,zlhs2=0,... zlhs7=0;
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001124 REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001125
1126#if defined(REINTERPRET_INPUT_AS_3D)
Usama Arif0681e3b2019-04-25 14:28:07 +01001127 // The plane (zlhs) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
1128 CALCULATE_Z_OFFSET(M0, uint, zlhs, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001129
1130 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1131 // multiply lhs_stride_z by DEPTH_GEMM3D
1132 lhs_offset += z * lhs_stride_z * DEPTH_GEMM3D;
1133
1134#else // defined(REINTERPRET_INPUT_AS_3D)
1135
1136 // Add offset for batched GEMM
1137 lhs_offset += z * lhs_stride_z;
1138
1139#endif // defined(REINTERPRET_INPUT_AS_3D)
1140
1141 // Initialize the accumulators
1142 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(M0-1)=0;
1143
1144 int i = 0;
1145 for(; i <= (K - K0); i += K0)
1146 {
1147 // Supported cases (M0, K0):
1148 // 1,2 - 1,3 - 1,4 - 1,8 - 1,16
1149 // 2,2 - 2,3 - 2,4 - 2,8 - 2,16
1150 // 3,2 - 3,3 - 3,4 - 3,8 - 3,16
1151 // 4,2 - 4,3 - 4,4 - 4,8 - 4,16
1152 // 5,2 - 5,3 - 5,4 - 5,8 - 5,16
1153 // 6,2 - 6,3 - 6,4 - 6,8 - 6,16
1154 // 7,2 - 7,3 - 7,4 - 7,8 - 7,16
1155 // 8,2 - 8,3 - 8,4 - 8,8 - 8,16
1156 // Load values from LHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +01001157 LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001158
1159 // Load values from RHS matrix
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001160 LOAD_BLOCK(N0, K0, DATA_TYPE, b, rhs_ptr, rhs_offset, RHS_STEP_X * sizeof(DATA_TYPE), zero);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001161
1162 // Accumulate
1163 ARM_DOT_K0XN0(K0, a0, b, c0);
1164#if M0 > 1
1165 ARM_DOT_K0XN0(K0, a1, b, c1);
1166#endif // M0 > 1
1167#if M0 > 2
1168 ARM_DOT_K0XN0(K0, a2, b, c2);
1169#endif // M0 > 2
1170#if M0 > 3
1171 ARM_DOT_K0XN0(K0, a3, b, c3);
1172#endif // M0 > 3
1173#if M0 > 4
1174 ARM_DOT_K0XN0(K0, a4, b, c4);
1175#endif // M0 > 4
1176#if M0 > 5
1177 ARM_DOT_K0XN0(K0, a5, b, c5);
1178#endif // M0 > 5
1179#if M0 > 6
1180 ARM_DOT_K0XN0(K0, a6, b, c6);
1181#endif // M0 > 6
1182#if M0 > 7
1183 ARM_DOT_K0XN0(K0, a7, b, c7);
1184#endif // M0 > 7
1185
1186 lhs_offset += K0 * sizeof(DATA_TYPE);
1187 rhs_offset += (N0 * RHS_STEP_X * RHS_STEP_LOOP) * sizeof(DATA_TYPE);
1188 }
1189
1190 // Left-over accumulations
1191 for(; i < K; ++i)
1192 {
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001193 // Load values from LHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +01001194 LOAD_BLOCK(M0, 1, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001195
1196 // Load values from RHS matrix
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001197 LOAD_BLOCK(N0, 1, DATA_TYPE, b, rhs_ptr, rhs_offset, RHS_STEP_X * sizeof(DATA_TYPE), zero);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001198
1199 // Accumulate
1200 ARM_DOT_K0XN0(1, a0, b, c0);
1201#if M0 > 1
1202 ARM_DOT_K0XN0(1, a1, b, c1);
1203#endif // M0 > 1
1204#if M0 > 2
1205 ARM_DOT_K0XN0(1, a2, b, c2);
1206#endif // M0 > 2
1207#if M0 > 3
1208 ARM_DOT_K0XN0(1, a3, b, c3);
1209#endif // M0 > 3
1210#if M0 > 4
1211 ARM_DOT_K0XN0(1, a4, b, c4);
1212#endif // M0 > 4
1213#if M0 > 5
1214 ARM_DOT_K0XN0(1, a5, b, c5);
1215#endif // M0 > 5
1216#if M0 > 6
1217 ARM_DOT_K0XN0(1, a6, b, c6);
1218#endif // M0 > 6
1219#if M0 > 7
1220 ARM_DOT_K0XN0(1, a7, b, c7);
1221#endif // M0 > 7
1222
1223 lhs_offset += sizeof(DATA_TYPE);
1224 rhs_offset += sizeof(DATA_TYPE);
1225 }
1226
1227 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * dst_stride_y);
1228
1229 REPEAT_VAR_INIT_TO_CONST(8, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
1230
1231#if defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001232
1233 // The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
Usama Arif0681e3b2019-04-25 14:28:07 +01001234 CALCULATE_Z_OFFSET(M0, uint, zout, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001235
1236 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1237 // multiply dst_stride_z by DEPTH_GEMM3D
1238 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1239
1240#else // defined(REINTERPRET_OUTPUT_AS_3D)
1241
1242 // Add offset for batched GEMM
1243 dst_addr += z * dst_stride_z;
1244
1245#endif // defined(REINTERPRET_OUTPUT_AS_3D)
1246
1247 // Multiply by the weight of matrix-matrix product and store the result
1248#if defined(ALPHA)
Usama Arif0681e3b2019-04-25 14:28:07 +01001249 SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001250#endif // defined(ALPHA)
1251
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001252 // Add beta*bias
1253#if defined(BETA)
1254#if defined(BROADCAST_BIAS)
1255 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
1256
1257 LOAD_BLOCK(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
1258
1259#ifndef UNIT_BETA
1260 SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
1261#endif // UNIT_BIAS
1262
1263 // c = c + bias[broadcasted]
1264 ADD_BLOCK_BROADCAST(M0, c, bias0);
1265
1266#else // defined(BROADCAST_BIAS)
1267 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id(
1268 2) * bias_stride_z;
1269
1270 LOAD_BLOCK(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
1271
1272#ifndef UNIT_BETA
1273 SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
1274#endif // UNIT_BIAS
1275
1276 // c = c + bias
1277 ADD_BLOCK(M0, c, bias);
1278
1279#endif // defined(BROADCAST_BIAS)
1280#endif // defined(BETA)
1281
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001282 // Store output block
Usama Arif0681e3b2019-04-25 14:28:07 +01001283 STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001284
1285#undef RHS_BLOCK_SIZE
1286#undef RHS_OFFSET_X
1287#undef RHS_STEP_X
1288}
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001289
1290#define VFMA(a, b, c) \
1291 ({ \
1292 c = fma(a, b, c); \
1293 })
1294
1295#if M0 == 1
1296#define LD_RHS_VFMA_M0xN0(i, a, c) \
1297 ({ \
1298 VEC_DATA_TYPE(DATA_TYPE, N0) \
1299 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1300 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1301 })
1302#elif M0 == 2 // M0 == 2
1303#define LD_RHS_VFMA_M0xN0(i, a, c) \
1304 ({ \
1305 VEC_DATA_TYPE(DATA_TYPE, N0) \
1306 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1307 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1308 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1309 })
1310#elif M0 == 3 // M0 == 3
1311#define LD_RHS_VFMA_M0xN0(i, a, c) \
1312 ({ \
1313 VEC_DATA_TYPE(DATA_TYPE, N0) \
1314 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1315 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1316 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1317 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1318 })
1319#elif M0 == 4 // M0 == 4
1320#define LD_RHS_VFMA_M0xN0(i, a, c) \
1321 ({ \
1322 VEC_DATA_TYPE(DATA_TYPE, N0) \
1323 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1324 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1325 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1326 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1327 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1328 })
1329#elif M0 == 5 // M0 == 5
1330#define LD_RHS_VFMA_M0xN0(i, a, c) \
1331 ({ \
1332 VEC_DATA_TYPE(DATA_TYPE, N0) \
1333 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1334 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1335 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1336 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1337 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1338 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
1339 })
1340#elif M0 == 6 // M0 == 6
1341#define LD_RHS_VFMA_M0xN0(i, a, c) \
1342 ({ \
1343 VEC_DATA_TYPE(DATA_TYPE, N0) \
1344 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1345 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1346 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1347 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1348 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1349 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
1350 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
1351 })
1352#elif M0 == 7 // M0 == 7
1353#define LD_RHS_VFMA_M0xN0(i, a, c) \
1354 ({ \
1355 VEC_DATA_TYPE(DATA_TYPE, N0) \
1356 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1357 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1358 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1359 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1360 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1361 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
1362 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
1363 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##6).s##i), b, (c##6)); \
1364 })
1365#elif M0 == 8 // M0 == 8
1366#define LD_RHS_VFMA_M0xN0(i, a, c) \
1367 ({ \
1368 VEC_DATA_TYPE(DATA_TYPE, N0) \
1369 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1370 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1371 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1372 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1373 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1374 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
1375 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
1376 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##6).s##i), b, (c##6)); \
1377 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##7).s##i), b, (c##7)); \
1378 })
1379#else // M0 not supported
1380#error "M0 not supported"
1381#endif // M0 not supported
1382
1383/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
1384 * The LHS matrix is NOT reshaped
1385 * The RHS is reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the block K0xN0 is NOT transposed
1386 *
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001387 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
1388 * @note The GEMM's dimensions (M,N and K) must be passed at compile time using -DM, -DN and and -DK (i.e. -DM=52, -DN=30 and -DK=90).
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001389 * @note The block's dimensions used for reshaping the RHS matrix (N0 and K0) must be passed at compile time using -DN0 and -DK0 (i.e. -DN0=8, -DK0=4).
1390 * @note The number of M0 rows to process must be passed at compile time using -DM0 (i.e. -DM0=2)
1391 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (i.e. -DH0=2)
1392 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
1393 * @note Only the following configurations of M0, N0 and K0 are currently supported:
1394 * - M0 = 1, 2, 3, 4, 5, 6, 7, 8
1395 * - N0 = 2, 3, 4, 8, 16
1396 * - K0 = 2, 3, 4, 8, 16
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001397 * - H0 >= 1
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001398 *
1399 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
1400 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
1401 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1402 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1403 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1404 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix
1405 *
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001406 * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: F16/F32
1407 * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes)
1408 * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1409 * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes)
1410 * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1411 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
1412 * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
1413 * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes)
1414 * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1415 * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
1416 * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1417 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
1418 * @param[in] bias_ptr (Optional) Pointer to the bias reshaped matrix. Supported data type: same as @p lhs_ptr
1419 * @param[in] bias_stride_x (Optional) Stride of the bias reshaped matrix in X dimension (in bytes)
1420 * @param[in] bias_step_x (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes)
1421 * @param[in] bias_stride_y (Optional) Stride of the bias reshaped matrix in Y dimension (in bytes)
1422 * @param[in] bias_step_y (Optional) bias_stride_y * number of elements along Y processed per workitem(in bytes)
1423 * @param[in] bias_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
1424 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
1425 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1426 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1427 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1428 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1429 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1430 * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes)
1431 * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
1432 * @param[in] bias_stride_z (Optional)Stride of the bias reshaped matrix in Z dimension (in bytes)
1433 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1434 * @param[in] lhs_cross_plane_pad (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
1435 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001436 */
1437__kernel void gemm_mm_reshaped_only_rhs_nt(IMAGE_DECLARATION(lhs),
1438 IMAGE_DECLARATION(rhs),
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001439#if defined(BETA)
1440 IMAGE_DECLARATION(bias),
1441#endif // defined(BETA)
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001442 IMAGE_DECLARATION(dst),
1443 uint lhs_stride_z,
1444 uint rhs_stride_z,
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001445#if defined(BETA)
1446 uint bias_stride_z,
1447#endif //defined(BETA)
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001448 uint dst_stride_z
1449#if defined(REINTERPRET_INPUT_AS_3D)
1450 ,
1451 uint lhs_cross_plane_pad
1452#endif // REINTERPRET_INPUT_AS_3D
1453#if defined(REINTERPRET_OUTPUT_AS_3D)
1454 ,
1455 uint dst_cross_plane_pad
1456#endif // REINTERPRET_OUTPUT_AS_3D
1457 )
1458{
1459 // Block size
1460#define RHS_BLOCK_SIZE ((K0) * (N0))
1461
1462 // RHS offset and step X
1463#if defined(RHS_INTERLEAVE)
1464#define RHS_OFFSET_X (N0)
1465#define RHS_STEP_X ((N0) * (H0))
1466#define RHS_STEP_LOOP (1)
1467#else // defined(RHS_INTERLEAVE)
1468#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
1469#define RHS_STEP_X (N0)
1470#define RHS_STEP_LOOP (H0)
1471#endif // defined(RHS_INTERLEAVE)
1472
1473 uint x = get_global_id(0);
1474 uint y = get_global_id(1);
1475 uint z = get_global_id(2);
1476
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001477#if defined(DUMMY_WORK_ITEMS)
1478 if((x * N0 >= N) || (y * M0 >= M))
1479 {
1480 return;
1481 }
1482#endif // defined(DUMMY_WORK_ITEMS)
1483
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001484 // Compute LHS matrix address
1485 uint lhs_offset = lhs_offset_first_element_in_bytes + y * M0 * (uint)lhs_stride_y;
1486
1487 // Compute RHS matrix address
1488 uint rhs_offset = rhs_offset_first_element_in_bytes + (x % H0) * (uint)RHS_OFFSET_X * sizeof(DATA_TYPE) + (x / (uint)H0) * rhs_stride_y;
1489
1490#if defined(MATRIX_B_DEPTH)
1491 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1492 rhs_offset += (z % MATRIX_B_DEPTH) * rhs_stride_z;
1493#else // defined(MATRIX_B_DEPTH)
1494 rhs_offset += z * rhs_stride_z;
1495#endif // defined(MATRIX_B_DEPTH)
1496
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001497 REPEAT_VAR_INIT_TO_CONST(8, uint, zin, 0); //uint zin0=0,zin1=0,zin2=0,... zin7=0;
1498 REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0); //uint zero0=0,zero1=0,zero2=0,... zero7=0;
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001499
1500#if defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001501
1502 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
Usama Arif0681e3b2019-04-25 14:28:07 +01001503 CALCULATE_Z_OFFSET(M0, uint, zin, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001504
1505 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1506 // multiply lhs_stride_z by DEPTH_GEMM3D
1507 lhs_offset += z * lhs_stride_z * DEPTH_GEMM3D;
1508
1509#else // defined(REINTERPRET_INPUT_AS_3D)
1510
1511 // Add offset for batched GEMM
1512 lhs_offset += z * lhs_stride_z;
1513
1514#endif // defined(REINTERPRET_INPUT_AS_3D)
1515
1516 // Initialize the accumulators
1517 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(N0-1)=0;
1518
1519 int i = 0;
1520 for(; i <= (K - K0); i += K0)
1521 {
1522 // Supported cases (M0, K0):
1523 // 1,2 - 1,3 - 1,4 - 1,8 - 1,16
1524 // 2,2 - 2,3 - 2,4 - 2,8 - 2,16
1525 // 3,2 - 3,3 - 3,4 - 3,8 - 3,16
1526 // 4,2 - 4,3 - 4,4 - 4,8 - 4,16
1527 // 5,2 - 5,3 - 5,4 - 5,8 - 5,16
1528 // 6,2 - 6,3 - 6,4 - 6,8 - 6,16
1529 // 7,2 - 7,3 - 7,4 - 7,8 - 7,16
1530 // 8,2 - 8,3 - 8,4 - 8,8 - 8,16
1531 // Load values from LHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +01001532 LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zin);
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001533
1534 LD_RHS_VFMA_M0xN0(0, a, c);
1535 LD_RHS_VFMA_M0xN0(1, a, c);
1536#if K0 > 2
1537 LD_RHS_VFMA_M0xN0(2, a, c);
1538#endif // K0 > 2
1539#if K0 > 3
1540 LD_RHS_VFMA_M0xN0(3, a, c);
1541#endif // K0 > 3
1542#if K0 > 4
1543 LD_RHS_VFMA_M0xN0(4, a, c);
1544 LD_RHS_VFMA_M0xN0(5, a, c);
1545 LD_RHS_VFMA_M0xN0(6, a, c);
1546 LD_RHS_VFMA_M0xN0(7, a, c);
1547#endif // K0 > 4
1548#if K0 > 8
1549 LD_RHS_VFMA_M0xN0(8, a, c);
1550 LD_RHS_VFMA_M0xN0(9, a, c);
1551 LD_RHS_VFMA_M0xN0(A, a, c);
1552 LD_RHS_VFMA_M0xN0(B, a, c);
1553 LD_RHS_VFMA_M0xN0(C, a, c);
1554 LD_RHS_VFMA_M0xN0(D, a, c);
1555 LD_RHS_VFMA_M0xN0(E, a, c);
1556 LD_RHS_VFMA_M0xN0(F, a, c);
1557#endif // K0 > 8
1558
1559 lhs_offset += K0 * sizeof(DATA_TYPE);
1560 rhs_offset += K0 * RHS_STEP_X * RHS_STEP_LOOP * sizeof(DATA_TYPE);
1561 }
1562
1563 // Left-over accumulations
1564 for(; i < K; ++i)
1565 {
1566 // Load values from LHS matrix
1567 VEC_DATA_TYPE(DATA_TYPE, 2)
1568 a0 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 0 * lhs_stride_y + zin0));
1569#if M0 > 1
1570 VEC_DATA_TYPE(DATA_TYPE, 2)
1571 a1 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 1 * lhs_stride_y + zin1));
1572#endif // M0 > 1
1573#if M0 > 2
1574 VEC_DATA_TYPE(DATA_TYPE, 2)
1575 a2 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 2 * lhs_stride_y + zin2));
1576#endif // M0 > 2
1577#if M0 > 3
1578 VEC_DATA_TYPE(DATA_TYPE, 2)
1579 a3 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 3 * lhs_stride_y + zin3));
1580#endif // M0 > 3
1581#if M0 > 4
1582 VEC_DATA_TYPE(DATA_TYPE, 2)
1583 a4 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 4 * lhs_stride_y + zin4));
1584#endif // M0 > 4
1585#if M0 > 5
1586 VEC_DATA_TYPE(DATA_TYPE, 2)
1587 a5 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 5 * lhs_stride_y + zin5));
1588#endif // M0 > 5
1589#if M0 > 6
1590 VEC_DATA_TYPE(DATA_TYPE, 2)
1591 a6 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 6 * lhs_stride_y + zin6));
1592#endif // M0 > 6
1593#if M0 > 7
1594 VEC_DATA_TYPE(DATA_TYPE, 2)
giuros01b3204e72019-04-01 13:50:22 +01001595 a7 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 7 * lhs_stride_y + zin7));
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001596#endif // M0 > 7
1597
1598 LD_RHS_VFMA_M0xN0(0, a, c);
1599
1600 lhs_offset += sizeof(DATA_TYPE);
1601 rhs_offset += RHS_STEP_X * sizeof(DATA_TYPE);
1602 }
1603
1604 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * dst_stride_y);
1605
1606 REPEAT_VAR_INIT_TO_CONST(8, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
1607
1608#if defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001609 // The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
Usama Arif0681e3b2019-04-25 14:28:07 +01001610 CALCULATE_Z_OFFSET(M0, uint, zout, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001611
1612 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1613 // multiply dst_stride_z by DEPTH_GEMM3D
1614 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1615
1616#else // defined(REINTERPRET_OUTPUT_AS_3D)
1617
1618 // Add offset for batched GEMM
1619 dst_addr += z * dst_stride_z;
1620
1621#endif // defined(REINTERPRET_OUTPUT_AS_3D)
1622
1623 // Multiply by the weight of matrix-matrix product and store the result
1624#if defined(ALPHA)
Usama Arif0681e3b2019-04-25 14:28:07 +01001625 SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001626#endif // defined(ALPHA)
1627
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001628 // Add beta*bias
1629#if defined(BETA)
1630#if defined(BROADCAST_BIAS)
1631 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
1632
1633 LOAD_BLOCK(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
1634
1635#ifndef UNIT_BETA
1636 SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
1637#endif // UNIT_BIAS
1638
1639 // c = c + bias[broadcasted]
1640 ADD_BLOCK_BROADCAST(M0, c, bias0);
1641
1642#else // defined(BROADCAST_BIAS)
1643 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id(
1644 2) * bias_stride_z;
1645
1646 LOAD_BLOCK(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
1647
1648#ifndef UNIT_BETA
1649 SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
1650#endif // UNIT_BIAS
1651
1652 // c = c + bias
1653 ADD_BLOCK(M0, c, bias);
1654
1655#endif // defined(BROADCAST_BIAS)
1656#endif // defined(BETA)
1657
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001658 // Store output block
Usama Arif0681e3b2019-04-25 14:28:07 +01001659 STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001660
1661#undef RHS_BLOCK_SIZE
1662#undef RHS_OFFSET_X
1663#undef RHS_STEP_X
1664}
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001665#endif // defined(M0) && defined(N0) && defined(K0) && defined(H0) && defined(DATA_TYPE) && defined(M) && defined(N) && defined(K)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001666
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001667#if defined(M0) && defined(N0) && defined(K0) && defined(V0) && defined(H0) && defined(DATA_TYPE) && defined(M) && defined(N)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001668
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001669#if K0 == 2
1670#define ARM_DOT_K0(a, b, c) \
1671 ({ \
1672 c = fma(a.s0, b.s0, c); \
1673 c = fma(a.s1, b.s1, c); \
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001674 })
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001675#elif K0 == 3 // K0 == 3
1676#define ARM_DOT_K0(a, b, c) \
1677 ({ \
1678 c = fma(a.s0, b.s0, c); \
1679 c = fma(a.s1, b.s1, c); \
1680 c = fma(a.s2, b.s2, c); \
1681 })
1682#elif K0 == 4 // K0 == 4
1683#define ARM_DOT_K0(a, b, c) \
1684 ({ \
1685 c = fma(a.s0, b.s0, c); \
1686 c = fma(a.s1, b.s1, c); \
1687 c = fma(a.s2, b.s2, c); \
1688 c = fma(a.s3, b.s3, c); \
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001689 })
1690#elif K0 == 8 // K0 == 8
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001691#define ARM_DOT_K0(a, b, c) \
1692 ({ \
1693 c = fma(a.s0, b.s0, c); \
1694 c = fma(a.s1, b.s1, c); \
1695 c = fma(a.s2, b.s2, c); \
1696 c = fma(a.s3, b.s3, c); \
1697 c = fma(a.s4, b.s4, c); \
1698 c = fma(a.s5, b.s5, c); \
1699 c = fma(a.s6, b.s6, c); \
1700 c = fma(a.s7, b.s7, c); \
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001701 })
1702#elif K0 == 16 // K0 == 16
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001703#define ARM_DOT_K0(a, b, c) \
1704 ({ \
1705 c = fma(a.s0, b.s0, c); \
1706 c = fma(a.s1, b.s1, c); \
1707 c = fma(a.s2, b.s2, c); \
1708 c = fma(a.s3, b.s3, c); \
1709 c = fma(a.s4, b.s4, c); \
1710 c = fma(a.s5, b.s5, c); \
1711 c = fma(a.s6, b.s6, c); \
1712 c = fma(a.s7, b.s7, c); \
1713 c = fma(a.s8, b.s8, c); \
1714 c = fma(a.s9, b.s9, c); \
1715 c = fma(a.sA, b.sA, c); \
1716 c = fma(a.sB, b.sB, c); \
1717 c = fma(a.sC, b.sC, c); \
1718 c = fma(a.sD, b.sD, c); \
1719 c = fma(a.sE, b.sE, c); \
1720 c = fma(a.sF, b.sF, c); \
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001721 })
1722#else // K0 not supported
1723#error "K0 value not supported"
1724#endif // K0 conditions
1725
1726#if N0 == 2
1727#define ARM_DOT_K0XN0(a, b, c) \
1728 ({ \
1729 ARM_DOT_K0((a), (b##0), (c.s0)); \
1730 ARM_DOT_K0((a), (b##1), (c.s1)); \
1731 })
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001732#elif N0 == 3 // N0 == 3
1733#define ARM_DOT_K0XN0(a, b, c) \
1734 ({ \
1735 ARM_DOT_K0((a), (b##0), (c.s0)); \
1736 ARM_DOT_K0((a), (b##1), (c.s1)); \
1737 ARM_DOT_K0((a), (b##2), (c.s2)); \
1738 })
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001739#elif N0 == 4 // N0 == 4
1740#define ARM_DOT_K0XN0(a, b, c) \
1741 ({ \
1742 ARM_DOT_K0((a), (b##0), (c.s0)); \
1743 ARM_DOT_K0((a), (b##1), (c.s1)); \
1744 ARM_DOT_K0((a), (b##2), (c.s2)); \
1745 ARM_DOT_K0((a), (b##3), (c.s3)); \
1746 })
1747#elif N0 == 8 // N0 == 8
1748#define ARM_DOT_K0XN0(a, b, c) \
1749 ({ \
1750 ARM_DOT_K0((a), (b##0), (c.s0)); \
1751 ARM_DOT_K0((a), (b##1), (c.s1)); \
1752 ARM_DOT_K0((a), (b##2), (c.s2)); \
1753 ARM_DOT_K0((a), (b##3), (c.s3)); \
1754 ARM_DOT_K0((a), (b##4), (c.s4)); \
1755 ARM_DOT_K0((a), (b##5), (c.s5)); \
1756 ARM_DOT_K0((a), (b##6), (c.s6)); \
1757 ARM_DOT_K0((a), (b##7), (c.s7)); \
1758 })
1759#elif N0 == 16 // N0 == 16
1760#define ARM_DOT_K0XN0(a, b, c) \
1761 ({ \
1762 ARM_DOT_K0((a), (b##0), (c.s0)); \
1763 ARM_DOT_K0((a), (b##1), (c.s1)); \
1764 ARM_DOT_K0((a), (b##2), (c.s2)); \
1765 ARM_DOT_K0((a), (b##3), (c.s3)); \
1766 ARM_DOT_K0((a), (b##4), (c.s4)); \
1767 ARM_DOT_K0((a), (b##5), (c.s5)); \
1768 ARM_DOT_K0((a), (b##6), (c.s6)); \
1769 ARM_DOT_K0((a), (b##7), (c.s7)); \
1770 ARM_DOT_K0((a), (b##8), (c.s8)); \
1771 ARM_DOT_K0((a), (b##9), (c.s9)); \
1772 ARM_DOT_K0((a), (b##A), (c.sA)); \
1773 ARM_DOT_K0((a), (b##B), (c.sB)); \
1774 ARM_DOT_K0((a), (b##C), (c.sC)); \
1775 ARM_DOT_K0((a), (b##D), (c.sD)); \
1776 ARM_DOT_K0((a), (b##E), (c.sE)); \
1777 ARM_DOT_K0((a), (b##F), (c.sF)); \
1778 })
1779#else // N0 not supported
1780#error "N0 value not supported"
1781#endif // N0 conditions
1782
1783/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
1784 * The LHS matrix must be reshaped with @ref CLGEMMReshapeLHSMatrixKernel and the M0xK0 must be NOT transposed
1785 * The RHS matrix must be reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the K0xN0 must be transposed
1786 *
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001787 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
1788 * @note The GEMM's dimensions M and N must be passed at compile time using -DM and -DN (i.e. -DM=52 and -DN=90).
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001789 * @note The block's dimensions used for reshaping the LHS matrix and the RHS matrix (M0, N0 and K0) must be passed at compile time using -DM0, -DN0 and -DK0 (i.e. -DM0=4, -DN0=8, -DK0=4).
1790 * @note The number of M0xK0 vertical blocks stored on the same output row of the reshaped LHS matrix must be passed at compile time using -DV0 (i.e. -DV0=2)
1791 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (i.e. -DH0=2)
1792 * @note If the M0xK0 blocks in the reshaped LHS matrix have been interleaved, the option -DLHS_INTERLEAVE must passed at compile time.
1793 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
1794 * @note Only the following configurations of M0, N0 and K0 are currently supported:
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001795 * - M0 = 1, 2, 3, 4, 5, 6, 7, 8
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001796 * - N0 = 2, 3, 4, 8, 16
1797 * - K0 = 2, 3, 4, 8, 16
Gian Marco Iodice62251f72019-03-11 16:07:12 +00001798 * - V0 >= 1
1799 * - H0 >= 1
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001800 *
1801 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
1802 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1803 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1804 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1805 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix NOT reshaped
1806 *
1807 * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: F16/F32
1808 * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes)
1809 * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1810 * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes)
1811 * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1812 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
Gian Marco Iodice49b10152018-12-14 17:13:34 +00001813 * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001814 * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes)
1815 * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1816 * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
1817 * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1818 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
Gian Marco Iodice49b10152018-12-14 17:13:34 +00001819 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001820 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1821 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1822 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1823 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1824 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001825 * @param[in] k Number of columns in LHS matrix and rows in RHS matrix not reshaped.
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001826 * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes)
1827 * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
1828 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1829 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
1830 */
1831__kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs),
1832 IMAGE_DECLARATION(rhs),
1833 IMAGE_DECLARATION(dst),
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001834 uint k,
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001835 uint lhs_stride_z,
1836 uint rhs_stride_z,
1837 uint dst_stride_z
1838#if defined(REINTERPRET_OUTPUT_AS_3D)
1839 ,
1840 uint dst_cross_plane_pad
1841#endif // REINTERPRET_OUTPUT_AS_3D
1842 )
1843{
1844 // Block size
1845#define LHS_BLOCK_SIZE ((K0) * (M0))
1846
1847#if defined(LHS_INTERLEAVE)
1848#define LHS_OFFSET_X (K0)
1849#define LHS_STEP_X ((K0) * (V0))
1850#define LHS_STEP_LOOP (1)
1851#else // defined(INTERLEAVE)
1852#define LHS_OFFSET_X (LHS_BLOCK_SIZE)
1853#define LHS_STEP_X (K0)
1854#define LHS_STEP_LOOP (V0)
1855#endif // defined(INTERLEAVE)
1856
1857 // Block size
1858#define RHS_BLOCK_SIZE ((K0) * (N0))
1859
1860 // RHS offset and step X
1861#if defined(RHS_INTERLEAVE)
1862#define RHS_OFFSET_X (K0)
1863#define RHS_STEP_X ((K0) * (H0))
1864#define RHS_STEP_LOOP (1)
1865#else // defined(RHS_INTERLEAVE)
1866#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
1867#define RHS_STEP_X (K0)
1868#define RHS_STEP_LOOP (H0)
1869#endif // defined(RHS_INTERLEAVE)
1870
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001871#if defined(DUMMY_WORK_ITEMS)
1872 if((get_global_id(0) * N0 >= N) || (get_global_id(1) * M0 >= M))
1873 {
1874 return;
1875 }
1876#endif // defined(DUMMY_WORK_ITEMS)
1877
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001878 // Compute LHS matrix address
1879 __global uchar *lhs_addr = lhs_ptr + lhs_offset_first_element_in_bytes + (get_global_id(1) % V0) * (uint)LHS_OFFSET_X * sizeof(DATA_TYPE) + (get_global_id(1) / V0) * (uint)lhs_stride_y +
1880 (get_global_id(2) * lhs_stride_z);
1881
1882 // Compute RHS matrix address
1883 __global uchar *rhs_addr = rhs_ptr + rhs_offset_first_element_in_bytes + (get_global_id(0) % H0) * (uint)RHS_OFFSET_X * sizeof(DATA_TYPE) + (get_global_id(0) / (uint)H0) * rhs_stride_y;
1884
1885#if defined(MATRIX_B_DEPTH)
1886 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1887 rhs_addr += (get_global_id(2) % MATRIX_B_DEPTH) * rhs_stride_z;
1888#else // defined(MATRIX_B_DEPTH)
1889 rhs_addr += get_global_id(2) * rhs_stride_z;
1890#endif // defined(MATRIX_B_DEPTH)
1891
1892 // Initialize the accumulators
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +00001893 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(M0-1)=0;
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001894
Usama Arif0681e3b2019-04-25 14:28:07 +01001895 REPEAT_VAR_INIT_TO_CONST(8, uint, zlhs, 0); //uint zlhs0=0,zlhs1=0,zlhs2=0,... zlhs7=0;
1896 REPEAT_VAR_INIT_TO_CONST(16, uint, zrhs, 0);
1897
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001898 for(int i = 0; i < k; i += K0)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001899 {
1900 // Supported cases (M0, K0):
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001901 // 1,2 - 1,3 - 1,4 - 1,8 - 1,16
1902 // 2,2 - 2,3 - 2,4 - 2,8 - 2,16
1903 // 3,2 - 3,3 - 3,4 - 3,8 - 3,16
1904 // 4,2 - 4,3 - 4,4 - 4,8 - 4,16
1905 // 5,2 - 5,3 - 5,4 - 5,8 - 5,16
1906 // 6,2 - 6,3 - 6,4 - 6,8 - 6,16
1907 // 7,2 - 7,3 - 7,4 - 7,8 - 7,16
1908 // 8,2 - 8,3 - 8,4 - 8,8 - 8,16
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001909 // Load values from LHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +01001910 LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_addr, 0, LHS_STEP_X * sizeof(DATA_TYPE), zlhs);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001911
1912 // Load values from RHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +01001913 LOAD_BLOCK(N0, K0, DATA_TYPE, b, rhs_addr, 0, RHS_STEP_X * sizeof(DATA_TYPE), zrhs);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001914
1915 // Accumulate
1916 ARM_DOT_K0XN0(a0, b, c0);
1917#if M0 > 1
1918 ARM_DOT_K0XN0(a1, b, c1);
1919#endif // M0 > 1
1920#if M0 > 2
1921 ARM_DOT_K0XN0(a2, b, c2);
1922#endif // M0 > 2
1923#if M0 > 3
1924 ARM_DOT_K0XN0(a3, b, c3);
1925#endif // M0 > 3
1926#if M0 > 4
1927 ARM_DOT_K0XN0(a4, b, c4);
1928#endif // M0 > 4
1929#if M0 > 5
1930 ARM_DOT_K0XN0(a5, b, c5);
1931#endif // M0 > 5
1932#if M0 > 6
1933 ARM_DOT_K0XN0(a6, b, c6);
1934#endif // M0 > 6
1935#if M0 > 7
1936 ARM_DOT_K0XN0(a7, b, c7);
1937#endif // M0 > 7
1938
1939 lhs_addr += (M0 * LHS_STEP_X * LHS_STEP_LOOP) * sizeof(DATA_TYPE);
1940 rhs_addr += (N0 * RHS_STEP_X * RHS_STEP_LOOP) * sizeof(DATA_TYPE);
1941 }
1942
1943 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * dst_stride_y);
1944
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +00001945 REPEAT_VAR_INIT_TO_CONST(8, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001946
1947#if defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001948
1949 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
Usama Arif0681e3b2019-04-25 14:28:07 +01001950 CALCULATE_Z_OFFSET(M0, uint, zout, get_global_id(1), HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001951 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1952 // multiply dst_stride_z by DEPTH_GEMM3D
1953 dst_addr += get_global_id(2) * dst_stride_z * DEPTH_GEMM3D;
1954
1955#else // defined(REINTERPRET_OUTPUT_AS_3D)
1956
1957 // Add offset for batched GEMM
1958 dst_addr += get_global_id(2) * dst_stride_z;
1959
1960#endif // defined(REINTERPRET_OUTPUT_AS_3D)
1961
1962 // Multiply by the weight of matrix-matrix product and store the result
1963#if defined(ALPHA)
Usama Arif0681e3b2019-04-25 14:28:07 +01001964 SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001965#endif // defined(ALPHA)
1966
1967 // Store output block
Usama Arif0681e3b2019-04-25 14:28:07 +01001968 STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001969#undef LHS_BLOCK_SIZE
1970#undef LHS_OFFSET_X
1971#undef LHS_STEP_X
1972#undef RHS_BLOCK_SIZE
1973#undef RHS_OFFSET_X
1974#undef RHS_STEP_X
1975}
giuros01b3204e72019-04-01 13:50:22 +01001976
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001977#endif // defined(M0) && defined(N0) && defined(K0) && defined(V0) && defined(H0) && defined(K) && defined(DATA_TYPE)
1978
giuros01b3204e72019-04-01 13:50:22 +01001979#if defined(M0) && defined(N0) && defined(K0) && defined(K) && defined(DATA_TYPE)
1980
1981#define VFMA(a, b, c) \
1982 ({ \
1983 c = fma(a, b, c); \
1984 })
1985
1986#if M0 == 1
1987#define RHS_VFMA_M0xN0(i, a, b, c) \
1988 ({ \
1989 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1990 })
1991#elif M0 == 2 // M0 == 2
1992#define RHS_VFMA_M0xN0(i, a, b, c) \
1993 ({ \
1994 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1995 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1996 })
1997#elif M0 == 3 // M0 == 3
1998#define RHS_VFMA_M0xN0(i, a, b, c) \
1999 ({ \
2000 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
2001 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
2002 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
2003 })
2004#elif M0 == 4 // M0 == 4
2005#define RHS_VFMA_M0xN0(i, a, b, c) \
2006 ({ \
2007 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
2008 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
2009 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
2010 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
2011 })
2012#elif M0 == 5 // M0 == 5
2013#define RHS_VFMA_M0xN0(i, a, b, c) \
2014 ({ \
2015 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
2016 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
2017 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
2018 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
2019 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
2020 })
2021#elif M0 == 6 // M0 == 6
2022#define RHS_VFMA_M0xN0(i, a, b, c) \
2023 ({ \
2024 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
2025 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
2026 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
2027 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
2028 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
2029 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
2030 })
2031#elif M0 == 7 // M0 == 7
2032#define RHS_VFMA_M0xN0(i, a, b, c) \
2033 ({ \
2034 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
2035 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
2036 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
2037 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
2038 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
2039 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
2040 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##6).s##i), b, (c##6)); \
2041 })
2042#elif M0 == 8 // M0 == 8
2043#define RHS_VFMA_M0xN0(i, a, b, c) \
2044 ({ \
2045 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
2046 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
2047 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
2048 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
2049 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
2050 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
2051 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##6).s##i), b, (c##6)); \
2052 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##7).s##i), b, (c##7)); \
2053 })
2054#else // M0 not supported
2055#error "M0 not supported"
2056#endif // M0 not supported
2057
2058/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
2059 * The LHS matrix is NOT reshaped
2060 * The RHS matrix is NOT reshaped
2061 *
2062 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
2063 * @note The GEMM's dimensions (M,N and K) must be passed at compile time using -DM, -DN and and -DK (i.e. -DM=52, -DN=30 and -DK=90)
2064 * @note The number of columns of LHS matrix must be passed at compile time using -DK (i.e. -DK=64)
2065 * @note The number of M0 rows to process must be passed at compile time using -DM0 (i.e. -DM0=2)
2066 * @note The number of K0 partial accumulations must be passed at compile time using -DK0 (i.e., -DK0=2)
2067 * @note The number of N0 columns to process must be passed at compile time using -DN0 (i.e. -DN0=2)
2068 * @note Only the following configurations of M0, N0 and K0 are currently supported:
2069 * - M0 = 1, 2, 3, 4, 5, 6, 7, 8
2070 * - N0 = 2, 3, 4, 8, 16
2071 * - K0 = 2, 3, 4, 8, 16
2072 *
2073 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
2074 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
2075 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2076 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2077 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2078 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix
2079 *
2080 * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: F16/F32
2081 * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes)
2082 * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2083 * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes)
2084 * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2085 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
2086 * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
2087 * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes)
2088 * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2089 * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
2090 * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2091 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
2092 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
2093 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2094 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2095 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2096 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2097 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2098 * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes)
2099 * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
2100 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2101 * @param[in] lhs_cross_plane_pad (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
2102 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
2103 */
2104__kernel void gemm_mm_native(IMAGE_DECLARATION(lhs),
2105 IMAGE_DECLARATION(rhs),
2106 IMAGE_DECLARATION(dst),
2107 uint lhs_stride_z,
2108 uint rhs_stride_z,
2109 uint dst_stride_z
2110#if defined(REINTERPRET_INPUT_AS_3D)
2111 ,
2112 uint lhs_cross_plane_pad
2113#endif // REINTERPRET_INPUT_AS_3D
2114#if defined(REINTERPRET_OUTPUT_AS_3D)
2115 ,
2116 uint dst_cross_plane_pad
2117#endif // REINTERPRET_OUTPUT_AS_3D
2118 )
2119{
2120 // Block size
2121#define RHS_BLOCK_SIZE ((K0) * (N0))
2122
2123 // RHS offset and step X
2124#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
2125
2126 uint x = get_global_id(0);
2127 uint y = get_global_id(1);
2128 uint z = get_global_id(2);
2129
2130#if defined(DUMMY_WORK_ITEMS)
2131 if((x * N0 >= N) || (y * M0 >= M))
2132 {
2133 return;
2134 }
2135#endif // defined(DUMMY_WORK_ITEMS)
2136
2137 // Compute LHS matrix address
2138 uint lhs_offset = lhs_offset_first_element_in_bytes + y * M0 * (uint)lhs_stride_y;
2139
2140 // Compute RHS matrix address
2141 uint rhs_offset = rhs_offset_first_element_in_bytes + x * N0 * sizeof(DATA_TYPE);
2142
2143#if defined(MATRIX_B_DEPTH)
2144 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2145 rhs_offset += (z % MATRIX_B_DEPTH) * rhs_stride_z;
2146#else // defined(MATRIX_B_DEPTH)
2147 rhs_offset += z * rhs_stride_z;
2148#endif // defined(MATRIX_B_DEPTH)
2149
2150 REPEAT_VAR_INIT_TO_CONST(8, uint, zlhs, 0); //uint zlhs0=0,zlhs1=0,zlhs2=0,... zlhs7=0;
2151 REPEAT_VAR_INIT_TO_CONST(16, uint, zrhs, 0);
2152
2153#if defined(REINTERPRET_INPUT_AS_3D)
2154 // The plane (zlhs) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
2155 CALCULATE_Z_OFFSET(M0, uint, zlhs, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
2156
2157 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2158 // multiply lhs_stride_z by DEPTH_GEMM3D
2159 lhs_offset += z * lhs_stride_z * DEPTH_GEMM3D;
2160
2161#else // defined(REINTERPRET_INPUT_AS_3D)
2162
2163 // Add offset for batched GEMM
2164 lhs_offset += z * lhs_stride_z;
2165
2166#endif // defined(REINTERPRET_INPUT_AS_3D)
2167
2168 // Initialize the accumulators
2169 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(N0-1)=0;
2170
2171 int i = 0;
2172 for(; i <= (K - K0); i += K0)
2173 {
2174 // Supported cases (M0, K0):
2175 // 1,2 - 1,3 - 1,4 - 1,8 - 1,16
2176 // 2,2 - 2,3 - 2,4 - 2,8 - 2,16
2177 // 3,2 - 3,3 - 3,4 - 3,8 - 3,16
2178 // 4,2 - 4,3 - 4,4 - 4,8 - 4,16
2179 // 5,2 - 5,3 - 5,4 - 5,8 - 5,16
2180 // 6,2 - 6,3 - 6,4 - 6,8 - 6,16
2181 // 7,2 - 7,3 - 7,4 - 7,8 - 7,16
2182 // 8,2 - 8,3 - 8,4 - 8,8 - 8,16
2183 // Load values from LHS matrix
2184 LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs);
2185
2186 // Load values from RHS matrix
2187 LOAD_BLOCK(K0, N0, DATA_TYPE, b, rhs_ptr, rhs_offset, rhs_stride_y, zrhs);
2188
2189 RHS_VFMA_M0xN0(0, a, b0, c);
2190 RHS_VFMA_M0xN0(1, a, b1, c);
2191#if K0 > 2
2192 RHS_VFMA_M0xN0(2, a, b2, c);
2193#endif // K0 > 2
2194#if K0 > 3
2195 RHS_VFMA_M0xN0(3, a, b3, c);
2196#endif // K0 > 3
2197#if K0 > 4
2198 RHS_VFMA_M0xN0(4, a, b4, c);
2199 RHS_VFMA_M0xN0(5, a, b5, c);
2200 RHS_VFMA_M0xN0(6, a, b6, c);
2201 RHS_VFMA_M0xN0(7, a, b7, c);
2202#endif // K0 > 4
2203#if K0 > 8
2204 RHS_VFMA_M0xN0(8, a, b8, c);
2205 RHS_VFMA_M0xN0(9, a, b9, c);
2206 RHS_VFMA_M0xN0(A, a, b10, c);
2207 RHS_VFMA_M0xN0(B, a, b11, c);
2208 RHS_VFMA_M0xN0(C, a, b12, c);
2209 RHS_VFMA_M0xN0(D, a, b13, c);
2210 RHS_VFMA_M0xN0(E, a, b14, c);
2211 RHS_VFMA_M0xN0(F, a, b15, c);
2212#endif // K0 > 8
2213
2214 lhs_offset += K0 * sizeof(DATA_TYPE);
2215 rhs_offset += K0 * rhs_stride_y;
2216 }
2217
2218 // Left-over accumulations
2219 for(; i < K; ++i)
2220 {
2221 // Load values from LHS matrix
2222 VEC_DATA_TYPE(DATA_TYPE, 2)
2223 a0 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 0 * lhs_stride_y + zlhs0));
2224#if M0 > 1
2225 VEC_DATA_TYPE(DATA_TYPE, 2)
2226 a1 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 1 * lhs_stride_y + zlhs1));
2227#endif // M0 > 1
2228#if M0 > 2
2229 VEC_DATA_TYPE(DATA_TYPE, 2)
2230 a2 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 2 * lhs_stride_y + zlhs2));
2231#endif // M0 > 2
2232#if M0 > 3
2233 VEC_DATA_TYPE(DATA_TYPE, 2)
2234 a3 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 3 * lhs_stride_y + zlhs3));
2235#endif // M0 > 3
2236#if M0 > 4
2237 VEC_DATA_TYPE(DATA_TYPE, 2)
2238 a4 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 4 * lhs_stride_y + zlhs4));
2239#endif // M0 > 4
2240#if M0 > 5
2241 VEC_DATA_TYPE(DATA_TYPE, 2)
2242 a5 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 5 * lhs_stride_y + zlhs5));
2243#endif // M0 > 5
2244#if M0 > 6
2245 VEC_DATA_TYPE(DATA_TYPE, 2)
2246 a6 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 6 * lhs_stride_y + zlhs6));
2247#endif // M0 > 6
2248#if M0 > 7
2249 VEC_DATA_TYPE(DATA_TYPE, 2)
2250 a7 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 7 * lhs_stride_y + zlhs7));
2251#endif // M0 > 7
2252
2253 VEC_DATA_TYPE(DATA_TYPE, N0)
2254 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0 * rhs_stride_y));
2255 RHS_VFMA_M0xN0(0, a, b, c);
2256
2257 lhs_offset += sizeof(DATA_TYPE);
2258 rhs_offset += rhs_stride_y;
2259 }
2260
2261 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * dst_stride_y);
2262
2263 REPEAT_VAR_INIT_TO_CONST(8, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
2264
2265#if defined(REINTERPRET_OUTPUT_AS_3D)
2266 // The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
2267 CALCULATE_Z_OFFSET(M0, uint, zout, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
2268
2269 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2270 // multiply dst_stride_z by DEPTH_GEMM3D
2271 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
2272
2273#else // defined(REINTERPRET_OUTPUT_AS_3D)
2274
2275 // Add offset for batched GEMM
2276 dst_addr += z * dst_stride_z;
2277
2278#endif // defined(REINTERPRET_OUTPUT_AS_3D)
2279
2280 // Multiply by the weight of matrix-matrix product and store the result
2281 // Multiply by the weight of matrix-matrix product and store the result
2282#if defined(ALPHA)
2283 SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
2284#endif // defined(ALPHA)
2285
2286 // Store output block
2287 STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
2288
2289#undef RHS_BLOCK_SIZE
2290#undef RHS_OFFSET_X
2291#undef RHS_STEP_X
2292}
2293#endif // defined(M0) && defined(N0) && defined(K0) && defined(K) && defined(DATA_TYPE)
2294
Gian Marco36a0a462018-01-12 10:21:40 +00002295#if defined(COLS_B) && defined(MULT_TRANSPOSE1XW_WIDTH) && defined(MULT_INTERLEAVE4X4_HEIGHT)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002296/** This OpenCL kernel is optimised for Midgard. It computes the matrix multiplication between matrix A (src0) and matrix B (src1)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002297 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_32bit and @ref gemm_transpose1x4 before running the matrix multiplication
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002298 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002299 * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
2300 *
Gian Marco19835e52018-01-30 13:35:54 +00002301 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
2302 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
2303 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002304 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
2305 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002306 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002307 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
2308 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2309 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2310 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2311 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
2312 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002313 * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
2314 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002315 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
2316 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
2317 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2318 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
2319 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2320 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002321 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002322 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
2323 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2324 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
2325 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2326 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002327 * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
2328 * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes)
2329 * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
2330 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002331 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002332 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00002333 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002334 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00002335 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002336 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002337 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
2338 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
2339 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002340 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002341 */
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01002342__kernel void gemm_mm_interleaved_transposed_f32(IMAGE_DECLARATION(src0),
2343 IMAGE_DECLARATION(src1),
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002344#if defined(ADD_VEC_C)
2345 VECTOR_DECLARATION(src2),
2346#endif /* defined(ADD_VEC_C) */
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01002347 IMAGE_DECLARATION(dst),
2348 uint src0_stride_z,
2349 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002350 uint dst_stride_z
2351#if defined(REINTERPRET_OUTPUT_AS_3D)
2352 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002353 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002354#endif // REINTERPRET_OUTPUT_AS_3D
2355 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002356{
Gian Marco36a0a462018-01-12 10:21:40 +00002357 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
2358 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +00002359 int z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002360
Gian Marco36a0a462018-01-12 10:21:40 +00002361 // Offset
2362 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
2363 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 4;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002364
Gian Marco36a0a462018-01-12 10:21:40 +00002365 // src_addr_a = address of matrix A
2366 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002367 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
2368 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
2369
2370#if defined(MATRIX_B_DEPTH)
2371 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2372 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
2373#else // defined(MATRIX_B_DEPTH)
2374 src1_addr_in_bytes += z * src1_stride_z;
2375#endif // defined(MATRIX_B_DEPTH)
2376
2377 __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes);
2378 __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002379
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002380 // Compute end row address for matrix B
Gian Marco36a0a462018-01-12 10:21:40 +00002381 __global float *src_end_addr_b = src_addr_b + COLS_B;
2382
2383 src_addr_a += offset_row_a;
2384 src_addr_b += offset_row_b;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002385
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002386 // Reset accumulators
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002387 float4 c00 = 0.0f;
2388 float4 c10 = 0.0f;
2389 float4 c20 = 0.0f;
2390 float4 c30 = 0.0f;
2391
Gian Marco36a0a462018-01-12 10:21:40 +00002392 for(; src_addr_b <= (src_end_addr_b - (int)(8 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002393 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002394 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00002395 float4 a0 = vload4(0, src_addr_a);
2396 float4 b0 = vload4(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002397
2398 c00 += (float4)a0.s0 * b0;
2399 c10 += (float4)a0.s1 * b0;
2400 c20 += (float4)a0.s2 * b0;
2401 c30 += (float4)a0.s3 * b0;
2402
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002403 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00002404 a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
2405 b0 = vload4(0, src_addr_b + 4 * MULT_TRANSPOSE1XW_WIDTH);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002406
2407 c00 += (float4)a0.s0 * b0;
2408 c10 += (float4)a0.s1 * b0;
2409 c20 += (float4)a0.s2 * b0;
2410 c30 += (float4)a0.s3 * b0;
2411 }
2412
Gian Marco36a0a462018-01-12 10:21:40 +00002413 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002414 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002415 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00002416 float4 a0 = vload4(0, src_addr_a);
2417 float4 b0 = vload4(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002418
2419 c00 += (float4)a0.s0 * b0;
2420 c10 += (float4)a0.s1 * b0;
2421 c20 += (float4)a0.s2 * b0;
2422 c30 += (float4)a0.s3 * b0;
2423 }
2424
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002425 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002426 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2427
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002428#if defined(ALPHA)
2429 // Multiply by the weight of matrix product
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002430 c00 = c00 * (float4)ALPHA;
2431 c10 = c10 * (float4)ALPHA;
2432 c20 = c20 * (float4)ALPHA;
2433 c30 = c30 * (float4)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002434#endif // defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002435
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002436#if defined(ADD_VEC_C)
2437 __global float *src2_addr = (__global float *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
2438 float4 c0 = vload4(0, src2_addr);
2439
2440 c00 += c0;
2441 c10 += c0;
2442 c20 += c0;
2443 c30 += c0;
2444#endif /* defined(ADD_VEC_C) */
2445
Gian Marcoae2af742018-02-15 12:35:44 +00002446 // Compute dst address
2447 __global uchar *dst_addr = offset(&dst, 0, 0);
2448
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002449#if defined(REINTERPRET_OUTPUT_AS_3D)
2450 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002451 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002452 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002453 // | |
2454 // | plane0 |
2455 // | |
2456 // |__________________|
2457 // |******************|
2458 // | cross_plane_pad |
2459 // |******************|
2460 // | |
2461 // | plane1 |
2462 // | |
2463 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002464
2465 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
2466 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
2467 zout = min(DEPTH_GEMM3D - 1, zout);
2468
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002469 // Add offset due to the cross plane paddings
2470 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002471
2472 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2473 // multiply dst_stride_z by DEPTH_GEMM3D
2474 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
2475
2476 // Store 4x4 block
2477 vstore4(c00, 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
2478 vstore4(c10, 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
2479 vstore4(c20, 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
2480 vstore4(c30, 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
2481
2482#else // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +00002483 // Add offset for batched GEMM
2484 dst_addr += z * dst_stride_z;
2485
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002486 // Store 4x4 block
Gian Marcoae2af742018-02-15 12:35:44 +00002487 vstore4(c00, 0, (__global float *)(dst_addr + 0 * dst_stride_y));
2488 vstore4(c10, 0, (__global float *)(dst_addr + 1 * dst_stride_y));
2489 vstore4(c20, 0, (__global float *)(dst_addr + 2 * dst_stride_y));
2490 vstore4(c30, 0, (__global float *)(dst_addr + 3 * dst_stride_y));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002491#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002492}
2493
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002494/** This OpenCL kernel is optimized for Bifrost. It computes the matrix multiplication between matrix A (src0) and matrix B (src1)
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002495 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_32bit and @ref gemm_transpose1x4 before running the matrix multiplication.
2496 *
2497 * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002498 *
Gian Marco19835e52018-01-30 13:35:54 +00002499 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
2500 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
2501 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002502 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
2503 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
2504 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002505 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002506 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
2507 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2508 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2509 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2510 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
2511 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002512 * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
2513 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002514 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
2515 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
2516 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2517 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
2518 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2519 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002520 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002521 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
2522 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2523 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
2524 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2525 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002526 * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
2527 * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes)
2528 * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
2529 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002530 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002531 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00002532 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002533 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00002534 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002535 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002536 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
2537 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
2538 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002539 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002540 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002541__kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0),
2542 IMAGE_DECLARATION(src1),
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002543#if defined(ADD_VEC_C)
2544 VECTOR_DECLARATION(src2),
2545#endif /* defined(ADD_VEC_C) */
Gian Marcoae2af742018-02-15 12:35:44 +00002546 IMAGE_DECLARATION(dst),
2547 uint src0_stride_z,
2548 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002549 uint dst_stride_z
2550#if defined(REINTERPRET_OUTPUT_AS_3D)
2551 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002552 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002553#endif // REINTERPRET_OUTPUT_AS_3D
2554 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002555{
Gian Marco36a0a462018-01-12 10:21:40 +00002556 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
2557 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +00002558 int z = get_global_id(2);
Gian Marco36a0a462018-01-12 10:21:40 +00002559
2560 // Offset
2561 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
2562 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 4;
2563
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002564 // src_addr_a = address of matrix A
2565 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002566 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
2567 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
2568
2569#if defined(MATRIX_B_DEPTH)
2570 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2571 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
2572#else // defined(MATRIX_B_DEPTH)
2573 src1_addr_in_bytes += z * src1_stride_z;
2574#endif // defined(MATRIX_B_DEPTH)
2575
2576 __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes);
2577 __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002578
Gian Marco36a0a462018-01-12 10:21:40 +00002579 src_addr_a += offset_row_a;
2580 src_addr_b += offset_row_b;
2581
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002582 // Reset accumulators
2583 float c00 = 0.0f;
2584 float c01 = 0.0f;
2585 float c02 = 0.0f;
2586 float c03 = 0.0f;
2587 float c10 = 0.0f;
2588 float c11 = 0.0f;
2589 float c12 = 0.0f;
2590 float c13 = 0.0f;
2591 float c20 = 0.0f;
2592 float c21 = 0.0f;
2593 float c22 = 0.0f;
2594 float c23 = 0.0f;
2595 float c30 = 0.0f;
2596 float c31 = 0.0f;
2597 float c32 = 0.0f;
2598 float c33 = 0.0f;
2599
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002600#define COLS_MTX_B (COLS_B / (4 * MULT_TRANSPOSE1XW_WIDTH))
2601
2602 int i = 0;
2603 for(; i <= (int)(COLS_MTX_B - 4); i += 4)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002604 {
2605 // Load values from matrix A (interleaved) and matrix B (transposed)
2606 float4 a0 = vload4(0, src_addr_a);
2607 float4 b0 = vload4(0, src_addr_b);
2608
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002609 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
2610 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002611
2612 c00 = fma(a0.s0, b0.s0, c00);
2613 c01 = fma(a0.s0, b0.s1, c01);
2614 c02 = fma(a0.s0, b0.s2, c02);
2615 c03 = fma(a0.s0, b0.s3, c03);
2616
2617 c10 = fma(a0.s1, b0.s0, c10);
2618 c11 = fma(a0.s1, b0.s1, c11);
2619 c12 = fma(a0.s1, b0.s2, c12);
2620 c13 = fma(a0.s1, b0.s3, c13);
2621
2622 c20 = fma(a0.s2, b0.s0, c20);
2623 c21 = fma(a0.s2, b0.s1, c21);
2624 c22 = fma(a0.s2, b0.s2, c22);
2625 c23 = fma(a0.s2, b0.s3, c23);
2626
2627 c30 = fma(a0.s3, b0.s0, c30);
2628 c31 = fma(a0.s3, b0.s1, c31);
2629 c32 = fma(a0.s3, b0.s2, c32);
2630 c33 = fma(a0.s3, b0.s3, c33);
2631
2632 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002633 a0 = vload4(0, src_addr_a);
2634 b0 = vload4(0, src_addr_b);
2635
2636 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
2637 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002638
2639 c00 = fma(a0.s0, b0.s0, c00);
2640 c01 = fma(a0.s0, b0.s1, c01);
2641 c02 = fma(a0.s0, b0.s2, c02);
2642 c03 = fma(a0.s0, b0.s3, c03);
2643
2644 c10 = fma(a0.s1, b0.s0, c10);
2645 c11 = fma(a0.s1, b0.s1, c11);
2646 c12 = fma(a0.s1, b0.s2, c12);
2647 c13 = fma(a0.s1, b0.s3, c13);
2648
2649 c20 = fma(a0.s2, b0.s0, c20);
2650 c21 = fma(a0.s2, b0.s1, c21);
2651 c22 = fma(a0.s2, b0.s2, c22);
2652 c23 = fma(a0.s2, b0.s3, c23);
2653
2654 c30 = fma(a0.s3, b0.s0, c30);
2655 c31 = fma(a0.s3, b0.s1, c31);
2656 c32 = fma(a0.s3, b0.s2, c32);
2657 c33 = fma(a0.s3, b0.s3, c33);
2658
2659 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002660 a0 = vload4(0, src_addr_a);
2661 b0 = vload4(0, src_addr_b);
2662
2663 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
2664 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
2665
2666 c00 = fma(a0.s0, b0.s0, c00);
2667 c01 = fma(a0.s0, b0.s1, c01);
2668 c02 = fma(a0.s0, b0.s2, c02);
2669 c03 = fma(a0.s0, b0.s3, c03);
2670
2671 c10 = fma(a0.s1, b0.s0, c10);
2672 c11 = fma(a0.s1, b0.s1, c11);
2673 c12 = fma(a0.s1, b0.s2, c12);
2674 c13 = fma(a0.s1, b0.s3, c13);
2675
2676 c20 = fma(a0.s2, b0.s0, c20);
2677 c21 = fma(a0.s2, b0.s1, c21);
2678 c22 = fma(a0.s2, b0.s2, c22);
2679 c23 = fma(a0.s2, b0.s3, c23);
2680
2681 c30 = fma(a0.s3, b0.s0, c30);
2682 c31 = fma(a0.s3, b0.s1, c31);
2683 c32 = fma(a0.s3, b0.s2, c32);
2684 c33 = fma(a0.s3, b0.s3, c33);
2685
2686 // Load values from matrix A (interleaved) and matrix B (transposed)
2687 a0 = vload4(0, src_addr_a);
2688 b0 = vload4(0, src_addr_b);
2689
2690 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
2691 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002692
2693 c00 = fma(a0.s0, b0.s0, c00);
2694 c01 = fma(a0.s0, b0.s1, c01);
2695 c02 = fma(a0.s0, b0.s2, c02);
2696 c03 = fma(a0.s0, b0.s3, c03);
2697
2698 c10 = fma(a0.s1, b0.s0, c10);
2699 c11 = fma(a0.s1, b0.s1, c11);
2700 c12 = fma(a0.s1, b0.s2, c12);
2701 c13 = fma(a0.s1, b0.s3, c13);
2702
2703 c20 = fma(a0.s2, b0.s0, c20);
2704 c21 = fma(a0.s2, b0.s1, c21);
2705 c22 = fma(a0.s2, b0.s2, c22);
2706 c23 = fma(a0.s2, b0.s3, c23);
2707
2708 c30 = fma(a0.s3, b0.s0, c30);
2709 c31 = fma(a0.s3, b0.s1, c31);
2710 c32 = fma(a0.s3, b0.s2, c32);
2711 c33 = fma(a0.s3, b0.s3, c33);
2712 }
2713
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002714 for(; i < (int)(COLS_MTX_B); ++i)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002715 {
2716 // Load values from matrix A (interleaved) and matrix B (transposed)
2717 float4 a0 = vload4(0, src_addr_a);
2718 float4 b0 = vload4(0, src_addr_b);
2719
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002720 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
2721 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
2722
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002723 c00 = fma(a0.s0, b0.s0, c00);
2724 c01 = fma(a0.s0, b0.s1, c01);
2725 c02 = fma(a0.s0, b0.s2, c02);
2726 c03 = fma(a0.s0, b0.s3, c03);
2727
2728 c10 = fma(a0.s1, b0.s0, c10);
2729 c11 = fma(a0.s1, b0.s1, c11);
2730 c12 = fma(a0.s1, b0.s2, c12);
2731 c13 = fma(a0.s1, b0.s3, c13);
2732
2733 c20 = fma(a0.s2, b0.s0, c20);
2734 c21 = fma(a0.s2, b0.s1, c21);
2735 c22 = fma(a0.s2, b0.s2, c22);
2736 c23 = fma(a0.s2, b0.s3, c23);
2737
2738 c30 = fma(a0.s3, b0.s0, c30);
2739 c31 = fma(a0.s3, b0.s1, c31);
2740 c32 = fma(a0.s3, b0.s2, c32);
2741 c33 = fma(a0.s3, b0.s3, c33);
2742 }
2743
2744 // Compute destination address
2745 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2746
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002747#if defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002748 // Multiply by the weight of matrix product
2749 c00 = c00 * ALPHA;
2750 c01 = c01 * ALPHA;
2751 c02 = c02 * ALPHA;
2752 c03 = c03 * ALPHA;
2753 c10 = c10 * ALPHA;
2754 c11 = c11 * ALPHA;
2755 c12 = c12 * ALPHA;
2756 c13 = c13 * ALPHA;
2757 c20 = c20 * ALPHA;
2758 c21 = c21 * ALPHA;
2759 c22 = c22 * ALPHA;
2760 c23 = c23 * ALPHA;
2761 c30 = c30 * ALPHA;
2762 c31 = c31 * ALPHA;
2763 c32 = c32 * ALPHA;
2764 c33 = c33 * ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002765#endif // defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002766
Gian Marcoae2af742018-02-15 12:35:44 +00002767 // Compute dst address
2768 __global uchar *dst_addr = offset(&dst, 0, 0);
2769
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002770#if defined(ADD_VEC_C)
2771 __global float *src2_addr = (__global float *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
2772 float4 c0 = vload4(0, src2_addr);
2773
2774 c00 += c0.s0;
2775 c01 += c0.s1;
2776 c02 += c0.s2;
2777 c03 += c0.s3;
2778 c10 += c0.s0;
2779 c11 += c0.s1;
2780 c12 += c0.s2;
2781 c13 += c0.s3;
2782 c20 += c0.s0;
2783 c21 += c0.s1;
2784 c22 += c0.s2;
2785 c23 += c0.s3;
2786 c30 += c0.s0;
2787 c31 += c0.s1;
2788 c32 += c0.s2;
2789 c33 += c0.s3;
2790#endif /* defined(ADD_VEC_C) */
2791
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002792#if defined(REINTERPRET_OUTPUT_AS_3D)
2793 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002794 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002795 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002796 // | |
2797 // | plane0 |
2798 // | |
2799 // |__________________|
2800 // |******************|
2801 // | cross_plane_pad |
2802 // |******************|
2803 // | |
2804 // | plane1 |
2805 // | |
2806 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002807
2808 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
2809 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
2810 zout = min(DEPTH_GEMM3D - 1, zout);
2811
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002812 // Add offset due to the cross plane paddings
2813 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002814
2815 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2816 // multiply dst_stride_z by DEPTH_GEMM3D
2817 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
2818
2819 // Store 4x4 block
2820 vstore4((float4)(c00, c01, c02, c03), 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
2821 vstore4((float4)(c10, c11, c12, c13), 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
2822 vstore4((float4)(c20, c21, c22, c23), 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
2823 vstore4((float4)(c30, c31, c32, c33), 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
2824
2825#else // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +00002826 // Add offset for batched GEMM
2827 dst_addr += z * dst_stride_z;
2828
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002829 // Store 4x4 block
Gian Marcoae2af742018-02-15 12:35:44 +00002830 vstore4((float4)(c00, c01, c02, c03), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
2831 vstore4((float4)(c10, c11, c12, c13), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
2832 vstore4((float4)(c20, c21, c22, c23), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
2833 vstore4((float4)(c30, c31, c32, c33), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002834#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002835}
2836
Georgios Pinitas84225582018-05-14 12:00:05 +01002837// Undefine local defines
2838#undef COLS_MTX_B
2839
Matthew Bentham6f31f8c2017-10-27 11:50:06 +01002840#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002841/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002842 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002843 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002844 * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
2845 *
Gian Marco19835e52018-01-30 13:35:54 +00002846 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
2847 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
2848 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002849 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
2850 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002851 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002852 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
2853 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2854 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2855 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2856 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
2857 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002858 * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
2859 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002860 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
2861 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
2862 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2863 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
2864 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2865 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002866 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002867 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
2868 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2869 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
2870 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2871 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002872 * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
2873 * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes)
2874 * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
2875 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002876 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002877 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00002878 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002879 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00002880 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002881 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002882 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
2883 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
2884 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002885 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002886 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002887__kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0),
2888 IMAGE_DECLARATION(src1),
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002889#if defined(ADD_VEC_C)
2890 VECTOR_DECLARATION(src2),
2891#endif /* defined(ADD_VEC_C) */
Gian Marcoae2af742018-02-15 12:35:44 +00002892 IMAGE_DECLARATION(dst),
2893 uint src0_stride_z,
2894 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002895 uint dst_stride_z
2896#if defined(REINTERPRET_OUTPUT_AS_3D)
2897 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002898 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002899#endif // REINTERPRET_OUTPUT_AS_3D
2900 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002901{
Gian Marco36a0a462018-01-12 10:21:40 +00002902 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
2903 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +00002904 int z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002905
Gian Marco36a0a462018-01-12 10:21:40 +00002906 // Offset
2907 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
2908 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002909
Gian Marco36a0a462018-01-12 10:21:40 +00002910 // src_addr_a = address of matrix A
2911 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002912 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
2913 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
2914
2915#if defined(MATRIX_B_DEPTH)
2916 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2917 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
2918#else // defined(MATRIX_B_DEPTH)
2919 src1_addr_in_bytes += z * src1_stride_z;
2920#endif // defined(MATRIX_B_DEPTH)
2921
2922 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
2923 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002924
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002925 // Compute end row address for matrix B
Gian Marco36a0a462018-01-12 10:21:40 +00002926 __global half *src_end_addr_b = src_addr_b + COLS_B;
2927
2928 src_addr_a += offset_row_a;
2929 src_addr_b += offset_row_b;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002930
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002931 // Reset accumulators
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002932 half8 c00 = 0.0f;
2933 half8 c10 = 0.0f;
2934 half8 c20 = 0.0f;
2935 half8 c30 = 0.0f;
2936
Gian Marco36a0a462018-01-12 10:21:40 +00002937 for(; src_addr_b <= (src_end_addr_b - (int)(16 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 16 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002938 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002939 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00002940 half4 a0 = vload4(0, src_addr_a);
2941 half8 b0 = vload8(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002942
2943 c00 += (half8)a0.s0 * b0;
2944 c10 += (half8)a0.s1 * b0;
2945 c20 += (half8)a0.s2 * b0;
2946 c30 += (half8)a0.s3 * b0;
2947
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002948 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00002949 a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
2950 b0 = vload8(0, src_addr_b + 8 * MULT_TRANSPOSE1XW_WIDTH);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002951
2952 c00 += (half8)a0.s0 * b0;
2953 c10 += (half8)a0.s1 * b0;
2954 c20 += (half8)a0.s2 * b0;
2955 c30 += (half8)a0.s3 * b0;
2956 }
2957
Gian Marco36a0a462018-01-12 10:21:40 +00002958 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002959 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002960 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00002961 half4 a0 = vload4(0, src_addr_a);
2962 half8 b0 = vload8(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002963
2964 c00 += (half8)a0.s0 * b0;
2965 c10 += (half8)a0.s1 * b0;
2966 c20 += (half8)a0.s2 * b0;
2967 c30 += (half8)a0.s3 * b0;
2968 }
2969
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002970 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002971 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2972
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002973#if defined(ALPHA)
2974 // Multiply by the weight of matrix product
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002975 c00 = c00 * (half8)ALPHA;
2976 c10 = c10 * (half8)ALPHA;
2977 c20 = c20 * (half8)ALPHA;
2978 c30 = c30 * (half8)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002979#endif // defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002980
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002981#if defined(ADD_VEC_C)
2982 // *INDENT-OFF*
2983 // clang-format off
2984 __global half *src2_addr = (__global half *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
2985 half8 c0 = vload8(0, src2_addr);
2986 // clang-format on
2987 // *INDENT-ON*
2988
2989 c00 += c0;
2990 c10 += c0;
2991 c20 += c0;
2992 c30 += c0;
2993#endif /* defined(ADD_VEC_C) */
2994
Gian Marcoae2af742018-02-15 12:35:44 +00002995 // Compute dst address
2996 __global uchar *dst_addr = offset(&dst, 0, 0);
2997
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002998#if defined(REINTERPRET_OUTPUT_AS_3D)
2999 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003000 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003001 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003002 // | |
3003 // | plane0 |
3004 // | |
3005 // |__________________|
3006 // |******************|
3007 // | cross_plane_pad |
3008 // |******************|
3009 // | |
3010 // | plane1 |
3011 // | |
3012 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003013
3014 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
3015 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
3016 zout = min(DEPTH_GEMM3D - 1, zout);
3017
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003018 // Add offset due to the cross plane paddings
3019 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003020
3021 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3022 // multiply dst_stride_z by DEPTH_GEMM3D
3023 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
3024
3025 // Store 4x8 block
3026 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
3027 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
3028 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
3029 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
3030
3031#else // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +00003032 // Add offset for batched GEMM
3033 dst_addr += z * dst_stride_z;
3034
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003035 // Store 4x8 block
Gian Marcoae2af742018-02-15 12:35:44 +00003036 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
3037 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
3038 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
3039 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003040#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003041}
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003042
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003043/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1) while accumulating the result in a 32 floating point variable.
3044 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
3045 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003046 * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
3047 *
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003048 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
3049 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
3050 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
3051 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
3052 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
3053 *
3054 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
3055 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
3056 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
3057 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
3058 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
3059 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003060 * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
3061 *
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003062 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
3063 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
3064 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3065 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
3066 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3067 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
3068 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
3069 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
3070 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3071 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
3072 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3073 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003074 * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
3075 * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes)
3076 * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
3077 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003078 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
3079 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
3080 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
3081 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
3082 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
3083 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
3084 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
3085 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
3086 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
3087 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
3088 */
3089__kernel void gemm_mm_interleaved_transposed_f16_acc32(IMAGE_DECLARATION(src0),
3090 IMAGE_DECLARATION(src1),
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003091#if defined(ADD_VEC_C)
3092 VECTOR_DECLARATION(src2),
3093#endif /* defined(ADD_VEC_C) */
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003094 IMAGE_DECLARATION(dst),
3095 uint src0_stride_z,
3096 uint src1_stride_z,
3097 uint dst_stride_z
3098#if defined(REINTERPRET_OUTPUT_AS_3D)
3099 ,
3100 uint cross_plane_pad
3101#endif // REINTERPRET_OUTPUT_AS_3D
3102 )
3103{
3104 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
3105 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
3106 int z = get_global_id(2);
3107
3108 // Offset
3109 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
3110 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
3111
3112 // src_addr_a = address of matrix A
3113 // src_addr_b = address of matrix B
3114 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
3115 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
3116
3117#if defined(MATRIX_B_DEPTH)
3118 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
3119 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
3120#else // defined(MATRIX_B_DEPTH)
3121 src1_addr_in_bytes += z * src1_stride_z;
3122#endif // defined(MATRIX_B_DEPTH)
3123
3124 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
3125 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
3126
3127 // Compute end row address for matrix B
3128 __global half *src_end_addr_b = src_addr_b + COLS_B;
3129
3130 src_addr_a += offset_row_a;
3131 src_addr_b += offset_row_b;
3132
3133 // Reset accumulators
3134 float8 c00 = 0.0f;
3135 float8 c10 = 0.0f;
3136 float8 c20 = 0.0f;
3137 float8 c30 = 0.0f;
3138
3139 for(; src_addr_b <= (src_end_addr_b - (int)(16 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 16 * MULT_TRANSPOSE1XW_WIDTH)
3140 {
3141 // Load values from matrix A (interleaved) and matrix B (transposed)
3142 float4 a0 = convert_float4(vload4(0, src_addr_a));
3143 float8 b0 = convert_float8(vload8(0, src_addr_b));
3144
3145 c00 += (float8)a0.s0 * b0;
3146 c10 += (float8)a0.s1 * b0;
3147 c20 += (float8)a0.s2 * b0;
3148 c30 += (float8)a0.s3 * b0;
3149
3150 // Load values from matrix A (interleaved) and matrix B (transposed)
3151 a0 = convert_float4(vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT));
3152 b0 = convert_float8(vload8(0, src_addr_b + 8 * MULT_TRANSPOSE1XW_WIDTH));
3153
3154 c00 += (float8)a0.s0 * b0;
3155 c10 += (float8)a0.s1 * b0;
3156 c20 += (float8)a0.s2 * b0;
3157 c30 += (float8)a0.s3 * b0;
3158 }
3159
3160 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
3161 {
3162 // Load values from matrix A (interleaved) and matrix B (transposed)
3163 float4 a0 = convert_float4(vload4(0, src_addr_a));
3164 float8 b0 = convert_float8(vload8(0, src_addr_b));
3165
3166 c00 += (float8)a0.s0 * b0;
3167 c10 += (float8)a0.s1 * b0;
3168 c20 += (float8)a0.s2 * b0;
3169 c30 += (float8)a0.s3 * b0;
3170 }
3171
3172 // Compute destination address
3173 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
3174
3175#if defined(ALPHA)
3176 // Multiply by the weight of matrix product
3177 c00 = c00 * (float8)ALPHA;
3178 c10 = c10 * (float8)ALPHA;
3179 c20 = c20 * (float8)ALPHA;
3180 c30 = c30 * (float8)ALPHA;
3181#endif // defined(ALPHA)
3182
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003183#if defined(ADD_VEC_C)
3184 // *INDENT-OFF*
3185 // clang-format off
3186 __global half *src2_addr = (__global half *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
3187 float8 c0 = convert_float8(vload8(0, src2_addr));
3188 // clang-format on
3189 // *INDENT-ON*
3190
3191 c00 += c0;
3192 c10 += c0;
3193 c20 += c0;
3194 c30 += c0;
3195#endif /* defined(ADD_VEC_C) */
3196
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003197 // Compute dst address
3198 __global uchar *dst_addr = offset(&dst, 0, 0);
3199
3200#if defined(REINTERPRET_OUTPUT_AS_3D)
3201 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
3202 // in order to take into account the presence of possible cross plane paddings
3203 //
3204 // | |
3205 // | plane0 |
3206 // | |
3207 // |__________________|
3208 // |******************|
3209 // | cross_plane_pad |
3210 // |******************|
3211 // | |
3212 // | plane1 |
3213 // | |
3214 // |__________________|
3215
3216 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
3217 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
3218 zout = min(DEPTH_GEMM3D - 1, zout);
3219
3220 // Add offset due to the cross plane paddings
3221 zout *= (cross_plane_pad * dst_stride_y);
3222
3223 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3224 // multiply dst_stride_z by DEPTH_GEMM3D
3225 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
3226
3227 // Store 4x8 block
3228 vstore8(convert_half8(c00), 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
3229 vstore8(convert_half8(c10), 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
3230 vstore8(convert_half8(c20), 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
3231 vstore8(convert_half8(c30), 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
3232
3233#else // defined(REINTERPRET_OUTPUT_AS_3D)
3234 // Add offset for batched GEMM
3235 dst_addr += z * dst_stride_z;
3236
3237 // Store 4x8 block
3238 vstore8(convert_half8(c00), 0, (__global half *)(dst_addr + 0 * dst_stride_y));
3239 vstore8(convert_half8(c10), 0, (__global half *)(dst_addr + 1 * dst_stride_y));
3240 vstore8(convert_half8(c20), 0, (__global half *)(dst_addr + 2 * dst_stride_y));
3241 vstore8(convert_half8(c30), 0, (__global half *)(dst_addr + 3 * dst_stride_y));
3242#endif // defined(REINTERPRET_OUTPUT_AS_3D)
3243}
3244
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003245/** This OpenCL kernel optimized for Bifrost architectures computes the matrix multiplication between matrix A (src0) and matrix B (src1)
3246 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
3247 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003248 * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
3249 *
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003250 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
3251 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
3252 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
3253 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
3254 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
3255 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003256 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
3257 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
3258 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
3259 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
3260 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
3261 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003262 * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
3263 *
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003264 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
3265 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
3266 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3267 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
3268 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3269 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
3270 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
3271 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
3272 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3273 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
3274 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3275 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003276 * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
3277 * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes)
3278 * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
3279 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003280 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
3281 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
3282 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
3283 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
3284 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
3285 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003286 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003287 */
3288__kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0),
3289 IMAGE_DECLARATION(src1),
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003290#if defined(ADD_VEC_C)
3291 VECTOR_DECLARATION(src2),
3292#endif /* defined(ADD_VEC_C) */
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003293 IMAGE_DECLARATION(dst),
3294 uint src0_stride_z,
3295 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003296 uint dst_stride_z
3297#if defined(REINTERPRET_OUTPUT_AS_3D)
3298 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003299 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003300#endif // REINTERPRET_OUTPUT_AS_3D
3301 )
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003302{
3303 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
3304 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
3305 int z = get_global_id(2);
3306
3307 // Offset
3308 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
3309 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
3310
3311 // src_addr_a = address of matrix A
3312 // src_addr_b = address of matrix B
3313 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
3314 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
3315
3316#if defined(MATRIX_B_DEPTH)
3317 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
3318 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
3319#else // defined(MATRIX_B_DEPTH)
3320 src1_addr_in_bytes += z * src1_stride_z;
3321#endif // defined(MATRIX_B_DEPTH)
3322
3323 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
3324 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
3325
3326 // Compute end row address for matrix B
3327 __global half *src_end_addr_b = src_addr_b + COLS_B;
3328
3329 src_addr_a += offset_row_a;
3330 src_addr_b += offset_row_b;
3331
3332 // Reset accumulators
3333 half8 c00 = 0.0f;
3334 half8 c10 = 0.0f;
3335 half8 c20 = 0.0f;
3336 half8 c30 = 0.0f;
3337
3338#define COLS_MTX_B (COLS_B / (8 * MULT_TRANSPOSE1XW_WIDTH))
3339
3340 int i = 0;
3341 for(; i <= (int)(COLS_MTX_B - 4); i += 4)
3342 {
3343#if MULT_INTERLEAVE4X4_HEIGHT == 1
3344 // Load values from matrix A (interleaved) and matrix B (transposed)
3345 half8 a0 = vload8(0, src_addr_a);
3346 half8 b0 = vload8(0, src_addr_b);
3347
3348 src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT;
3349 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3350
3351 c00 = fma((half8)a0.s0, b0, c00);
3352 c10 = fma((half8)a0.s1, b0, c10);
3353 c20 = fma((half8)a0.s2, b0, c20);
3354 c30 = fma((half8)a0.s3, b0, c30);
3355
3356 // Load values from matrix B (transposed)
3357 b0 = vload8(0, src_addr_b);
3358
3359 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3360
3361 c00 = fma((half8)a0.s4, b0, c00);
3362 c10 = fma((half8)a0.s5, b0, c10);
3363 c20 = fma((half8)a0.s6, b0, c20);
3364 c30 = fma((half8)a0.s7, b0, c30);
3365
3366 // Load values from matrix A (interleaved) and matrix B (transposed)
3367 a0 = vload8(0, src_addr_a);
3368 b0 = vload8(0, src_addr_b);
3369
3370 src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT;
3371 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3372
3373 c00 = fma((half8)a0.s0, b0, c00);
3374 c10 = fma((half8)a0.s1, b0, c10);
3375 c20 = fma((half8)a0.s2, b0, c20);
3376 c30 = fma((half8)a0.s3, b0, c30);
3377
3378 // Load values from matrix B (transposed)
3379 b0 = vload8(0, src_addr_b);
3380
3381 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3382
3383 c00 = fma((half8)a0.s4, b0, c00);
3384 c10 = fma((half8)a0.s5, b0, c10);
3385 c20 = fma((half8)a0.s6, b0, c20);
3386 c30 = fma((half8)a0.s7, b0, c30);
3387#else // MULT_INTERLEAVE4X4_HEIGHT == 1
3388 // Load values from matrix A (interleaved) and matrix B (transposed)
3389 half4 a0 = vload4(0, src_addr_a);
3390 half8 b0 = vload8(0, src_addr_b);
3391
3392 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3393 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3394
3395 c00 = fma((half8)a0.s0, b0, c00);
3396 c10 = fma((half8)a0.s1, b0, c10);
3397 c20 = fma((half8)a0.s2, b0, c20);
3398 c30 = fma((half8)a0.s3, b0, c30);
3399
3400 // Load values from matrix A (interleaved) and matrix B (transposed)
3401 a0 = vload4(0, src_addr_a);
3402 b0 = vload8(0, src_addr_b);
3403
3404 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3405 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3406
3407 c00 = fma((half8)a0.s0, b0, c00);
3408 c10 = fma((half8)a0.s1, b0, c10);
3409 c20 = fma((half8)a0.s2, b0, c20);
3410 c30 = fma((half8)a0.s3, b0, c30);
3411
3412 // Load values from matrix A (interleaved) and matrix B (transposed)
3413 a0 = vload4(0, src_addr_a);
3414 b0 = vload8(0, src_addr_b);
3415
3416 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3417 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3418
3419 c00 = fma((half8)a0.s0, b0, c00);
3420 c10 = fma((half8)a0.s1, b0, c10);
3421 c20 = fma((half8)a0.s2, b0, c20);
3422 c30 = fma((half8)a0.s3, b0, c30);
3423
3424 // Load values from matrix A (interleaved) and matrix B (transposed)
3425 a0 = vload4(0, src_addr_a);
3426 b0 = vload8(0, src_addr_b);
3427
3428 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3429 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3430
3431 c00 = fma((half8)a0.s0, b0, c00);
3432 c10 = fma((half8)a0.s1, b0, c10);
3433 c20 = fma((half8)a0.s2, b0, c20);
3434 c30 = fma((half8)a0.s3, b0, c30);
3435#endif // MULT_INTERLEAVE4X4_HEIGHT == 1
3436 }
3437
3438 for(; i < (int)(COLS_MTX_B); ++i)
3439 {
3440 // Load values from matrix A (interleaved) and matrix B (transposed)
3441 half4 a0 = vload4(0, src_addr_a);
3442 half8 b0 = vload8(0, src_addr_b);
3443
3444 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3445 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3446
3447 c00 = fma((half8)a0.s0, b0, c00);
3448 c10 = fma((half8)a0.s1, b0, c10);
3449 c20 = fma((half8)a0.s2, b0, c20);
3450 c30 = fma((half8)a0.s3, b0, c30);
3451 }
3452
3453 // Compute destination address
3454 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
3455
3456#if defined(ALPHA)
3457 // Multiply by the weight of matrix product
3458 c00 = c00 * (half8)ALPHA;
3459 c10 = c10 * (half8)ALPHA;
3460 c20 = c20 * (half8)ALPHA;
3461 c30 = c30 * (half8)ALPHA;
3462#endif // defined(ALPHA)
3463
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003464#if defined(ADD_VEC_C)
3465 // *INDENT-OFF*
3466 // clang-format off
3467 __global half *src2_addr = (__global half *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
3468 half8 c0 = vload8(0, src2_addr);
3469 // clang-format on
3470 // *INDENT-ON*
3471
3472 c00 += c0;
3473 c10 += c0;
3474 c20 += c0;
3475 c30 += c0;
3476#endif /* defined(ADD_VEC_C) */
3477
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003478 // Compute dst address
3479 __global uchar *dst_addr = offset(&dst, 0, 0);
3480
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003481#if defined(REINTERPRET_OUTPUT_AS_3D)
3482 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003483 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003484 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003485 // | |
3486 // | plane0 |
3487 // | |
3488 // |__________________|
3489 // |******************|
3490 // | cross_plane_pad |
3491 // |******************|
3492 // | |
3493 // | plane1 |
3494 // | |
3495 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003496
3497 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
3498 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
3499 zout = min(DEPTH_GEMM3D - 1, zout);
3500
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003501 // Add offset due to the cross plane paddings
3502 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003503
3504 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3505 // multiply dst_stride_z by DEPTH_GEMM3D
3506 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
3507
3508 // Store 4x8 block
3509 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
3510 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
3511 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
3512 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
3513
3514#else // defined(REINTERPRET_OUTPUT_AS_3D)
3515 // Add offset for batched GEMM
3516 dst_addr += z * dst_stride_z;
3517
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003518 // Store 4x8 block
3519 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
3520 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
3521 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
3522 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003523#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003524}
Georgios Pinitas84225582018-05-14 12:00:05 +01003525
3526// Undefine local defines
3527#undef COLS_MTX_B
3528
Matthew Bentham6f31f8c2017-10-27 11:50:06 +01003529#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003530
Gian Marco36a0a462018-01-12 10:21:40 +00003531#endif // defined(COLS_B) && defined(MULT_TRANSPOSE1XW_WIDTH) && defined(MULT_INTERLEAVE4X4_HEIGHT)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003532
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003533#if defined(COLS_A) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && (NUM_ELEMS_PROCESSED_PER_THREAD_Y)
3534#if defined(DATA_TYPE)
3535#define VECTOR_TYPE VEC_DATA_TYPE(DATA_TYPE, NUM_ELEMS_PROCESSED_PER_THREAD_X)
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003536/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped.
3537 *
3538 * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003539 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003540 * @note This OpenCL kernel works with floating point data types (F16/F32)
3541 * @note The floating point data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
3542 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003543 * @note The number of matrix A columns and the optional alpha's value need to be passed at compile time using -DCOLS_A and -DALPHA
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003544 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
3545 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003546 *
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003547 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
3548 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003549 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
3550 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
3551 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
3552 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
3553 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003554 * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
3555 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003556 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003557 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
3558 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3559 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
3560 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3561 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003562 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003563 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
3564 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3565 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
3566 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3567 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003568 * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
3569 * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes)
3570 * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
3571 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003572 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003573 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
3574 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
3575 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
3576 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
3577 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003578 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
3579 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
3580 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003581 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
3582 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements for the output tensor (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003583 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003584__kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
3585 IMAGE_DECLARATION(src1),
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003586#if defined(ADD_VEC_C)
3587 VECTOR_DECLARATION(src2),
3588#endif /* defined(ADD_VEC_C) */
Gian Marcoae2af742018-02-15 12:35:44 +00003589 IMAGE_DECLARATION(dst),
3590 uint src0_stride_z,
3591 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003592 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003593#if defined(REINTERPRET_INPUT_AS_3D)
3594 ,
3595 uint src_cross_plane_pad
3596#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003597#if defined(REINTERPRET_OUTPUT_AS_3D)
3598 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003599 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003600#endif // REINTERPRET_OUTPUT_AS_3D
3601 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003602{
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003603 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003604
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003605 // Compute starting address for matrix A and Matrix B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003606 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003607
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003608 // Update address for the matrix A
3609 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003610
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003611 // Update address for the matrix B
3612 src_addr.s1 += idx * sizeof(DATA_TYPE);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003613
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003614#if defined(REINTERPRET_INPUT_AS_3D)
3615 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
3616 // in order to take into account the presence of possible cross plane paddings
3617 //
3618 // | |
3619 // | plane0 |
3620 // | |
3621 // |__________________|
3622 // |******************|
3623 // | cross_plane_pad |
3624 // |******************|
3625 // | |
3626 // | plane1 |
3627 // | |
3628 // |__________________|
3629
3630 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
3631 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
3632 zin = min(DEPTH_GEMM3D - 1, zin);
3633
3634 // Add offset due to the cross plane paddings
3635 zin *= (src_cross_plane_pad * src0_stride_y);
3636
3637 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3638 // multiply src0_stride_z by DEPTH_GEMM3D
3639 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
3640
3641#else // defined(REINTERPRET_INPUT_AS_3D)
3642
Gian Marcoae2af742018-02-15 12:35:44 +00003643 // Add offset for batched GEMM
3644 src_addr.s0 += get_global_id(2) * src0_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003645
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003646#endif // defined(REINTERPRET_INPUT_AS_3D)
3647
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003648#if defined(MATRIX_B_DEPTH)
3649 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
3650 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
3651#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00003652 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003653#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00003654
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003655 int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(DATA_TYPE));
3656
3657 VECTOR_TYPE acc0 = 0.0f;
3658#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3659 VECTOR_TYPE acc1 = 0.0f;
3660#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3661#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3662 VECTOR_TYPE acc2 = 0.0f;
3663#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3664#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3665 VECTOR_TYPE acc3 = 0.0f;
3666#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3667
Georgios Pinitas96880cf2017-10-20 18:52:20 +01003668 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(DATA_TYPE)); src_addr += (int2)(2 * sizeof(DATA_TYPE), 2 * src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003669 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003670#if defined(REINTERPRET_INPUT_AS_3D)
3671 // Load values from matrix A
Usama Arif0681e3b2019-04-25 14:28:07 +01003672 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 2, DATA_TYPE, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s);
3673#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003674 // Load values from matrix A
3675 VEC_DATA_TYPE(DATA_TYPE, 2)
3676 a0 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
3677#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3678 VEC_DATA_TYPE(DATA_TYPE, 2)
3679 a1 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
3680#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3681#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3682 VEC_DATA_TYPE(DATA_TYPE, 2)
3683 a2 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
3684#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3685#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3686 VEC_DATA_TYPE(DATA_TYPE, 2)
3687 a3 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
3688#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003689#endif // defined(REINTERPRET_INPUT_AS_3D)
3690
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003691 // Load values from matrix B
3692 VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
3693 VECTOR_TYPE b1 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1 + src1_stride_y));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003694
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003695 // Accumulate
3696 acc0 += b0 * (VECTOR_TYPE)a0.s0;
3697 acc0 += b1 * (VECTOR_TYPE)a0.s1;
3698#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3699 acc1 += b0 * (VECTOR_TYPE)a1.s0;
3700 acc1 += b1 * (VECTOR_TYPE)a1.s1;
3701#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3702#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3703 acc2 += b0 * (VECTOR_TYPE)a2.s0;
3704 acc2 += b1 * (VECTOR_TYPE)a2.s1;
3705#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3706#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3707 acc3 += b0 * (VECTOR_TYPE)a3.s0;
3708 acc3 += b1 * (VECTOR_TYPE)a3.s1;
3709#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003710 }
3711
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003712 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(DATA_TYPE), src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003713 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003714#if defined(REINTERPRET_INPUT_AS_3D)
3715 // Load values from matrix A
3716 DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
3717#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3718 DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
3719#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3720#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3721 DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
3722#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3723#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3724 DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
3725#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3726#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003727 // Load values from matrix A
3728 DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
3729#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3730 DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
3731#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3732#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3733 DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
3734#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3735#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3736 DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
3737#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003738#endif // defined(REINTERPRET_INPUT_AS_3D)
3739
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003740 // Load values from matrix B
3741 VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003742
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003743 // Accumulate
3744 acc0 += b0 * (VECTOR_TYPE)a0;
3745#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3746 acc1 += b0 * (VECTOR_TYPE)a1;
3747#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3748#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3749 acc2 += b0 * (VECTOR_TYPE)a2;
3750#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3751#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3752 acc3 += b0 * (VECTOR_TYPE)a3;
3753#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003754 }
3755
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003756 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003757 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
3758
Gian Marcoae2af742018-02-15 12:35:44 +00003759 // Compute dst address
3760 __global uchar *dst_addr = offset(&dst, 0, 0);
3761
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003762 // Multiply by the weight of matrix-matrix product and store the result
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003763#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003764 acc0 = acc0 * (VECTOR_TYPE)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003765#endif // defined(ALPHA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003766#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
3767 acc1 = acc1 * (VECTOR_TYPE)ALPHA;
3768#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
3769#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
3770 acc2 = acc2 * (VECTOR_TYPE)ALPHA;
3771#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
3772#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
3773 acc3 = acc3 * (VECTOR_TYPE)ALPHA;
3774#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
3775
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003776#if defined(ADD_VEC_C)
3777 // *INDENT-OFF*
3778 // clang-format off
3779 __global DATA_TYPE *src2_addr = (__global DATA_TYPE *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
3780 VECTOR_TYPE c0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, src2_addr);
3781 // clang-format on
3782 // *INDENT-ON*
3783
3784 acc0 += c0;
3785#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3786 acc1 += c0;
3787#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3788#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3789 acc2 += c0;
3790#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3791#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3792 acc3 += c0;
3793#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3794#endif /* defined(ADD_VEC_C) */
3795
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003796 int z = get_global_id(2);
3797
3798#if defined(REINTERPRET_OUTPUT_AS_3D)
3799 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003800 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003801 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003802 // | |
3803 // | plane0 |
3804 // | |
3805 // |__________________|
3806 // |******************|
3807 // | cross_plane_pad |
3808 // |******************|
3809 // | |
3810 // | plane1 |
3811 // | |
3812 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003813
3814 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
3815 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
3816 zout = min(DEPTH_GEMM3D - 1, zout);
3817
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003818 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003819 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003820
3821 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3822 // multiply dst_stride_z by DEPTH_GEMM3D
3823 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
3824
3825 // Store output block
Usama Arif0681e3b2019-04-25 14:28:07 +01003826 STORE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, NUM_ELEMS_PROCESSED_PER_THREAD_X, DATA_TYPE, acc, dst_addr, dst_stride_y, zout.s);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003827#else // defined(REINTERPRET_OUTPUT_AS_3D)
3828 // Add offset for batched GEMM
3829 dst_addr += z * dst_stride_z;
3830
3831 // Store output block
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003832 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00003833 (acc0, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003834#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003835 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00003836 (acc1, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003837#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3838#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003839 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00003840 (acc2, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003841#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3842#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003843 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00003844 (acc3, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003845#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003846#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003847}
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003848#endif // defined(DATA_TYPE)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003849
Michele Di Giorgiof6f08da2018-04-26 10:24:30 +01003850/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003851 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003852 * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
3853 *
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003854 * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
3855 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
3856 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
3857 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
3858 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003859 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
3860 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003861 *
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003862 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
3863 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003864 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
3865 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
3866 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
3867 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
3868 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003869 * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
3870 *
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003871 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
3872 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
3873 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3874 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
3875 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3876 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
3877 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
3878 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
3879 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3880 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
3881 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3882 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003883 * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
3884 * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes)
3885 * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
3886 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003887 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
3888 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
3889 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
3890 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
3891 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
3892 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003893 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
3894 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
3895 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003896 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
3897 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003898 */
3899__kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
3900 IMAGE_DECLARATION(src1),
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003901#if defined(ADD_VEC_C)
3902 VECTOR_DECLARATION(src2),
3903#endif /* defined(ADD_VEC_C) */
Gian Marcoae2af742018-02-15 12:35:44 +00003904 IMAGE_DECLARATION(dst),
3905 uint src0_stride_z,
3906 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003907 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003908#if defined(REINTERPRET_INPUT_AS_3D)
3909 ,
3910 uint src_cross_plane_pad
3911#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003912#if defined(REINTERPRET_OUTPUT_AS_3D)
3913 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003914 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003915#endif // REINTERPRET_OUTPUT_AS_3D
3916 )
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003917{
3918 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
3919
3920 // Compute starting address for matrix A and matrix B
3921 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
3922
3923 // Update address for matrix A
3924 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
3925
3926 // Update address for matrix B
3927 src_addr.s1 += idx * sizeof(float);
3928
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003929#if defined(REINTERPRET_INPUT_AS_3D)
3930 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
3931 // in order to take into account the presence of possible cross plane paddings
3932 //
3933 // | |
3934 // | plane0 |
3935 // | |
3936 // |__________________|
3937 // |******************|
3938 // | cross_plane_pad |
3939 // |******************|
3940 // | |
3941 // | plane1 |
3942 // | |
3943 // |__________________|
3944
3945 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
3946 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
3947 zin = min(DEPTH_GEMM3D - 1, zin);
3948
3949 // Add offset due to the cross plane paddings
3950 zin *= (src_cross_plane_pad * src0_stride_y);
3951
3952 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3953 // multiply src0_stride_z by DEPTH_GEMM3D
3954 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
3955
3956#else // defined(REINTERPRET_INPUT_AS_3D)
3957
Gian Marcoae2af742018-02-15 12:35:44 +00003958 // Add offset for batched GEMM
3959 src_addr.s0 += get_global_id(2) * src0_stride_z;
3960
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003961#endif // defined(REINTERPRET_INPUT_AS_3D)
3962
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003963#if defined(MATRIX_B_DEPTH)
3964 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
3965 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
3966#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00003967 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003968#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00003969
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003970 // Initialize accumulators
3971 float acc00 = 0.0f;
3972 float acc01 = 0.0f;
3973 float acc02 = 0.0f;
3974 float acc03 = 0.0f;
3975
3976#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3977 float acc10 = 0.0f;
3978 float acc11 = 0.0f;
3979 float acc12 = 0.0f;
3980 float acc13 = 0.0f;
3981#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3982
3983#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3984 float acc20 = 0.0f;
3985 float acc21 = 0.0f;
3986 float acc22 = 0.0f;
3987 float acc23 = 0.0f;
3988#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3989
3990#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3991 float acc30 = 0.0f;
3992 float acc31 = 0.0f;
3993 float acc32 = 0.0f;
3994 float acc33 = 0.0f;
3995#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3996
3997 // A and B src indices get incremented at the same time.
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003998 int i = 0;
3999 for(; i <= ((int)COLS_A - 4); i += 4)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004000 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004001#if defined(REINTERPRET_INPUT_AS_3D)
4002 // Load values from matrix A and matrix B
Usama Arif0681e3b2019-04-25 14:28:07 +01004003 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 4, float, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s);
4004#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004005 // Load values from matrix A and matrix B
4006 float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004007#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004008 float4 a1 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004009#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4010#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004011 float4 a2 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004012#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4013#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004014 float4 a3 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004015#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004016#endif // defined(REINTERPRET_INPUT_AS_3D)
4017
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004018 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
4019 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004020
4021 // Multiply and accumulate
4022 acc00 = fma(a0.s0, b0.s0, acc00);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004023 acc01 = fma(a0.s0, b0.s1, acc01);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004024 acc02 = fma(a0.s0, b0.s2, acc02);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004025 acc03 = fma(a0.s0, b0.s3, acc03);
4026
4027#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004028
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004029 acc10 = fma(a1.s0, b0.s0, acc10);
4030 acc11 = fma(a1.s0, b0.s1, acc11);
4031 acc12 = fma(a1.s0, b0.s2, acc12);
4032 acc13 = fma(a1.s0, b0.s3, acc13);
4033
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004034#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4035#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004036
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004037 acc20 = fma(a2.s0, b0.s0, acc20);
4038 acc21 = fma(a2.s0, b0.s1, acc21);
4039 acc22 = fma(a2.s0, b0.s2, acc22);
4040 acc23 = fma(a2.s0, b0.s3, acc23);
4041
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004042#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4043#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004044
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004045 acc30 = fma(a3.s0, b0.s0, acc30);
4046 acc31 = fma(a3.s0, b0.s1, acc31);
4047 acc32 = fma(a3.s0, b0.s2, acc32);
4048 acc33 = fma(a3.s0, b0.s3, acc33);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004049#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004050
4051 // Load values from matrix A and matrix B
4052 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
4053 src_addr.s1 += src1_stride_y;
4054
4055 // Multiply and accumulate
4056 acc00 = fma(a0.s1, b0.s0, acc00);
4057 acc01 = fma(a0.s1, b0.s1, acc01);
4058 acc02 = fma(a0.s1, b0.s2, acc02);
4059 acc03 = fma(a0.s1, b0.s3, acc03);
4060
4061#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4062
4063 acc10 = fma(a1.s1, b0.s0, acc10);
4064 acc11 = fma(a1.s1, b0.s1, acc11);
4065 acc12 = fma(a1.s1, b0.s2, acc12);
4066 acc13 = fma(a1.s1, b0.s3, acc13);
4067
4068#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4069#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4070
4071 acc20 = fma(a2.s1, b0.s0, acc20);
4072 acc21 = fma(a2.s1, b0.s1, acc21);
4073 acc22 = fma(a2.s1, b0.s2, acc22);
4074 acc23 = fma(a2.s1, b0.s3, acc23);
4075
4076#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4077#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4078
4079 acc30 = fma(a3.s1, b0.s0, acc30);
4080 acc31 = fma(a3.s1, b0.s1, acc31);
4081 acc32 = fma(a3.s1, b0.s2, acc32);
4082 acc33 = fma(a3.s1, b0.s3, acc33);
4083#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4084
4085 // Load values from matrix A and matrix B
4086 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
4087 src_addr.s1 += src1_stride_y;
4088
4089 // Multiply and accumulate
4090 acc00 = fma(a0.s2, b0.s0, acc00);
4091 acc01 = fma(a0.s2, b0.s1, acc01);
4092 acc02 = fma(a0.s2, b0.s2, acc02);
4093 acc03 = fma(a0.s2, b0.s3, acc03);
4094
4095#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4096
4097 acc10 = fma(a1.s2, b0.s0, acc10);
4098 acc11 = fma(a1.s2, b0.s1, acc11);
4099 acc12 = fma(a1.s2, b0.s2, acc12);
4100 acc13 = fma(a1.s2, b0.s3, acc13);
4101
4102#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4103#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4104
4105 acc20 = fma(a2.s2, b0.s0, acc20);
4106 acc21 = fma(a2.s2, b0.s1, acc21);
4107 acc22 = fma(a2.s2, b0.s2, acc22);
4108 acc23 = fma(a2.s2, b0.s3, acc23);
4109
4110#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4111#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4112
4113 acc30 = fma(a3.s2, b0.s0, acc30);
4114 acc31 = fma(a3.s2, b0.s1, acc31);
4115 acc32 = fma(a3.s2, b0.s2, acc32);
4116 acc33 = fma(a3.s2, b0.s3, acc33);
4117#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4118
4119 // Load values from matrix A and matrix B
4120 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
4121 src_addr.s1 += src1_stride_y;
4122
4123 // Multiply and accumulate
4124 acc00 = fma(a0.s3, b0.s0, acc00);
4125 acc01 = fma(a0.s3, b0.s1, acc01);
4126 acc02 = fma(a0.s3, b0.s2, acc02);
4127 acc03 = fma(a0.s3, b0.s3, acc03);
4128
4129#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4130
4131 acc10 = fma(a1.s3, b0.s0, acc10);
4132 acc11 = fma(a1.s3, b0.s1, acc11);
4133 acc12 = fma(a1.s3, b0.s2, acc12);
4134 acc13 = fma(a1.s3, b0.s3, acc13);
4135
4136#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4137#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4138
4139 acc20 = fma(a2.s3, b0.s0, acc20);
4140 acc21 = fma(a2.s3, b0.s1, acc21);
4141 acc22 = fma(a2.s3, b0.s2, acc22);
4142 acc23 = fma(a2.s3, b0.s3, acc23);
4143
4144#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4145#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4146
4147 acc30 = fma(a3.s3, b0.s0, acc30);
4148 acc31 = fma(a3.s3, b0.s1, acc31);
4149 acc32 = fma(a3.s3, b0.s2, acc32);
4150 acc33 = fma(a3.s3, b0.s3, acc33);
4151#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4152
4153 src_addr.s0 += 4 * sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004154 }
4155
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004156 for(; i < (int)COLS_A; ++i)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004157 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004158#if defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004159 // Load values from matrix A
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004160 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
4161#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4162 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
4163#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4164#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4165 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
4166#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4167#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4168 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
4169#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4170#else // defined(REINTERPRET_INPUT_AS_3D)
4171 // Load values from matrix A
4172 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004173#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4174 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
4175#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4176#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4177 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
4178#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4179#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4180 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
4181#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004182#endif // defined(REINTERPRET_INPUT_AS_3D)
4183
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004184 // Load values from matrix B
4185 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004186 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004187
4188 // Multiply and accumulate
4189 acc00 = fma(a0, b0.s0, acc00);
4190 acc01 = fma(a0, b0.s1, acc01);
4191 acc02 = fma(a0, b0.s2, acc02);
4192 acc03 = fma(a0, b0.s3, acc03);
4193#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4194 acc10 = fma(a1, b0.s0, acc10);
4195 acc11 = fma(a1, b0.s1, acc11);
4196 acc12 = fma(a1, b0.s2, acc12);
4197 acc13 = fma(a1, b0.s3, acc13);
4198#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4199#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4200 acc20 = fma(a2, b0.s0, acc20);
4201 acc21 = fma(a2, b0.s1, acc21);
4202 acc22 = fma(a2, b0.s2, acc22);
4203 acc23 = fma(a2, b0.s3, acc23);
4204#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4205#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4206 acc30 = fma(a3, b0.s0, acc30);
4207 acc31 = fma(a3, b0.s1, acc31);
4208 acc32 = fma(a3, b0.s2, acc32);
4209 acc33 = fma(a3, b0.s3, acc33);
4210#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004211
4212 src_addr.s0 += sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004213 }
4214
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004215 int z = get_global_id(2);
4216
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004217 // Compute destination address
4218 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
4219
4220 // Multiply by the weight of matrix-matrix product and store the result
4221#if defined(ALPHA)
4222 acc00 = acc00 * ALPHA;
4223 acc01 = acc01 * ALPHA;
4224 acc02 = acc02 * ALPHA;
4225 acc03 = acc03 * ALPHA;
4226#endif // defined(ALPHA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004227#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004228 acc10 = acc10 * ALPHA;
4229 acc11 = acc11 * ALPHA;
4230 acc12 = acc12 * ALPHA;
4231 acc13 = acc13 * ALPHA;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004232#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
4233#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004234 acc20 = acc20 * ALPHA;
4235 acc21 = acc21 * ALPHA;
4236 acc22 = acc22 * ALPHA;
4237 acc23 = acc23 * ALPHA;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004238#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
4239#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004240 acc30 = acc30 * ALPHA;
4241 acc31 = acc31 * ALPHA;
4242 acc32 = acc32 * ALPHA;
4243 acc33 = acc33 * ALPHA;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004244#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
4245
4246 // Compute dst address
4247 __global uchar *dst_addr = offset(&dst, 0, 0);
4248
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004249#if defined(ADD_VEC_C)
4250 __global float *src2_addr = (__global float *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
4251 float4 c0 = vload4(0, src2_addr);
4252
4253 acc00 += c0.s0;
4254 acc01 += c0.s1;
4255 acc02 += c0.s2;
4256 acc03 += c0.s3;
4257#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4258 acc10 += c0.s0;
4259 acc11 += c0.s1;
4260 acc12 += c0.s2;
4261 acc13 += c0.s3;
4262#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4263#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4264 acc20 += c0.s0;
4265 acc21 += c0.s1;
4266 acc22 += c0.s2;
4267 acc23 += c0.s3;
4268#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4269#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4270 acc30 += c0.s0;
4271 acc31 += c0.s1;
4272 acc32 += c0.s2;
4273 acc33 += c0.s3;
4274#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4275#endif /* defined(ADD_VEC_C) */
4276
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004277#if defined(REINTERPRET_OUTPUT_AS_3D)
4278 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004279 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004280 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004281 // | |
4282 // | plane0 |
4283 // | |
4284 // |__________________|
4285 // |******************|
4286 // | cross_plane_pad |
4287 // |******************|
4288 // | |
4289 // | plane1 |
4290 // | |
4291 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004292
4293 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
4294 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
4295 zout = min(DEPTH_GEMM3D - 1, zout);
4296
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004297 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004298 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004299
4300 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4301 // multiply dst_stride_z by DEPTH_GEMM3D
4302 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
4303
4304 // Store the output block
4305 vstore4((float4)(acc00, acc01, acc02, acc03), 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
4306#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4307 vstore4((float4)(acc10, acc11, acc12, acc13), 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
4308#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4309#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4310 vstore4((float4)(acc20, acc21, acc22, acc23), 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
4311#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4312#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4313 vstore4((float4)(acc30, acc31, acc32, acc33), 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004314#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004315
4316#else // defined(REINTERPRET_OUTPUT_AS_3D)
4317 // Add offset for batched GEMM
4318 dst_addr += z * dst_stride_z;
4319
4320 // Store the output block
4321 vstore4((float4)(acc00, acc01, acc02, acc03), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
4322#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4323 vstore4((float4)(acc10, acc11, acc12, acc13), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
4324#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4325#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4326 vstore4((float4)(acc20, acc21, acc22, acc23), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
4327#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4328#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4329 vstore4((float4)(acc30, acc31, acc32, acc33), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
4330#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4331#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004332}
4333
4334/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
4335 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004336 * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
4337 *
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004338 * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
4339 * This OpenCL kernel is optimized for Bifrost when the number of matrix B columns is less or equal to 1000.
4340 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
4341 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=2.
4342 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
4343 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha if alpha!=1.0f.
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00004344 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
4345 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004346 *
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004347 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
4348 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004349 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
4350 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
4351 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
4352 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
4353 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004354 * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
4355 *
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004356 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
4357 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
4358 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4359 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
4360 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4361 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
4362 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
4363 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
4364 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4365 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
4366 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4367 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004368 * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
4369 * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes)
4370 * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
4371 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004372 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
4373 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
4374 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
4375 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
4376 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
4377 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004378 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
4379 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
4380 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004381 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
4382 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004383 */
4384__kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
4385 IMAGE_DECLARATION(src1),
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004386#if defined(ADD_VEC_C)
4387 VECTOR_DECLARATION(src2),
4388#endif /* defined(ADD_VEC_C) */
Gian Marcoae2af742018-02-15 12:35:44 +00004389 IMAGE_DECLARATION(dst),
4390 uint src0_stride_z,
4391 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004392 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004393#if defined(REINTERPRET_INPUT_AS_3D)
4394 ,
4395 uint src_cross_plane_pad
4396#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004397#if defined(REINTERPRET_OUTPUT_AS_3D)
4398 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004399 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004400#endif // REINTERPRET_OUTPUT_AS_3D
4401 )
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004402{
4403 // Requires 2 NUM_ELEMS_PROCESSED_PER_THREAD_X, C vect2, A vect4, B (2 vload2) // to fix for NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4404 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
4405
4406 // Compute starting address for matrix A and Matrix B
4407 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
4408
4409 // Update address for the matrix A
4410 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
4411
4412 // Update address for the matrix B
4413 src_addr.s1 += idx * sizeof(float);
4414
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004415#if defined(REINTERPRET_INPUT_AS_3D)
4416 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
4417 // in order to take into account the presence of possible cross plane paddings
4418 //
4419 // | |
4420 // | plane0 |
4421 // | |
4422 // |__________________|
4423 // |******************|
4424 // | cross_plane_pad |
4425 // |******************|
4426 // | |
4427 // | plane1 |
4428 // | |
4429 // |__________________|
4430
4431 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
4432 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
4433 zin = min(DEPTH_GEMM3D - 1, zin);
4434
4435 // Add offset due to the cross plane paddings
4436 zin *= (src_cross_plane_pad * src0_stride_y);
4437
4438 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4439 // multiply src0_stride_z by DEPTH_GEMM3D
4440 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
4441
4442#else // defined(REINTERPRET_INPUT_AS_3D)
4443
Gian Marcoae2af742018-02-15 12:35:44 +00004444 // Add offset for batched GEMM
4445 src_addr.s0 += get_global_id(2) * src0_stride_z;
4446
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004447#endif // defined(REINTERPRET_INPUT_AS_3D)
4448
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00004449#if defined(MATRIX_B_DEPTH)
4450 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
4451 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
4452#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00004453 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00004454#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00004455
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004456 // Initialize accumulators
4457 float acc00 = 0.0f;
4458 float acc01 = 0.0f;
4459
4460#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4461 float acc10 = 0.0f;
4462 float acc11 = 0.0f;
4463#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4464#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4465 float acc20 = 0.0f;
4466 float acc21 = 0.0f;
4467#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4468#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4469 float acc30 = 0.0f;
4470 float acc31 = 0.0f;
4471#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4472
4473 // A and B src indices get incremented at the same time.
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004474 int i = 0;
4475 for(; i <= ((int)COLS_A - 8); i += 8)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004476 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004477#if defined(REINTERPRET_INPUT_AS_3D)
4478 // Load values from matrix A
4479 float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + zin.s0));
4480#else // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004481 // Load values from matrix A
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004482 float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0));
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004483#endif // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004484
4485 // Load values from matrix B
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004486 float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
4487 src_addr.s1 += src1_stride_y;
4488 float2 b1 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
4489 src_addr.s1 += src1_stride_y;
4490 float2 b2 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
4491 src_addr.s1 += src1_stride_y;
4492 float2 b3 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
4493 src_addr.s1 += src1_stride_y;
4494 float2 b4 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
4495 src_addr.s1 += src1_stride_y;
4496 float2 b5 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
4497 src_addr.s1 += src1_stride_y;
4498 float2 b6 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
4499 src_addr.s1 += src1_stride_y;
4500 float2 b7 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
4501 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004502
4503 // Multiply and accumulate
4504 acc00 = fma(a0.s0, b0.s0, acc00);
4505 acc00 = fma(a0.s1, b1.s0, acc00);
4506 acc00 = fma(a0.s2, b2.s0, acc00);
4507 acc00 = fma(a0.s3, b3.s0, acc00);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004508 acc00 = fma(a0.s4, b4.s0, acc00);
4509 acc00 = fma(a0.s5, b5.s0, acc00);
4510 acc00 = fma(a0.s6, b6.s0, acc00);
4511 acc00 = fma(a0.s7, b7.s0, acc00);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004512
4513 acc01 = fma(a0.s0, b0.s1, acc01);
4514 acc01 = fma(a0.s1, b1.s1, acc01);
4515 acc01 = fma(a0.s2, b2.s1, acc01);
4516 acc01 = fma(a0.s3, b3.s1, acc01);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004517 acc01 = fma(a0.s4, b4.s1, acc01);
4518 acc01 = fma(a0.s5, b5.s1, acc01);
4519 acc01 = fma(a0.s6, b6.s1, acc01);
4520 acc01 = fma(a0.s7, b7.s1, acc01);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004521
4522#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004523#if defined(REINTERPRET_INPUT_AS_3D)
4524 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
4525#else // defined(REINTERPRET_INPUT_AS_3D)
4526 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
4527#endif // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004528 acc10 = fma(a0.s0, b0.s0, acc10);
4529 acc10 = fma(a0.s1, b1.s0, acc10);
4530 acc10 = fma(a0.s2, b2.s0, acc10);
4531 acc10 = fma(a0.s3, b3.s0, acc10);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004532 acc10 = fma(a0.s4, b4.s0, acc10);
4533 acc10 = fma(a0.s5, b5.s0, acc10);
4534 acc10 = fma(a0.s6, b6.s0, acc10);
4535 acc10 = fma(a0.s7, b7.s0, acc10);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004536
4537 acc11 = fma(a0.s0, b0.s1, acc11);
4538 acc11 = fma(a0.s1, b1.s1, acc11);
4539 acc11 = fma(a0.s2, b2.s1, acc11);
4540 acc11 = fma(a0.s3, b3.s1, acc11);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004541 acc11 = fma(a0.s4, b4.s1, acc11);
4542 acc11 = fma(a0.s5, b5.s1, acc11);
4543 acc11 = fma(a0.s6, b6.s1, acc11);
4544 acc11 = fma(a0.s7, b7.s1, acc11);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004545#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4546#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004547#if defined(REINTERPRET_INPUT_AS_3D)
4548 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
4549#else // defined(REINTERPRET_INPUT_AS_3D)
4550 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
4551#endif // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004552 acc20 = fma(a0.s0, b0.s0, acc20);
4553 acc20 = fma(a0.s1, b1.s0, acc20);
4554 acc20 = fma(a0.s2, b2.s0, acc20);
4555 acc20 = fma(a0.s3, b3.s0, acc20);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004556 acc20 = fma(a0.s4, b4.s0, acc20);
4557 acc20 = fma(a0.s5, b5.s0, acc20);
4558 acc20 = fma(a0.s6, b6.s0, acc20);
4559 acc20 = fma(a0.s7, b7.s0, acc20);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004560
4561 acc21 = fma(a0.s0, b0.s1, acc21);
4562 acc21 = fma(a0.s1, b1.s1, acc21);
4563 acc21 = fma(a0.s2, b2.s1, acc21);
4564 acc21 = fma(a0.s3, b3.s1, acc21);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004565 acc21 = fma(a0.s4, b4.s1, acc21);
4566 acc21 = fma(a0.s5, b5.s1, acc21);
4567 acc21 = fma(a0.s6, b6.s1, acc21);
4568 acc21 = fma(a0.s7, b7.s1, acc21);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004569#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4570#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004571#if defined(REINTERPRET_INPUT_AS_3D)
4572 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
4573#else // defined(REINTERPRET_INPUT_AS_3D)
4574 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
4575#endif // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004576 acc30 = fma(a0.s0, b0.s0, acc30);
4577 acc30 = fma(a0.s1, b1.s0, acc30);
4578 acc30 = fma(a0.s2, b2.s0, acc30);
4579 acc30 = fma(a0.s3, b3.s0, acc30);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004580 acc30 = fma(a0.s4, b4.s0, acc30);
4581 acc30 = fma(a0.s5, b5.s0, acc30);
4582 acc30 = fma(a0.s6, b6.s0, acc30);
4583 acc30 = fma(a0.s7, b7.s0, acc30);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004584
4585 acc31 = fma(a0.s0, b0.s1, acc31);
4586 acc31 = fma(a0.s1, b1.s1, acc31);
4587 acc31 = fma(a0.s2, b2.s1, acc31);
4588 acc31 = fma(a0.s3, b3.s1, acc31);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004589 acc31 = fma(a0.s4, b4.s1, acc31);
4590 acc31 = fma(a0.s5, b5.s1, acc31);
4591 acc31 = fma(a0.s6, b6.s1, acc31);
4592 acc31 = fma(a0.s7, b7.s1, acc31);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004593#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004594
4595 src_addr.s0 += sizeof(float) * 8;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004596 }
4597 // float size increment
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004598 for(; i < (int)COLS_A; ++i)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004599 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004600#if defined(REINTERPRET_INPUT_AS_3D)
4601 // Load values from matrix A
4602 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
4603#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4604 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
4605#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4606#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4607 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
4608#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4609#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4610 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
4611#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4612#else // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004613 // Load values from matrix A
4614 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
4615#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4616 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
4617#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4618#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4619 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
4620#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4621#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4622 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
4623#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004624#endif // defined(REINTERPRET_INPUT_AS_3D)
4625
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004626 // Load values from matrix B
4627 float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004628 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004629
4630 // Multiply and accumulate
4631 acc00 = fma(a0, b0.s0, acc00);
4632 acc01 = fma(a0, b0.s1, acc01);
4633#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4634 acc10 = fma(a1, b0.s0, acc10);
4635 acc11 = fma(a1, b0.s1, acc11);
4636#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4637#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4638 acc20 = fma(a2, b0.s0, acc20);
4639 acc21 = fma(a2, b0.s1, acc21);
4640#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4641#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4642 acc30 = fma(a3, b0.s0, acc30);
4643 acc31 = fma(a3, b0.s1, acc31);
4644#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004645
4646 src_addr.s0 += sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004647 }
4648
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004649 // Multiply by the weight of matrix-matrix product and store the result
4650#if defined(ALPHA)
4651 acc00 = acc00 * ALPHA;
4652 acc01 = acc01 * ALPHA;
4653#endif // defined(ALPHA)
4654#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
4655 acc10 = acc10 * ALPHA;
4656 acc11 = acc11 * ALPHA;
4657#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
4658#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
4659 acc20 = acc20 * ALPHA;
4660 acc21 = acc21 * ALPHA;
4661#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
4662#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
4663 acc30 = acc30 * ALPHA;
4664 acc31 = acc31 * ALPHA;
4665#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
4666
4667 int z = get_global_id(2);
4668
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004669 // Compute destination address
4670 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
4671
Gian Marcoae2af742018-02-15 12:35:44 +00004672 // Compute dst address
4673 __global uchar *dst_addr = offset(&dst, 0, 0);
4674
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004675#if defined(ADD_VEC_C)
4676 __global float *src2_addr = (__global float *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
4677 float2 c0 = vload2(0, src2_addr);
4678
4679 acc00 += c0.s0;
4680 acc01 += c0.s1;
4681#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4682 acc10 += c0.s0;
4683 acc11 += c0.s1;
4684#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4685#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4686 acc20 += c0.s0;
4687 acc21 += c0.s1;
4688#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4689#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4690 acc30 += c0.s0;
4691 acc31 += c0.s1;
4692#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4693#endif /* defined(ADD_VEC_C) */
4694
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004695#if defined(REINTERPRET_OUTPUT_AS_3D)
4696 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004697 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004698 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004699 // | |
4700 // | plane0 |
4701 // | |
4702 // |__________________|
4703 // |******************|
4704 // | cross_plane_pad |
4705 // |******************|
4706 // | |
4707 // | plane1 |
4708 // | |
4709 // |__________________|
Gian Marcoae2af742018-02-15 12:35:44 +00004710
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004711 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
4712 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
4713 zout = min(DEPTH_GEMM3D - 1, zout);
4714
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004715 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004716 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004717
4718 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4719 // multiply dst_stride_z by DEPTH_GEMM3D
4720 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
4721
4722 // Store the output block
4723 vstore2((float2)(acc00, acc01), 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004724#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004725 vstore2((float2)(acc10, acc11), 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004726#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4727#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004728 vstore2((float2)(acc20, acc21), 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004729#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4730#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004731 vstore2((float2)(acc30, acc31), 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004732#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004733
4734#else // defined(REINTERPRET_OUTPUT_AS_3D)
4735 // Add offset for batched GEMM
4736 dst_addr += z * dst_stride_z;
4737
4738 // Store the output block
4739 vstore2((float2)(acc00, acc01), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
4740#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4741 vstore2((float2)(acc10, acc11), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
4742#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4743#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4744 vstore2((float2)(acc20, acc21), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
4745#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4746#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4747 vstore2((float2)(acc30, acc31), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
4748#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4749#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004750}
4751
Vidhya Sudhan Loganathanbdff4912018-05-22 15:03:09 +01004752#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01004753/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
4754 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004755 * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
4756 *
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00004757 * @note This OpenCL kernel works with the 16-bit floating point data type (half) and accumulating the result in a 32 floating point variable.
4758 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
4759 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
4760 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
4761 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
4762 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
4763 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
4764 *
4765 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
4766 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
4767 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
4768 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
4769 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
4770 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
4771 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004772 * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
4773 *
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00004774 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
4775 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
4776 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4777 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
4778 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4779 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
4780 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
4781 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
4782 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4783 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
4784 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4785 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004786 * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
4787 * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes)
4788 * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
4789 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00004790 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
4791 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
4792 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
4793 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
4794 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
4795 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
4796 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
4797 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
4798 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
4799 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
4800 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
4801 */
4802__kernel void gemm_mm_floating_point_f16_bifrost_acc32(IMAGE_DECLARATION(src0),
4803 IMAGE_DECLARATION(src1),
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004804#if defined(ADD_VEC_C)
4805 VECTOR_DECLARATION(src2),
4806#endif /* defined(ADD_VEC_C) */
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00004807 IMAGE_DECLARATION(dst),
4808 uint src0_stride_z,
4809 uint src1_stride_z,
4810 uint dst_stride_z
4811#if defined(REINTERPRET_INPUT_AS_3D)
4812 ,
4813 uint src_cross_plane_pad
4814#endif // REINTERPRET_INPUT_AS_3D
4815#if defined(REINTERPRET_OUTPUT_AS_3D)
4816 ,
4817 uint dst_cross_plane_pad
4818#endif // REINTERPRET_OUTPUT_AS_3D
4819 )
4820{
4821 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
4822
4823 // Compute starting address for matrix A and Matrix B
4824 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
4825
4826 // Update address for the matrix A
4827 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
4828
4829 // Update address for the matrix B
4830 src_addr.s1 += idx * sizeof(half);
4831
4832#if defined(REINTERPRET_INPUT_AS_3D)
4833 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
4834 // in order to take into account the presence of possible cross plane paddings
4835 //
4836 // | |
4837 // | plane0 |
4838 // | |
4839 // |__________________|
4840 // |******************|
4841 // | cross_plane_pad |
4842 // |******************|
4843 // | |
4844 // | plane1 |
4845 // | |
4846 // |__________________|
4847
4848 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
4849 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
4850 zin = min(DEPTH_GEMM3D - 1, zin);
4851
4852 // Add offset due to the cross plane paddings
4853 zin *= (src_cross_plane_pad * src0_stride_y);
4854
4855 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4856 // multiply src0_stride_z by DEPTH_GEMM3D
4857 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
4858
4859#else // defined(REINTERPRET_INPUT_AS_3D)
4860
4861 // Add offset for batched GEMM
4862 src_addr.s0 += get_global_id(2) * src0_stride_z;
4863
4864#endif // defined(REINTERPRET_INPUT_AS_3D)
4865
4866#if defined(MATRIX_B_DEPTH)
4867 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
4868 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
4869#else // defined(MATRIX_B_DEPTH)
4870 src_addr.s1 += get_global_id(2) * src1_stride_z;
4871#endif // defined(MATRIX_B_DEPTH)
4872
4873 float8 acc0 = 0.0h;
4874#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4875 float8 acc1 = 0.0h;
4876#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4877#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4878 float8 acc2 = 0.0h;
4879#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4880#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4881 float8 acc3 = 0.0h;
4882#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4883
4884 int i = 0;
4885 for(; i <= ((int)COLS_A - 4); i += 4)
4886 {
4887#if defined(REINTERPRET_INPUT_AS_3D)
4888 // Load values from matrix A
Usama Arif0681e3b2019-04-25 14:28:07 +01004889 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 4, half, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s);
4890#else // defined(REINTERPRET_INPUT_AS_3D)
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00004891 // Load values from matrix A
4892 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
4893#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4894 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
4895#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4896#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4897 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
4898#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4899#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4900 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
4901#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4902#endif // defined(REINTERPRET_INPUT_AS_3D)
4903
4904 // Load values from matrix B
4905 float8 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
4906 src_addr.s1 += src1_stride_y;
4907
4908 // Accumulate
4909 acc0 = fma(b0, (float8)a0.s0, acc0);
4910#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4911 acc1 = fma(b0, (float8)a1.s0, acc1);
4912#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4913#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4914 acc2 = fma(b0, (float8)a2.s0, acc2);
4915#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4916#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4917 acc3 = fma(b0, (float8)a3.s0, acc3);
4918#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4919
4920 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
4921 src_addr.s1 += src1_stride_y;
4922 acc0 = fma(b0, (float8)a0.s1, acc0);
4923#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4924 acc1 = fma(b0, (float8)a1.s1, acc1);
4925#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4926#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4927 acc2 = fma(b0, (float8)a2.s1, acc2);
4928#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4929#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4930 acc3 = fma(b0, (float8)a3.s1, acc3);
4931#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4932
4933 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
4934 src_addr.s1 += src1_stride_y;
4935 acc0 = fma(b0, (float8)a0.s2, acc0);
4936#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4937 acc1 = fma(b0, (float8)a1.s2, acc1);
4938#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4939#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4940 acc2 = fma(b0, (float8)a2.s2, acc2);
4941#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4942#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4943 acc3 = fma(b0, (float8)a3.s2, acc3);
4944#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4945
4946 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
4947 src_addr.s1 += src1_stride_y;
4948 acc0 = fma(b0, (float8)a0.s3, acc0);
4949#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4950 acc1 = fma(b0, (float8)a1.s3, acc1);
4951#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4952#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4953 acc2 = fma(b0, (float8)a2.s3, acc2);
4954#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4955#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4956 acc3 = fma(b0, (float8)a3.s3, acc3);
4957#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4958
4959 src_addr.s0 += 4 * sizeof(half);
4960 }
4961
4962 for(; i < (int)COLS_A; ++i)
4963 {
4964#if defined(REINTERPRET_INPUT_AS_3D)
4965 // Load values from matrix A
4966 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
4967#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4968 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
4969#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4970#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4971 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
4972#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4973#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4974 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
4975#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4976#else // defined(REINTERPRET_INPUT_AS_3D)
4977 // Load values from matrix A
4978 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
4979#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4980 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
4981#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4982#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4983 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
4984#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4985#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4986 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
4987#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4988#endif // defined(REINTERPRET_INPUT_AS_3D)
4989
4990 // Load values from matrix B
4991 float8 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
4992
4993 src_addr += (int2)(sizeof(half), src1_stride_y);
4994
4995 // Accumulate
4996 acc0 = fma(b0, (float8)a0, acc0); // b0 * (half8)a0;
4997#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4998 acc1 = fma(b0, (float8)a1, acc1); // b0 * (half8)a1;
4999#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5000#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5001 acc2 = fma(b0, (float8)a2, acc2); // b0 * (half8)a2;
5002#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5003#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5004 acc3 = fma(b0, (float8)a3, acc3); // b0 * (half8)a3;
5005#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5006 }
5007
5008 // Multiply by the weight of matrix-matrix product and store the result
5009#if defined(ALPHA)
5010 half8 hacc0 = convert_half8(acc0) * (half8)ALPHA;
5011#else //defined(ALPHA)
5012 half8 hacc0 = convert_half8(acc0);
5013#endif // defined(ALPHA)
5014#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5015#if defined(ALPHA)
5016 half8 hacc1 = convert_half8(acc1) * (half8)ALPHA;
5017#else //defined(ALPHA)
5018 half8 hacc1 = convert_half8(acc1);
5019#endif //defined(ALPHA)
5020#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y
5021
5022#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5023#if defined(ALPHA)
5024 half8 hacc2 = convert_half8(acc2) * (half8)ALPHA;
5025#else //defined(ALPHA)
5026 half8 hacc2 = convert_half8(acc2);
5027#endif //defined(ALPHA)
5028#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5029
5030#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5031#if defined(ALPHA)
5032 half8 hacc3 = convert_half8(acc3) * (half8)ALPHA;
5033#else //defined(ALPHA)
5034 half8 hacc3 = convert_half8(acc3);
5035#endif // defined(ALPHA)
5036#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5037
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00005038#if defined(ADD_VEC_C)
5039 // *INDENT-OFF*
5040 // clang-format off
5041 __global half *src2_addr = (__global half *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
5042 half8 c0 = vload8(0, src2_addr);
5043 // clang-format on
5044 // *INDENT-ON*
5045
5046 hacc0 += c0;
5047#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5048 hacc1 += c0;
5049#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5050#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5051 hacc2 += c0;
5052#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5053#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5054 hacc3 += c0;
5055#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5056#endif /* defined(ADD_VEC_C) */
5057
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005058 int z = get_global_id(2);
5059
5060 // Compute destination address
5061 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
5062
5063 // Compute dst address
5064 __global uchar *dst_addr = offset(&dst, 0, 0);
5065
5066#if defined(REINTERPRET_OUTPUT_AS_3D)
5067 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
5068 // in order to take into account the presence of possible cross plane paddings
5069 //
5070 // | |
5071 // | plane0 |
5072 // | |
5073 // |__________________|
5074 // |******************|
5075 // | cross_plane_pad |
5076 // |******************|
5077 // | |
5078 // | plane1 |
5079 // | |
5080 // |__________________|
5081
5082 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
5083 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
5084 zout = min(DEPTH_GEMM3D - 1, zout);
5085
5086 // Add offset due to the cross plane paddings
5087 zout *= (dst_cross_plane_pad * dst_stride_y);
5088
5089 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
5090 // multiply dst_stride_z by DEPTH_GEMM3D
5091 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005092 // Store the output block
Usama Arif0681e3b2019-04-25 14:28:07 +01005093 STORE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 8, half, hacc, dst_addr, dst_stride_y, zout.s);
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005094#else // defined(REINTERPRET_OUTPUT_AS_3D)
5095 // Add offset for batched GEMM
5096 dst_addr += z * dst_stride_z;
5097
5098 // Store the output block
5099 vstore8(hacc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
5100#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5101 vstore8(hacc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
5102#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5103#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5104 vstore8(hacc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
5105#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5106#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5107 vstore8(hacc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
5108#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5109#endif // REINTERPRET_OUTPUT_AS_3D
5110}
5111
5112/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
5113 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00005114 * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
5115 *
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005116 * @note This OpenCL kernel works with the 16-bit floating point data type (half) and uses the fma units.
5117 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
5118 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
5119 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
5120 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
5121 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
5122 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
5123 *
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005124 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
5125 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005126 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
5127 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
5128 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
5129 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
5130 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00005131 * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
5132 *
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005133 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
5134 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
5135 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5136 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
5137 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
5138 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
5139 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
5140 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
5141 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5142 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
5143 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
5144 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00005145 * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
5146 * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes)
5147 * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
5148 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005149 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
5150 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
5151 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
5152 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
5153 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
5154 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005155 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
5156 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
5157 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005158 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
5159 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005160 */
5161__kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0),
5162 IMAGE_DECLARATION(src1),
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00005163#if defined(ADD_VEC_C)
5164 VECTOR_DECLARATION(src2),
5165#endif /* defined(ADD_VEC_C) */
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005166 IMAGE_DECLARATION(dst),
5167 uint src0_stride_z,
5168 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005169 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005170#if defined(REINTERPRET_INPUT_AS_3D)
5171 ,
5172 uint src_cross_plane_pad
5173#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005174#if defined(REINTERPRET_OUTPUT_AS_3D)
5175 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005176 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005177#endif // REINTERPRET_OUTPUT_AS_3D
5178 )
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005179{
5180 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
5181
5182 // Compute starting address for matrix A and Matrix B
5183 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
5184
5185 // Update address for the matrix A
5186 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
5187
5188 // Update address for the matrix B
5189 src_addr.s1 += idx * sizeof(half);
5190
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005191#if defined(REINTERPRET_INPUT_AS_3D)
5192 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
5193 // in order to take into account the presence of possible cross plane paddings
5194 //
5195 // | |
5196 // | plane0 |
5197 // | |
5198 // |__________________|
5199 // |******************|
5200 // | cross_plane_pad |
5201 // |******************|
5202 // | |
5203 // | plane1 |
5204 // | |
5205 // |__________________|
5206
5207 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
5208 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
5209 zin = min(DEPTH_GEMM3D - 1, zin);
5210
5211 // Add offset due to the cross plane paddings
5212 zin *= (src_cross_plane_pad * src0_stride_y);
5213
5214 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
5215 // multiply src0_stride_z by DEPTH_GEMM3D
5216 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
5217
5218#else // defined(REINTERPRET_INPUT_AS_3D)
5219
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005220 // Add offset for batched GEMM
5221 src_addr.s0 += get_global_id(2) * src0_stride_z;
5222
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005223#endif // defined(REINTERPRET_INPUT_AS_3D)
5224
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005225#if defined(MATRIX_B_DEPTH)
5226 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
5227 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
5228#else // defined(MATRIX_B_DEPTH)
5229 src_addr.s1 += get_global_id(2) * src1_stride_z;
5230#endif // defined(MATRIX_B_DEPTH)
5231
5232 half8 acc0 = 0.0h;
5233#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5234 half8 acc1 = 0.0h;
5235#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5236#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5237 half8 acc2 = 0.0h;
5238#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5239#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5240 half8 acc3 = 0.0h;
5241#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5242
5243 int i = 0;
5244 for(; i <= ((int)COLS_A - 4); i += 4)
5245 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005246#if defined(REINTERPRET_INPUT_AS_3D)
5247 // Load values from matrix A
Usama Arif0681e3b2019-04-25 14:28:07 +01005248 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 4, half, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s);
5249#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005250 // Load values from matrix A
5251 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
5252#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5253 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
5254#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5255#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5256 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
5257#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5258#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5259 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
5260#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005261#endif // defined(REINTERPRET_INPUT_AS_3D)
5262
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005263 // Load values from matrix B
5264 half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
5265 src_addr.s1 += src1_stride_y;
5266
5267 // Accumulate
5268 acc0 = fma(b0, (half8)a0.s0, acc0);
5269#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5270 acc1 = fma(b0, (half8)a1.s0, acc1);
5271#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5272#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5273 acc2 = fma(b0, (half8)a2.s0, acc2);
5274#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5275#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5276 acc3 = fma(b0, (half8)a3.s0, acc3);
5277#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5278
5279 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
5280 src_addr.s1 += src1_stride_y;
5281 acc0 = fma(b0, (half8)a0.s1, acc0);
5282#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5283 acc1 = fma(b0, (half8)a1.s1, acc1);
5284#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5285#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5286 acc2 = fma(b0, (half8)a2.s1, acc2);
5287#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5288#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5289 acc3 = fma(b0, (half8)a3.s1, acc3);
5290#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5291
5292 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
5293 src_addr.s1 += src1_stride_y;
5294 acc0 = fma(b0, (half8)a0.s2, acc0);
5295#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5296 acc1 = fma(b0, (half8)a1.s2, acc1);
5297#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5298#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5299 acc2 = fma(b0, (half8)a2.s2, acc2);
5300#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5301#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5302 acc3 = fma(b0, (half8)a3.s2, acc3);
5303#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5304
5305 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
5306 src_addr.s1 += src1_stride_y;
5307 acc0 = fma(b0, (half8)a0.s3, acc0);
5308#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5309 acc1 = fma(b0, (half8)a1.s3, acc1);
5310#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5311#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5312 acc2 = fma(b0, (half8)a2.s3, acc2);
5313#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5314#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5315 acc3 = fma(b0, (half8)a3.s3, acc3);
5316#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5317
5318 src_addr.s0 += 4 * sizeof(half);
5319 }
5320
5321 for(; i < (int)COLS_A; ++i)
5322 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005323#if defined(REINTERPRET_INPUT_AS_3D)
5324 // Load values from matrix A
5325 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
5326#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5327 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
5328#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5329#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5330 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
5331#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5332#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5333 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
5334#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5335#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005336 // Load values from matrix A
5337 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
5338#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5339 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
5340#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5341#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5342 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
5343#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5344#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5345 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
5346#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005347#endif // defined(REINTERPRET_INPUT_AS_3D)
5348
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005349 // Load values from matrix B
5350 half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
5351
5352 src_addr += (int2)(sizeof(half), src1_stride_y);
5353
5354 // Accumulate
5355 acc0 = fma(b0, (half8)a0, acc0); // b0 * (half8)a0;
5356#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5357 acc1 = fma(b0, (half8)a1, acc1); // b0 * (half8)a1;
5358#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5359#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5360 acc2 = fma(b0, (half8)a2, acc2); // b0 * (half8)a2;
5361#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5362#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5363 acc3 = fma(b0, (half8)a3, acc3); // b0 * (half8)a3;
5364#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5365 }
5366
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005367 // Multiply by the weight of matrix-matrix product and store the result
5368#if defined(ALPHA)
5369 acc0 = acc0 * (half8)ALPHA;
5370#endif // defined(ALPHA)
5371#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
5372 acc1 = acc1 * (half8)ALPHA;
5373#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
5374#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
5375 acc2 = acc2 * (half8)ALPHA;
5376#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
5377#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
5378 acc3 = acc3 * (half8)ALPHA;
5379#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
5380
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00005381#if defined(ADD_VEC_C)
5382 // *INDENT-OFF*
5383 // clang-format off
5384 __global half *src2_addr = (__global half *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
5385 half8 c0 = vload8(0, src2_addr);
5386 // clang-format on
5387 // *INDENT-ON*
5388
5389 acc0 += c0;
5390#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5391 acc1 += c0;
5392#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5393#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5394 acc2 += c0;
5395#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5396#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5397 acc3 += c0;
5398#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5399#endif /* defined(ADD_VEC_C) */
5400
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005401 int z = get_global_id(2);
5402
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005403 // Compute destination address
5404 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
5405
5406 // Compute dst address
5407 __global uchar *dst_addr = offset(&dst, 0, 0);
5408
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005409#if defined(REINTERPRET_OUTPUT_AS_3D)
5410 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01005411 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005412 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01005413 // | |
5414 // | plane0 |
5415 // | |
5416 // |__________________|
5417 // |******************|
5418 // | cross_plane_pad |
5419 // |******************|
5420 // | |
5421 // | plane1 |
5422 // | |
5423 // |__________________|
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005424
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005425 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
5426 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
5427 zout = min(DEPTH_GEMM3D - 1, zout);
5428
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01005429 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005430 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005431
5432 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
5433 // multiply dst_stride_z by DEPTH_GEMM3D
5434 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
5435
5436 // Store the output block
Usama Arif0681e3b2019-04-25 14:28:07 +01005437 STORE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 8, half, acc, dst_addr, dst_stride_y, zout.s);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005438#else // defined(REINTERPRET_OUTPUT_AS_3D)
5439 // Add offset for batched GEMM
5440 dst_addr += z * dst_stride_z;
5441
5442 // Store the output block
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005443 vstore8(acc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
5444#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005445 vstore8(acc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
5446#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5447#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005448 vstore8(acc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
5449#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5450#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005451 vstore8(acc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
5452#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005453#endif // REINTERPRET_OUTPUT_AS_3D
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005454}
Vidhya Sudhan Loganathanbdff4912018-05-22 15:03:09 +01005455#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005456
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01005457#endif // defined(COLS_A) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && (NUM_ELEMS_PROCESSED_PER_THREAD_Y)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005458
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005459#if defined(BETA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005460/** This OpenCL kernel performs the in-place matrix addition between 2 matrices taking into account that the second matrix might be weighted by a scalar value beta:
5461 *
Gian Marco19835e52018-01-30 13:35:54 +00005462 * @note The beta's value need to be passed at compile time using -DBETA
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005463 *
5464 * @param[in] src_ptr Pointer to the source matrix. Supported data types: F32
5465 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
5466 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5467 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
5468 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005469 * @param[in] src_stride_z Stride of the destination tensor in Z dimension (in bytes)
5470 * @param[in] src_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005471 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01005472 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005473 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
5474 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
5475 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
5476 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005477 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
5478 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005479 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
5480 */
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005481__kernel void gemm_ma_f32(TENSOR3D_DECLARATION(src),
5482 TENSOR3D_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005483{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005484 // Compute source and destination addresses
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005485 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
5486 Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005487
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005488 // Load values from A x B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005489 float4 alpha_ab = vload4(0, (__global float *)dst.ptr);
5490
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005491 // Load values from Matrix C
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005492 float4 c = vload4(0, (__global float *)src.ptr);
5493
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005494 // Computes alpha * axb + beta * c
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005495 float4 out = alpha_ab + (float4)BETA * c;
5496
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005497 // Store final result in axb matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005498 vstore4(out, 0, (__global float *)dst.ptr);
5499}
5500
Vidhya Sudhan Loganathan76c85642018-05-25 13:53:02 +01005501#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005502/** This OpenCL kernel performs the in-place matrix addition between 2 matrices taking into account that the second matrix might be weighted by a scalar value beta:
5503 *
Gian Marco19835e52018-01-30 13:35:54 +00005504 * @note The beta's value need to be passed at compile time using -DBETA
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01005505 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005506 * @param[in] src_ptr Pointer to the source matrix. Supported data types: F16
5507 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
5508 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5509 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
5510 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005511 * @param[in] src_stride_z Stride of the destination tensor in Z dimension (in bytes)
5512 * @param[in] src_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005513 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01005514 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005515 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
5516 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
5517 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
5518 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005519 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
5520 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005521 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
5522 */
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005523__kernel void gemm_ma_f16(TENSOR3D_DECLARATION(src),
5524 TENSOR3D_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005525{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005526 // Compute source and destination addresses
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005527 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
5528 Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005529
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005530 // Load values from A x B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005531 half8 alpha_ab = vload8(0, (__global half *)dst.ptr);
5532
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005533 // Load values from Matrix C
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005534 half8 c = vload8(0, (__global half *)src.ptr);
5535
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005536 // Computes alpha * axb + beta * c
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005537 half8 out = alpha_ab + (half8)BETA * c;
5538
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005539 // Store final result in axb matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005540 vstore8(out, 0, (__global half *)dst.ptr);
5541}
Vidhya Sudhan Loganathan76c85642018-05-25 13:53:02 +01005542#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005543#endif // defined(BETA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005544
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005545#if defined(WIDTH_VECTOR_A)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005546/** This OpenCL kernel computes the vector by matrix multiplication between each row of A (src0) and matrix B (src1) used for locally connected layer
5547 *
Gian Marco19835e52018-01-30 13:35:54 +00005548 * @note The width of A need to be passed at compile time using -DWIDTH_VECTOR_A
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005549 *
Gian Marco19835e52018-01-30 13:35:54 +00005550 * @note The input A and matrix B must not be reshaped
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005551 *
5552 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
5553 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
5554 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5555 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
5556 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
5557 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01005558 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005559 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
5560 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5561 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
5562 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
5563 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
5564 * @param[in] src1_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
5565 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01005566 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005567 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
5568 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
5569 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
5570 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
5571 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
5572 */
5573__kernel void gemm_lc_vm_f32(IMAGE_DECLARATION(src0),
5574 TENSOR3D_DECLARATION(src1),
5575 IMAGE_DECLARATION(dst))
5576{
5577 int idx = get_global_id(0) * 4;
5578 int idy = get_global_id(1);
5579
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005580 // Compute the address for the vector A and matrix B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005581 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes + src0_stride_y * idy, src1_offset_first_element_in_bytes + src1_stride_z * idy));
5582 src_addr.s1 += idx * sizeof(float);
5583
5584 int end_row_vec_a = src_addr.s0 + (WIDTH_VECTOR_A * sizeof(float));
5585
5586 float4 acc = 0.0f;
5587
Georgios Pinitas96880cf2017-10-20 18:52:20 +01005588 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(float)); src_addr += (int2)(2 * sizeof(float), 2 * src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005589 {
5590 float2 a0 = vload2(0, (__global float *)(src0_ptr + src_addr.s0));
5591 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
5592 float4 b1 = vload4(0, (__global float *)(src1_ptr + src_addr.s1 + src1_stride_y));
5593
5594 acc += b0 * (float4)a0.s0;
5595 acc += b1 * (float4)a0.s1;
5596 }
5597
5598 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(float), src1_stride_y))
5599 {
5600 float a0 = *((__global float *)(src0_ptr + src_addr.s0));
5601 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
5602
5603 acc += b0 * (float4)a0;
5604 }
5605
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005606 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005607 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
5608
5609 vstore4(acc, 0, (__global float *)(offset(&dst, 0, 0)));
5610}
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005611#endif // defined(WIDTH_VECTOR_A)
5612
5613/** This kernel accumulates each row with the biases vector.
5614 *
5615 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=short.
5616 * @note The vector size must be passed at compile time using -DVECTOR_SIZE e.g. -DVECTOR_SIZE=16.
5617 *
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01005618 * @param[in, out] accum_ptr Pointer to the accumulate tensor. Supported data type: U8/S8/U16/S16/F16/U32/S32/F32
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005619 * @param[in] accum_stride_x Stride of the accmulate tensor in X dimension (in bytes)
5620 * @param[in] accum_step_x accum_stride_x * number of elements along X processed per workitem(in bytes)
5621 * @param[in] accum_stride_y Stride of the accumlulate tensor in Y dimension (in bytes)
5622 * @param[in] accum_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
5623 * @param[in] accum_offset_first_element_in_bytes The offset of the first element in the accumulate tensor
5624 * @param[in] biases_ptr Pointer to the biases vector. Same as @p accum_ptr
5625 * @param[in] biases_stride_x Stride of the destination tensor in X dimension (in bytes)
5626 * @param[in] biases_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
5627 * @param[in] biases_offset_first_element_in_bytes The offset of the first element in the destination tensor
5628 */
5629#if defined(DATA_TYPE) && defined(VECTOR_SIZE)
5630__kernel void gemm_accumulate_biases(
5631 IMAGE_DECLARATION(accum),
5632 VECTOR_DECLARATION(biases))
5633{
5634 Image accum = CONVERT_TO_IMAGE_STRUCT(accum);
5635 Vector biases = CONVERT_TO_VECTOR_STRUCT(biases);
5636
5637 // Vector size, i.e. number of vector elements.
5638 VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
5639 accum_value = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)accum.ptr);
5640 VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
5641 biases_value = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)biases.ptr);
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01005642 accum_value = biases_value + accum_value;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005643 // Store result in the accumulate buffer
5644 VSTORE(VECTOR_SIZE)
5645 (accum_value, 0, (__global DATA_TYPE *)accum.ptr);
5646}
5647#endif // defined(DATA_TYPE) && defined(VECTOR_SIZE)