blob: 7a861dd20778bbd2b5f98098d5e4d999a5dbd96a [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Gian Marco36a0a462018-01-12 10:21:40 +00002 * Copyright (c) 2017-2018 ARM Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "helpers.h"
25
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000026#if defined(M0) && defined(K0) && defined(V0) && defined(DATA_TYPE)
27
28/** This OpenCL kernel reshapes the lhs input matrix. The kernel splits the input matrix in blocks of size M0xK0 and stores each one (not transposed) in
29 * the output matrix unrolling the values.
30 *
31 * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float)
32 * @note The block's dimensions (M0 and K0) must be passed at compile time using -DM0 and -DK0 (i.e. -DM0=2, -DK0=2).
33 * @note The number of M0xK0 vertical blocks to store on the same output row must be passed at compile time using -DV0 (i.e. -DV0=2)
34 * @note Only the following values for M0, K0 and V0 are supported:
35 * M0: 2,3,4,5,6,7,8
36 * K0: 2,4,8,16
37 * V0: greater than 0
38 * @note In case the input has to be reinterpreted as a 3D tensor (i.e. input of convolution layer 1x1), the following information must be passed at compile time:
39 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
40 * -# HEIGHT_GEMM3D: The height of the input in case it has to be reinterpreted as a 3D tensor.
41 * -# DEPTH_GEMM3D: The depth of the input in case it has to be reinterpreted as a 3D tensor
42 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
43 * @note If the M0xK0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
44 *
45 * @param[in] src_ptr Pointer to the source LHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
46 * @param[in] src_stride_x Stride of the source LHS tensor in X dimension (in bytes)
47 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
48 * @param[in] src_stride_y Stride of the source LHS tensor in Y dimension (in bytes)
49 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
50 * @param[in] src_stride_z Stride of the source LHS tensor in Z dimension (in bytes)
51 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
52 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source LHS tensor
53 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
54 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
55 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
56 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
57 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
58 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
59 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
60 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
61 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
62 */
63__kernel void gemm_reshape_lhs_matrix_nt(TENSOR3D_DECLARATION(src),
64 TENSOR3D_DECLARATION(dst)
65#if defined(REINTERPRET_INPUT_AS_3D)
66 ,
67 uint cross_plane_pad
68#endif // REINTERPRET_INPUT_AS_3D
69 )
70{
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +000071 // Block size
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000072#define BLOCK_SIZE ((M0) * (K0))
73
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +000074 // Output offset X
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000075#if defined(INTERLEAVE)
76#define OUTPUT_OFFSET_X (K0)
77#else // defined(INTERLEAVE)
78#define OUTPUT_OFFSET_X (BLOCK_SIZE)
79#endif // defined(INTERLEAVE)
80
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +000081 // Output step X
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000082#if defined(INTERLEAVE)
83#define OUTPUT_STEP_X (K0) * (V0)
84#else // Do not interleave
85#define OUTPUT_STEP_X (K0)
86#endif // defined(INTERLEAVE)
87
88 // Compute source and destination addresses
89 uint x = get_global_id(0);
90 uint y = get_global_id(1);
91 uint z = get_global_id(2);
92
93 // ------------------ Compute input/output addresses ---------------------------
94
95 // Compute the input address
96 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)K0 * sizeof(DATA_TYPE) + y * (uint)M0 * src_stride_y;
97
98 // Compute the output address
99 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)BLOCK_SIZE * (uint)V0 * sizeof(DATA_TYPE)) + ((y / (uint)V0) * (uint)dst_stride_y) + ((y % V0) *
100 (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE));
101
102 uint zin0 = 0;
103 uint zin1 = 0;
104 uint zin2 = 0;
105 uint zin3 = 0;
106 uint zin4 = 0;
107 uint zin5 = 0;
108 uint zin6 = 0;
109 uint zin7 = 0;
110
111#if defined(REINTERPRET_INPUT_AS_3D)
112 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
113 // multiply src_stride_z by DEPTH_GEMM3D
114
115 // Note for the REINTERPRET_INPUT_AS_3D case
116 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
117 // in order to take into account the presence of possible cross plane paddings
118 //
119 // | |
120 // | plane0 |
121 // | |
122 // |__________________|
123 // |******************|
124 // | cross_plane_pad |
125 // |******************|
126 // | |
127 // | plane1 |
128 // | |
129 // |__________________|
130
131 input_ptr += z * (uint)src_stride_z * DEPTH_GEMM3D;
132
133 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
134 zin0 = (0 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
135 zin0 = min((uint)(DEPTH_GEMM3D - 1), zin0);
136 zin0 *= (cross_plane_pad * src_stride_y);
137#if M0 > 1
138 zin1 = (1 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
139 zin1 = min((uint)(DEPTH_GEMM3D - 1), zin1);
140 zin1 *= (cross_plane_pad * src_stride_y);
141#endif // M0 > 1
142#if M0 > 2
143 zin2 = (2 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
144 zin2 = min((uint)(DEPTH_GEMM3D - 1), zin2);
145 zin2 *= (cross_plane_pad * src_stride_y);
146#endif // M0 > 2
147#if M0 > 3
148 zin3 = (3 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
149 zin3 = min((uint)(DEPTH_GEMM3D - 1), zin3);
150 zin3 *= (cross_plane_pad * src_stride_y);
151#endif // M0 > 3
152#if M0 > 4
153 zin4 = (4 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
154 zin4 = min((uint)(DEPTH_GEMM3D - 1), zin4);
155 zin4 *= (cross_plane_pad * src_stride_y);
156#endif // M0 > 4
157#if M0 > 5
158 zin5 = (5 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
159 zin5 = min((uint)(DEPTH_GEMM3D - 1), zin5);
160 zin5 *= (cross_plane_pad * src_stride_y);
161#endif // M0 > 5
162#if M0 > 6
163 zin6 = (6 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
164 zin6 = min((uint)(DEPTH_GEMM3D - 1), zin6);
165 zin6 *= (cross_plane_pad * src_stride_y);
166#endif // M0 > 6
167#if M0 > 6
168 zin7 = (7 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
169 zin7 = min((uint)(DEPTH_GEMM3D - 1), zin7);
170 zin7 *= (cross_plane_pad * src_stride_y);
171#endif // M0 > 7
172
173#else // defined(REINTERPRET_INPUT_AS_3D)
174
175 input_ptr += z * (uint)src_stride_z;
176
177#endif // defined(REINTERPRET_INPUT_AS_3D)
178
179 // Add offset for batched GEMM
180 output_ptr += z * (uint)dst_stride_z;
181
182 // ---------------------------Load input values --------------------------------
183
184 // Load values from the LHS matrix
185 VEC_DATA_TYPE(DATA_TYPE, K0)
186 a0 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y + zin0));
187#if M0 > 1
188 VEC_DATA_TYPE(DATA_TYPE, K0)
189 a1 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y + zin1));
190#endif // M0 > 1
191#if M0 > 2
192 VEC_DATA_TYPE(DATA_TYPE, K0)
193 a2 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y + zin2));
194#endif // M0 > 2
195#if M0 > 3
196 VEC_DATA_TYPE(DATA_TYPE, K0)
197 a3 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y + zin3));
198#endif // M0 > 3
199#if M0 > 4
200 VEC_DATA_TYPE(DATA_TYPE, K0)
201 a4 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 4 * src_stride_y + zin4));
202#endif // M0 > 4
203#if M0 > 5
204 VEC_DATA_TYPE(DATA_TYPE, K0)
205 a5 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 5 * src_stride_y + zin5));
206#endif // M0 > 5
207#if M0 > 6
208 VEC_DATA_TYPE(DATA_TYPE, K0)
209 a6 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 6 * src_stride_y + zin6));
210#endif // M0 > 6
211#if M0 > 7
212 VEC_DATA_TYPE(DATA_TYPE, K0)
213 a7 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 7 * src_stride_y + zin7));
214#endif // M0 > 7
215
216 // ---------------------------Store output values ------------------------------
217
218 VSTORE(K0)
219 (a0, 0, (__global DATA_TYPE *)(output_ptr + 0 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
220#if M0 > 1
221 VSTORE(K0)
222 (a1, 0, (__global DATA_TYPE *)(output_ptr + 1 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
223#endif // M0 > 1
224#if M0 > 2
225 VSTORE(K0)
226 (a2, 0, (__global DATA_TYPE *)(output_ptr + 2 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
227#endif // M0 > 2
228#if M0 > 3
229 VSTORE(K0)
230 (a3, 0, (__global DATA_TYPE *)(output_ptr + 3 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
231#endif // M0 > 3
232#if M0 > 4
233 VSTORE(K0)
234 (a4, 0, (__global DATA_TYPE *)(output_ptr + 4 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
235#endif // M0 > 4
236#if M0 > 5
237 VSTORE(K0)
238 (a5, 0, (__global DATA_TYPE *)(output_ptr + 5 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
239#endif // M0 > 5
240#if M0 > 6
241 VSTORE(K0)
242 (a6, 0, (__global DATA_TYPE *)(output_ptr + 6 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
243#endif // M0 > 6
244#if M0 > 7
245 VSTORE(K0)
246 (a7, 0, (__global DATA_TYPE *)(output_ptr + 7 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
247#endif // M0 > 7
248
249#undef BLOCK_SIZE
250#undef OUTPUT_OFFSET_X
251#undef OUTPUT_STEP_X
252}
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000253
254#if M0 == 2
255#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
256 ({ \
257 VEC_DATA_TYPE(DATA_TYPE, M0) \
258 res = (VEC_DATA_TYPE(DATA_TYPE, M0))(a0.s##i, a1.s##i); \
259 VSTORE(M0) \
260 (res, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
261 })
262#elif M0 == 3 // M0 == 3
263#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
264 ({ \
265 VEC_DATA_TYPE(DATA_TYPE, M0) \
266 res = (VEC_DATA_TYPE(DATA_TYPE, M0))(a0.s##i, a1.s##i, a2.s##i); \
267 VSTORE(M0) \
268 (res, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
269 })
270#elif M0 == 4 // M0 == 4
271#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
272 ({ \
273 VEC_DATA_TYPE(DATA_TYPE, M0) \
274 res = (VEC_DATA_TYPE(DATA_TYPE, M0))(a0.s##i, a1.s##i, a2.s##i, a3.s##i); \
275 VSTORE(M0) \
276 (res, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
277 })
278#elif M0 == 5 // M0 == 5
279#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
280 ({ \
281 VEC_DATA_TYPE(DATA_TYPE, 4) \
282 res0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s##i, a1.s##i, a2.s##i, a3.s##i); \
283 DATA_TYPE res1 = a4.s##i; \
284 VSTORE(4) \
285 (res0, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
286 *((__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE)) + 4) = res1; \
287 })
288#elif M0 == 6 // M0 == 6
289#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
290 ({ \
291 VEC_DATA_TYPE(DATA_TYPE, 4) \
292 res0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s##i, a1.s##i, a2.s##i, a3.s##i); \
293 VEC_DATA_TYPE(DATA_TYPE, 2) \
294 res1 = (VEC_DATA_TYPE(DATA_TYPE, 2))(a4.s##i, a5.s##i); \
295 VSTORE(4) \
296 (res0, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
297 VSTORE(2) \
298 (res1, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE)) + 4); \
299 })
300#elif M0 == 7 // M0 == 7
301#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
302 ({ \
303 VEC_DATA_TYPE(DATA_TYPE, 4) \
304 res0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s##i, a1.s##i, a2.s##i, a3.s##i); \
305 VEC_DATA_TYPE(DATA_TYPE, 3) \
306 res1 = (VEC_DATA_TYPE(DATA_TYPE, 3))(a4.s##i, a5.s##i, a6.s##i); \
307 VSTORE(4) \
308 (res0, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
309 VSTORE(3) \
310 (res1, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE)) + 4); \
311 })
312#elif M0 == 8 // M0 == 8
313#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
314 ({ \
315 VEC_DATA_TYPE(DATA_TYPE, M0) \
316 res = (VEC_DATA_TYPE(DATA_TYPE, M0))(a0.s##i, a1.s##i, a2.s##i, a3.s##i, a4.s##i, a5.s##i, a6.s##i, a7.s##i); \
317 VSTORE(M0) \
318 (res, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
319 })
320#else // M0 not supported
321#error "M0 value not supported"
322#endif // N0 conditions
323
324/** This OpenCL kernel reshapes the lhs input matrix. The kernel splits the input matrix in blocks of size M0xK0 and stores each one (transposed) in
325 * the output matrix unrolling the values.
326 *
327 * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float)
328 * @note The block's dimensions (M0 and K0) must be passed at compile time using -DM0 and -DK0 (i.e. -DM0=2, -DK0=2).
329 * @note The number of M0xK0 vertical blocks to store on the same output row must be passed at compile time using -DV0 (i.e. -DV0=2)
330 * @note Only the following values for M0, K0 and V0 are supported:
331 * M0: 2,3,4,5,6,7,8
332 * K0: 2,4,8,16
333 * V0: greater than 0
334 * @note In case the input has to be reinterpreted as a 3D tensor (i.e. input of convolution layer 1x1), the following information must be passed at compile time:
335 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
336 * -# HEIGHT_GEMM3D: The height of the input in case it has to be reinterpreted as a 3D tensor.
337 * -# DEPTH_GEMM3D: The depth of the input in case it has to be reinterpreted as a 3D tensor
338 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
339 * @note If the M0xK0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
340 *
341 * @param[in] src_ptr Pointer to the source LHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
342 * @param[in] src_stride_x Stride of the source LHS tensor in X dimension (in bytes)
343 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
344 * @param[in] src_stride_y Stride of the source LHS tensor in Y dimension (in bytes)
345 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
346 * @param[in] src_stride_z Stride of the source LHS tensor in Z dimension (in bytes)
347 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
348 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source LHS tensor
349 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
350 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
351 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
352 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
353 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
354 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
355 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
356 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
357 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
358 */
359__kernel void gemm_reshape_lhs_matrix_t(TENSOR3D_DECLARATION(src),
360 TENSOR3D_DECLARATION(dst)
361#if defined(REINTERPRET_INPUT_AS_3D)
362 ,
363 uint cross_plane_pad
364#endif // REINTERPRET_INPUT_AS_3D
365 )
366{
367 // Block size
368#define BLOCK_SIZE ((M0) * (K0))
369
370 // Output offset X
371#if defined(INTERLEAVE)
372#define OUTPUT_OFFSET_X (M0)
373#else // defined(INTERLEAVE)
374#define OUTPUT_OFFSET_X (BLOCK_SIZE)
375#endif // defined(INTERLEAVE)
376
377 // Output step X
378#if defined(INTERLEAVE)
379#define OUTPUT_STEP_X (M0) * (V0)
380#else // Do not interleave
381#define OUTPUT_STEP_X (M0)
382#endif // defined(INTERLEAVE)
383
384 // Compute source and destination addresses
385 uint x = get_global_id(0);
386 uint y = get_global_id(1);
387 uint z = get_global_id(2);
388
389 // ------------------ Compute input/output addresses ---------------------------
390
391 // Compute the input address
392 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)K0 * sizeof(DATA_TYPE) + y * (uint)M0 * src_stride_y;
393
394 // Compute the output address
395 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)BLOCK_SIZE * (uint)V0 * sizeof(DATA_TYPE)) + ((y / (uint)V0) * (uint)dst_stride_y) + ((y % V0) *
396 (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE));
397
398 uint zin0 = 0;
399 uint zin1 = 0;
400 uint zin2 = 0;
401 uint zin3 = 0;
402 uint zin4 = 0;
403 uint zin5 = 0;
404 uint zin6 = 0;
405 uint zin7 = 0;
406
407#if defined(REINTERPRET_INPUT_AS_3D)
408 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
409 // multiply src_stride_z by DEPTH_GEMM3D
410
411 // Note for the REINTERPRET_INPUT_AS_3D case
412 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
413 // in order to take into account the presence of possible cross plane paddings
414 //
415 // | |
416 // | plane0 |
417 // | |
418 // |__________________|
419 // |******************|
420 // | cross_plane_pad |
421 // |******************|
422 // | |
423 // | plane1 |
424 // | |
425 // |__________________|
426
427 input_ptr += z * (uint)src_stride_z * DEPTH_GEMM3D;
428
429 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
430 zin0 = (0 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
431 zin0 = min((uint)(DEPTH_GEMM3D - 1), zin0);
432 zin0 *= (cross_plane_pad * src_stride_y);
433#if M0 > 1
434 zin1 = (1 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
435 zin1 = min((uint)(DEPTH_GEMM3D - 1), zin1);
436 zin1 *= (cross_plane_pad * src_stride_y);
437#endif // M0 > 1
438#if M0 > 2
439 zin2 = (2 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
440 zin2 = min((uint)(DEPTH_GEMM3D - 1), zin2);
441 zin2 *= (cross_plane_pad * src_stride_y);
442#endif // M0 > 2
443#if M0 > 3
444 zin3 = (3 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
445 zin3 = min((uint)(DEPTH_GEMM3D - 1), zin3);
446 zin3 *= (cross_plane_pad * src_stride_y);
447#endif // M0 > 3
448#if M0 > 4
449 zin4 = (4 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
450 zin4 = min((uint)(DEPTH_GEMM3D - 1), zin4);
451 zin4 *= (cross_plane_pad * src_stride_y);
452#endif // M0 > 4
453#if M0 > 5
454 zin5 = (5 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
455 zin5 = min((uint)(DEPTH_GEMM3D - 1), zin5);
456 zin5 *= (cross_plane_pad * src_stride_y);
457#endif // M0 > 5
458#if M0 > 6
459 zin6 = (6 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
460 zin6 = min((uint)(DEPTH_GEMM3D - 1), zin6);
461 zin6 *= (cross_plane_pad * src_stride_y);
462#endif // M0 > 6
463#if M0 > 6
464 zin7 = (7 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
465 zin7 = min((uint)(DEPTH_GEMM3D - 1), zin7);
466 zin7 *= (cross_plane_pad * src_stride_y);
467#endif // M0 > 7
468
469#else // defined(REINTERPRET_INPUT_AS_3D)
470
471 input_ptr += z * (uint)src_stride_z;
472
473#endif // defined(REINTERPRET_INPUT_AS_3D)
474
475 // Add offset for batched GEMM
476 output_ptr += z * (uint)dst_stride_z;
477
478 // ---------------------------Load input values --------------------------------
479
480 // Load values from the LHS matrix
481 VEC_DATA_TYPE(DATA_TYPE, K0)
482 a0 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y + zin0));
483#if M0 > 1
484 VEC_DATA_TYPE(DATA_TYPE, K0)
485 a1 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y + zin1));
486#endif // M0 > 1
487#if M0 > 2
488 VEC_DATA_TYPE(DATA_TYPE, K0)
489 a2 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y + zin2));
490#endif // M0 > 2
491#if M0 > 3
492 VEC_DATA_TYPE(DATA_TYPE, K0)
493 a3 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y + zin3));
494#endif // M0 > 3
495#if M0 > 4
496 VEC_DATA_TYPE(DATA_TYPE, K0)
497 a4 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 4 * src_stride_y + zin4));
498#endif // M0 > 4
499#if M0 > 5
500 VEC_DATA_TYPE(DATA_TYPE, K0)
501 a5 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 5 * src_stride_y + zin5));
502#endif // M0 > 5
503#if M0 > 6
504 VEC_DATA_TYPE(DATA_TYPE, K0)
505 a6 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 6 * src_stride_y + zin6));
506#endif // M0 > 6
507#if M0 > 7
508 VEC_DATA_TYPE(DATA_TYPE, K0)
509 a7 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 7 * src_stride_y + zin7));
510#endif // M0 > 7
511
512 // ---------------------------Transpose and store block -----------------------
513
514 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 0);
515 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 1);
516#if K0 > 2
517 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 2);
518 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 3);
519#endif // K0 > 2
520#if K0 > 4
521 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 4);
522 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 5);
523 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 6);
524 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 7);
525#endif // K0 > 4
526#if K0 > 8
527 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 8);
528 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 9);
529 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, A);
530 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, B);
531 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, C);
532 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, D);
533 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, E);
534 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, F);
535#endif // K0 > 8
536
537#undef BLOCK_SIZE
538#undef OUTPUT_OFFSET_X
539#undef OUTPUT_STEP_X
540}
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000541#endif // defined(M0) && defined(K0) && defined(V0) && defined(DATA_TYPE)
542
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000543#if defined(K0) && defined(N0) && defined(H0) && defined(DATA_TYPE) && defined(SRC_HEIGHT)
544/** This OpenCL kernel reshapes the rhs input matrix. The kernel splits the input matrix in blocks of size K0xN0 and stores each one (not transposed) in
545 * the output matrix unrolling the values.
546 *
547 * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float)
548 * @note The height of the input tensor must be passed at compile time using -DSRC_HEIGHT (i.e. -DSRC_HEIGHT=16)
549 * @note The block's dimensions (K0 and N0) must be passed at compile time using -DK0 and -DN0 (i.e. -DK0=2, -DN0=2).
550 * @note The number of K0xN0 vertical blocks to store on the same output row must be passed at compile time using -DH0 (i.e. -DH0=2)
551 * @note If the K0xN0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
552 * @note Only the following values for K0, N0 and H0 are supported:
553 * N0: 2,4,8,16
554 * K0: 1,2,4,8,16
555 * H0: greater than 0
556 *
557 * @param[in] src_ptr Pointer to the source RHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
558 * @param[in] src_stride_x Stride of the source RHS tensor in X dimension (in bytes)
559 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
560 * @param[in] src_stride_y Stride of the source RHS tensor in Y dimension (in bytes)
561 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
562 * @param[in] src_stride_z Stride of the source RHS tensor in Z dimension (in bytes)
563 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
564 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source RHS tensor
565 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
566 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
567 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
568 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
569 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
570 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
571 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
572 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
573 */
574__kernel void gemm_reshape_rhs_matrix_nt(TENSOR3D_DECLARATION(src),
575 TENSOR3D_DECLARATION(dst))
576{
577 // Block size
578#define BLOCK_SIZE ((K0) * (N0))
579
580 // Output offset X
581#if defined(INTERLEAVE)
582#define OUTPUT_OFFSET_X (N0)
583#else // defined(INTERLEAVE)
584#define OUTPUT_OFFSET_X (BLOCK_SIZE)
585#endif // defined(INTERLEAVE)
586
587 // Output step X
588#if defined(INTERLEAVE)
589#define OUTPUT_STEP_X (N0) * (H0)
590#else // Do not interleave
591#define OUTPUT_STEP_X (N0)
592#endif // defined(INTERLEAVE)
593
594 // Compute source and destination addresses
595 uint x = get_global_id(0);
596 uint y = get_global_id(1);
597 uint z = get_global_id(2);
598
599 // ------------------ Compute input/output addresses ---------------------------
600
601 // Compute the input address
602 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)N0 * sizeof(DATA_TYPE) + y * (uint)K0 * src_stride_y + z * (uint)src_stride_z;
603
604 // Compute the output address
605 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (y * (uint)BLOCK_SIZE * (uint)H0 * sizeof(DATA_TYPE)) + ((x % (uint)H0) * (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE)) + ((
606 x / (uint)H0)
607 * (uint)dst_stride_y)
608 + z * (uint)dst_stride_z;
609
610 // ---------------------------Load input values --------------------------------
611
612 VEC_DATA_TYPE(DATA_TYPE, N0)
613 a0 = 0;
614 VEC_DATA_TYPE(DATA_TYPE, N0)
615 a1 = 0;
616 VEC_DATA_TYPE(DATA_TYPE, N0)
617 a2 = 0;
618 VEC_DATA_TYPE(DATA_TYPE, N0)
619 a3 = 0;
620 VEC_DATA_TYPE(DATA_TYPE, N0)
621 a4 = 0;
622 VEC_DATA_TYPE(DATA_TYPE, N0)
623 a5 = 0;
624 VEC_DATA_TYPE(DATA_TYPE, N0)
625 a6 = 0;
626 VEC_DATA_TYPE(DATA_TYPE, N0)
627 a7 = 0;
628 VEC_DATA_TYPE(DATA_TYPE, N0)
629 a8 = 0;
630 VEC_DATA_TYPE(DATA_TYPE, N0)
631 a9 = 0;
632 VEC_DATA_TYPE(DATA_TYPE, N0)
633 aA = 0;
634 VEC_DATA_TYPE(DATA_TYPE, N0)
635 aB = 0;
636 VEC_DATA_TYPE(DATA_TYPE, N0)
637 aC = 0;
638 VEC_DATA_TYPE(DATA_TYPE, N0)
639 aD = 0;
640 VEC_DATA_TYPE(DATA_TYPE, N0)
641 aE = 0;
642 VEC_DATA_TYPE(DATA_TYPE, N0)
643 aF = 0;
644
645 // Load values from the RHS matrix
646 a0 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y));
647#if K0 > 1
648 if(y * (uint)K0 + 1 < SRC_HEIGHT)
649 {
650 a1 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y));
651 }
652#endif // K0 > 1
653#if K0 > 2
654 if(y * (uint)K0 + 2 < SRC_HEIGHT)
655 {
656 a2 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y));
657 }
658 if(y * (uint)K0 + 3 < SRC_HEIGHT)
659 {
660 a3 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y));
661 }
662#endif // K0 > 2
663#if K0 > 4
664 if(y * (uint)K0 + 4 < SRC_HEIGHT)
665 {
666 a4 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 4 * src_stride_y));
667 }
668 if(y * (uint)K0 + 5 < SRC_HEIGHT)
669 {
670 a5 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 5 * src_stride_y));
671 }
672 if(y * (uint)K0 + 6 < SRC_HEIGHT)
673 {
674 a6 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 6 * src_stride_y));
675 }
676 if(y * (uint)K0 + 7 < SRC_HEIGHT)
677 {
678 a7 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 7 * src_stride_y));
679 }
680#endif // K0 > 4
681#if K0 > 8
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000682 if(y * (uint)K0 + 8 < SRC_HEIGHT)
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000683 {
684 a8 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 8 * src_stride_y));
685 }
686 if(y * (uint)K0 + 9 < SRC_HEIGHT)
687 {
688 a9 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 9 * src_stride_y));
689 }
690 if(y * (uint)K0 + 10 < SRC_HEIGHT)
691 {
692 aA = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 10 * src_stride_y));
693 }
694 if(y * (uint)K0 + 11 < SRC_HEIGHT)
695 {
696 aB = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 11 * src_stride_y));
697 }
698 if(y * (uint)K0 + 12 < SRC_HEIGHT)
699 {
700 aC = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 12 * src_stride_y));
701 }
702 if(y * (uint)K0 + 13 < SRC_HEIGHT)
703 {
704 aD = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 13 * src_stride_y));
705 }
706 if(y * (uint)K0 + 14 < SRC_HEIGHT)
707 {
708 aE = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 14 * src_stride_y));
709 }
710 if(y * (uint)K0 + 15 < SRC_HEIGHT)
711 {
712 aF = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 15 * src_stride_y));
713 }
714#endif // K0 > 8
715
716 // ---------------------------Store output values ------------------------------
717
718 VSTORE(N0)
719 (a0, 0, (__global DATA_TYPE *)(output_ptr + 0 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
720#if K0 > 1
721 VSTORE(N0)
722 (a1, 0, (__global DATA_TYPE *)(output_ptr + 1 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
723#endif // K0 > 1
724#if K0 > 2
725 VSTORE(N0)
726 (a2, 0, (__global DATA_TYPE *)(output_ptr + 2 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
727 VSTORE(N0)
728 (a3, 0, (__global DATA_TYPE *)(output_ptr + 3 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
729#endif // K0 > 2
730#if K0 > 4
731 VSTORE(N0)
732 (a4, 0, (__global DATA_TYPE *)(output_ptr + 4 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
733 VSTORE(N0)
734 (a5, 0, (__global DATA_TYPE *)(output_ptr + 5 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
735 VSTORE(N0)
736 (a6, 0, (__global DATA_TYPE *)(output_ptr + 6 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
737 VSTORE(N0)
738 (a7, 0, (__global DATA_TYPE *)(output_ptr + 7 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
739#endif // N0 > 4
740#if K0 > 8
741 VSTORE(N0)
742 (a8, 0, (__global DATA_TYPE *)(output_ptr + 8 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
743 VSTORE(N0)
744 (a9, 0, (__global DATA_TYPE *)(output_ptr + 9 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
745 VSTORE(N0)
746 (aA, 0, (__global DATA_TYPE *)(output_ptr + 10 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
747 VSTORE(N0)
748 (aB, 0, (__global DATA_TYPE *)(output_ptr + 11 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
749 VSTORE(N0)
750 (aC, 0, (__global DATA_TYPE *)(output_ptr + 12 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
751 VSTORE(N0)
752 (aD, 0, (__global DATA_TYPE *)(output_ptr + 13 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
753 VSTORE(N0)
754 (aE, 0, (__global DATA_TYPE *)(output_ptr + 14 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
755 VSTORE(N0)
756 (aF, 0, (__global DATA_TYPE *)(output_ptr + 15 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
757#endif // N0 > 8
758
759#undef BLOCK_SIZE
760#undef OUTPUT_OFFSET_X
761#undef OUTPUT_STEP_X
762}
763
764#if defined(TRANSPOSE)
765/** This OpenCL kernel reshapes the rhs input matrix. The kernel splits the input matrix in blocks of size K0xN0 and stores each one (transposed) in
766 * the output matrix unrolling the values.
767 *
768 * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float)
769 * @note The height of the input tensor must be passed at compile time using -DSRC_HEIGHT (i.e. -DSRC_HEIGHT=16)
770 * @note The block's dimensions (K0 and N0) must be passed at compile time using -DK0 and -DN0 (i.e. -DK0=2, -DN0=2).
771 * @note The number of K0xN0 vertical blocks to store on the same output row must be passed at compile time using -DH0 (i.e. -DH0=2)
772 * @note If the K0xN0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
773 * @note The option -DTRANSPOSE must passed at compile time.
774 * @note Only the following values for K0, N0 and H0 are supported:
775 * N0: 2,4,8,16
776 * K0: 4,8,16
777 * H0: greater than 0
778 *
779 * @param[in] src_ptr Pointer to the source RHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
780 * @param[in] src_stride_x Stride of the source RHS tensor in X dimension (in bytes)
781 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
782 * @param[in] src_stride_y Stride of the source RHS tensor in Y dimension (in bytes)
783 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
784 * @param[in] src_stride_z Stride of the source RHS tensor in Z dimension (in bytes)
785 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
786 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source RHS tensor
787 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
788 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
789 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
790 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
791 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
792 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
793 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
794 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
795 */
796__kernel void gemm_reshape_rhs_matrix_t(TENSOR3D_DECLARATION(src),
797 TENSOR3D_DECLARATION(dst))
798{
799 // Block size
800#define BLOCK_SIZE ((K0) * (N0))
801
802 // Output offset X
803#if defined(INTERLEAVE)
804#define OUTPUT_OFFSET_X (K0)
805#else // defined(INTERLEAVE)
806#define OUTPUT_OFFSET_X (BLOCK_SIZE)
807#endif // defined(INTERLEAVE)
808
809 // Output step X
810#if defined(INTERLEAVE)
811#define OUTPUT_STEP_X (K0) * (H0)
812#else // Do not interleave
813#define OUTPUT_STEP_X (K0)
814#endif // defined(INTERLEAVE)
815
816 // Compute source and destination addresses
817 uint x = get_global_id(0);
818 uint y = get_global_id(1);
819 uint z = get_global_id(2);
820
821 // ------------------ Compute input/output addresses ---------------------------
822
823 // Compute the input address
824 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)N0 * sizeof(DATA_TYPE) + y * (uint)K0 * src_stride_y + z * (uint)src_stride_z;
825
826 // Compute the output address
827 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (y * (uint)BLOCK_SIZE * (uint)H0 * sizeof(DATA_TYPE)) + ((x % H0) * (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE)) + ((x /
828 (uint)H0) * (uint)dst_stride_y) + z * (uint)dst_stride_z;
829
830 // ---------------------------Load input values --------------------------------
831
832 VEC_DATA_TYPE(DATA_TYPE, N0)
833 a0 = 0;
834 VEC_DATA_TYPE(DATA_TYPE, N0)
835 a1 = 0;
836 VEC_DATA_TYPE(DATA_TYPE, N0)
837 a2 = 0;
838 VEC_DATA_TYPE(DATA_TYPE, N0)
839 a3 = 0;
840 VEC_DATA_TYPE(DATA_TYPE, N0)
841 a4 = 0;
842 VEC_DATA_TYPE(DATA_TYPE, N0)
843 a5 = 0;
844 VEC_DATA_TYPE(DATA_TYPE, N0)
845 a6 = 0;
846 VEC_DATA_TYPE(DATA_TYPE, N0)
847 a7 = 0;
848 VEC_DATA_TYPE(DATA_TYPE, N0)
849 a8 = 0;
850 VEC_DATA_TYPE(DATA_TYPE, N0)
851 a9 = 0;
852 VEC_DATA_TYPE(DATA_TYPE, N0)
853 aA = 0;
854 VEC_DATA_TYPE(DATA_TYPE, N0)
855 aB = 0;
856 VEC_DATA_TYPE(DATA_TYPE, N0)
857 aC = 0;
858 VEC_DATA_TYPE(DATA_TYPE, N0)
859 aD = 0;
860 VEC_DATA_TYPE(DATA_TYPE, N0)
861 aE = 0;
862 VEC_DATA_TYPE(DATA_TYPE, N0)
863 aF = 0;
864
865 // Load values from the RHS matrix
866 a0 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y));
867 if(y * (uint)K0 + 1 < SRC_HEIGHT)
868 {
869 a1 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y));
870 }
871 if(y * (uint)K0 + 2 < SRC_HEIGHT)
872 {
873 a2 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y));
874 }
875 if(y * (uint)K0 + 3 < SRC_HEIGHT)
876 {
877 a3 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y));
878 }
879#if K0 > 4
880 if(y * (uint)K0 + 4 < SRC_HEIGHT)
881 {
882 a4 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 4 * src_stride_y));
883 }
884 if(y * (uint)K0 + 5 < SRC_HEIGHT)
885 {
886 a5 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 5 * src_stride_y));
887 }
888 if(y * (uint)K0 + 6 < SRC_HEIGHT)
889 {
890 a6 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 6 * src_stride_y));
891 }
892 if(y * (uint)K0 + 7 < SRC_HEIGHT)
893 {
894 a7 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 7 * src_stride_y));
895 }
896#endif // K0 > 4
897#if K0 > 8
Gian Marco Iodice89124342018-12-19 14:17:22 +0000898 if(y * (uint)K0 + 8 < SRC_HEIGHT)
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000899 {
900 a8 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 8 * src_stride_y));
901 }
902 if(y * (uint)K0 + 9 < SRC_HEIGHT)
903 {
904 a9 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 9 * src_stride_y));
905 }
906 if(y * (uint)K0 + 10 < SRC_HEIGHT)
907 {
908 aA = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 10 * src_stride_y));
909 }
910 if(y * (uint)K0 + 11 < SRC_HEIGHT)
911 {
912 aB = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 11 * src_stride_y));
913 }
914 if(y * (uint)K0 + 12 < SRC_HEIGHT)
915 {
916 aC = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 12 * src_stride_y));
917 }
918 if(y * (uint)K0 + 13 < SRC_HEIGHT)
919 {
920 aD = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 13 * src_stride_y));
921 }
922 if(y * (uint)K0 + 14 < SRC_HEIGHT)
923 {
924 aE = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 14 * src_stride_y));
925 }
926 if(y * (uint)K0 + 15 < SRC_HEIGHT)
927 {
928 aF = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 15 * src_stride_y));
929 }
930#endif // K0 > 8
931
932 // ---------------------------Transpose the block ------------------------------
933
934 VEC_DATA_TYPE(DATA_TYPE, K0)
935 res0 = 0;
936 VEC_DATA_TYPE(DATA_TYPE, K0)
937 res1 = 0;
938 VEC_DATA_TYPE(DATA_TYPE, K0)
939 res2 = 0;
940 VEC_DATA_TYPE(DATA_TYPE, K0)
941 res3 = 0;
942 VEC_DATA_TYPE(DATA_TYPE, K0)
943 res4 = 0;
944 VEC_DATA_TYPE(DATA_TYPE, K0)
945 res5 = 0;
946 VEC_DATA_TYPE(DATA_TYPE, K0)
947 res6 = 0;
948 VEC_DATA_TYPE(DATA_TYPE, K0)
949 res7 = 0;
950 VEC_DATA_TYPE(DATA_TYPE, K0)
951 res8 = 0;
952 VEC_DATA_TYPE(DATA_TYPE, K0)
953 res9 = 0;
954 VEC_DATA_TYPE(DATA_TYPE, K0)
955 resA = 0;
956 VEC_DATA_TYPE(DATA_TYPE, K0)
957 resB = 0;
958 VEC_DATA_TYPE(DATA_TYPE, K0)
959 resC = 0;
960 VEC_DATA_TYPE(DATA_TYPE, K0)
961 resD = 0;
962 VEC_DATA_TYPE(DATA_TYPE, K0)
963 resE = 0;
964 VEC_DATA_TYPE(DATA_TYPE, K0)
965 resF = 0;
966
967#if K0 == 4
968 // This part computes the following transpositions:
969 // 4x2 -> 2x4
970 // 4x4 -> 4x4
971 // 4x8 -> 8x4
972 // 4x16 -> 16x4
973 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0, a3.s0);
974 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1, a3.s1);
975#if N0 > 2
976 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2, a3.s2);
977 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3, a3.s3);
978#endif // N0 > 2
979#if N0 > 4
980 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4, a3.s4);
981 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5, a3.s5);
982 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6, a3.s6);
983 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7, a3.s7);
984#endif // N0 > 4
985#if N0 > 8
986 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8, a3.s8);
987 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9, a3.s9);
988 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA, a3.sA);
989 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB, a3.sB);
990 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC, a3.sC);
991 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD, a3.sD);
992 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE, a3.sE);
993 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF, a3.sF);
994#endif // N0 > 8
995
996#elif K0 == 8 // N0 == 3
997 // This part computes the following transpositions:
998 // 8x2 -> 2x8
999 // 8x4 -> 4x8
1000 // 8x8 -> 8x8
1001 // 8x16 -> 16x8
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001002 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0, a3.s0, a4.s0, a5.s0, a6.s0, a7.s0);
1003 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1, a3.s1, a4.s1, a5.s1, a6.s1, a7.s1);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +00001004#if N0 > 2
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001005 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2, a3.s2, a4.s2, a5.s2, a6.s2, a7.s2);
1006 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3, a3.s3, a4.s3, a5.s3, a6.s3, a7.s3);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +00001007#endif // N0 > 2
1008#if N0 > 4
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001009 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4, a3.s4, a4.s4, a5.s4, a6.s4, a7.s4);
1010 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5, a3.s5, a4.s5, a5.s5, a6.s5, a7.s5);
1011 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6, a3.s6, a4.s6, a5.s6, a6.s6, a7.s6);
1012 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7, a3.s7, a4.s7, a5.s7, a6.s7, a7.s7);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +00001013#endif // N0 > 4
1014#if N0 > 8
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001015 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8, a3.s8, a4.s8, a5.s8, a6.s8, a7.s8);
1016 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9, a3.s9, a4.s9, a5.s9, a6.s9, a7.s9);
1017 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA, a3.sA, a4.sA, a5.sA, a6.sA, a7.sA);
1018 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB, a3.sB, a4.sB, a5.sB, a6.sB, a7.sB);
1019 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC, a3.sC, a4.sC, a5.sC, a6.sC, a7.sC);
1020 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD, a3.sD, a4.sD, a5.sD, a6.sD, a7.sD);
1021 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE, a3.sE, a4.sE, a5.sE, a6.sE, a7.sE);
1022 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF, a3.sF, a4.sF, a5.sF, a6.sF, a7.sF);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +00001023#endif // N0 > 8
1024
1025#elif K0 == 16 // N0 == 16
1026
1027 // This part computes the following transpositions:
1028 // 16x2 -> 2x16
1029 // 16x4 -> 4x16
1030 // 16x8 -> 8x16
1031 // 16x16 -> 16x16
1032 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0, a3.s0, a4.s0, a5.s0, a6.s0, a7.s0,
1033 a8.s0, a9.s0, aA.s0, aB.s0, aC.s0, aD.s0, aE.s0, aF.s0);
1034 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1, a3.s1, a4.s1, a5.s1, a6.s1, a7.s1,
1035 a8.s1, a9.s1, aA.s1, aB.s1, aC.s1, aD.s1, aE.s1, aF.s1);
1036#if N0 > 2
1037 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2, a3.s2, a4.s2, a5.s2, a6.s2, a7.s2,
1038 a8.s2, a9.s2, aA.s2, aB.s2, aC.s2, aD.s2, aE.s2, aF.s2);
1039 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3, a3.s3, a4.s3, a5.s3, a6.s3, a7.s3,
1040 a8.s3, a9.s3, aA.s3, aB.s3, aC.s3, aD.s3, aE.s3, aF.s3);
1041#endif // N0 > 2
1042#if N0 > 4
1043 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4, a3.s4, a4.s4, a5.s4, a6.s4, a7.s4,
1044 a8.s4, a9.s4, aA.s4, aB.s4, aC.s4, aD.s4, aE.s4, aF.s4);
1045 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5, a3.s5, a4.s5, a5.s5, a6.s5, a7.s5,
1046 a8.s5, a9.s5, aA.s5, aB.s5, aC.s5, aD.s5, aE.s5, aF.s5);
1047 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6, a3.s6, a4.s6, a5.s6, a6.s6, a7.s6,
1048 a8.s6, a9.s6, aA.s6, aB.s6, aC.s6, aD.s6, aE.s6, aF.s6);
1049 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7, a3.s7, a4.s7, a5.s7, a6.s7, a7.s7,
1050 a8.s7, a9.s7, aA.s7, aB.s7, aC.s7, aD.s7, aE.s7, aF.s7);
1051#endif // N0 > 4
1052#if N0 > 8
1053 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8, a3.s8, a4.s8, a5.s8, a6.s8, a7.s8,
1054 a8.s8, a9.s8, aA.s8, aB.s8, aC.s8, aD.s8, aE.s8, aF.s8);
1055 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9, a3.s9, a4.s9, a5.s9, a6.s9, a7.s9,
1056 a8.s9, a9.s9, aA.s9, aB.s9, aC.s9, aD.s9, aE.s9, aF.s9);
1057 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA, a3.sA, a4.sA, a5.sA, a6.sA, a7.sA,
1058 a8.sA, a9.sA, aA.sA, aB.sA, aC.sA, aD.sA, aE.sA, aF.sA);
1059 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB, a3.sB, a4.sB, a5.sB, a6.sB, a7.sB,
1060 a8.sB, a9.sB, aA.sB, aB.sB, aC.sB, aD.sB, aE.sB, aF.sB);
1061 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC, a3.sC, a4.sC, a5.sC, a6.sC, a7.sC,
1062 a8.sC, a9.sC, aA.sC, aB.sC, aC.sC, aD.sC, aE.sC, aF.sC);
1063 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD, a3.sD, a4.sD, a5.sD, a6.sD, a7.sD,
1064 a8.sD, a9.sD, aA.sD, aB.sD, aC.sD, aD.sD, aE.sD, aF.sD);
1065 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE, a3.sE, a4.sE, a5.sE, a6.sE, a7.sE,
1066 a8.sE, a9.sE, aA.sE, aB.sE, aC.sE, aD.sE, aE.sE, aF.sE);
1067 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF, a3.sF, a4.sF, a5.sF, a6.sF, a7.sF,
1068 a8.sF, a9.sF, aA.sF, aB.sF, aC.sF, aD.sF, aE.sF, aF.sF);
1069#endif // N0 > 8
1070
1071#else // N0 == 16
1072#error "Not supported N0 value"
1073#endif // N0 > 2
1074
1075 // ---------------------------Store the output values ------------------------------
1076
1077 VSTORE(K0)
1078 (res0, 0, (__global DATA_TYPE *)(output_ptr + 0 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
1079 VSTORE(K0)
1080 (res1, 0, (__global DATA_TYPE *)(output_ptr + 1 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
1081#if N0 > 2
1082 VSTORE(K0)
1083 (res2, 0, (__global DATA_TYPE *)(output_ptr + 2 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
1084 VSTORE(K0)
1085 (res3, 0, (__global DATA_TYPE *)(output_ptr + 3 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
1086#endif // N0 > 2
1087#if N0 > 4
1088 VSTORE(K0)
1089 (res4, 0, (__global DATA_TYPE *)(output_ptr + 4 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
1090 VSTORE(K0)
1091 (res5, 0, (__global DATA_TYPE *)(output_ptr + 5 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
1092 VSTORE(K0)
1093 (res6, 0, (__global DATA_TYPE *)(output_ptr + 6 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
1094 VSTORE(K0)
1095 (res7, 0, (__global DATA_TYPE *)(output_ptr + 7 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
1096#endif // N0 > 4
1097#if N0 > 8
1098 VSTORE(K0)
1099 (res8, 0, (__global DATA_TYPE *)(output_ptr + 8 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
1100 VSTORE(K0)
1101 (res9, 0, (__global DATA_TYPE *)(output_ptr + 9 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
1102 VSTORE(K0)
1103 (resA, 0, (__global DATA_TYPE *)(output_ptr + 10 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
1104 VSTORE(K0)
1105 (resB, 0, (__global DATA_TYPE *)(output_ptr + 11 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
1106 VSTORE(K0)
1107 (resC, 0, (__global DATA_TYPE *)(output_ptr + 12 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
1108 VSTORE(K0)
1109 (resD, 0, (__global DATA_TYPE *)(output_ptr + 13 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
1110 VSTORE(K0)
1111 (resE, 0, (__global DATA_TYPE *)(output_ptr + 14 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
1112 VSTORE(K0)
1113 (resF, 0, (__global DATA_TYPE *)(output_ptr + 15 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
1114#endif // N0 > 8
1115
1116#undef BLOCK_SIZE
1117#undef OUTPUT_OFFSET_X
1118#undef OUTPUT_STEP_X
1119}
1120#endif // defined(TRANSPOSE)
1121#endif // defined(K0) && defined(N0) && defined(H0) && defined(DATA_TYPE) && defined(SRC_HEIGHT)
1122
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001123#if defined(M0) && defined(N0) && defined(K0) && defined(V0) && defined(H0) && defined(K) && defined(DATA_TYPE)
1124
1125#define ARM_DOT(x, y, val) \
1126 ({ \
1127 val = fma(x.s0, y.s0, val); \
1128 val = fma(x.s1, y.s1, val); \
1129 val = fma(x.s2, y.s2, val); \
1130 val = fma(x.s3, y.s3, val); \
1131 })
1132
1133#if K0 == 4
1134#define ARM_DOT_K0(a, b, c) \
1135 ({ \
1136 ARM_DOT(a, b, c); \
1137 })
1138#elif K0 == 8 // K0 == 8
1139#define ARM_DOT_K0(a, b, c) \
1140 ({ \
1141 ARM_DOT((a).s0123, (b).s0123, c); \
1142 ARM_DOT((a).s4567, (b).s4567, c); \
1143 })
1144#elif K0 == 16 // K0 == 16
1145#define ARM_DOT_K0(a, b, c) \
1146 ({ \
1147 ARM_DOT((a).s0123, (b).s0123, c); \
1148 ARM_DOT((a).s4567, (b).s4567, c); \
1149 ARM_DOT((a).s89AB, (b).s89AB, c); \
1150 ARM_DOT((a).sCDEF, (b).sCDEF, c); \
1151 })
1152#else // K0 not supported
1153#error "K0 value not supported"
1154#endif // K0 conditions
1155
1156#if N0 == 2
1157#define ARM_DOT_K0XN0(a, b, c) \
1158 ({ \
1159 ARM_DOT_K0((a), (b##0), (c.s0)); \
1160 ARM_DOT_K0((a), (b##1), (c.s1)); \
1161 })
1162#elif N0 == 4 // N0 == 4
1163#define ARM_DOT_K0XN0(a, b, c) \
1164 ({ \
1165 ARM_DOT_K0((a), (b##0), (c.s0)); \
1166 ARM_DOT_K0((a), (b##1), (c.s1)); \
1167 ARM_DOT_K0((a), (b##2), (c.s2)); \
1168 ARM_DOT_K0((a), (b##3), (c.s3)); \
1169 })
1170#elif N0 == 8 // N0 == 8
1171#define ARM_DOT_K0XN0(a, b, c) \
1172 ({ \
1173 ARM_DOT_K0((a), (b##0), (c.s0)); \
1174 ARM_DOT_K0((a), (b##1), (c.s1)); \
1175 ARM_DOT_K0((a), (b##2), (c.s2)); \
1176 ARM_DOT_K0((a), (b##3), (c.s3)); \
1177 ARM_DOT_K0((a), (b##4), (c.s4)); \
1178 ARM_DOT_K0((a), (b##5), (c.s5)); \
1179 ARM_DOT_K0((a), (b##6), (c.s6)); \
1180 ARM_DOT_K0((a), (b##7), (c.s7)); \
1181 })
1182#elif N0 == 16 // N0 == 16
1183#define ARM_DOT_K0XN0(a, b, c) \
1184 ({ \
1185 ARM_DOT_K0((a), (b##0), (c.s0)); \
1186 ARM_DOT_K0((a), (b##1), (c.s1)); \
1187 ARM_DOT_K0((a), (b##2), (c.s2)); \
1188 ARM_DOT_K0((a), (b##3), (c.s3)); \
1189 ARM_DOT_K0((a), (b##4), (c.s4)); \
1190 ARM_DOT_K0((a), (b##5), (c.s5)); \
1191 ARM_DOT_K0((a), (b##6), (c.s6)); \
1192 ARM_DOT_K0((a), (b##7), (c.s7)); \
1193 ARM_DOT_K0((a), (b##8), (c.s8)); \
1194 ARM_DOT_K0((a), (b##9), (c.s9)); \
1195 ARM_DOT_K0((a), (b##A), (c.sA)); \
1196 ARM_DOT_K0((a), (b##B), (c.sB)); \
1197 ARM_DOT_K0((a), (b##C), (c.sC)); \
1198 ARM_DOT_K0((a), (b##D), (c.sD)); \
1199 ARM_DOT_K0((a), (b##E), (c.sE)); \
1200 ARM_DOT_K0((a), (b##F), (c.sF)); \
1201 })
1202#else // N0 not supported
1203#error "N0 value not supported"
1204#endif // N0 conditions
1205
1206/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
1207 * The LHS matrix must be reshaped with @ref CLGEMMReshapeLHSMatrixKernel and the M0xK0 must be NOT transposed
1208 * The RHS matrix must be reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the K0xN0 must be transposed
1209 *
1210 * @note The number of columns in the RHS matrix NOT reshaped needs to be passed at compile time using -DK (i.e. -Dk=128).
1211 * @note The block's dimensions used for reshaping the LHS matrix and the RHS matrix (M0, N0 and K0) must be passed at compile time using -DM0, -DN0 and -DK0 (i.e. -DM0=4, -DN0=8, -DK0=4).
1212 * @note The number of M0xK0 vertical blocks stored on the same output row of the reshaped LHS matrix must be passed at compile time using -DV0 (i.e. -DV0=2)
1213 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (i.e. -DH0=2)
1214 * @note If the M0xK0 blocks in the reshaped LHS matrix have been interleaved, the option -DLHS_INTERLEAVE must passed at compile time.
1215 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
1216 * @note Only the following configurations of M0, N0 and K0 are currently supported:
1217 * - M0 = 2, 3, 4, 5, 6, 7, 8
1218 * - N0 = 2, 4, 8, 16
1219 * - K0 = 4, 8, 16
1220 *
1221 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
1222 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1223 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1224 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1225 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix NOT reshaped
1226 *
1227 * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: F16/F32
1228 * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes)
1229 * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1230 * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes)
1231 * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1232 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
Gian Marco Iodice49b10152018-12-14 17:13:34 +00001233 * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001234 * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes)
1235 * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1236 * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
1237 * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1238 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
Gian Marco Iodice49b10152018-12-14 17:13:34 +00001239 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001240 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1241 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1242 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1243 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1244 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1245 * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes)
1246 * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
1247 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1248 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
1249 */
1250__kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs),
1251 IMAGE_DECLARATION(rhs),
1252 IMAGE_DECLARATION(dst),
1253 uint lhs_stride_z,
1254 uint rhs_stride_z,
1255 uint dst_stride_z
1256#if defined(REINTERPRET_OUTPUT_AS_3D)
1257 ,
1258 uint dst_cross_plane_pad
1259#endif // REINTERPRET_OUTPUT_AS_3D
1260 )
1261{
1262 // Block size
1263#define LHS_BLOCK_SIZE ((K0) * (M0))
1264
1265#if defined(LHS_INTERLEAVE)
1266#define LHS_OFFSET_X (K0)
1267#define LHS_STEP_X ((K0) * (V0))
1268#define LHS_STEP_LOOP (1)
1269#else // defined(INTERLEAVE)
1270#define LHS_OFFSET_X (LHS_BLOCK_SIZE)
1271#define LHS_STEP_X (K0)
1272#define LHS_STEP_LOOP (V0)
1273#endif // defined(INTERLEAVE)
1274
1275 // Block size
1276#define RHS_BLOCK_SIZE ((K0) * (N0))
1277
1278 // RHS offset and step X
1279#if defined(RHS_INTERLEAVE)
1280#define RHS_OFFSET_X (K0)
1281#define RHS_STEP_X ((K0) * (H0))
1282#define RHS_STEP_LOOP (1)
1283#else // defined(RHS_INTERLEAVE)
1284#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
1285#define RHS_STEP_X (K0)
1286#define RHS_STEP_LOOP (H0)
1287#endif // defined(RHS_INTERLEAVE)
1288
1289 // Compute LHS matrix address
1290 __global uchar *lhs_addr = lhs_ptr + lhs_offset_first_element_in_bytes + (get_global_id(1) % V0) * (uint)LHS_OFFSET_X * sizeof(DATA_TYPE) + (get_global_id(1) / V0) * (uint)lhs_stride_y +
1291 (get_global_id(2) * lhs_stride_z);
1292
1293 // Compute RHS matrix address
1294 __global uchar *rhs_addr = rhs_ptr + rhs_offset_first_element_in_bytes + (get_global_id(0) % H0) * (uint)RHS_OFFSET_X * sizeof(DATA_TYPE) + (get_global_id(0) / (uint)H0) * rhs_stride_y;
1295
1296#if defined(MATRIX_B_DEPTH)
1297 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1298 rhs_addr += (get_global_id(2) % MATRIX_B_DEPTH) * rhs_stride_z;
1299#else // defined(MATRIX_B_DEPTH)
1300 rhs_addr += get_global_id(2) * rhs_stride_z;
1301#endif // defined(MATRIX_B_DEPTH)
1302
1303 // Initialize the accumulators
1304 VEC_DATA_TYPE(DATA_TYPE, N0)
1305 c0 = 0;
1306#if M0 > 1
1307 VEC_DATA_TYPE(DATA_TYPE, N0)
1308 c1 = 0;
1309#endif // M0 > 1
1310#if M0 > 2
1311 VEC_DATA_TYPE(DATA_TYPE, N0)
1312 c2 = 0;
1313#endif // M0 > 2
1314#if M0 > 3
1315 VEC_DATA_TYPE(DATA_TYPE, N0)
1316 c3 = 0;
1317#endif // M0 > 3
1318#if M0 > 4
1319 VEC_DATA_TYPE(DATA_TYPE, N0)
1320 c4 = 0;
1321#endif // M0 > 4
1322#if M0 > 5
1323 VEC_DATA_TYPE(DATA_TYPE, N0)
1324 c5 = 0;
1325#endif // M0 > 5
1326#if M0 > 6
1327 VEC_DATA_TYPE(DATA_TYPE, N0)
1328 c6 = 0;
1329#endif // M0 > 6
1330#if M0 > 7
1331 VEC_DATA_TYPE(DATA_TYPE, N0)
1332 c7 = 0;
1333#endif // M0 > 7
1334
1335 for(int i = 0; i < K; i += K0)
1336 {
1337 // Supported cases (M0, K0):
1338 // 2,4 - 2,8 - 2,16
1339 // 3,4 - 3,8 - 3,16
1340 // 4,4 - 4,8 - 4,16
1341 // 5,4 - 5,8 - 5,16
1342 // 6,4 - 6,8 - 6,16
1343 // Load values from LHS matrix
1344 VEC_DATA_TYPE(DATA_TYPE, K0)
1345 a0 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_addr + 0 * LHS_STEP_X * sizeof(DATA_TYPE)));
1346#if M0 > 1
1347 VEC_DATA_TYPE(DATA_TYPE, K0)
1348 a1 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_addr + 1 * LHS_STEP_X * sizeof(DATA_TYPE)));
1349#endif // M0 > 1
1350#if M0 > 2
1351 VEC_DATA_TYPE(DATA_TYPE, K0)
1352 a2 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_addr + 2 * LHS_STEP_X * sizeof(DATA_TYPE)));
1353#endif // M0 > 2
1354#if M0 > 3
1355 VEC_DATA_TYPE(DATA_TYPE, K0)
1356 a3 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_addr + 3 * LHS_STEP_X * sizeof(DATA_TYPE)));
1357#endif // M0 > 3
1358#if M0 > 4
1359 VEC_DATA_TYPE(DATA_TYPE, K0)
1360 a4 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_addr + 4 * LHS_STEP_X * sizeof(DATA_TYPE)));
1361#endif // M0 > 4
1362#if M0 > 5
1363 VEC_DATA_TYPE(DATA_TYPE, K0)
1364 a5 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_addr + 5 * LHS_STEP_X * sizeof(DATA_TYPE)));
1365#endif // M0 > 5
1366#if M0 > 6
1367 VEC_DATA_TYPE(DATA_TYPE, K0)
1368 a6 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_addr + 6 * LHS_STEP_X * sizeof(DATA_TYPE)));
1369#endif // M0 > 6
1370#if M0 > 7
1371 VEC_DATA_TYPE(DATA_TYPE, K0)
1372 a7 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_addr + 7 * LHS_STEP_X * sizeof(DATA_TYPE)));
1373#endif // M0 > 7
1374
1375 // Load values from RHS matrix
1376 VEC_DATA_TYPE(DATA_TYPE, K0)
1377 b0 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 0 * RHS_STEP_X * sizeof(DATA_TYPE)));
1378 VEC_DATA_TYPE(DATA_TYPE, K0)
1379 b1 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 1 * RHS_STEP_X * sizeof(DATA_TYPE)));
1380#if N0 > 2
1381 VEC_DATA_TYPE(DATA_TYPE, K0)
1382 b2 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 2 * RHS_STEP_X * sizeof(DATA_TYPE)));
1383 VEC_DATA_TYPE(DATA_TYPE, K0)
1384 b3 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 3 * RHS_STEP_X * sizeof(DATA_TYPE)));
1385#endif // N0 > 2
1386#if N0 > 4
1387 VEC_DATA_TYPE(DATA_TYPE, K0)
1388 b4 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 4 * RHS_STEP_X * sizeof(DATA_TYPE)));
1389 VEC_DATA_TYPE(DATA_TYPE, K0)
1390 b5 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 5 * RHS_STEP_X * sizeof(DATA_TYPE)));
1391 VEC_DATA_TYPE(DATA_TYPE, K0)
1392 b6 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 6 * RHS_STEP_X * sizeof(DATA_TYPE)));
1393 VEC_DATA_TYPE(DATA_TYPE, K0)
1394 b7 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 7 * RHS_STEP_X * sizeof(DATA_TYPE)));
1395#endif // N0 > 4
1396#if N0 > 8
1397 VEC_DATA_TYPE(DATA_TYPE, K0)
1398 b8 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 8 * RHS_STEP_X * sizeof(DATA_TYPE)));
1399 VEC_DATA_TYPE(DATA_TYPE, K0)
1400 b9 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 9 * RHS_STEP_X * sizeof(DATA_TYPE)));
1401 VEC_DATA_TYPE(DATA_TYPE, K0)
1402 bA = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 10 * RHS_STEP_X * sizeof(DATA_TYPE)));
1403 VEC_DATA_TYPE(DATA_TYPE, K0)
1404 bB = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 11 * RHS_STEP_X * sizeof(DATA_TYPE)));
1405 VEC_DATA_TYPE(DATA_TYPE, K0)
1406 bC = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 12 * RHS_STEP_X * sizeof(DATA_TYPE)));
1407 VEC_DATA_TYPE(DATA_TYPE, K0)
1408 bD = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 13 * RHS_STEP_X * sizeof(DATA_TYPE)));
1409 VEC_DATA_TYPE(DATA_TYPE, K0)
1410 bE = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 14 * RHS_STEP_X * sizeof(DATA_TYPE)));
1411 VEC_DATA_TYPE(DATA_TYPE, K0)
1412 bF = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 15 * RHS_STEP_X * sizeof(DATA_TYPE)));
1413#endif // N0 > 8
1414
1415 // Accumulate
1416 ARM_DOT_K0XN0(a0, b, c0);
1417#if M0 > 1
1418 ARM_DOT_K0XN0(a1, b, c1);
1419#endif // M0 > 1
1420#if M0 > 2
1421 ARM_DOT_K0XN0(a2, b, c2);
1422#endif // M0 > 2
1423#if M0 > 3
1424 ARM_DOT_K0XN0(a3, b, c3);
1425#endif // M0 > 3
1426#if M0 > 4
1427 ARM_DOT_K0XN0(a4, b, c4);
1428#endif // M0 > 4
1429#if M0 > 5
1430 ARM_DOT_K0XN0(a5, b, c5);
1431#endif // M0 > 5
1432#if M0 > 6
1433 ARM_DOT_K0XN0(a6, b, c6);
1434#endif // M0 > 6
1435#if M0 > 7
1436 ARM_DOT_K0XN0(a7, b, c7);
1437#endif // M0 > 7
1438
1439 lhs_addr += (M0 * LHS_STEP_X * LHS_STEP_LOOP) * sizeof(DATA_TYPE);
1440 rhs_addr += (N0 * RHS_STEP_X * RHS_STEP_LOOP) * sizeof(DATA_TYPE);
1441 }
1442
1443 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * dst_stride_y);
1444
1445 uint zout0 = 0;
1446 uint zout1 = 0;
1447 uint zout2 = 0;
1448 uint zout3 = 0;
1449 uint zout4 = 0;
1450 uint zout5 = 0;
1451 uint zout6 = 0;
1452 uint zout7 = 0;
1453
1454#if defined(REINTERPRET_OUTPUT_AS_3D)
1455 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
1456 // in order to take into account the presence of possible cross plane paddings
1457 //
1458 // | |
1459 // | plane0 |
1460 // | |
1461 // |__________________|
1462 // |******************|
1463 // | cross_plane_pad |
1464 // |******************|
1465 // | |
1466 // | plane1 |
1467 // | |
1468 // |__________________|
1469
1470 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
1471 zout0 = (0 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1472 zout0 = min((uint)(DEPTH_GEMM3D - 1), zout0);
Gian Marco Iodice49b10152018-12-14 17:13:34 +00001473 zout0 *= (dst_cross_plane_pad * dst_stride_y);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001474#if M0 > 1
1475 zout1 = (1 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1476 zout1 = min((uint)(DEPTH_GEMM3D - 1), zout1);
Gian Marco Iodice49b10152018-12-14 17:13:34 +00001477 zout1 *= (dst_cross_plane_pad * dst_stride_y);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001478#endif // M0 > 1
1479#if M0 > 2
1480 zout2 = (2 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1481 zout2 = min((uint)(DEPTH_GEMM3D - 1), zout2);
Gian Marco Iodice49b10152018-12-14 17:13:34 +00001482 zout2 *= (dst_cross_plane_pad * dst_stride_y);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001483#endif // M0 > 2
1484#if M0 > 3
1485 zout3 = (3 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1486 zout3 = min((uint)(DEPTH_GEMM3D - 1), zout3);
Gian Marco Iodice49b10152018-12-14 17:13:34 +00001487 zout3 *= (dst_cross_plane_pad * dst_stride_y);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001488#endif // M0 > 3
1489#if M0 > 4
1490 zout4 = (4 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1491 zout4 = min((uint)(DEPTH_GEMM3D - 1), zout4);
Gian Marco Iodice49b10152018-12-14 17:13:34 +00001492 zout4 *= (dst_cross_plane_pad * dst_stride_y);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001493#endif // M0 > 4
1494#if M0 > 5
1495 zout5 = (5 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1496 zout5 = min((uint)(DEPTH_GEMM3D - 1), zout5);
Gian Marco Iodice49b10152018-12-14 17:13:34 +00001497 zout5 *= (dst_cross_plane_pad * dst_stride_y);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001498#endif // M0 > 5
1499#if M0 > 6
1500 zout6 = (6 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1501 zout6 = min((uint)(DEPTH_GEMM3D - 1), zout6);
Gian Marco Iodice49b10152018-12-14 17:13:34 +00001502 zout6 *= (dst_cross_plane_pad * dst_stride_y);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001503#endif // M0 > 6
1504#if M0 > 6
1505 zout7 = (7 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1506 zout7 = min((uint)(DEPTH_GEMM3D - 1), zout7);
Gian Marco Iodice49b10152018-12-14 17:13:34 +00001507 zout7 *= (dst_cross_plane_pad * dst_stride_y);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001508#endif // M0 > 7
1509
1510 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1511 // multiply dst_stride_z by DEPTH_GEMM3D
1512 dst_addr += get_global_id(2) * dst_stride_z * DEPTH_GEMM3D;
1513
1514#else // defined(REINTERPRET_OUTPUT_AS_3D)
1515
1516 // Add offset for batched GEMM
1517 dst_addr += get_global_id(2) * dst_stride_z;
1518
1519#endif // defined(REINTERPRET_OUTPUT_AS_3D)
1520
1521 // Multiply by the weight of matrix-matrix product and store the result
1522#if defined(ALPHA)
1523 c0 = c0 * (DATA_TYPE)ALPHA;
1524#if M0 > 1
1525 c1 = c1 * (DATA_TYPE)ALPHA;
1526#endif // M0 > 1
1527#if M0 > 2
1528 c2 = c2 * (DATA_TYPE)ALPHA;
1529#endif // M0 > 2
1530#if M0 > 3
1531 c3 = c3 * (DATA_TYPE)ALPHA;
1532#endif // M0 > 3
1533#if M0 > 4
1534 c4 = c4 * (DATA_TYPE)ALPHA;
1535#endif // M0 > 4
1536#if M0 > 5
1537 c5 = c5 * (DATA_TYPE)ALPHA;
1538#endif // M0 > 5
1539#if M0 > 6
1540 c6 = c6 * (DATA_TYPE)ALPHA;
1541#endif // M0 > 5
1542#if M0 > 7
1543 c7 = c7 * (DATA_TYPE)ALPHA;
1544#endif // M0 > 7
1545#endif // defined(ALPHA)
1546
1547 // Store output block
1548 VSTORE(N0)
1549 (c0, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y + zout0));
1550#if M0 > 1
1551 VSTORE(N0)
1552 (c1, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y + zout1));
1553#endif // M0 > 1
1554#if M0 > 2
1555 VSTORE(N0)
1556 (c2, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y + zout2));
1557#endif // M0 > 2
1558#if M0 > 3
1559 VSTORE(N0)
1560 (c3, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y + zout3));
1561#endif // M0 > 3
1562#if M0 > 4
1563 VSTORE(N0)
1564 (c4, 0, (__global DATA_TYPE *)(dst_addr + 4 * dst_stride_y + zout4));
1565#endif // M0 > 4
1566#if M0 > 5
1567 VSTORE(N0)
1568 (c5, 0, (__global DATA_TYPE *)(dst_addr + 5 * dst_stride_y + zout5));
1569#endif // M0 > 5
1570#if M0 > 6
1571 VSTORE(N0)
1572 (c6, 0, (__global DATA_TYPE *)(dst_addr + 6 * dst_stride_y + zout6));
1573#endif // M0 > 6
1574#if M0 > 7
1575 VSTORE(N0)
1576 (c7, 0, (__global DATA_TYPE *)(dst_addr + 7 * dst_stride_y + zout7));
1577#endif // M0 > 7
1578
1579#undef LHS_BLOCK_SIZE
1580#undef LHS_OFFSET_X
1581#undef LHS_STEP_X
1582#undef RHS_BLOCK_SIZE
1583#undef RHS_OFFSET_X
1584#undef RHS_STEP_X
1585}
1586#endif // defined(M0) && defined(N0) && defined(K0) && defined(V0) && defined(H0) && defined(K) && defined(DATA_TYPE)
1587
Gian Marco36a0a462018-01-12 10:21:40 +00001588#if defined(TRANSPOSE_W) && defined(MULT_TRANSPOSE1XW_WIDTH)
1589
Gian Marco19835e52018-01-30 13:35:54 +00001590#if ELEMENT_SIZE == 1
Gian Marco36a0a462018-01-12 10:21:40 +00001591#define DATA_TYPE uchar
Gian Marco19835e52018-01-30 13:35:54 +00001592#elif ELEMENT_SIZE == 2
1593#define DATA_TYPE ushort
1594#elif ELEMENT_SIZE == 4
1595#define DATA_TYPE uint
1596#else // ELEMENT_SIZE == 1
1597#error "Element size not supported"
1598#endif // ELEMENT_SIZE
Gian Marco36a0a462018-01-12 10:21:40 +00001599
1600/** This OpenCL kernel computes the "vector" 1xW transposition of input matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001601 *
Gian Marco19835e52018-01-30 13:35:54 +00001602 * @note The transposition width must be passed at compile time using -DTRANSPOSE_W (i.e. -DTRANSPOSE_W)
1603 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
Gian Marco36a0a462018-01-12 10:21:40 +00001604 *
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01001605 * @param[in] src_ptr Pointer to the source matrix. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001606 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
1607 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1608 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
1609 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marcoae2af742018-02-15 12:35:44 +00001610 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1611 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001612 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001613 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001614 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00001615 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001616 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00001617 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marcoae2af742018-02-15 12:35:44 +00001618 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1619 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001620 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1621 */
Gian Marcoae2af742018-02-15 12:35:44 +00001622__kernel void gemm_transpose1xW(TENSOR3D_DECLARATION(src),
1623 TENSOR3D_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001624{
1625 uint x = get_global_id(0);
1626 uint y = get_global_id(1);
Gian Marcoae2af742018-02-15 12:35:44 +00001627 uint z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001628
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001629 // Compute address for Matrix B - source
Gian Marcoae2af742018-02-15 12:35:44 +00001630 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001631
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001632 // Compute address for Matrix B transposed - destination. X and Y are swapped
Gian Marco36a0a462018-01-12 10:21:40 +00001633 uint dst_addr_in_bytes = dst_offset_first_element_in_bytes + y * TRANSPOSE_W * sizeof(DATA_TYPE) * MULT_TRANSPOSE1XW_WIDTH + (x / MULT_TRANSPOSE1XW_WIDTH) * dst_stride_y +
1634 (x % MULT_TRANSPOSE1XW_WIDTH) * TRANSPOSE_W * sizeof(DATA_TYPE);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001635
Gian Marcoae2af742018-02-15 12:35:44 +00001636 // Add offset for batched GEMM
1637 dst_addr_in_bytes += z * dst_stride_z;
1638
Gian Marco36a0a462018-01-12 10:21:40 +00001639 VEC_DATA_TYPE(DATA_TYPE, TRANSPOSE_W)
1640 b0 = VLOAD(TRANSPOSE_W)(0, (__global DATA_TYPE *)src.ptr);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001641
Gian Marco36a0a462018-01-12 10:21:40 +00001642 VSTORE(TRANSPOSE_W)
1643 (b0, 0, (__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001644}
Gian Marco36a0a462018-01-12 10:21:40 +00001645#endif // defined(TRANSPOSE_W) && defined(MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001646
Gian Marco36a0a462018-01-12 10:21:40 +00001647#if defined(MULT_INTERLEAVE4X4_HEIGHT) && defined(DATA_TYPE)
1648
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001649/** This OpenCL kernel reshapes the input matrix transposing each 4x4 block. If -DUNROLL_BLOCK is passed at compile time, the 4x4 block
1650 * will be simply unrolled.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001651 *
Gian Marco19835e52018-01-30 13:35:54 +00001652 * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float)
1653 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001654 * @note In case the input has to be reinterpreted as a 3D tensor (i.e. input of convolution layer 1x1), the following information must be passed at compile time:
1655 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
1656 * -# HEIGHT_GEMM3D: The height of the input in case it has to be reinterpreted as a 3D tensor.
1657 * -# DEPTH_GEMM3D: The depth of the input in case it has to be reinterpreted as a 3D tensor
1658 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
Gian Marco19835e52018-01-30 13:35:54 +00001659 *
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01001660 * @param[in] src_ptr Pointer to the source matrix. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001661 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
1662 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1663 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
1664 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marcoae2af742018-02-15 12:35:44 +00001665 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1666 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001667 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001668 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001669 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1670 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1671 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1672 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marcoae2af742018-02-15 12:35:44 +00001673 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1674 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001675 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001676 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001677 */
Gian Marcoae2af742018-02-15 12:35:44 +00001678__kernel void gemm_interleave4x4(TENSOR3D_DECLARATION(src),
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001679 TENSOR3D_DECLARATION(dst)
1680#if defined(REINTERPRET_INPUT_AS_3D)
1681 ,
1682 uint cross_plane_pad
1683#endif // REINTERPRET_INPUT_AS_3D
1684 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001685{
Gian Marco36a0a462018-01-12 10:21:40 +00001686 // Compute source and destination addresses
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001687 uint x = get_global_id(0);
1688 uint y = get_global_id(1);
Gian Marcoae2af742018-02-15 12:35:44 +00001689 uint z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001690
Gian Marcoae2af742018-02-15 12:35:44 +00001691 // Compute address for source tensor
1692 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001693
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001694 // Compute address for Matrix B transposed - destination. X and Y are swapped
Gian Marco36a0a462018-01-12 10:21:40 +00001695 uint dst_addr_in_bytes = dst_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) * 16 * MULT_INTERLEAVE4X4_HEIGHT + (y / MULT_INTERLEAVE4X4_HEIGHT) * dst_stride_y +
1696 (y % MULT_INTERLEAVE4X4_HEIGHT) * 4 * sizeof(DATA_TYPE);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001697
Gian Marcoae2af742018-02-15 12:35:44 +00001698 // Add offset for batched GEMM
1699 dst_addr_in_bytes += z * dst_stride_z;
1700
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001701#if defined(REINTERPRET_INPUT_AS_3D)
1702 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * 4 * sizeof(DATA_TYPE) + y * 4 * src_stride_y;
1703
1704 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
1705 // in order to take into account the presence of possible cross plane paddings
1706 //
1707 // | |
1708 // | plane0 |
1709 // | |
1710 // |__________________|
1711 // |******************|
1712 // | cross_plane_pad |
1713 // |******************|
1714 // | |
1715 // | plane1 |
1716 // | |
1717 // |__________________|
1718
1719 // The plane (zin) is calculated dividing M (y * 4) by HEIGHT_GEMM3D
1720 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(y * 4)) / (uint4)HEIGHT_GEMM3D;
1721 zin = min(DEPTH_GEMM3D - 1, zin);
1722
1723 // Add offset due to the cross plane paddings
1724 zin *= (cross_plane_pad * src_stride_y);
1725
1726 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1727 // multiply src_stride_z by DEPTH_GEMM3D
1728 input_ptr += z * src_stride_z * DEPTH_GEMM3D;
1729
1730 // Load values from Matrix A
1731 VEC_DATA_TYPE(DATA_TYPE, 4)
1732 a0 = vload4(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y + zin.s0));
1733 VEC_DATA_TYPE(DATA_TYPE, 4)
1734 a1 = vload4(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y + zin.s1));
1735 VEC_DATA_TYPE(DATA_TYPE, 4)
1736 a2 = vload4(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y + zin.s2));
1737 VEC_DATA_TYPE(DATA_TYPE, 4)
1738 a3 = vload4(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y + zin.s3));
1739#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +00001740 __global uchar *input_ptr = src.ptr;
1741
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001742 // Load values from Matrix A
Gian Marco36a0a462018-01-12 10:21:40 +00001743 VEC_DATA_TYPE(DATA_TYPE, 4)
Gian Marcoae2af742018-02-15 12:35:44 +00001744 a0 = vload4(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y));
Gian Marco36a0a462018-01-12 10:21:40 +00001745 VEC_DATA_TYPE(DATA_TYPE, 4)
Gian Marcoae2af742018-02-15 12:35:44 +00001746 a1 = vload4(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y));
Gian Marco36a0a462018-01-12 10:21:40 +00001747 VEC_DATA_TYPE(DATA_TYPE, 4)
Gian Marcoae2af742018-02-15 12:35:44 +00001748 a2 = vload4(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y));
Gian Marco36a0a462018-01-12 10:21:40 +00001749 VEC_DATA_TYPE(DATA_TYPE, 4)
Gian Marcoae2af742018-02-15 12:35:44 +00001750 a3 = vload4(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y));
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001751#endif // defined(REINTERPRET_INPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001752
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001753#if defined(UNROLL_BLOCK)
1754 vstore4(a0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 0 * MULT_INTERLEAVE4X4_HEIGHT));
1755 vstore4(a1, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 4 * MULT_INTERLEAVE4X4_HEIGHT));
1756 vstore4(a2, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 8 * MULT_INTERLEAVE4X4_HEIGHT));
1757 vstore4(a3, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 12 * MULT_INTERLEAVE4X4_HEIGHT));
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +00001758#else // defined(UNROLL_BLOCK)
Gian Marco36a0a462018-01-12 10:21:40 +00001759 VEC_DATA_TYPE(DATA_TYPE, 4)
1760 val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s0, a1.s0, a2.s0, a3.s0);
1761 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 0 * MULT_INTERLEAVE4X4_HEIGHT));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001762
Gian Marco36a0a462018-01-12 10:21:40 +00001763 val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s1, a1.s1, a2.s1, a3.s1);
1764 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 4 * MULT_INTERLEAVE4X4_HEIGHT));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001765
Gian Marco36a0a462018-01-12 10:21:40 +00001766 val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s2, a1.s2, a2.s2, a3.s2);
1767 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 8 * MULT_INTERLEAVE4X4_HEIGHT));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001768
Gian Marco36a0a462018-01-12 10:21:40 +00001769 val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s3, a1.s3, a2.s3, a3.s3);
1770 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 12 * MULT_INTERLEAVE4X4_HEIGHT));
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001771#endif // defined(UNROLL_BLOCK)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001772}
Gian Marco36a0a462018-01-12 10:21:40 +00001773#endif // defined(MULT_INTERLEAVE4X4_HEIGHT) && defined(DATA_TYPE)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001774
Gian Marco36a0a462018-01-12 10:21:40 +00001775#if defined(COLS_B) && defined(MULT_TRANSPOSE1XW_WIDTH) && defined(MULT_INTERLEAVE4X4_HEIGHT)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001776/** This OpenCL kernel is optimised for Midgard. It computes the matrix multiplication between matrix A (src0) and matrix B (src1)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001777 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_32bit and @ref gemm_transpose1x4 before running the matrix multiplication
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001778 *
Gian Marco19835e52018-01-30 13:35:54 +00001779 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
1780 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
1781 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001782 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
1783 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001784 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001785 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
1786 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1787 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1788 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1789 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
1790 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001791 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
1792 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1793 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1794 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1795 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1796 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001797 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001798 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1799 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1800 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1801 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1802 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001803 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001804 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00001805 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001806 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00001807 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001808 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001809 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
1810 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
1811 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001812 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001813 */
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01001814__kernel void gemm_mm_interleaved_transposed_f32(IMAGE_DECLARATION(src0),
1815 IMAGE_DECLARATION(src1),
1816 IMAGE_DECLARATION(dst),
1817 uint src0_stride_z,
1818 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001819 uint dst_stride_z
1820#if defined(REINTERPRET_OUTPUT_AS_3D)
1821 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001822 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001823#endif // REINTERPRET_OUTPUT_AS_3D
1824 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001825{
Gian Marco36a0a462018-01-12 10:21:40 +00001826 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
1827 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +00001828 int z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001829
Gian Marco36a0a462018-01-12 10:21:40 +00001830 // Offset
1831 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
1832 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 4;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001833
Gian Marco36a0a462018-01-12 10:21:40 +00001834 // src_addr_a = address of matrix A
1835 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001836 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
1837 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
1838
1839#if defined(MATRIX_B_DEPTH)
1840 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1841 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
1842#else // defined(MATRIX_B_DEPTH)
1843 src1_addr_in_bytes += z * src1_stride_z;
1844#endif // defined(MATRIX_B_DEPTH)
1845
1846 __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes);
1847 __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001848
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001849 // Compute end row address for matrix B
Gian Marco36a0a462018-01-12 10:21:40 +00001850 __global float *src_end_addr_b = src_addr_b + COLS_B;
1851
1852 src_addr_a += offset_row_a;
1853 src_addr_b += offset_row_b;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001854
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001855 // Reset accumulators
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001856 float4 c00 = 0.0f;
1857 float4 c10 = 0.0f;
1858 float4 c20 = 0.0f;
1859 float4 c30 = 0.0f;
1860
Gian Marco36a0a462018-01-12 10:21:40 +00001861 for(; src_addr_b <= (src_end_addr_b - (int)(8 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001862 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001863 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00001864 float4 a0 = vload4(0, src_addr_a);
1865 float4 b0 = vload4(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001866
1867 c00 += (float4)a0.s0 * b0;
1868 c10 += (float4)a0.s1 * b0;
1869 c20 += (float4)a0.s2 * b0;
1870 c30 += (float4)a0.s3 * b0;
1871
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001872 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00001873 a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
1874 b0 = vload4(0, src_addr_b + 4 * MULT_TRANSPOSE1XW_WIDTH);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001875
1876 c00 += (float4)a0.s0 * b0;
1877 c10 += (float4)a0.s1 * b0;
1878 c20 += (float4)a0.s2 * b0;
1879 c30 += (float4)a0.s3 * b0;
1880 }
1881
Gian Marco36a0a462018-01-12 10:21:40 +00001882 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001883 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001884 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00001885 float4 a0 = vload4(0, src_addr_a);
1886 float4 b0 = vload4(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001887
1888 c00 += (float4)a0.s0 * b0;
1889 c10 += (float4)a0.s1 * b0;
1890 c20 += (float4)a0.s2 * b0;
1891 c30 += (float4)a0.s3 * b0;
1892 }
1893
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001894 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001895 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1896
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001897#if defined(ALPHA)
1898 // Multiply by the weight of matrix product
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001899 c00 = c00 * (float4)ALPHA;
1900 c10 = c10 * (float4)ALPHA;
1901 c20 = c20 * (float4)ALPHA;
1902 c30 = c30 * (float4)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001903#endif // defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001904
Gian Marcoae2af742018-02-15 12:35:44 +00001905 // Compute dst address
1906 __global uchar *dst_addr = offset(&dst, 0, 0);
1907
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001908#if defined(REINTERPRET_OUTPUT_AS_3D)
1909 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001910 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001911 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001912 // | |
1913 // | plane0 |
1914 // | |
1915 // |__________________|
1916 // |******************|
1917 // | cross_plane_pad |
1918 // |******************|
1919 // | |
1920 // | plane1 |
1921 // | |
1922 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001923
1924 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
1925 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
1926 zout = min(DEPTH_GEMM3D - 1, zout);
1927
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001928 // Add offset due to the cross plane paddings
1929 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001930
1931 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1932 // multiply dst_stride_z by DEPTH_GEMM3D
1933 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1934
1935 // Store 4x4 block
1936 vstore4(c00, 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
1937 vstore4(c10, 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
1938 vstore4(c20, 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
1939 vstore4(c30, 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
1940
1941#else // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +00001942 // Add offset for batched GEMM
1943 dst_addr += z * dst_stride_z;
1944
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001945 // Store 4x4 block
Gian Marcoae2af742018-02-15 12:35:44 +00001946 vstore4(c00, 0, (__global float *)(dst_addr + 0 * dst_stride_y));
1947 vstore4(c10, 0, (__global float *)(dst_addr + 1 * dst_stride_y));
1948 vstore4(c20, 0, (__global float *)(dst_addr + 2 * dst_stride_y));
1949 vstore4(c30, 0, (__global float *)(dst_addr + 3 * dst_stride_y));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001950#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001951}
1952
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001953/** This OpenCL kernel is optimized for Bifrost. It computes the matrix multiplication between matrix A (src0) and matrix B (src1)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001954 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_32bit and @ref gemm_transpose1x4 before running the matrix multiplication
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001955 *
Gian Marco19835e52018-01-30 13:35:54 +00001956 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
1957 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
1958 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001959 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
1960 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
1961 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001962 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001963 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
1964 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1965 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1966 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1967 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
1968 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001969 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
1970 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1971 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1972 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1973 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1974 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001975 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001976 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1977 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1978 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1979 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1980 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001981 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001982 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00001983 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001984 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00001985 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001986 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001987 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
1988 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
1989 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001990 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001991 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001992__kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0),
1993 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +00001994 IMAGE_DECLARATION(dst),
1995 uint src0_stride_z,
1996 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001997 uint dst_stride_z
1998#if defined(REINTERPRET_OUTPUT_AS_3D)
1999 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002000 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002001#endif // REINTERPRET_OUTPUT_AS_3D
2002 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002003{
Gian Marco36a0a462018-01-12 10:21:40 +00002004 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
2005 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +00002006 int z = get_global_id(2);
Gian Marco36a0a462018-01-12 10:21:40 +00002007
2008 // Offset
2009 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
2010 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 4;
2011
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002012 // src_addr_a = address of matrix A
2013 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002014 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
2015 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
2016
2017#if defined(MATRIX_B_DEPTH)
2018 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2019 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
2020#else // defined(MATRIX_B_DEPTH)
2021 src1_addr_in_bytes += z * src1_stride_z;
2022#endif // defined(MATRIX_B_DEPTH)
2023
2024 __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes);
2025 __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002026
Gian Marco36a0a462018-01-12 10:21:40 +00002027 src_addr_a += offset_row_a;
2028 src_addr_b += offset_row_b;
2029
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002030 // Reset accumulators
2031 float c00 = 0.0f;
2032 float c01 = 0.0f;
2033 float c02 = 0.0f;
2034 float c03 = 0.0f;
2035 float c10 = 0.0f;
2036 float c11 = 0.0f;
2037 float c12 = 0.0f;
2038 float c13 = 0.0f;
2039 float c20 = 0.0f;
2040 float c21 = 0.0f;
2041 float c22 = 0.0f;
2042 float c23 = 0.0f;
2043 float c30 = 0.0f;
2044 float c31 = 0.0f;
2045 float c32 = 0.0f;
2046 float c33 = 0.0f;
2047
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002048#define COLS_MTX_B (COLS_B / (4 * MULT_TRANSPOSE1XW_WIDTH))
2049
2050 int i = 0;
2051 for(; i <= (int)(COLS_MTX_B - 4); i += 4)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002052 {
2053 // Load values from matrix A (interleaved) and matrix B (transposed)
2054 float4 a0 = vload4(0, src_addr_a);
2055 float4 b0 = vload4(0, src_addr_b);
2056
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002057 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
2058 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002059
2060 c00 = fma(a0.s0, b0.s0, c00);
2061 c01 = fma(a0.s0, b0.s1, c01);
2062 c02 = fma(a0.s0, b0.s2, c02);
2063 c03 = fma(a0.s0, b0.s3, c03);
2064
2065 c10 = fma(a0.s1, b0.s0, c10);
2066 c11 = fma(a0.s1, b0.s1, c11);
2067 c12 = fma(a0.s1, b0.s2, c12);
2068 c13 = fma(a0.s1, b0.s3, c13);
2069
2070 c20 = fma(a0.s2, b0.s0, c20);
2071 c21 = fma(a0.s2, b0.s1, c21);
2072 c22 = fma(a0.s2, b0.s2, c22);
2073 c23 = fma(a0.s2, b0.s3, c23);
2074
2075 c30 = fma(a0.s3, b0.s0, c30);
2076 c31 = fma(a0.s3, b0.s1, c31);
2077 c32 = fma(a0.s3, b0.s2, c32);
2078 c33 = fma(a0.s3, b0.s3, c33);
2079
2080 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002081 a0 = vload4(0, src_addr_a);
2082 b0 = vload4(0, src_addr_b);
2083
2084 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
2085 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002086
2087 c00 = fma(a0.s0, b0.s0, c00);
2088 c01 = fma(a0.s0, b0.s1, c01);
2089 c02 = fma(a0.s0, b0.s2, c02);
2090 c03 = fma(a0.s0, b0.s3, c03);
2091
2092 c10 = fma(a0.s1, b0.s0, c10);
2093 c11 = fma(a0.s1, b0.s1, c11);
2094 c12 = fma(a0.s1, b0.s2, c12);
2095 c13 = fma(a0.s1, b0.s3, c13);
2096
2097 c20 = fma(a0.s2, b0.s0, c20);
2098 c21 = fma(a0.s2, b0.s1, c21);
2099 c22 = fma(a0.s2, b0.s2, c22);
2100 c23 = fma(a0.s2, b0.s3, c23);
2101
2102 c30 = fma(a0.s3, b0.s0, c30);
2103 c31 = fma(a0.s3, b0.s1, c31);
2104 c32 = fma(a0.s3, b0.s2, c32);
2105 c33 = fma(a0.s3, b0.s3, c33);
2106
2107 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002108 a0 = vload4(0, src_addr_a);
2109 b0 = vload4(0, src_addr_b);
2110
2111 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
2112 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
2113
2114 c00 = fma(a0.s0, b0.s0, c00);
2115 c01 = fma(a0.s0, b0.s1, c01);
2116 c02 = fma(a0.s0, b0.s2, c02);
2117 c03 = fma(a0.s0, b0.s3, c03);
2118
2119 c10 = fma(a0.s1, b0.s0, c10);
2120 c11 = fma(a0.s1, b0.s1, c11);
2121 c12 = fma(a0.s1, b0.s2, c12);
2122 c13 = fma(a0.s1, b0.s3, c13);
2123
2124 c20 = fma(a0.s2, b0.s0, c20);
2125 c21 = fma(a0.s2, b0.s1, c21);
2126 c22 = fma(a0.s2, b0.s2, c22);
2127 c23 = fma(a0.s2, b0.s3, c23);
2128
2129 c30 = fma(a0.s3, b0.s0, c30);
2130 c31 = fma(a0.s3, b0.s1, c31);
2131 c32 = fma(a0.s3, b0.s2, c32);
2132 c33 = fma(a0.s3, b0.s3, c33);
2133
2134 // Load values from matrix A (interleaved) and matrix B (transposed)
2135 a0 = vload4(0, src_addr_a);
2136 b0 = vload4(0, src_addr_b);
2137
2138 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
2139 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002140
2141 c00 = fma(a0.s0, b0.s0, c00);
2142 c01 = fma(a0.s0, b0.s1, c01);
2143 c02 = fma(a0.s0, b0.s2, c02);
2144 c03 = fma(a0.s0, b0.s3, c03);
2145
2146 c10 = fma(a0.s1, b0.s0, c10);
2147 c11 = fma(a0.s1, b0.s1, c11);
2148 c12 = fma(a0.s1, b0.s2, c12);
2149 c13 = fma(a0.s1, b0.s3, c13);
2150
2151 c20 = fma(a0.s2, b0.s0, c20);
2152 c21 = fma(a0.s2, b0.s1, c21);
2153 c22 = fma(a0.s2, b0.s2, c22);
2154 c23 = fma(a0.s2, b0.s3, c23);
2155
2156 c30 = fma(a0.s3, b0.s0, c30);
2157 c31 = fma(a0.s3, b0.s1, c31);
2158 c32 = fma(a0.s3, b0.s2, c32);
2159 c33 = fma(a0.s3, b0.s3, c33);
2160 }
2161
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002162 for(; i < (int)(COLS_MTX_B); ++i)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002163 {
2164 // Load values from matrix A (interleaved) and matrix B (transposed)
2165 float4 a0 = vload4(0, src_addr_a);
2166 float4 b0 = vload4(0, src_addr_b);
2167
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002168 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
2169 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
2170
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002171 c00 = fma(a0.s0, b0.s0, c00);
2172 c01 = fma(a0.s0, b0.s1, c01);
2173 c02 = fma(a0.s0, b0.s2, c02);
2174 c03 = fma(a0.s0, b0.s3, c03);
2175
2176 c10 = fma(a0.s1, b0.s0, c10);
2177 c11 = fma(a0.s1, b0.s1, c11);
2178 c12 = fma(a0.s1, b0.s2, c12);
2179 c13 = fma(a0.s1, b0.s3, c13);
2180
2181 c20 = fma(a0.s2, b0.s0, c20);
2182 c21 = fma(a0.s2, b0.s1, c21);
2183 c22 = fma(a0.s2, b0.s2, c22);
2184 c23 = fma(a0.s2, b0.s3, c23);
2185
2186 c30 = fma(a0.s3, b0.s0, c30);
2187 c31 = fma(a0.s3, b0.s1, c31);
2188 c32 = fma(a0.s3, b0.s2, c32);
2189 c33 = fma(a0.s3, b0.s3, c33);
2190 }
2191
2192 // Compute destination address
2193 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2194
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002195#if defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002196 // Multiply by the weight of matrix product
2197 c00 = c00 * ALPHA;
2198 c01 = c01 * ALPHA;
2199 c02 = c02 * ALPHA;
2200 c03 = c03 * ALPHA;
2201 c10 = c10 * ALPHA;
2202 c11 = c11 * ALPHA;
2203 c12 = c12 * ALPHA;
2204 c13 = c13 * ALPHA;
2205 c20 = c20 * ALPHA;
2206 c21 = c21 * ALPHA;
2207 c22 = c22 * ALPHA;
2208 c23 = c23 * ALPHA;
2209 c30 = c30 * ALPHA;
2210 c31 = c31 * ALPHA;
2211 c32 = c32 * ALPHA;
2212 c33 = c33 * ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002213#endif // defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002214
Gian Marcoae2af742018-02-15 12:35:44 +00002215 // Compute dst address
2216 __global uchar *dst_addr = offset(&dst, 0, 0);
2217
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002218#if defined(REINTERPRET_OUTPUT_AS_3D)
2219 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002220 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002221 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002222 // | |
2223 // | plane0 |
2224 // | |
2225 // |__________________|
2226 // |******************|
2227 // | cross_plane_pad |
2228 // |******************|
2229 // | |
2230 // | plane1 |
2231 // | |
2232 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002233
2234 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
2235 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
2236 zout = min(DEPTH_GEMM3D - 1, zout);
2237
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002238 // Add offset due to the cross plane paddings
2239 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002240
2241 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2242 // multiply dst_stride_z by DEPTH_GEMM3D
2243 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
2244
2245 // Store 4x4 block
2246 vstore4((float4)(c00, c01, c02, c03), 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
2247 vstore4((float4)(c10, c11, c12, c13), 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
2248 vstore4((float4)(c20, c21, c22, c23), 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
2249 vstore4((float4)(c30, c31, c32, c33), 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
2250
2251#else // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +00002252 // Add offset for batched GEMM
2253 dst_addr += z * dst_stride_z;
2254
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002255 // Store 4x4 block
Gian Marcoae2af742018-02-15 12:35:44 +00002256 vstore4((float4)(c00, c01, c02, c03), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
2257 vstore4((float4)(c10, c11, c12, c13), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
2258 vstore4((float4)(c20, c21, c22, c23), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
2259 vstore4((float4)(c30, c31, c32, c33), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002260#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002261}
2262
Georgios Pinitas84225582018-05-14 12:00:05 +01002263// Undefine local defines
2264#undef COLS_MTX_B
2265
Matthew Bentham6f31f8c2017-10-27 11:50:06 +01002266#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002267/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002268 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002269 *
Gian Marco19835e52018-01-30 13:35:54 +00002270 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
2271 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
2272 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002273 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
2274 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002275 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002276 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
2277 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2278 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2279 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2280 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
2281 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002282 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
2283 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
2284 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2285 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
2286 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2287 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002288 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002289 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
2290 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2291 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
2292 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2293 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002294 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002295 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00002296 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002297 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00002298 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002299 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002300 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
2301 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
2302 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002303 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002304 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002305__kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0),
2306 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +00002307 IMAGE_DECLARATION(dst),
2308 uint src0_stride_z,
2309 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002310 uint dst_stride_z
2311#if defined(REINTERPRET_OUTPUT_AS_3D)
2312 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002313 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002314#endif // REINTERPRET_OUTPUT_AS_3D
2315 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002316{
Gian Marco36a0a462018-01-12 10:21:40 +00002317 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
2318 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +00002319 int z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002320
Gian Marco36a0a462018-01-12 10:21:40 +00002321 // Offset
2322 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
2323 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002324
Gian Marco36a0a462018-01-12 10:21:40 +00002325 // src_addr_a = address of matrix A
2326 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002327 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
2328 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
2329
2330#if defined(MATRIX_B_DEPTH)
2331 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2332 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
2333#else // defined(MATRIX_B_DEPTH)
2334 src1_addr_in_bytes += z * src1_stride_z;
2335#endif // defined(MATRIX_B_DEPTH)
2336
2337 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
2338 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002339
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002340 // Compute end row address for matrix B
Gian Marco36a0a462018-01-12 10:21:40 +00002341 __global half *src_end_addr_b = src_addr_b + COLS_B;
2342
2343 src_addr_a += offset_row_a;
2344 src_addr_b += offset_row_b;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002345
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002346 // Reset accumulators
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002347 half8 c00 = 0.0f;
2348 half8 c10 = 0.0f;
2349 half8 c20 = 0.0f;
2350 half8 c30 = 0.0f;
2351
Gian Marco36a0a462018-01-12 10:21:40 +00002352 for(; src_addr_b <= (src_end_addr_b - (int)(16 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 16 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002353 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002354 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00002355 half4 a0 = vload4(0, src_addr_a);
2356 half8 b0 = vload8(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002357
2358 c00 += (half8)a0.s0 * b0;
2359 c10 += (half8)a0.s1 * b0;
2360 c20 += (half8)a0.s2 * b0;
2361 c30 += (half8)a0.s3 * b0;
2362
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002363 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00002364 a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
2365 b0 = vload8(0, src_addr_b + 8 * MULT_TRANSPOSE1XW_WIDTH);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002366
2367 c00 += (half8)a0.s0 * b0;
2368 c10 += (half8)a0.s1 * b0;
2369 c20 += (half8)a0.s2 * b0;
2370 c30 += (half8)a0.s3 * b0;
2371 }
2372
Gian Marco36a0a462018-01-12 10:21:40 +00002373 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002374 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002375 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00002376 half4 a0 = vload4(0, src_addr_a);
2377 half8 b0 = vload8(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002378
2379 c00 += (half8)a0.s0 * b0;
2380 c10 += (half8)a0.s1 * b0;
2381 c20 += (half8)a0.s2 * b0;
2382 c30 += (half8)a0.s3 * b0;
2383 }
2384
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002385 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002386 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2387
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002388#if defined(ALPHA)
2389 // Multiply by the weight of matrix product
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002390 c00 = c00 * (half8)ALPHA;
2391 c10 = c10 * (half8)ALPHA;
2392 c20 = c20 * (half8)ALPHA;
2393 c30 = c30 * (half8)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002394#endif // defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002395
Gian Marcoae2af742018-02-15 12:35:44 +00002396 // Compute dst address
2397 __global uchar *dst_addr = offset(&dst, 0, 0);
2398
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002399#if defined(REINTERPRET_OUTPUT_AS_3D)
2400 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002401 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002402 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002403 // | |
2404 // | plane0 |
2405 // | |
2406 // |__________________|
2407 // |******************|
2408 // | cross_plane_pad |
2409 // |******************|
2410 // | |
2411 // | plane1 |
2412 // | |
2413 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002414
2415 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
2416 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
2417 zout = min(DEPTH_GEMM3D - 1, zout);
2418
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002419 // Add offset due to the cross plane paddings
2420 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002421
2422 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2423 // multiply dst_stride_z by DEPTH_GEMM3D
2424 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
2425
2426 // Store 4x8 block
2427 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
2428 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
2429 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
2430 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
2431
2432#else // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +00002433 // Add offset for batched GEMM
2434 dst_addr += z * dst_stride_z;
2435
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002436 // Store 4x8 block
Gian Marcoae2af742018-02-15 12:35:44 +00002437 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
2438 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
2439 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
2440 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002441#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002442}
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01002443
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00002444/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1) while accumulating the result in a 32 floating point variable.
2445 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
2446 *
2447 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
2448 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
2449 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
2450 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
2451 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
2452 *
2453 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
2454 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2455 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2456 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2457 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
2458 *
2459 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
2460 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
2461 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2462 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
2463 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2464 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
2465 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
2466 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
2467 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2468 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
2469 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2470 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
2471 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
2472 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2473 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2474 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2475 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2476 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2477 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
2478 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
2479 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2480 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
2481 */
2482__kernel void gemm_mm_interleaved_transposed_f16_acc32(IMAGE_DECLARATION(src0),
2483 IMAGE_DECLARATION(src1),
2484 IMAGE_DECLARATION(dst),
2485 uint src0_stride_z,
2486 uint src1_stride_z,
2487 uint dst_stride_z
2488#if defined(REINTERPRET_OUTPUT_AS_3D)
2489 ,
2490 uint cross_plane_pad
2491#endif // REINTERPRET_OUTPUT_AS_3D
2492 )
2493{
2494 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
2495 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
2496 int z = get_global_id(2);
2497
2498 // Offset
2499 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
2500 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
2501
2502 // src_addr_a = address of matrix A
2503 // src_addr_b = address of matrix B
2504 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
2505 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
2506
2507#if defined(MATRIX_B_DEPTH)
2508 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2509 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
2510#else // defined(MATRIX_B_DEPTH)
2511 src1_addr_in_bytes += z * src1_stride_z;
2512#endif // defined(MATRIX_B_DEPTH)
2513
2514 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
2515 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
2516
2517 // Compute end row address for matrix B
2518 __global half *src_end_addr_b = src_addr_b + COLS_B;
2519
2520 src_addr_a += offset_row_a;
2521 src_addr_b += offset_row_b;
2522
2523 // Reset accumulators
2524 float8 c00 = 0.0f;
2525 float8 c10 = 0.0f;
2526 float8 c20 = 0.0f;
2527 float8 c30 = 0.0f;
2528
2529 for(; src_addr_b <= (src_end_addr_b - (int)(16 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 16 * MULT_TRANSPOSE1XW_WIDTH)
2530 {
2531 // Load values from matrix A (interleaved) and matrix B (transposed)
2532 float4 a0 = convert_float4(vload4(0, src_addr_a));
2533 float8 b0 = convert_float8(vload8(0, src_addr_b));
2534
2535 c00 += (float8)a0.s0 * b0;
2536 c10 += (float8)a0.s1 * b0;
2537 c20 += (float8)a0.s2 * b0;
2538 c30 += (float8)a0.s3 * b0;
2539
2540 // Load values from matrix A (interleaved) and matrix B (transposed)
2541 a0 = convert_float4(vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT));
2542 b0 = convert_float8(vload8(0, src_addr_b + 8 * MULT_TRANSPOSE1XW_WIDTH));
2543
2544 c00 += (float8)a0.s0 * b0;
2545 c10 += (float8)a0.s1 * b0;
2546 c20 += (float8)a0.s2 * b0;
2547 c30 += (float8)a0.s3 * b0;
2548 }
2549
2550 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
2551 {
2552 // Load values from matrix A (interleaved) and matrix B (transposed)
2553 float4 a0 = convert_float4(vload4(0, src_addr_a));
2554 float8 b0 = convert_float8(vload8(0, src_addr_b));
2555
2556 c00 += (float8)a0.s0 * b0;
2557 c10 += (float8)a0.s1 * b0;
2558 c20 += (float8)a0.s2 * b0;
2559 c30 += (float8)a0.s3 * b0;
2560 }
2561
2562 // Compute destination address
2563 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2564
2565#if defined(ALPHA)
2566 // Multiply by the weight of matrix product
2567 c00 = c00 * (float8)ALPHA;
2568 c10 = c10 * (float8)ALPHA;
2569 c20 = c20 * (float8)ALPHA;
2570 c30 = c30 * (float8)ALPHA;
2571#endif // defined(ALPHA)
2572
2573 // Compute dst address
2574 __global uchar *dst_addr = offset(&dst, 0, 0);
2575
2576#if defined(REINTERPRET_OUTPUT_AS_3D)
2577 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
2578 // in order to take into account the presence of possible cross plane paddings
2579 //
2580 // | |
2581 // | plane0 |
2582 // | |
2583 // |__________________|
2584 // |******************|
2585 // | cross_plane_pad |
2586 // |******************|
2587 // | |
2588 // | plane1 |
2589 // | |
2590 // |__________________|
2591
2592 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
2593 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
2594 zout = min(DEPTH_GEMM3D - 1, zout);
2595
2596 // Add offset due to the cross plane paddings
2597 zout *= (cross_plane_pad * dst_stride_y);
2598
2599 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2600 // multiply dst_stride_z by DEPTH_GEMM3D
2601 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
2602
2603 // Store 4x8 block
2604 vstore8(convert_half8(c00), 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
2605 vstore8(convert_half8(c10), 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
2606 vstore8(convert_half8(c20), 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
2607 vstore8(convert_half8(c30), 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
2608
2609#else // defined(REINTERPRET_OUTPUT_AS_3D)
2610 // Add offset for batched GEMM
2611 dst_addr += z * dst_stride_z;
2612
2613 // Store 4x8 block
2614 vstore8(convert_half8(c00), 0, (__global half *)(dst_addr + 0 * dst_stride_y));
2615 vstore8(convert_half8(c10), 0, (__global half *)(dst_addr + 1 * dst_stride_y));
2616 vstore8(convert_half8(c20), 0, (__global half *)(dst_addr + 2 * dst_stride_y));
2617 vstore8(convert_half8(c30), 0, (__global half *)(dst_addr + 3 * dst_stride_y));
2618#endif // defined(REINTERPRET_OUTPUT_AS_3D)
2619}
2620
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01002621/** This OpenCL kernel optimized for Bifrost architectures computes the matrix multiplication between matrix A (src0) and matrix B (src1)
2622 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
2623 *
2624 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
2625 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
2626 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
2627 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
2628 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
2629 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002630 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
2631 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2632 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2633 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2634 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
2635 *
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01002636 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
2637 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
2638 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2639 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
2640 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2641 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
2642 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
2643 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
2644 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2645 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
2646 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2647 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
2648 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
2649 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2650 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2651 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2652 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2653 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002654 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01002655 */
2656__kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0),
2657 IMAGE_DECLARATION(src1),
2658 IMAGE_DECLARATION(dst),
2659 uint src0_stride_z,
2660 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002661 uint dst_stride_z
2662#if defined(REINTERPRET_OUTPUT_AS_3D)
2663 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002664 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002665#endif // REINTERPRET_OUTPUT_AS_3D
2666 )
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01002667{
2668 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
2669 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
2670 int z = get_global_id(2);
2671
2672 // Offset
2673 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
2674 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
2675
2676 // src_addr_a = address of matrix A
2677 // src_addr_b = address of matrix B
2678 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
2679 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
2680
2681#if defined(MATRIX_B_DEPTH)
2682 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2683 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
2684#else // defined(MATRIX_B_DEPTH)
2685 src1_addr_in_bytes += z * src1_stride_z;
2686#endif // defined(MATRIX_B_DEPTH)
2687
2688 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
2689 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
2690
2691 // Compute end row address for matrix B
2692 __global half *src_end_addr_b = src_addr_b + COLS_B;
2693
2694 src_addr_a += offset_row_a;
2695 src_addr_b += offset_row_b;
2696
2697 // Reset accumulators
2698 half8 c00 = 0.0f;
2699 half8 c10 = 0.0f;
2700 half8 c20 = 0.0f;
2701 half8 c30 = 0.0f;
2702
2703#define COLS_MTX_B (COLS_B / (8 * MULT_TRANSPOSE1XW_WIDTH))
2704
2705 int i = 0;
2706 for(; i <= (int)(COLS_MTX_B - 4); i += 4)
2707 {
2708#if MULT_INTERLEAVE4X4_HEIGHT == 1
2709 // Load values from matrix A (interleaved) and matrix B (transposed)
2710 half8 a0 = vload8(0, src_addr_a);
2711 half8 b0 = vload8(0, src_addr_b);
2712
2713 src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT;
2714 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
2715
2716 c00 = fma((half8)a0.s0, b0, c00);
2717 c10 = fma((half8)a0.s1, b0, c10);
2718 c20 = fma((half8)a0.s2, b0, c20);
2719 c30 = fma((half8)a0.s3, b0, c30);
2720
2721 // Load values from matrix B (transposed)
2722 b0 = vload8(0, src_addr_b);
2723
2724 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
2725
2726 c00 = fma((half8)a0.s4, b0, c00);
2727 c10 = fma((half8)a0.s5, b0, c10);
2728 c20 = fma((half8)a0.s6, b0, c20);
2729 c30 = fma((half8)a0.s7, b0, c30);
2730
2731 // Load values from matrix A (interleaved) and matrix B (transposed)
2732 a0 = vload8(0, src_addr_a);
2733 b0 = vload8(0, src_addr_b);
2734
2735 src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT;
2736 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
2737
2738 c00 = fma((half8)a0.s0, b0, c00);
2739 c10 = fma((half8)a0.s1, b0, c10);
2740 c20 = fma((half8)a0.s2, b0, c20);
2741 c30 = fma((half8)a0.s3, b0, c30);
2742
2743 // Load values from matrix B (transposed)
2744 b0 = vload8(0, src_addr_b);
2745
2746 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
2747
2748 c00 = fma((half8)a0.s4, b0, c00);
2749 c10 = fma((half8)a0.s5, b0, c10);
2750 c20 = fma((half8)a0.s6, b0, c20);
2751 c30 = fma((half8)a0.s7, b0, c30);
2752#else // MULT_INTERLEAVE4X4_HEIGHT == 1
2753 // Load values from matrix A (interleaved) and matrix B (transposed)
2754 half4 a0 = vload4(0, src_addr_a);
2755 half8 b0 = vload8(0, src_addr_b);
2756
2757 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
2758 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
2759
2760 c00 = fma((half8)a0.s0, b0, c00);
2761 c10 = fma((half8)a0.s1, b0, c10);
2762 c20 = fma((half8)a0.s2, b0, c20);
2763 c30 = fma((half8)a0.s3, b0, c30);
2764
2765 // Load values from matrix A (interleaved) and matrix B (transposed)
2766 a0 = vload4(0, src_addr_a);
2767 b0 = vload8(0, src_addr_b);
2768
2769 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
2770 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
2771
2772 c00 = fma((half8)a0.s0, b0, c00);
2773 c10 = fma((half8)a0.s1, b0, c10);
2774 c20 = fma((half8)a0.s2, b0, c20);
2775 c30 = fma((half8)a0.s3, b0, c30);
2776
2777 // Load values from matrix A (interleaved) and matrix B (transposed)
2778 a0 = vload4(0, src_addr_a);
2779 b0 = vload8(0, src_addr_b);
2780
2781 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
2782 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
2783
2784 c00 = fma((half8)a0.s0, b0, c00);
2785 c10 = fma((half8)a0.s1, b0, c10);
2786 c20 = fma((half8)a0.s2, b0, c20);
2787 c30 = fma((half8)a0.s3, b0, c30);
2788
2789 // Load values from matrix A (interleaved) and matrix B (transposed)
2790 a0 = vload4(0, src_addr_a);
2791 b0 = vload8(0, src_addr_b);
2792
2793 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
2794 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
2795
2796 c00 = fma((half8)a0.s0, b0, c00);
2797 c10 = fma((half8)a0.s1, b0, c10);
2798 c20 = fma((half8)a0.s2, b0, c20);
2799 c30 = fma((half8)a0.s3, b0, c30);
2800#endif // MULT_INTERLEAVE4X4_HEIGHT == 1
2801 }
2802
2803 for(; i < (int)(COLS_MTX_B); ++i)
2804 {
2805 // Load values from matrix A (interleaved) and matrix B (transposed)
2806 half4 a0 = vload4(0, src_addr_a);
2807 half8 b0 = vload8(0, src_addr_b);
2808
2809 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
2810 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
2811
2812 c00 = fma((half8)a0.s0, b0, c00);
2813 c10 = fma((half8)a0.s1, b0, c10);
2814 c20 = fma((half8)a0.s2, b0, c20);
2815 c30 = fma((half8)a0.s3, b0, c30);
2816 }
2817
2818 // Compute destination address
2819 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2820
2821#if defined(ALPHA)
2822 // Multiply by the weight of matrix product
2823 c00 = c00 * (half8)ALPHA;
2824 c10 = c10 * (half8)ALPHA;
2825 c20 = c20 * (half8)ALPHA;
2826 c30 = c30 * (half8)ALPHA;
2827#endif // defined(ALPHA)
2828
2829 // Compute dst address
2830 __global uchar *dst_addr = offset(&dst, 0, 0);
2831
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002832#if defined(REINTERPRET_OUTPUT_AS_3D)
2833 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002834 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002835 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002836 // | |
2837 // | plane0 |
2838 // | |
2839 // |__________________|
2840 // |******************|
2841 // | cross_plane_pad |
2842 // |******************|
2843 // | |
2844 // | plane1 |
2845 // | |
2846 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002847
2848 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
2849 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
2850 zout = min(DEPTH_GEMM3D - 1, zout);
2851
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002852 // Add offset due to the cross plane paddings
2853 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002854
2855 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2856 // multiply dst_stride_z by DEPTH_GEMM3D
2857 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
2858
2859 // Store 4x8 block
2860 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
2861 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
2862 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
2863 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
2864
2865#else // defined(REINTERPRET_OUTPUT_AS_3D)
2866 // Add offset for batched GEMM
2867 dst_addr += z * dst_stride_z;
2868
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01002869 // Store 4x8 block
2870 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
2871 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
2872 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
2873 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002874#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01002875}
Georgios Pinitas84225582018-05-14 12:00:05 +01002876
2877// Undefine local defines
2878#undef COLS_MTX_B
2879
Matthew Bentham6f31f8c2017-10-27 11:50:06 +01002880#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002881
Gian Marco36a0a462018-01-12 10:21:40 +00002882#endif // defined(COLS_B) && defined(MULT_TRANSPOSE1XW_WIDTH) && defined(MULT_INTERLEAVE4X4_HEIGHT)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002883
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002884#if defined(COLS_A) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && (NUM_ELEMS_PROCESSED_PER_THREAD_Y)
2885#if defined(DATA_TYPE)
2886#define VECTOR_TYPE VEC_DATA_TYPE(DATA_TYPE, NUM_ELEMS_PROCESSED_PER_THREAD_X)
Michele Di Giorgiof6f08da2018-04-26 10:24:30 +01002887/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002888 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002889 * @note This OpenCL kernel works with floating point data types (F16/F32)
2890 * @note The floating point data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
2891 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002892 * @note The number of matrix A columns and the optional alpha's value need to be passed at compile time using -DCOLS_A and -DALPHA
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002893 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
2894 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002895 *
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002896 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
2897 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002898 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2899 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2900 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2901 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
2902 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002903 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002904 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
2905 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2906 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
2907 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2908 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002909 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002910 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
2911 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2912 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
2913 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2914 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002915 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002916 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2917 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2918 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2919 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
2920 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002921 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
2922 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
2923 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002924 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
2925 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements for the output tensor (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002926 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002927__kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
2928 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +00002929 IMAGE_DECLARATION(dst),
2930 uint src0_stride_z,
2931 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002932 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002933#if defined(REINTERPRET_INPUT_AS_3D)
2934 ,
2935 uint src_cross_plane_pad
2936#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002937#if defined(REINTERPRET_OUTPUT_AS_3D)
2938 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002939 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002940#endif // REINTERPRET_OUTPUT_AS_3D
2941 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002942{
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002943 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002944
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002945 // Compute starting address for matrix A and Matrix B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002946 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002947
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002948 // Update address for the matrix A
2949 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002950
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002951 // Update address for the matrix B
2952 src_addr.s1 += idx * sizeof(DATA_TYPE);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002953
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002954#if defined(REINTERPRET_INPUT_AS_3D)
2955 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
2956 // in order to take into account the presence of possible cross plane paddings
2957 //
2958 // | |
2959 // | plane0 |
2960 // | |
2961 // |__________________|
2962 // |******************|
2963 // | cross_plane_pad |
2964 // |******************|
2965 // | |
2966 // | plane1 |
2967 // | |
2968 // |__________________|
2969
2970 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
2971 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
2972 zin = min(DEPTH_GEMM3D - 1, zin);
2973
2974 // Add offset due to the cross plane paddings
2975 zin *= (src_cross_plane_pad * src0_stride_y);
2976
2977 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2978 // multiply src0_stride_z by DEPTH_GEMM3D
2979 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
2980
2981#else // defined(REINTERPRET_INPUT_AS_3D)
2982
Gian Marcoae2af742018-02-15 12:35:44 +00002983 // Add offset for batched GEMM
2984 src_addr.s0 += get_global_id(2) * src0_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002985
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002986#endif // defined(REINTERPRET_INPUT_AS_3D)
2987
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002988#if defined(MATRIX_B_DEPTH)
2989 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2990 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
2991#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00002992 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002993#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00002994
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002995 int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(DATA_TYPE));
2996
2997 VECTOR_TYPE acc0 = 0.0f;
2998#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2999 VECTOR_TYPE acc1 = 0.0f;
3000#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3001#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3002 VECTOR_TYPE acc2 = 0.0f;
3003#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3004#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3005 VECTOR_TYPE acc3 = 0.0f;
3006#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3007
Georgios Pinitas96880cf2017-10-20 18:52:20 +01003008 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(DATA_TYPE)); src_addr += (int2)(2 * sizeof(DATA_TYPE), 2 * src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003009 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003010#if defined(REINTERPRET_INPUT_AS_3D)
3011 // Load values from matrix A
3012 VEC_DATA_TYPE(DATA_TYPE, 2)
3013 a0 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
3014#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3015 VEC_DATA_TYPE(DATA_TYPE, 2)
3016 a1 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
3017#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3018#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3019 VEC_DATA_TYPE(DATA_TYPE, 2)
3020 a2 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
3021#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3022#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3023 VEC_DATA_TYPE(DATA_TYPE, 2)
3024 a3 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
3025#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3026#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003027 // Load values from matrix A
3028 VEC_DATA_TYPE(DATA_TYPE, 2)
3029 a0 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
3030#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3031 VEC_DATA_TYPE(DATA_TYPE, 2)
3032 a1 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
3033#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3034#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3035 VEC_DATA_TYPE(DATA_TYPE, 2)
3036 a2 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
3037#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3038#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3039 VEC_DATA_TYPE(DATA_TYPE, 2)
3040 a3 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
3041#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003042#endif // defined(REINTERPRET_INPUT_AS_3D)
3043
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003044 // Load values from matrix B
3045 VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
3046 VECTOR_TYPE b1 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1 + src1_stride_y));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003047
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003048 // Accumulate
3049 acc0 += b0 * (VECTOR_TYPE)a0.s0;
3050 acc0 += b1 * (VECTOR_TYPE)a0.s1;
3051#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3052 acc1 += b0 * (VECTOR_TYPE)a1.s0;
3053 acc1 += b1 * (VECTOR_TYPE)a1.s1;
3054#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3055#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3056 acc2 += b0 * (VECTOR_TYPE)a2.s0;
3057 acc2 += b1 * (VECTOR_TYPE)a2.s1;
3058#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3059#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3060 acc3 += b0 * (VECTOR_TYPE)a3.s0;
3061 acc3 += b1 * (VECTOR_TYPE)a3.s1;
3062#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003063 }
3064
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003065 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(DATA_TYPE), src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003066 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003067#if defined(REINTERPRET_INPUT_AS_3D)
3068 // Load values from matrix A
3069 DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
3070#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3071 DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
3072#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3073#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3074 DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
3075#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3076#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3077 DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
3078#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3079#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003080 // Load values from matrix A
3081 DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
3082#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3083 DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
3084#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3085#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3086 DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
3087#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3088#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3089 DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
3090#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003091#endif // defined(REINTERPRET_INPUT_AS_3D)
3092
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003093 // Load values from matrix B
3094 VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003095
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003096 // Accumulate
3097 acc0 += b0 * (VECTOR_TYPE)a0;
3098#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3099 acc1 += b0 * (VECTOR_TYPE)a1;
3100#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3101#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3102 acc2 += b0 * (VECTOR_TYPE)a2;
3103#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3104#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3105 acc3 += b0 * (VECTOR_TYPE)a3;
3106#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003107 }
3108
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003109 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003110 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
3111
Gian Marcoae2af742018-02-15 12:35:44 +00003112 // Compute dst address
3113 __global uchar *dst_addr = offset(&dst, 0, 0);
3114
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003115 // Multiply by the weight of matrix-matrix product and store the result
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003116#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003117 acc0 = acc0 * (VECTOR_TYPE)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003118#endif // defined(ALPHA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003119#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
3120 acc1 = acc1 * (VECTOR_TYPE)ALPHA;
3121#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
3122#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
3123 acc2 = acc2 * (VECTOR_TYPE)ALPHA;
3124#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
3125#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
3126 acc3 = acc3 * (VECTOR_TYPE)ALPHA;
3127#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
3128
3129 int z = get_global_id(2);
3130
3131#if defined(REINTERPRET_OUTPUT_AS_3D)
3132 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003133 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003134 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003135 // | |
3136 // | plane0 |
3137 // | |
3138 // |__________________|
3139 // |******************|
3140 // | cross_plane_pad |
3141 // |******************|
3142 // | |
3143 // | plane1 |
3144 // | |
3145 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003146
3147 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
3148 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
3149 zout = min(DEPTH_GEMM3D - 1, zout);
3150
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003151 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003152 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003153
3154 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3155 // multiply dst_stride_z by DEPTH_GEMM3D
3156 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
3157
3158 // Store output block
3159 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
3160 (acc0, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y + zout.s0));
3161#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3162 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
3163 (acc1, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y + zout.s1));
3164#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3165#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3166 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
3167 (acc2, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y + zout.s2));
3168#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3169#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3170 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
3171 (acc3, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y + zout.s3));
3172#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3173
3174#else // defined(REINTERPRET_OUTPUT_AS_3D)
3175 // Add offset for batched GEMM
3176 dst_addr += z * dst_stride_z;
3177
3178 // Store output block
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003179 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00003180 (acc0, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003181#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003182 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00003183 (acc1, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003184#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3185#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003186 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00003187 (acc2, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003188#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3189#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003190 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00003191 (acc3, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003192#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003193#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003194}
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003195#endif // defined(DATA_TYPE)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003196
Michele Di Giorgiof6f08da2018-04-26 10:24:30 +01003197/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003198 *
3199 * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
3200 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
3201 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
3202 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
3203 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003204 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
3205 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003206 *
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003207 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
3208 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003209 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
3210 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
3211 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
3212 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
3213 *
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003214 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
3215 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
3216 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3217 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
3218 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3219 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
3220 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
3221 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
3222 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3223 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
3224 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3225 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
3226 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
3227 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
3228 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
3229 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
3230 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
3231 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003232 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
3233 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
3234 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003235 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
3236 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003237 */
3238__kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
3239 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +00003240 IMAGE_DECLARATION(dst),
3241 uint src0_stride_z,
3242 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003243 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003244#if defined(REINTERPRET_INPUT_AS_3D)
3245 ,
3246 uint src_cross_plane_pad
3247#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003248#if defined(REINTERPRET_OUTPUT_AS_3D)
3249 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003250 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003251#endif // REINTERPRET_OUTPUT_AS_3D
3252 )
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003253{
3254 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
3255
3256 // Compute starting address for matrix A and matrix B
3257 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
3258
3259 // Update address for matrix A
3260 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
3261
3262 // Update address for matrix B
3263 src_addr.s1 += idx * sizeof(float);
3264
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003265#if defined(REINTERPRET_INPUT_AS_3D)
3266 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
3267 // in order to take into account the presence of possible cross plane paddings
3268 //
3269 // | |
3270 // | plane0 |
3271 // | |
3272 // |__________________|
3273 // |******************|
3274 // | cross_plane_pad |
3275 // |******************|
3276 // | |
3277 // | plane1 |
3278 // | |
3279 // |__________________|
3280
3281 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
3282 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
3283 zin = min(DEPTH_GEMM3D - 1, zin);
3284
3285 // Add offset due to the cross plane paddings
3286 zin *= (src_cross_plane_pad * src0_stride_y);
3287
3288 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3289 // multiply src0_stride_z by DEPTH_GEMM3D
3290 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
3291
3292#else // defined(REINTERPRET_INPUT_AS_3D)
3293
Gian Marcoae2af742018-02-15 12:35:44 +00003294 // Add offset for batched GEMM
3295 src_addr.s0 += get_global_id(2) * src0_stride_z;
3296
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003297#endif // defined(REINTERPRET_INPUT_AS_3D)
3298
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003299#if defined(MATRIX_B_DEPTH)
3300 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
3301 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
3302#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00003303 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003304#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00003305
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003306 // Initialize accumulators
3307 float acc00 = 0.0f;
3308 float acc01 = 0.0f;
3309 float acc02 = 0.0f;
3310 float acc03 = 0.0f;
3311
3312#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3313 float acc10 = 0.0f;
3314 float acc11 = 0.0f;
3315 float acc12 = 0.0f;
3316 float acc13 = 0.0f;
3317#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3318
3319#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3320 float acc20 = 0.0f;
3321 float acc21 = 0.0f;
3322 float acc22 = 0.0f;
3323 float acc23 = 0.0f;
3324#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3325
3326#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3327 float acc30 = 0.0f;
3328 float acc31 = 0.0f;
3329 float acc32 = 0.0f;
3330 float acc33 = 0.0f;
3331#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3332
3333 // A and B src indices get incremented at the same time.
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003334 int i = 0;
3335 for(; i <= ((int)COLS_A - 4); i += 4)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003336 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003337#if defined(REINTERPRET_INPUT_AS_3D)
3338 // Load values from matrix A and matrix B
3339 float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
3340#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3341 float4 a1 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
3342#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3343#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3344 float4 a2 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
3345#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3346#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3347 float4 a3 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
3348#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3349#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003350 // Load values from matrix A and matrix B
3351 float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003352#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003353 float4 a1 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003354#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3355#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003356 float4 a2 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003357#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3358#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003359 float4 a3 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003360#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003361#endif // defined(REINTERPRET_INPUT_AS_3D)
3362
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003363 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
3364 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003365
3366 // Multiply and accumulate
3367 acc00 = fma(a0.s0, b0.s0, acc00);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003368 acc01 = fma(a0.s0, b0.s1, acc01);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003369 acc02 = fma(a0.s0, b0.s2, acc02);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003370 acc03 = fma(a0.s0, b0.s3, acc03);
3371
3372#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003373
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003374 acc10 = fma(a1.s0, b0.s0, acc10);
3375 acc11 = fma(a1.s0, b0.s1, acc11);
3376 acc12 = fma(a1.s0, b0.s2, acc12);
3377 acc13 = fma(a1.s0, b0.s3, acc13);
3378
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003379#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3380#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003381
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003382 acc20 = fma(a2.s0, b0.s0, acc20);
3383 acc21 = fma(a2.s0, b0.s1, acc21);
3384 acc22 = fma(a2.s0, b0.s2, acc22);
3385 acc23 = fma(a2.s0, b0.s3, acc23);
3386
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003387#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3388#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003389
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003390 acc30 = fma(a3.s0, b0.s0, acc30);
3391 acc31 = fma(a3.s0, b0.s1, acc31);
3392 acc32 = fma(a3.s0, b0.s2, acc32);
3393 acc33 = fma(a3.s0, b0.s3, acc33);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003394#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003395
3396 // Load values from matrix A and matrix B
3397 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
3398 src_addr.s1 += src1_stride_y;
3399
3400 // Multiply and accumulate
3401 acc00 = fma(a0.s1, b0.s0, acc00);
3402 acc01 = fma(a0.s1, b0.s1, acc01);
3403 acc02 = fma(a0.s1, b0.s2, acc02);
3404 acc03 = fma(a0.s1, b0.s3, acc03);
3405
3406#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3407
3408 acc10 = fma(a1.s1, b0.s0, acc10);
3409 acc11 = fma(a1.s1, b0.s1, acc11);
3410 acc12 = fma(a1.s1, b0.s2, acc12);
3411 acc13 = fma(a1.s1, b0.s3, acc13);
3412
3413#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3414#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3415
3416 acc20 = fma(a2.s1, b0.s0, acc20);
3417 acc21 = fma(a2.s1, b0.s1, acc21);
3418 acc22 = fma(a2.s1, b0.s2, acc22);
3419 acc23 = fma(a2.s1, b0.s3, acc23);
3420
3421#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3422#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3423
3424 acc30 = fma(a3.s1, b0.s0, acc30);
3425 acc31 = fma(a3.s1, b0.s1, acc31);
3426 acc32 = fma(a3.s1, b0.s2, acc32);
3427 acc33 = fma(a3.s1, b0.s3, acc33);
3428#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3429
3430 // Load values from matrix A and matrix B
3431 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
3432 src_addr.s1 += src1_stride_y;
3433
3434 // Multiply and accumulate
3435 acc00 = fma(a0.s2, b0.s0, acc00);
3436 acc01 = fma(a0.s2, b0.s1, acc01);
3437 acc02 = fma(a0.s2, b0.s2, acc02);
3438 acc03 = fma(a0.s2, b0.s3, acc03);
3439
3440#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3441
3442 acc10 = fma(a1.s2, b0.s0, acc10);
3443 acc11 = fma(a1.s2, b0.s1, acc11);
3444 acc12 = fma(a1.s2, b0.s2, acc12);
3445 acc13 = fma(a1.s2, b0.s3, acc13);
3446
3447#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3448#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3449
3450 acc20 = fma(a2.s2, b0.s0, acc20);
3451 acc21 = fma(a2.s2, b0.s1, acc21);
3452 acc22 = fma(a2.s2, b0.s2, acc22);
3453 acc23 = fma(a2.s2, b0.s3, acc23);
3454
3455#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3456#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3457
3458 acc30 = fma(a3.s2, b0.s0, acc30);
3459 acc31 = fma(a3.s2, b0.s1, acc31);
3460 acc32 = fma(a3.s2, b0.s2, acc32);
3461 acc33 = fma(a3.s2, b0.s3, acc33);
3462#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3463
3464 // Load values from matrix A and matrix B
3465 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
3466 src_addr.s1 += src1_stride_y;
3467
3468 // Multiply and accumulate
3469 acc00 = fma(a0.s3, b0.s0, acc00);
3470 acc01 = fma(a0.s3, b0.s1, acc01);
3471 acc02 = fma(a0.s3, b0.s2, acc02);
3472 acc03 = fma(a0.s3, b0.s3, acc03);
3473
3474#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3475
3476 acc10 = fma(a1.s3, b0.s0, acc10);
3477 acc11 = fma(a1.s3, b0.s1, acc11);
3478 acc12 = fma(a1.s3, b0.s2, acc12);
3479 acc13 = fma(a1.s3, b0.s3, acc13);
3480
3481#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3482#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3483
3484 acc20 = fma(a2.s3, b0.s0, acc20);
3485 acc21 = fma(a2.s3, b0.s1, acc21);
3486 acc22 = fma(a2.s3, b0.s2, acc22);
3487 acc23 = fma(a2.s3, b0.s3, acc23);
3488
3489#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3490#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3491
3492 acc30 = fma(a3.s3, b0.s0, acc30);
3493 acc31 = fma(a3.s3, b0.s1, acc31);
3494 acc32 = fma(a3.s3, b0.s2, acc32);
3495 acc33 = fma(a3.s3, b0.s3, acc33);
3496#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3497
3498 src_addr.s0 += 4 * sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003499 }
3500
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003501 for(; i < (int)COLS_A; ++i)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003502 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003503#if defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003504 // Load values from matrix A
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003505 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
3506#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3507 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
3508#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3509#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3510 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
3511#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3512#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3513 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
3514#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3515#else // defined(REINTERPRET_INPUT_AS_3D)
3516 // Load values from matrix A
3517 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003518#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3519 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
3520#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3521#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3522 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
3523#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3524#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3525 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
3526#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003527#endif // defined(REINTERPRET_INPUT_AS_3D)
3528
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003529 // Load values from matrix B
3530 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003531 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003532
3533 // Multiply and accumulate
3534 acc00 = fma(a0, b0.s0, acc00);
3535 acc01 = fma(a0, b0.s1, acc01);
3536 acc02 = fma(a0, b0.s2, acc02);
3537 acc03 = fma(a0, b0.s3, acc03);
3538#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3539 acc10 = fma(a1, b0.s0, acc10);
3540 acc11 = fma(a1, b0.s1, acc11);
3541 acc12 = fma(a1, b0.s2, acc12);
3542 acc13 = fma(a1, b0.s3, acc13);
3543#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3544#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3545 acc20 = fma(a2, b0.s0, acc20);
3546 acc21 = fma(a2, b0.s1, acc21);
3547 acc22 = fma(a2, b0.s2, acc22);
3548 acc23 = fma(a2, b0.s3, acc23);
3549#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3550#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3551 acc30 = fma(a3, b0.s0, acc30);
3552 acc31 = fma(a3, b0.s1, acc31);
3553 acc32 = fma(a3, b0.s2, acc32);
3554 acc33 = fma(a3, b0.s3, acc33);
3555#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003556
3557 src_addr.s0 += sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003558 }
3559
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003560 int z = get_global_id(2);
3561
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003562 // Compute destination address
3563 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
3564
3565 // Multiply by the weight of matrix-matrix product and store the result
3566#if defined(ALPHA)
3567 acc00 = acc00 * ALPHA;
3568 acc01 = acc01 * ALPHA;
3569 acc02 = acc02 * ALPHA;
3570 acc03 = acc03 * ALPHA;
3571#endif // defined(ALPHA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003572#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003573 acc10 = acc10 * ALPHA;
3574 acc11 = acc11 * ALPHA;
3575 acc12 = acc12 * ALPHA;
3576 acc13 = acc13 * ALPHA;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003577#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
3578#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003579 acc20 = acc20 * ALPHA;
3580 acc21 = acc21 * ALPHA;
3581 acc22 = acc22 * ALPHA;
3582 acc23 = acc23 * ALPHA;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003583#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
3584#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003585 acc30 = acc30 * ALPHA;
3586 acc31 = acc31 * ALPHA;
3587 acc32 = acc32 * ALPHA;
3588 acc33 = acc33 * ALPHA;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003589#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
3590
3591 // Compute dst address
3592 __global uchar *dst_addr = offset(&dst, 0, 0);
3593
3594#if defined(REINTERPRET_OUTPUT_AS_3D)
3595 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003596 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003597 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003598 // | |
3599 // | plane0 |
3600 // | |
3601 // |__________________|
3602 // |******************|
3603 // | cross_plane_pad |
3604 // |******************|
3605 // | |
3606 // | plane1 |
3607 // | |
3608 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003609
3610 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
3611 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
3612 zout = min(DEPTH_GEMM3D - 1, zout);
3613
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003614 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003615 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003616
3617 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3618 // multiply dst_stride_z by DEPTH_GEMM3D
3619 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
3620
3621 // Store the output block
3622 vstore4((float4)(acc00, acc01, acc02, acc03), 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
3623#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3624 vstore4((float4)(acc10, acc11, acc12, acc13), 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
3625#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3626#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3627 vstore4((float4)(acc20, acc21, acc22, acc23), 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
3628#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3629#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3630 vstore4((float4)(acc30, acc31, acc32, acc33), 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003631#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003632
3633#else // defined(REINTERPRET_OUTPUT_AS_3D)
3634 // Add offset for batched GEMM
3635 dst_addr += z * dst_stride_z;
3636
3637 // Store the output block
3638 vstore4((float4)(acc00, acc01, acc02, acc03), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
3639#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3640 vstore4((float4)(acc10, acc11, acc12, acc13), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
3641#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3642#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3643 vstore4((float4)(acc20, acc21, acc22, acc23), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
3644#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3645#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3646 vstore4((float4)(acc30, acc31, acc32, acc33), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
3647#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3648#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003649}
3650
3651/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
3652 *
3653 * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
3654 * This OpenCL kernel is optimized for Bifrost when the number of matrix B columns is less or equal to 1000.
3655 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
3656 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=2.
3657 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
3658 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha if alpha!=1.0f.
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003659 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
3660 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003661 *
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003662 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
3663 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003664 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
3665 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
3666 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
3667 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
3668 *
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003669 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
3670 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
3671 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3672 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
3673 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3674 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
3675 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
3676 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
3677 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3678 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
3679 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3680 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
3681 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
3682 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
3683 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
3684 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
3685 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
3686 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003687 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
3688 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
3689 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003690 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
3691 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003692 */
3693__kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
3694 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +00003695 IMAGE_DECLARATION(dst),
3696 uint src0_stride_z,
3697 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003698 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003699#if defined(REINTERPRET_INPUT_AS_3D)
3700 ,
3701 uint src_cross_plane_pad
3702#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003703#if defined(REINTERPRET_OUTPUT_AS_3D)
3704 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003705 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003706#endif // REINTERPRET_OUTPUT_AS_3D
3707 )
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003708{
3709 // Requires 2 NUM_ELEMS_PROCESSED_PER_THREAD_X, C vect2, A vect4, B (2 vload2) // to fix for NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3710 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
3711
3712 // Compute starting address for matrix A and Matrix B
3713 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
3714
3715 // Update address for the matrix A
3716 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
3717
3718 // Update address for the matrix B
3719 src_addr.s1 += idx * sizeof(float);
3720
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003721#if defined(REINTERPRET_INPUT_AS_3D)
3722 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
3723 // in order to take into account the presence of possible cross plane paddings
3724 //
3725 // | |
3726 // | plane0 |
3727 // | |
3728 // |__________________|
3729 // |******************|
3730 // | cross_plane_pad |
3731 // |******************|
3732 // | |
3733 // | plane1 |
3734 // | |
3735 // |__________________|
3736
3737 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
3738 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
3739 zin = min(DEPTH_GEMM3D - 1, zin);
3740
3741 // Add offset due to the cross plane paddings
3742 zin *= (src_cross_plane_pad * src0_stride_y);
3743
3744 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3745 // multiply src0_stride_z by DEPTH_GEMM3D
3746 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
3747
3748#else // defined(REINTERPRET_INPUT_AS_3D)
3749
Gian Marcoae2af742018-02-15 12:35:44 +00003750 // Add offset for batched GEMM
3751 src_addr.s0 += get_global_id(2) * src0_stride_z;
3752
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003753#endif // defined(REINTERPRET_INPUT_AS_3D)
3754
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003755#if defined(MATRIX_B_DEPTH)
3756 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
3757 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
3758#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00003759 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003760#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00003761
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003762 // Initialize accumulators
3763 float acc00 = 0.0f;
3764 float acc01 = 0.0f;
3765
3766#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3767 float acc10 = 0.0f;
3768 float acc11 = 0.0f;
3769#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3770#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3771 float acc20 = 0.0f;
3772 float acc21 = 0.0f;
3773#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3774#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3775 float acc30 = 0.0f;
3776 float acc31 = 0.0f;
3777#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3778
3779 // A and B src indices get incremented at the same time.
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003780 int i = 0;
3781 for(; i <= ((int)COLS_A - 8); i += 8)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003782 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003783#if defined(REINTERPRET_INPUT_AS_3D)
3784 // Load values from matrix A
3785 float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + zin.s0));
3786#else // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003787 // Load values from matrix A
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003788 float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0));
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003789#endif // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003790
3791 // Load values from matrix B
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003792 float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
3793 src_addr.s1 += src1_stride_y;
3794 float2 b1 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
3795 src_addr.s1 += src1_stride_y;
3796 float2 b2 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
3797 src_addr.s1 += src1_stride_y;
3798 float2 b3 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
3799 src_addr.s1 += src1_stride_y;
3800 float2 b4 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
3801 src_addr.s1 += src1_stride_y;
3802 float2 b5 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
3803 src_addr.s1 += src1_stride_y;
3804 float2 b6 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
3805 src_addr.s1 += src1_stride_y;
3806 float2 b7 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
3807 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003808
3809 // Multiply and accumulate
3810 acc00 = fma(a0.s0, b0.s0, acc00);
3811 acc00 = fma(a0.s1, b1.s0, acc00);
3812 acc00 = fma(a0.s2, b2.s0, acc00);
3813 acc00 = fma(a0.s3, b3.s0, acc00);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003814 acc00 = fma(a0.s4, b4.s0, acc00);
3815 acc00 = fma(a0.s5, b5.s0, acc00);
3816 acc00 = fma(a0.s6, b6.s0, acc00);
3817 acc00 = fma(a0.s7, b7.s0, acc00);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003818
3819 acc01 = fma(a0.s0, b0.s1, acc01);
3820 acc01 = fma(a0.s1, b1.s1, acc01);
3821 acc01 = fma(a0.s2, b2.s1, acc01);
3822 acc01 = fma(a0.s3, b3.s1, acc01);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003823 acc01 = fma(a0.s4, b4.s1, acc01);
3824 acc01 = fma(a0.s5, b5.s1, acc01);
3825 acc01 = fma(a0.s6, b6.s1, acc01);
3826 acc01 = fma(a0.s7, b7.s1, acc01);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003827
3828#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003829#if defined(REINTERPRET_INPUT_AS_3D)
3830 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
3831#else // defined(REINTERPRET_INPUT_AS_3D)
3832 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
3833#endif // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003834 acc10 = fma(a0.s0, b0.s0, acc10);
3835 acc10 = fma(a0.s1, b1.s0, acc10);
3836 acc10 = fma(a0.s2, b2.s0, acc10);
3837 acc10 = fma(a0.s3, b3.s0, acc10);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003838 acc10 = fma(a0.s4, b4.s0, acc10);
3839 acc10 = fma(a0.s5, b5.s0, acc10);
3840 acc10 = fma(a0.s6, b6.s0, acc10);
3841 acc10 = fma(a0.s7, b7.s0, acc10);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003842
3843 acc11 = fma(a0.s0, b0.s1, acc11);
3844 acc11 = fma(a0.s1, b1.s1, acc11);
3845 acc11 = fma(a0.s2, b2.s1, acc11);
3846 acc11 = fma(a0.s3, b3.s1, acc11);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003847 acc11 = fma(a0.s4, b4.s1, acc11);
3848 acc11 = fma(a0.s5, b5.s1, acc11);
3849 acc11 = fma(a0.s6, b6.s1, acc11);
3850 acc11 = fma(a0.s7, b7.s1, acc11);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003851#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3852#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003853#if defined(REINTERPRET_INPUT_AS_3D)
3854 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
3855#else // defined(REINTERPRET_INPUT_AS_3D)
3856 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
3857#endif // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003858 acc20 = fma(a0.s0, b0.s0, acc20);
3859 acc20 = fma(a0.s1, b1.s0, acc20);
3860 acc20 = fma(a0.s2, b2.s0, acc20);
3861 acc20 = fma(a0.s3, b3.s0, acc20);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003862 acc20 = fma(a0.s4, b4.s0, acc20);
3863 acc20 = fma(a0.s5, b5.s0, acc20);
3864 acc20 = fma(a0.s6, b6.s0, acc20);
3865 acc20 = fma(a0.s7, b7.s0, acc20);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003866
3867 acc21 = fma(a0.s0, b0.s1, acc21);
3868 acc21 = fma(a0.s1, b1.s1, acc21);
3869 acc21 = fma(a0.s2, b2.s1, acc21);
3870 acc21 = fma(a0.s3, b3.s1, acc21);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003871 acc21 = fma(a0.s4, b4.s1, acc21);
3872 acc21 = fma(a0.s5, b5.s1, acc21);
3873 acc21 = fma(a0.s6, b6.s1, acc21);
3874 acc21 = fma(a0.s7, b7.s1, acc21);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003875#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3876#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003877#if defined(REINTERPRET_INPUT_AS_3D)
3878 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
3879#else // defined(REINTERPRET_INPUT_AS_3D)
3880 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
3881#endif // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003882 acc30 = fma(a0.s0, b0.s0, acc30);
3883 acc30 = fma(a0.s1, b1.s0, acc30);
3884 acc30 = fma(a0.s2, b2.s0, acc30);
3885 acc30 = fma(a0.s3, b3.s0, acc30);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003886 acc30 = fma(a0.s4, b4.s0, acc30);
3887 acc30 = fma(a0.s5, b5.s0, acc30);
3888 acc30 = fma(a0.s6, b6.s0, acc30);
3889 acc30 = fma(a0.s7, b7.s0, acc30);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003890
3891 acc31 = fma(a0.s0, b0.s1, acc31);
3892 acc31 = fma(a0.s1, b1.s1, acc31);
3893 acc31 = fma(a0.s2, b2.s1, acc31);
3894 acc31 = fma(a0.s3, b3.s1, acc31);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003895 acc31 = fma(a0.s4, b4.s1, acc31);
3896 acc31 = fma(a0.s5, b5.s1, acc31);
3897 acc31 = fma(a0.s6, b6.s1, acc31);
3898 acc31 = fma(a0.s7, b7.s1, acc31);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003899#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003900
3901 src_addr.s0 += sizeof(float) * 8;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003902 }
3903 // float size increment
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003904 for(; i < (int)COLS_A; ++i)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003905 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003906#if defined(REINTERPRET_INPUT_AS_3D)
3907 // Load values from matrix A
3908 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
3909#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3910 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
3911#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3912#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3913 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
3914#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3915#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3916 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
3917#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3918#else // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003919 // Load values from matrix A
3920 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
3921#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3922 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
3923#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3924#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3925 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
3926#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3927#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3928 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
3929#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003930#endif // defined(REINTERPRET_INPUT_AS_3D)
3931
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003932 // Load values from matrix B
3933 float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003934 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003935
3936 // Multiply and accumulate
3937 acc00 = fma(a0, b0.s0, acc00);
3938 acc01 = fma(a0, b0.s1, acc01);
3939#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3940 acc10 = fma(a1, b0.s0, acc10);
3941 acc11 = fma(a1, b0.s1, acc11);
3942#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3943#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3944 acc20 = fma(a2, b0.s0, acc20);
3945 acc21 = fma(a2, b0.s1, acc21);
3946#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3947#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3948 acc30 = fma(a3, b0.s0, acc30);
3949 acc31 = fma(a3, b0.s1, acc31);
3950#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003951
3952 src_addr.s0 += sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003953 }
3954
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003955 // Multiply by the weight of matrix-matrix product and store the result
3956#if defined(ALPHA)
3957 acc00 = acc00 * ALPHA;
3958 acc01 = acc01 * ALPHA;
3959#endif // defined(ALPHA)
3960#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
3961 acc10 = acc10 * ALPHA;
3962 acc11 = acc11 * ALPHA;
3963#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
3964#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
3965 acc20 = acc20 * ALPHA;
3966 acc21 = acc21 * ALPHA;
3967#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
3968#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
3969 acc30 = acc30 * ALPHA;
3970 acc31 = acc31 * ALPHA;
3971#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
3972
3973 int z = get_global_id(2);
3974
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003975 // Compute destination address
3976 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
3977
Gian Marcoae2af742018-02-15 12:35:44 +00003978 // Compute dst address
3979 __global uchar *dst_addr = offset(&dst, 0, 0);
3980
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003981#if defined(REINTERPRET_OUTPUT_AS_3D)
3982 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003983 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003984 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003985 // | |
3986 // | plane0 |
3987 // | |
3988 // |__________________|
3989 // |******************|
3990 // | cross_plane_pad |
3991 // |******************|
3992 // | |
3993 // | plane1 |
3994 // | |
3995 // |__________________|
Gian Marcoae2af742018-02-15 12:35:44 +00003996
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003997 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
3998 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
3999 zout = min(DEPTH_GEMM3D - 1, zout);
4000
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004001 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004002 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004003
4004 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4005 // multiply dst_stride_z by DEPTH_GEMM3D
4006 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
4007
4008 // Store the output block
4009 vstore2((float2)(acc00, acc01), 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004010#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004011 vstore2((float2)(acc10, acc11), 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004012#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4013#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004014 vstore2((float2)(acc20, acc21), 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004015#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4016#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004017 vstore2((float2)(acc30, acc31), 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004018#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004019
4020#else // defined(REINTERPRET_OUTPUT_AS_3D)
4021 // Add offset for batched GEMM
4022 dst_addr += z * dst_stride_z;
4023
4024 // Store the output block
4025 vstore2((float2)(acc00, acc01), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
4026#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4027 vstore2((float2)(acc10, acc11), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
4028#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4029#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4030 vstore2((float2)(acc20, acc21), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
4031#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4032#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4033 vstore2((float2)(acc30, acc31), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
4034#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4035#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004036}
4037
Vidhya Sudhan Loganathanbdff4912018-05-22 15:03:09 +01004038#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01004039/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
4040 *
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00004041 * @note This OpenCL kernel works with the 16-bit floating point data type (half) and accumulating the result in a 32 floating point variable.
4042 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
4043 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
4044 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
4045 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
4046 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
4047 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
4048 *
4049 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
4050 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
4051 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
4052 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
4053 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
4054 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
4055 *
4056 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
4057 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
4058 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4059 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
4060 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4061 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
4062 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
4063 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
4064 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4065 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
4066 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4067 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
4068 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
4069 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
4070 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
4071 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
4072 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
4073 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
4074 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
4075 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
4076 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
4077 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
4078 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
4079 */
4080__kernel void gemm_mm_floating_point_f16_bifrost_acc32(IMAGE_DECLARATION(src0),
4081 IMAGE_DECLARATION(src1),
4082 IMAGE_DECLARATION(dst),
4083 uint src0_stride_z,
4084 uint src1_stride_z,
4085 uint dst_stride_z
4086#if defined(REINTERPRET_INPUT_AS_3D)
4087 ,
4088 uint src_cross_plane_pad
4089#endif // REINTERPRET_INPUT_AS_3D
4090#if defined(REINTERPRET_OUTPUT_AS_3D)
4091 ,
4092 uint dst_cross_plane_pad
4093#endif // REINTERPRET_OUTPUT_AS_3D
4094 )
4095{
4096 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
4097
4098 // Compute starting address for matrix A and Matrix B
4099 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
4100
4101 // Update address for the matrix A
4102 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
4103
4104 // Update address for the matrix B
4105 src_addr.s1 += idx * sizeof(half);
4106
4107#if defined(REINTERPRET_INPUT_AS_3D)
4108 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
4109 // in order to take into account the presence of possible cross plane paddings
4110 //
4111 // | |
4112 // | plane0 |
4113 // | |
4114 // |__________________|
4115 // |******************|
4116 // | cross_plane_pad |
4117 // |******************|
4118 // | |
4119 // | plane1 |
4120 // | |
4121 // |__________________|
4122
4123 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
4124 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
4125 zin = min(DEPTH_GEMM3D - 1, zin);
4126
4127 // Add offset due to the cross plane paddings
4128 zin *= (src_cross_plane_pad * src0_stride_y);
4129
4130 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4131 // multiply src0_stride_z by DEPTH_GEMM3D
4132 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
4133
4134#else // defined(REINTERPRET_INPUT_AS_3D)
4135
4136 // Add offset for batched GEMM
4137 src_addr.s0 += get_global_id(2) * src0_stride_z;
4138
4139#endif // defined(REINTERPRET_INPUT_AS_3D)
4140
4141#if defined(MATRIX_B_DEPTH)
4142 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
4143 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
4144#else // defined(MATRIX_B_DEPTH)
4145 src_addr.s1 += get_global_id(2) * src1_stride_z;
4146#endif // defined(MATRIX_B_DEPTH)
4147
4148 float8 acc0 = 0.0h;
4149#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4150 float8 acc1 = 0.0h;
4151#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4152#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4153 float8 acc2 = 0.0h;
4154#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4155#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4156 float8 acc3 = 0.0h;
4157#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4158
4159 int i = 0;
4160 for(; i <= ((int)COLS_A - 4); i += 4)
4161 {
4162#if defined(REINTERPRET_INPUT_AS_3D)
4163 // Load values from matrix A
4164 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
4165#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4166 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
4167#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4168#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4169 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
4170#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4171#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4172 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
4173#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4174#else // defined(REINTERPRET_INPUT_AS_3D)
4175 // Load values from matrix A
4176 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
4177#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4178 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
4179#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4180#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4181 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
4182#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4183#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4184 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
4185#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4186#endif // defined(REINTERPRET_INPUT_AS_3D)
4187
4188 // Load values from matrix B
4189 float8 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
4190 src_addr.s1 += src1_stride_y;
4191
4192 // Accumulate
4193 acc0 = fma(b0, (float8)a0.s0, acc0);
4194#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4195 acc1 = fma(b0, (float8)a1.s0, acc1);
4196#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4197#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4198 acc2 = fma(b0, (float8)a2.s0, acc2);
4199#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4200#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4201 acc3 = fma(b0, (float8)a3.s0, acc3);
4202#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4203
4204 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
4205 src_addr.s1 += src1_stride_y;
4206 acc0 = fma(b0, (float8)a0.s1, acc0);
4207#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4208 acc1 = fma(b0, (float8)a1.s1, acc1);
4209#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4210#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4211 acc2 = fma(b0, (float8)a2.s1, acc2);
4212#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4213#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4214 acc3 = fma(b0, (float8)a3.s1, acc3);
4215#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4216
4217 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
4218 src_addr.s1 += src1_stride_y;
4219 acc0 = fma(b0, (float8)a0.s2, acc0);
4220#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4221 acc1 = fma(b0, (float8)a1.s2, acc1);
4222#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4223#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4224 acc2 = fma(b0, (float8)a2.s2, acc2);
4225#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4226#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4227 acc3 = fma(b0, (float8)a3.s2, acc3);
4228#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4229
4230 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
4231 src_addr.s1 += src1_stride_y;
4232 acc0 = fma(b0, (float8)a0.s3, acc0);
4233#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4234 acc1 = fma(b0, (float8)a1.s3, acc1);
4235#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4236#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4237 acc2 = fma(b0, (float8)a2.s3, acc2);
4238#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4239#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4240 acc3 = fma(b0, (float8)a3.s3, acc3);
4241#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4242
4243 src_addr.s0 += 4 * sizeof(half);
4244 }
4245
4246 for(; i < (int)COLS_A; ++i)
4247 {
4248#if defined(REINTERPRET_INPUT_AS_3D)
4249 // Load values from matrix A
4250 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
4251#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4252 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
4253#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4254#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4255 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
4256#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4257#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4258 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
4259#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4260#else // defined(REINTERPRET_INPUT_AS_3D)
4261 // Load values from matrix A
4262 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
4263#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4264 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
4265#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4266#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4267 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
4268#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4269#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4270 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
4271#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4272#endif // defined(REINTERPRET_INPUT_AS_3D)
4273
4274 // Load values from matrix B
4275 float8 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
4276
4277 src_addr += (int2)(sizeof(half), src1_stride_y);
4278
4279 // Accumulate
4280 acc0 = fma(b0, (float8)a0, acc0); // b0 * (half8)a0;
4281#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4282 acc1 = fma(b0, (float8)a1, acc1); // b0 * (half8)a1;
4283#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4284#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4285 acc2 = fma(b0, (float8)a2, acc2); // b0 * (half8)a2;
4286#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4287#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4288 acc3 = fma(b0, (float8)a3, acc3); // b0 * (half8)a3;
4289#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4290 }
4291
4292 // Multiply by the weight of matrix-matrix product and store the result
4293#if defined(ALPHA)
4294 half8 hacc0 = convert_half8(acc0) * (half8)ALPHA;
4295#else //defined(ALPHA)
4296 half8 hacc0 = convert_half8(acc0);
4297#endif // defined(ALPHA)
4298#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4299#if defined(ALPHA)
4300 half8 hacc1 = convert_half8(acc1) * (half8)ALPHA;
4301#else //defined(ALPHA)
4302 half8 hacc1 = convert_half8(acc1);
4303#endif //defined(ALPHA)
4304#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y
4305
4306#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4307#if defined(ALPHA)
4308 half8 hacc2 = convert_half8(acc2) * (half8)ALPHA;
4309#else //defined(ALPHA)
4310 half8 hacc2 = convert_half8(acc2);
4311#endif //defined(ALPHA)
4312#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4313
4314#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4315#if defined(ALPHA)
4316 half8 hacc3 = convert_half8(acc3) * (half8)ALPHA;
4317#else //defined(ALPHA)
4318 half8 hacc3 = convert_half8(acc3);
4319#endif // defined(ALPHA)
4320#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4321
4322 int z = get_global_id(2);
4323
4324 // Compute destination address
4325 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
4326
4327 // Compute dst address
4328 __global uchar *dst_addr = offset(&dst, 0, 0);
4329
4330#if defined(REINTERPRET_OUTPUT_AS_3D)
4331 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
4332 // in order to take into account the presence of possible cross plane paddings
4333 //
4334 // | |
4335 // | plane0 |
4336 // | |
4337 // |__________________|
4338 // |******************|
4339 // | cross_plane_pad |
4340 // |******************|
4341 // | |
4342 // | plane1 |
4343 // | |
4344 // |__________________|
4345
4346 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
4347 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
4348 zout = min(DEPTH_GEMM3D - 1, zout);
4349
4350 // Add offset due to the cross plane paddings
4351 zout *= (dst_cross_plane_pad * dst_stride_y);
4352
4353 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4354 // multiply dst_stride_z by DEPTH_GEMM3D
4355 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
4356
4357 // Store the output block
4358 vstore8(hacc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
4359#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4360 vstore8(hacc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
4361#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4362#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4363 vstore8(hacc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
4364#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4365#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4366 vstore8(hacc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
4367#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4368
4369#else // defined(REINTERPRET_OUTPUT_AS_3D)
4370 // Add offset for batched GEMM
4371 dst_addr += z * dst_stride_z;
4372
4373 // Store the output block
4374 vstore8(hacc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
4375#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4376 vstore8(hacc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
4377#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4378#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4379 vstore8(hacc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
4380#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4381#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4382 vstore8(hacc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
4383#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4384#endif // REINTERPRET_OUTPUT_AS_3D
4385}
4386
4387/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
4388 *
Gian Marco Iodicefd683112018-04-17 09:52:44 +01004389 * @note This OpenCL kernel works with the 16-bit floating point data type (half) and uses the fma units.
4390 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
4391 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
4392 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
4393 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
4394 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
4395 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
4396 *
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004397 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
4398 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004399 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
4400 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
4401 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
4402 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
4403 *
Gian Marco Iodicefd683112018-04-17 09:52:44 +01004404 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
4405 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
4406 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4407 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
4408 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4409 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
4410 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
4411 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
4412 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4413 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
4414 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4415 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
4416 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
4417 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
4418 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
4419 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
4420 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
4421 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004422 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
4423 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
4424 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004425 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
4426 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01004427 */
4428__kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0),
4429 IMAGE_DECLARATION(src1),
4430 IMAGE_DECLARATION(dst),
4431 uint src0_stride_z,
4432 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004433 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004434#if defined(REINTERPRET_INPUT_AS_3D)
4435 ,
4436 uint src_cross_plane_pad
4437#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004438#if defined(REINTERPRET_OUTPUT_AS_3D)
4439 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004440 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004441#endif // REINTERPRET_OUTPUT_AS_3D
4442 )
Gian Marco Iodicefd683112018-04-17 09:52:44 +01004443{
4444 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
4445
4446 // Compute starting address for matrix A and Matrix B
4447 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
4448
4449 // Update address for the matrix A
4450 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
4451
4452 // Update address for the matrix B
4453 src_addr.s1 += idx * sizeof(half);
4454
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004455#if defined(REINTERPRET_INPUT_AS_3D)
4456 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
4457 // in order to take into account the presence of possible cross plane paddings
4458 //
4459 // | |
4460 // | plane0 |
4461 // | |
4462 // |__________________|
4463 // |******************|
4464 // | cross_plane_pad |
4465 // |******************|
4466 // | |
4467 // | plane1 |
4468 // | |
4469 // |__________________|
4470
4471 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
4472 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
4473 zin = min(DEPTH_GEMM3D - 1, zin);
4474
4475 // Add offset due to the cross plane paddings
4476 zin *= (src_cross_plane_pad * src0_stride_y);
4477
4478 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4479 // multiply src0_stride_z by DEPTH_GEMM3D
4480 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
4481
4482#else // defined(REINTERPRET_INPUT_AS_3D)
4483
Gian Marco Iodicefd683112018-04-17 09:52:44 +01004484 // Add offset for batched GEMM
4485 src_addr.s0 += get_global_id(2) * src0_stride_z;
4486
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004487#endif // defined(REINTERPRET_INPUT_AS_3D)
4488
Gian Marco Iodicefd683112018-04-17 09:52:44 +01004489#if defined(MATRIX_B_DEPTH)
4490 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
4491 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
4492#else // defined(MATRIX_B_DEPTH)
4493 src_addr.s1 += get_global_id(2) * src1_stride_z;
4494#endif // defined(MATRIX_B_DEPTH)
4495
4496 half8 acc0 = 0.0h;
4497#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4498 half8 acc1 = 0.0h;
4499#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4500#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4501 half8 acc2 = 0.0h;
4502#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4503#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4504 half8 acc3 = 0.0h;
4505#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4506
4507 int i = 0;
4508 for(; i <= ((int)COLS_A - 4); i += 4)
4509 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004510#if defined(REINTERPRET_INPUT_AS_3D)
4511 // Load values from matrix A
4512 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
4513#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4514 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
4515#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4516#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4517 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
4518#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4519#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4520 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
4521#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4522#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01004523 // Load values from matrix A
4524 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
4525#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4526 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
4527#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4528#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4529 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
4530#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4531#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4532 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
4533#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004534#endif // defined(REINTERPRET_INPUT_AS_3D)
4535
Gian Marco Iodicefd683112018-04-17 09:52:44 +01004536 // Load values from matrix B
4537 half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
4538 src_addr.s1 += src1_stride_y;
4539
4540 // Accumulate
4541 acc0 = fma(b0, (half8)a0.s0, acc0);
4542#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4543 acc1 = fma(b0, (half8)a1.s0, acc1);
4544#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4545#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4546 acc2 = fma(b0, (half8)a2.s0, acc2);
4547#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4548#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4549 acc3 = fma(b0, (half8)a3.s0, acc3);
4550#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4551
4552 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
4553 src_addr.s1 += src1_stride_y;
4554 acc0 = fma(b0, (half8)a0.s1, acc0);
4555#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4556 acc1 = fma(b0, (half8)a1.s1, acc1);
4557#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4558#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4559 acc2 = fma(b0, (half8)a2.s1, acc2);
4560#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4561#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4562 acc3 = fma(b0, (half8)a3.s1, acc3);
4563#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4564
4565 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
4566 src_addr.s1 += src1_stride_y;
4567 acc0 = fma(b0, (half8)a0.s2, acc0);
4568#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4569 acc1 = fma(b0, (half8)a1.s2, acc1);
4570#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4571#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4572 acc2 = fma(b0, (half8)a2.s2, acc2);
4573#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4574#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4575 acc3 = fma(b0, (half8)a3.s2, acc3);
4576#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4577
4578 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
4579 src_addr.s1 += src1_stride_y;
4580 acc0 = fma(b0, (half8)a0.s3, acc0);
4581#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4582 acc1 = fma(b0, (half8)a1.s3, acc1);
4583#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4584#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4585 acc2 = fma(b0, (half8)a2.s3, acc2);
4586#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4587#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4588 acc3 = fma(b0, (half8)a3.s3, acc3);
4589#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4590
4591 src_addr.s0 += 4 * sizeof(half);
4592 }
4593
4594 for(; i < (int)COLS_A; ++i)
4595 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004596#if defined(REINTERPRET_INPUT_AS_3D)
4597 // Load values from matrix A
4598 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
4599#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4600 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
4601#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4602#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4603 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
4604#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4605#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4606 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
4607#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4608#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01004609 // Load values from matrix A
4610 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
4611#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4612 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
4613#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4614#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4615 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
4616#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4617#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4618 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
4619#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004620#endif // defined(REINTERPRET_INPUT_AS_3D)
4621
Gian Marco Iodicefd683112018-04-17 09:52:44 +01004622 // Load values from matrix B
4623 half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
4624
4625 src_addr += (int2)(sizeof(half), src1_stride_y);
4626
4627 // Accumulate
4628 acc0 = fma(b0, (half8)a0, acc0); // b0 * (half8)a0;
4629#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4630 acc1 = fma(b0, (half8)a1, acc1); // b0 * (half8)a1;
4631#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4632#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4633 acc2 = fma(b0, (half8)a2, acc2); // b0 * (half8)a2;
4634#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4635#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4636 acc3 = fma(b0, (half8)a3, acc3); // b0 * (half8)a3;
4637#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4638 }
4639
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004640 // Multiply by the weight of matrix-matrix product and store the result
4641#if defined(ALPHA)
4642 acc0 = acc0 * (half8)ALPHA;
4643#endif // defined(ALPHA)
4644#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
4645 acc1 = acc1 * (half8)ALPHA;
4646#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
4647#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
4648 acc2 = acc2 * (half8)ALPHA;
4649#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
4650#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
4651 acc3 = acc3 * (half8)ALPHA;
4652#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
4653
4654 int z = get_global_id(2);
4655
Gian Marco Iodicefd683112018-04-17 09:52:44 +01004656 // Compute destination address
4657 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
4658
4659 // Compute dst address
4660 __global uchar *dst_addr = offset(&dst, 0, 0);
4661
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004662#if defined(REINTERPRET_OUTPUT_AS_3D)
4663 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004664 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004665 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004666 // | |
4667 // | plane0 |
4668 // | |
4669 // |__________________|
4670 // |******************|
4671 // | cross_plane_pad |
4672 // |******************|
4673 // | |
4674 // | plane1 |
4675 // | |
4676 // |__________________|
Gian Marco Iodicefd683112018-04-17 09:52:44 +01004677
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004678 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
4679 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
4680 zout = min(DEPTH_GEMM3D - 1, zout);
4681
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004682 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004683 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004684
4685 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4686 // multiply dst_stride_z by DEPTH_GEMM3D
4687 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
4688
4689 // Store the output block
4690 vstore8(acc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
4691#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4692 vstore8(acc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
4693#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4694#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4695 vstore8(acc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
4696#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4697#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4698 vstore8(acc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
4699#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4700
4701#else // defined(REINTERPRET_OUTPUT_AS_3D)
4702 // Add offset for batched GEMM
4703 dst_addr += z * dst_stride_z;
4704
4705 // Store the output block
Gian Marco Iodicefd683112018-04-17 09:52:44 +01004706 vstore8(acc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
4707#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicefd683112018-04-17 09:52:44 +01004708 vstore8(acc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
4709#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4710#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicefd683112018-04-17 09:52:44 +01004711 vstore8(acc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
4712#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4713#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicefd683112018-04-17 09:52:44 +01004714 vstore8(acc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
4715#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004716#endif // REINTERPRET_OUTPUT_AS_3D
Gian Marco Iodicefd683112018-04-17 09:52:44 +01004717}
Vidhya Sudhan Loganathanbdff4912018-05-22 15:03:09 +01004718#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01004719
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004720#endif // defined(COLS_A) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && (NUM_ELEMS_PROCESSED_PER_THREAD_Y)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004721
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004722#if defined(BETA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004723/** This OpenCL kernel performs the in-place matrix addition between 2 matrices taking into account that the second matrix might be weighted by a scalar value beta:
4724 *
Gian Marco19835e52018-01-30 13:35:54 +00004725 * @note The beta's value need to be passed at compile time using -DBETA
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004726 *
4727 * @param[in] src_ptr Pointer to the source matrix. Supported data types: F32
4728 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
4729 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4730 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
4731 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004732 * @param[in] src_stride_z Stride of the destination tensor in Z dimension (in bytes)
4733 * @param[in] src_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004734 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01004735 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004736 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
4737 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
4738 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
4739 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004740 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
4741 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004742 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
4743 */
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004744__kernel void gemm_ma_f32(TENSOR3D_DECLARATION(src),
4745 TENSOR3D_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004746{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004747 // Compute source and destination addresses
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004748 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
4749 Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004750
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004751 // Load values from A x B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004752 float4 alpha_ab = vload4(0, (__global float *)dst.ptr);
4753
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004754 // Load values from Matrix C
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004755 float4 c = vload4(0, (__global float *)src.ptr);
4756
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004757 // Computes alpha * axb + beta * c
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004758 float4 out = alpha_ab + (float4)BETA * c;
4759
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004760 // Store final result in axb matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004761 vstore4(out, 0, (__global float *)dst.ptr);
4762}
4763
Vidhya Sudhan Loganathan76c85642018-05-25 13:53:02 +01004764#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004765/** This OpenCL kernel performs the in-place matrix addition between 2 matrices taking into account that the second matrix might be weighted by a scalar value beta:
4766 *
Gian Marco19835e52018-01-30 13:35:54 +00004767 * @note The beta's value need to be passed at compile time using -DBETA
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01004768 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004769 * @param[in] src_ptr Pointer to the source matrix. Supported data types: F16
4770 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
4771 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4772 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
4773 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004774 * @param[in] src_stride_z Stride of the destination tensor in Z dimension (in bytes)
4775 * @param[in] src_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004776 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01004777 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004778 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
4779 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
4780 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
4781 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004782 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
4783 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004784 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
4785 */
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004786__kernel void gemm_ma_f16(TENSOR3D_DECLARATION(src),
4787 TENSOR3D_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004788{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004789 // Compute source and destination addresses
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004790 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
4791 Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004792
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004793 // Load values from A x B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004794 half8 alpha_ab = vload8(0, (__global half *)dst.ptr);
4795
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004796 // Load values from Matrix C
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004797 half8 c = vload8(0, (__global half *)src.ptr);
4798
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004799 // Computes alpha * axb + beta * c
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004800 half8 out = alpha_ab + (half8)BETA * c;
4801
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004802 // Store final result in axb matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004803 vstore8(out, 0, (__global half *)dst.ptr);
4804}
Vidhya Sudhan Loganathan76c85642018-05-25 13:53:02 +01004805#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004806#endif // defined(BETA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004807
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004808#if defined(WIDTH_VECTOR_A)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004809/** This OpenCL kernel computes the vector by matrix multiplication between each row of A (src0) and matrix B (src1) used for locally connected layer
4810 *
Gian Marco19835e52018-01-30 13:35:54 +00004811 * @note The width of A need to be passed at compile time using -DWIDTH_VECTOR_A
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004812 *
Gian Marco19835e52018-01-30 13:35:54 +00004813 * @note The input A and matrix B must not be reshaped
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004814 *
4815 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
4816 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
4817 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4818 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
4819 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4820 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01004821 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004822 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
4823 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4824 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
4825 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4826 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
4827 * @param[in] src1_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
4828 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01004829 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004830 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
4831 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
4832 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
4833 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
4834 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
4835 */
4836__kernel void gemm_lc_vm_f32(IMAGE_DECLARATION(src0),
4837 TENSOR3D_DECLARATION(src1),
4838 IMAGE_DECLARATION(dst))
4839{
4840 int idx = get_global_id(0) * 4;
4841 int idy = get_global_id(1);
4842
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004843 // Compute the address for the vector A and matrix B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004844 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes + src0_stride_y * idy, src1_offset_first_element_in_bytes + src1_stride_z * idy));
4845 src_addr.s1 += idx * sizeof(float);
4846
4847 int end_row_vec_a = src_addr.s0 + (WIDTH_VECTOR_A * sizeof(float));
4848
4849 float4 acc = 0.0f;
4850
Georgios Pinitas96880cf2017-10-20 18:52:20 +01004851 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(float)); src_addr += (int2)(2 * sizeof(float), 2 * src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004852 {
4853 float2 a0 = vload2(0, (__global float *)(src0_ptr + src_addr.s0));
4854 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
4855 float4 b1 = vload4(0, (__global float *)(src1_ptr + src_addr.s1 + src1_stride_y));
4856
4857 acc += b0 * (float4)a0.s0;
4858 acc += b1 * (float4)a0.s1;
4859 }
4860
4861 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(float), src1_stride_y))
4862 {
4863 float a0 = *((__global float *)(src0_ptr + src_addr.s0));
4864 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
4865
4866 acc += b0 * (float4)a0;
4867 }
4868
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004869 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004870 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
4871
4872 vstore4(acc, 0, (__global float *)(offset(&dst, 0, 0)));
4873}
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004874#endif // defined(WIDTH_VECTOR_A)
4875
4876/** This kernel accumulates each row with the biases vector.
4877 *
4878 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=short.
4879 * @note The vector size must be passed at compile time using -DVECTOR_SIZE e.g. -DVECTOR_SIZE=16.
4880 *
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01004881 * @param[in, out] accum_ptr Pointer to the accumulate tensor. Supported data type: U8/S8/U16/S16/F16/U32/S32/F32
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004882 * @param[in] accum_stride_x Stride of the accmulate tensor in X dimension (in bytes)
4883 * @param[in] accum_step_x accum_stride_x * number of elements along X processed per workitem(in bytes)
4884 * @param[in] accum_stride_y Stride of the accumlulate tensor in Y dimension (in bytes)
4885 * @param[in] accum_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4886 * @param[in] accum_offset_first_element_in_bytes The offset of the first element in the accumulate tensor
4887 * @param[in] biases_ptr Pointer to the biases vector. Same as @p accum_ptr
4888 * @param[in] biases_stride_x Stride of the destination tensor in X dimension (in bytes)
4889 * @param[in] biases_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
4890 * @param[in] biases_offset_first_element_in_bytes The offset of the first element in the destination tensor
4891 */
4892#if defined(DATA_TYPE) && defined(VECTOR_SIZE)
4893__kernel void gemm_accumulate_biases(
4894 IMAGE_DECLARATION(accum),
4895 VECTOR_DECLARATION(biases))
4896{
4897 Image accum = CONVERT_TO_IMAGE_STRUCT(accum);
4898 Vector biases = CONVERT_TO_VECTOR_STRUCT(biases);
4899
4900 // Vector size, i.e. number of vector elements.
4901 VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
4902 accum_value = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)accum.ptr);
4903 VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
4904 biases_value = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)biases.ptr);
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01004905 accum_value = biases_value + accum_value;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004906 // Store result in the accumulate buffer
4907 VSTORE(VECTOR_SIZE)
4908 (accum_value, 0, (__global DATA_TYPE *)accum.ptr);
4909}
4910#endif // defined(DATA_TYPE) && defined(VECTOR_SIZE)