blob: 40ee1d45addd81c472901765ae7e6839f3cc97de [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Gian Marco36a0a462018-01-12 10:21:40 +00002 * Copyright (c) 2017-2018 ARM Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "helpers.h"
25
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000026#if defined(M0) && defined(K0) && defined(V0) && defined(DATA_TYPE)
27
28/** This OpenCL kernel reshapes the lhs input matrix. The kernel splits the input matrix in blocks of size M0xK0 and stores each one (not transposed) in
29 * the output matrix unrolling the values.
30 *
31 * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float)
32 * @note The block's dimensions (M0 and K0) must be passed at compile time using -DM0 and -DK0 (i.e. -DM0=2, -DK0=2).
33 * @note The number of M0xK0 vertical blocks to store on the same output row must be passed at compile time using -DV0 (i.e. -DV0=2)
34 * @note Only the following values for M0, K0 and V0 are supported:
35 * M0: 2,3,4,5,6,7,8
36 * K0: 2,4,8,16
37 * V0: greater than 0
38 * @note In case the input has to be reinterpreted as a 3D tensor (i.e. input of convolution layer 1x1), the following information must be passed at compile time:
39 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
40 * -# HEIGHT_GEMM3D: The height of the input in case it has to be reinterpreted as a 3D tensor.
41 * -# DEPTH_GEMM3D: The depth of the input in case it has to be reinterpreted as a 3D tensor
42 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
43 * @note If the M0xK0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
44 *
45 * @param[in] src_ptr Pointer to the source LHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
46 * @param[in] src_stride_x Stride of the source LHS tensor in X dimension (in bytes)
47 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
48 * @param[in] src_stride_y Stride of the source LHS tensor in Y dimension (in bytes)
49 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
50 * @param[in] src_stride_z Stride of the source LHS tensor in Z dimension (in bytes)
51 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
52 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source LHS tensor
53 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
54 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
55 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
56 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
57 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
58 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
59 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
60 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
61 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
62 */
63__kernel void gemm_reshape_lhs_matrix_nt(TENSOR3D_DECLARATION(src),
64 TENSOR3D_DECLARATION(dst)
65#if defined(REINTERPRET_INPUT_AS_3D)
66 ,
67 uint cross_plane_pad
68#endif // REINTERPRET_INPUT_AS_3D
69 )
70{
71// Block size
72#define BLOCK_SIZE ((M0) * (K0))
73
74// Output offset X
75#if defined(INTERLEAVE)
76#define OUTPUT_OFFSET_X (K0)
77#else // defined(INTERLEAVE)
78#define OUTPUT_OFFSET_X (BLOCK_SIZE)
79#endif // defined(INTERLEAVE)
80
81// Output step X
82#if defined(INTERLEAVE)
83#define OUTPUT_STEP_X (K0) * (V0)
84#else // Do not interleave
85#define OUTPUT_STEP_X (K0)
86#endif // defined(INTERLEAVE)
87
88 // Compute source and destination addresses
89 uint x = get_global_id(0);
90 uint y = get_global_id(1);
91 uint z = get_global_id(2);
92
93 // ------------------ Compute input/output addresses ---------------------------
94
95 // Compute the input address
96 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)K0 * sizeof(DATA_TYPE) + y * (uint)M0 * src_stride_y;
97
98 // Compute the output address
99 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)BLOCK_SIZE * (uint)V0 * sizeof(DATA_TYPE)) + ((y / (uint)V0) * (uint)dst_stride_y) + ((y % V0) *
100 (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE));
101
102 uint zin0 = 0;
103 uint zin1 = 0;
104 uint zin2 = 0;
105 uint zin3 = 0;
106 uint zin4 = 0;
107 uint zin5 = 0;
108 uint zin6 = 0;
109 uint zin7 = 0;
110
111#if defined(REINTERPRET_INPUT_AS_3D)
112 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
113 // multiply src_stride_z by DEPTH_GEMM3D
114
115 // Note for the REINTERPRET_INPUT_AS_3D case
116 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
117 // in order to take into account the presence of possible cross plane paddings
118 //
119 // | |
120 // | plane0 |
121 // | |
122 // |__________________|
123 // |******************|
124 // | cross_plane_pad |
125 // |******************|
126 // | |
127 // | plane1 |
128 // | |
129 // |__________________|
130
131 input_ptr += z * (uint)src_stride_z * DEPTH_GEMM3D;
132
133 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
134 zin0 = (0 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
135 zin0 = min((uint)(DEPTH_GEMM3D - 1), zin0);
136 zin0 *= (cross_plane_pad * src_stride_y);
137#if M0 > 1
138 zin1 = (1 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
139 zin1 = min((uint)(DEPTH_GEMM3D - 1), zin1);
140 zin1 *= (cross_plane_pad * src_stride_y);
141#endif // M0 > 1
142#if M0 > 2
143 zin2 = (2 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
144 zin2 = min((uint)(DEPTH_GEMM3D - 1), zin2);
145 zin2 *= (cross_plane_pad * src_stride_y);
146#endif // M0 > 2
147#if M0 > 3
148 zin3 = (3 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
149 zin3 = min((uint)(DEPTH_GEMM3D - 1), zin3);
150 zin3 *= (cross_plane_pad * src_stride_y);
151#endif // M0 > 3
152#if M0 > 4
153 zin4 = (4 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
154 zin4 = min((uint)(DEPTH_GEMM3D - 1), zin4);
155 zin4 *= (cross_plane_pad * src_stride_y);
156#endif // M0 > 4
157#if M0 > 5
158 zin5 = (5 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
159 zin5 = min((uint)(DEPTH_GEMM3D - 1), zin5);
160 zin5 *= (cross_plane_pad * src_stride_y);
161#endif // M0 > 5
162#if M0 > 6
163 zin6 = (6 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
164 zin6 = min((uint)(DEPTH_GEMM3D - 1), zin6);
165 zin6 *= (cross_plane_pad * src_stride_y);
166#endif // M0 > 6
167#if M0 > 6
168 zin7 = (7 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
169 zin7 = min((uint)(DEPTH_GEMM3D - 1), zin7);
170 zin7 *= (cross_plane_pad * src_stride_y);
171#endif // M0 > 7
172
173#else // defined(REINTERPRET_INPUT_AS_3D)
174
175 input_ptr += z * (uint)src_stride_z;
176
177#endif // defined(REINTERPRET_INPUT_AS_3D)
178
179 // Add offset for batched GEMM
180 output_ptr += z * (uint)dst_stride_z;
181
182 // ---------------------------Load input values --------------------------------
183
184 // Load values from the LHS matrix
185 VEC_DATA_TYPE(DATA_TYPE, K0)
186 a0 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y + zin0));
187#if M0 > 1
188 VEC_DATA_TYPE(DATA_TYPE, K0)
189 a1 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y + zin1));
190#endif // M0 > 1
191#if M0 > 2
192 VEC_DATA_TYPE(DATA_TYPE, K0)
193 a2 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y + zin2));
194#endif // M0 > 2
195#if M0 > 3
196 VEC_DATA_TYPE(DATA_TYPE, K0)
197 a3 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y + zin3));
198#endif // M0 > 3
199#if M0 > 4
200 VEC_DATA_TYPE(DATA_TYPE, K0)
201 a4 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 4 * src_stride_y + zin4));
202#endif // M0 > 4
203#if M0 > 5
204 VEC_DATA_TYPE(DATA_TYPE, K0)
205 a5 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 5 * src_stride_y + zin5));
206#endif // M0 > 5
207#if M0 > 6
208 VEC_DATA_TYPE(DATA_TYPE, K0)
209 a6 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 6 * src_stride_y + zin6));
210#endif // M0 > 6
211#if M0 > 7
212 VEC_DATA_TYPE(DATA_TYPE, K0)
213 a7 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 7 * src_stride_y + zin7));
214#endif // M0 > 7
215
216 // ---------------------------Store output values ------------------------------
217
218 VSTORE(K0)
219 (a0, 0, (__global DATA_TYPE *)(output_ptr + 0 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
220#if M0 > 1
221 VSTORE(K0)
222 (a1, 0, (__global DATA_TYPE *)(output_ptr + 1 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
223#endif // M0 > 1
224#if M0 > 2
225 VSTORE(K0)
226 (a2, 0, (__global DATA_TYPE *)(output_ptr + 2 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
227#endif // M0 > 2
228#if M0 > 3
229 VSTORE(K0)
230 (a3, 0, (__global DATA_TYPE *)(output_ptr + 3 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
231#endif // M0 > 3
232#if M0 > 4
233 VSTORE(K0)
234 (a4, 0, (__global DATA_TYPE *)(output_ptr + 4 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
235#endif // M0 > 4
236#if M0 > 5
237 VSTORE(K0)
238 (a5, 0, (__global DATA_TYPE *)(output_ptr + 5 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
239#endif // M0 > 5
240#if M0 > 6
241 VSTORE(K0)
242 (a6, 0, (__global DATA_TYPE *)(output_ptr + 6 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
243#endif // M0 > 6
244#if M0 > 7
245 VSTORE(K0)
246 (a7, 0, (__global DATA_TYPE *)(output_ptr + 7 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
247#endif // M0 > 7
248
249#undef BLOCK_SIZE
250#undef OUTPUT_OFFSET_X
251#undef OUTPUT_STEP_X
252}
253#endif // defined(M0) && defined(K0) && defined(V0) && defined(DATA_TYPE)
254
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000255#if defined(K0) && defined(N0) && defined(H0) && defined(DATA_TYPE) && defined(SRC_HEIGHT)
256/** This OpenCL kernel reshapes the rhs input matrix. The kernel splits the input matrix in blocks of size K0xN0 and stores each one (not transposed) in
257 * the output matrix unrolling the values.
258 *
259 * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float)
260 * @note The height of the input tensor must be passed at compile time using -DSRC_HEIGHT (i.e. -DSRC_HEIGHT=16)
261 * @note The block's dimensions (K0 and N0) must be passed at compile time using -DK0 and -DN0 (i.e. -DK0=2, -DN0=2).
262 * @note The number of K0xN0 vertical blocks to store on the same output row must be passed at compile time using -DH0 (i.e. -DH0=2)
263 * @note If the K0xN0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
264 * @note Only the following values for K0, N0 and H0 are supported:
265 * N0: 2,4,8,16
266 * K0: 1,2,4,8,16
267 * H0: greater than 0
268 *
269 * @param[in] src_ptr Pointer to the source RHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
270 * @param[in] src_stride_x Stride of the source RHS tensor in X dimension (in bytes)
271 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
272 * @param[in] src_stride_y Stride of the source RHS tensor in Y dimension (in bytes)
273 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
274 * @param[in] src_stride_z Stride of the source RHS tensor in Z dimension (in bytes)
275 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
276 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source RHS tensor
277 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
278 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
279 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
280 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
281 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
282 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
283 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
284 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
285 */
286__kernel void gemm_reshape_rhs_matrix_nt(TENSOR3D_DECLARATION(src),
287 TENSOR3D_DECLARATION(dst))
288{
289 // Block size
290#define BLOCK_SIZE ((K0) * (N0))
291
292 // Output offset X
293#if defined(INTERLEAVE)
294#define OUTPUT_OFFSET_X (N0)
295#else // defined(INTERLEAVE)
296#define OUTPUT_OFFSET_X (BLOCK_SIZE)
297#endif // defined(INTERLEAVE)
298
299 // Output step X
300#if defined(INTERLEAVE)
301#define OUTPUT_STEP_X (N0) * (H0)
302#else // Do not interleave
303#define OUTPUT_STEP_X (N0)
304#endif // defined(INTERLEAVE)
305
306 // Compute source and destination addresses
307 uint x = get_global_id(0);
308 uint y = get_global_id(1);
309 uint z = get_global_id(2);
310
311 // ------------------ Compute input/output addresses ---------------------------
312
313 // Compute the input address
314 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)N0 * sizeof(DATA_TYPE) + y * (uint)K0 * src_stride_y + z * (uint)src_stride_z;
315
316 // Compute the output address
317 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (y * (uint)BLOCK_SIZE * (uint)H0 * sizeof(DATA_TYPE)) + ((x % (uint)H0) * (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE)) + ((
318 x / (uint)H0)
319 * (uint)dst_stride_y)
320 + z * (uint)dst_stride_z;
321
322 // ---------------------------Load input values --------------------------------
323
324 VEC_DATA_TYPE(DATA_TYPE, N0)
325 a0 = 0;
326 VEC_DATA_TYPE(DATA_TYPE, N0)
327 a1 = 0;
328 VEC_DATA_TYPE(DATA_TYPE, N0)
329 a2 = 0;
330 VEC_DATA_TYPE(DATA_TYPE, N0)
331 a3 = 0;
332 VEC_DATA_TYPE(DATA_TYPE, N0)
333 a4 = 0;
334 VEC_DATA_TYPE(DATA_TYPE, N0)
335 a5 = 0;
336 VEC_DATA_TYPE(DATA_TYPE, N0)
337 a6 = 0;
338 VEC_DATA_TYPE(DATA_TYPE, N0)
339 a7 = 0;
340 VEC_DATA_TYPE(DATA_TYPE, N0)
341 a8 = 0;
342 VEC_DATA_TYPE(DATA_TYPE, N0)
343 a9 = 0;
344 VEC_DATA_TYPE(DATA_TYPE, N0)
345 aA = 0;
346 VEC_DATA_TYPE(DATA_TYPE, N0)
347 aB = 0;
348 VEC_DATA_TYPE(DATA_TYPE, N0)
349 aC = 0;
350 VEC_DATA_TYPE(DATA_TYPE, N0)
351 aD = 0;
352 VEC_DATA_TYPE(DATA_TYPE, N0)
353 aE = 0;
354 VEC_DATA_TYPE(DATA_TYPE, N0)
355 aF = 0;
356
357 // Load values from the RHS matrix
358 a0 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y));
359#if K0 > 1
360 if(y * (uint)K0 + 1 < SRC_HEIGHT)
361 {
362 a1 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y));
363 }
364#endif // K0 > 1
365#if K0 > 2
366 if(y * (uint)K0 + 2 < SRC_HEIGHT)
367 {
368 a2 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y));
369 }
370 if(y * (uint)K0 + 3 < SRC_HEIGHT)
371 {
372 a3 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y));
373 }
374#endif // K0 > 2
375#if K0 > 4
376 if(y * (uint)K0 + 4 < SRC_HEIGHT)
377 {
378 a4 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 4 * src_stride_y));
379 }
380 if(y * (uint)K0 + 5 < SRC_HEIGHT)
381 {
382 a5 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 5 * src_stride_y));
383 }
384 if(y * (uint)K0 + 6 < SRC_HEIGHT)
385 {
386 a6 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 6 * src_stride_y));
387 }
388 if(y * (uint)K0 + 7 < SRC_HEIGHT)
389 {
390 a7 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 7 * src_stride_y));
391 }
392#endif // K0 > 4
393#if K0 > 8
394 if(y * (uint)K0 + 9 < SRC_HEIGHT)
395 {
396 a8 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 8 * src_stride_y));
397 }
398 if(y * (uint)K0 + 9 < SRC_HEIGHT)
399 {
400 a9 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 9 * src_stride_y));
401 }
402 if(y * (uint)K0 + 10 < SRC_HEIGHT)
403 {
404 aA = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 10 * src_stride_y));
405 }
406 if(y * (uint)K0 + 11 < SRC_HEIGHT)
407 {
408 aB = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 11 * src_stride_y));
409 }
410 if(y * (uint)K0 + 12 < SRC_HEIGHT)
411 {
412 aC = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 12 * src_stride_y));
413 }
414 if(y * (uint)K0 + 13 < SRC_HEIGHT)
415 {
416 aD = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 13 * src_stride_y));
417 }
418 if(y * (uint)K0 + 14 < SRC_HEIGHT)
419 {
420 aE = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 14 * src_stride_y));
421 }
422 if(y * (uint)K0 + 15 < SRC_HEIGHT)
423 {
424 aF = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 15 * src_stride_y));
425 }
426#endif // K0 > 8
427
428 // ---------------------------Store output values ------------------------------
429
430 VSTORE(N0)
431 (a0, 0, (__global DATA_TYPE *)(output_ptr + 0 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
432#if K0 > 1
433 VSTORE(N0)
434 (a1, 0, (__global DATA_TYPE *)(output_ptr + 1 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
435#endif // K0 > 1
436#if K0 > 2
437 VSTORE(N0)
438 (a2, 0, (__global DATA_TYPE *)(output_ptr + 2 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
439 VSTORE(N0)
440 (a3, 0, (__global DATA_TYPE *)(output_ptr + 3 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
441#endif // K0 > 2
442#if K0 > 4
443 VSTORE(N0)
444 (a4, 0, (__global DATA_TYPE *)(output_ptr + 4 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
445 VSTORE(N0)
446 (a5, 0, (__global DATA_TYPE *)(output_ptr + 5 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
447 VSTORE(N0)
448 (a6, 0, (__global DATA_TYPE *)(output_ptr + 6 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
449 VSTORE(N0)
450 (a7, 0, (__global DATA_TYPE *)(output_ptr + 7 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
451#endif // N0 > 4
452#if K0 > 8
453 VSTORE(N0)
454 (a8, 0, (__global DATA_TYPE *)(output_ptr + 8 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
455 VSTORE(N0)
456 (a9, 0, (__global DATA_TYPE *)(output_ptr + 9 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
457 VSTORE(N0)
458 (aA, 0, (__global DATA_TYPE *)(output_ptr + 10 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
459 VSTORE(N0)
460 (aB, 0, (__global DATA_TYPE *)(output_ptr + 11 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
461 VSTORE(N0)
462 (aC, 0, (__global DATA_TYPE *)(output_ptr + 12 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
463 VSTORE(N0)
464 (aD, 0, (__global DATA_TYPE *)(output_ptr + 13 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
465 VSTORE(N0)
466 (aE, 0, (__global DATA_TYPE *)(output_ptr + 14 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
467 VSTORE(N0)
468 (aF, 0, (__global DATA_TYPE *)(output_ptr + 15 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
469#endif // N0 > 8
470
471#undef BLOCK_SIZE
472#undef OUTPUT_OFFSET_X
473#undef OUTPUT_STEP_X
474}
475
476#if defined(TRANSPOSE)
477/** This OpenCL kernel reshapes the rhs input matrix. The kernel splits the input matrix in blocks of size K0xN0 and stores each one (transposed) in
478 * the output matrix unrolling the values.
479 *
480 * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float)
481 * @note The height of the input tensor must be passed at compile time using -DSRC_HEIGHT (i.e. -DSRC_HEIGHT=16)
482 * @note The block's dimensions (K0 and N0) must be passed at compile time using -DK0 and -DN0 (i.e. -DK0=2, -DN0=2).
483 * @note The number of K0xN0 vertical blocks to store on the same output row must be passed at compile time using -DH0 (i.e. -DH0=2)
484 * @note If the K0xN0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
485 * @note The option -DTRANSPOSE must passed at compile time.
486 * @note Only the following values for K0, N0 and H0 are supported:
487 * N0: 2,4,8,16
488 * K0: 4,8,16
489 * H0: greater than 0
490 *
491 * @param[in] src_ptr Pointer to the source RHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
492 * @param[in] src_stride_x Stride of the source RHS tensor in X dimension (in bytes)
493 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
494 * @param[in] src_stride_y Stride of the source RHS tensor in Y dimension (in bytes)
495 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
496 * @param[in] src_stride_z Stride of the source RHS tensor in Z dimension (in bytes)
497 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
498 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source RHS tensor
499 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
500 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
501 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
502 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
503 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
504 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
505 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
506 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
507 */
508__kernel void gemm_reshape_rhs_matrix_t(TENSOR3D_DECLARATION(src),
509 TENSOR3D_DECLARATION(dst))
510{
511 // Block size
512#define BLOCK_SIZE ((K0) * (N0))
513
514 // Output offset X
515#if defined(INTERLEAVE)
516#define OUTPUT_OFFSET_X (K0)
517#else // defined(INTERLEAVE)
518#define OUTPUT_OFFSET_X (BLOCK_SIZE)
519#endif // defined(INTERLEAVE)
520
521 // Output step X
522#if defined(INTERLEAVE)
523#define OUTPUT_STEP_X (K0) * (H0)
524#else // Do not interleave
525#define OUTPUT_STEP_X (K0)
526#endif // defined(INTERLEAVE)
527
528 // Compute source and destination addresses
529 uint x = get_global_id(0);
530 uint y = get_global_id(1);
531 uint z = get_global_id(2);
532
533 // ------------------ Compute input/output addresses ---------------------------
534
535 // Compute the input address
536 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)N0 * sizeof(DATA_TYPE) + y * (uint)K0 * src_stride_y + z * (uint)src_stride_z;
537
538 // Compute the output address
539 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (y * (uint)BLOCK_SIZE * (uint)H0 * sizeof(DATA_TYPE)) + ((x % H0) * (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE)) + ((x /
540 (uint)H0) * (uint)dst_stride_y) + z * (uint)dst_stride_z;
541
542 // ---------------------------Load input values --------------------------------
543
544 VEC_DATA_TYPE(DATA_TYPE, N0)
545 a0 = 0;
546 VEC_DATA_TYPE(DATA_TYPE, N0)
547 a1 = 0;
548 VEC_DATA_TYPE(DATA_TYPE, N0)
549 a2 = 0;
550 VEC_DATA_TYPE(DATA_TYPE, N0)
551 a3 = 0;
552 VEC_DATA_TYPE(DATA_TYPE, N0)
553 a4 = 0;
554 VEC_DATA_TYPE(DATA_TYPE, N0)
555 a5 = 0;
556 VEC_DATA_TYPE(DATA_TYPE, N0)
557 a6 = 0;
558 VEC_DATA_TYPE(DATA_TYPE, N0)
559 a7 = 0;
560 VEC_DATA_TYPE(DATA_TYPE, N0)
561 a8 = 0;
562 VEC_DATA_TYPE(DATA_TYPE, N0)
563 a9 = 0;
564 VEC_DATA_TYPE(DATA_TYPE, N0)
565 aA = 0;
566 VEC_DATA_TYPE(DATA_TYPE, N0)
567 aB = 0;
568 VEC_DATA_TYPE(DATA_TYPE, N0)
569 aC = 0;
570 VEC_DATA_TYPE(DATA_TYPE, N0)
571 aD = 0;
572 VEC_DATA_TYPE(DATA_TYPE, N0)
573 aE = 0;
574 VEC_DATA_TYPE(DATA_TYPE, N0)
575 aF = 0;
576
577 // Load values from the RHS matrix
578 a0 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y));
579 if(y * (uint)K0 + 1 < SRC_HEIGHT)
580 {
581 a1 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y));
582 }
583 if(y * (uint)K0 + 2 < SRC_HEIGHT)
584 {
585 a2 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y));
586 }
587 if(y * (uint)K0 + 3 < SRC_HEIGHT)
588 {
589 a3 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y));
590 }
591#if K0 > 4
592 if(y * (uint)K0 + 4 < SRC_HEIGHT)
593 {
594 a4 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 4 * src_stride_y));
595 }
596 if(y * (uint)K0 + 5 < SRC_HEIGHT)
597 {
598 a5 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 5 * src_stride_y));
599 }
600 if(y * (uint)K0 + 6 < SRC_HEIGHT)
601 {
602 a6 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 6 * src_stride_y));
603 }
604 if(y * (uint)K0 + 7 < SRC_HEIGHT)
605 {
606 a7 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 7 * src_stride_y));
607 }
608#endif // K0 > 4
609#if K0 > 8
610 if(y * (uint)K0 + 9 < SRC_HEIGHT)
611 {
612 a8 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 8 * src_stride_y));
613 }
614 if(y * (uint)K0 + 9 < SRC_HEIGHT)
615 {
616 a9 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 9 * src_stride_y));
617 }
618 if(y * (uint)K0 + 10 < SRC_HEIGHT)
619 {
620 aA = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 10 * src_stride_y));
621 }
622 if(y * (uint)K0 + 11 < SRC_HEIGHT)
623 {
624 aB = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 11 * src_stride_y));
625 }
626 if(y * (uint)K0 + 12 < SRC_HEIGHT)
627 {
628 aC = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 12 * src_stride_y));
629 }
630 if(y * (uint)K0 + 13 < SRC_HEIGHT)
631 {
632 aD = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 13 * src_stride_y));
633 }
634 if(y * (uint)K0 + 14 < SRC_HEIGHT)
635 {
636 aE = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 14 * src_stride_y));
637 }
638 if(y * (uint)K0 + 15 < SRC_HEIGHT)
639 {
640 aF = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 15 * src_stride_y));
641 }
642#endif // K0 > 8
643
644 // ---------------------------Transpose the block ------------------------------
645
646 VEC_DATA_TYPE(DATA_TYPE, K0)
647 res0 = 0;
648 VEC_DATA_TYPE(DATA_TYPE, K0)
649 res1 = 0;
650 VEC_DATA_TYPE(DATA_TYPE, K0)
651 res2 = 0;
652 VEC_DATA_TYPE(DATA_TYPE, K0)
653 res3 = 0;
654 VEC_DATA_TYPE(DATA_TYPE, K0)
655 res4 = 0;
656 VEC_DATA_TYPE(DATA_TYPE, K0)
657 res5 = 0;
658 VEC_DATA_TYPE(DATA_TYPE, K0)
659 res6 = 0;
660 VEC_DATA_TYPE(DATA_TYPE, K0)
661 res7 = 0;
662 VEC_DATA_TYPE(DATA_TYPE, K0)
663 res8 = 0;
664 VEC_DATA_TYPE(DATA_TYPE, K0)
665 res9 = 0;
666 VEC_DATA_TYPE(DATA_TYPE, K0)
667 resA = 0;
668 VEC_DATA_TYPE(DATA_TYPE, K0)
669 resB = 0;
670 VEC_DATA_TYPE(DATA_TYPE, K0)
671 resC = 0;
672 VEC_DATA_TYPE(DATA_TYPE, K0)
673 resD = 0;
674 VEC_DATA_TYPE(DATA_TYPE, K0)
675 resE = 0;
676 VEC_DATA_TYPE(DATA_TYPE, K0)
677 resF = 0;
678
679#if K0 == 4
680 // This part computes the following transpositions:
681 // 4x2 -> 2x4
682 // 4x4 -> 4x4
683 // 4x8 -> 8x4
684 // 4x16 -> 16x4
685 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0, a3.s0);
686 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1, a3.s1);
687#if N0 > 2
688 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2, a3.s2);
689 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3, a3.s3);
690#endif // N0 > 2
691#if N0 > 4
692 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4, a3.s4);
693 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5, a3.s5);
694 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6, a3.s6);
695 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7, a3.s7);
696#endif // N0 > 4
697#if N0 > 8
698 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8, a3.s8);
699 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9, a3.s9);
700 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA, a3.sA);
701 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB, a3.sB);
702 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC, a3.sC);
703 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD, a3.sD);
704 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE, a3.sE);
705 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF, a3.sF);
706#endif // N0 > 8
707
708#elif K0 == 8 // N0 == 3
709 // This part computes the following transpositions:
710 // 8x2 -> 2x8
711 // 8x4 -> 4x8
712 // 8x8 -> 8x8
713 // 8x16 -> 16x8
714 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0, a3.s0, a4.s0, a5.s0, a6.s0, a7.s0);
715 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1, a3.s1, a4.s1, a5.s1, a6.s1, a7.s1);
716#if N0 > 2
717 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2, a3.s2, a4.s2, a5.s2, a6.s2, a7.s2);
718 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3, a3.s3, a4.s3, a5.s3, a6.s3, a7.s3);
719#endif // N0 > 2
720#if N0 > 4
721 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4, a3.s4, a4.s4, a5.s4, a6.s4, a7.s4);
722 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5, a3.s5, a4.s5, a5.s5, a6.s5, a7.s5);
723 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6, a3.s6, a4.s6, a5.s6, a6.s6, a7.s6);
724 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7, a3.s7, a4.s7, a5.s7, a6.s7, a7.s7);
725#endif // N0 > 4
726#if N0 > 8
727 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8, a3.s8, a4.s8, a5.s8, a6.s8, a7.s8);
728 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9, a3.s9, a4.s9, a5.s9, a6.s9, a7.s9);
729 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA, a3.sA, a4.sA, a5.sA, a6.sA, a7.sA);
730 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB, a3.sB, a4.sB, a5.sB, a6.sB, a7.sB);
731 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC, a3.sC, a4.sC, a5.sC, a6.sC, a7.sC);
732 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD, a3.sD, a4.sD, a5.sD, a6.sD, a7.sD);
733 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE, a3.sE, a4.sE, a5.sE, a6.sE, a7.sE);
734 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF, a3.sF, a4.sF, a5.sF, a6.sF, a7.sF);
735#endif // N0 > 8
736
737#elif K0 == 16 // N0 == 16
738
739 // This part computes the following transpositions:
740 // 16x2 -> 2x16
741 // 16x4 -> 4x16
742 // 16x8 -> 8x16
743 // 16x16 -> 16x16
744 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0, a3.s0, a4.s0, a5.s0, a6.s0, a7.s0,
745 a8.s0, a9.s0, aA.s0, aB.s0, aC.s0, aD.s0, aE.s0, aF.s0);
746 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1, a3.s1, a4.s1, a5.s1, a6.s1, a7.s1,
747 a8.s1, a9.s1, aA.s1, aB.s1, aC.s1, aD.s1, aE.s1, aF.s1);
748#if N0 > 2
749 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2, a3.s2, a4.s2, a5.s2, a6.s2, a7.s2,
750 a8.s2, a9.s2, aA.s2, aB.s2, aC.s2, aD.s2, aE.s2, aF.s2);
751 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3, a3.s3, a4.s3, a5.s3, a6.s3, a7.s3,
752 a8.s3, a9.s3, aA.s3, aB.s3, aC.s3, aD.s3, aE.s3, aF.s3);
753#endif // N0 > 2
754#if N0 > 4
755 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4, a3.s4, a4.s4, a5.s4, a6.s4, a7.s4,
756 a8.s4, a9.s4, aA.s4, aB.s4, aC.s4, aD.s4, aE.s4, aF.s4);
757 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5, a3.s5, a4.s5, a5.s5, a6.s5, a7.s5,
758 a8.s5, a9.s5, aA.s5, aB.s5, aC.s5, aD.s5, aE.s5, aF.s5);
759 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6, a3.s6, a4.s6, a5.s6, a6.s6, a7.s6,
760 a8.s6, a9.s6, aA.s6, aB.s6, aC.s6, aD.s6, aE.s6, aF.s6);
761 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7, a3.s7, a4.s7, a5.s7, a6.s7, a7.s7,
762 a8.s7, a9.s7, aA.s7, aB.s7, aC.s7, aD.s7, aE.s7, aF.s7);
763#endif // N0 > 4
764#if N0 > 8
765 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8, a3.s8, a4.s8, a5.s8, a6.s8, a7.s8,
766 a8.s8, a9.s8, aA.s8, aB.s8, aC.s8, aD.s8, aE.s8, aF.s8);
767 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9, a3.s9, a4.s9, a5.s9, a6.s9, a7.s9,
768 a8.s9, a9.s9, aA.s9, aB.s9, aC.s9, aD.s9, aE.s9, aF.s9);
769 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA, a3.sA, a4.sA, a5.sA, a6.sA, a7.sA,
770 a8.sA, a9.sA, aA.sA, aB.sA, aC.sA, aD.sA, aE.sA, aF.sA);
771 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB, a3.sB, a4.sB, a5.sB, a6.sB, a7.sB,
772 a8.sB, a9.sB, aA.sB, aB.sB, aC.sB, aD.sB, aE.sB, aF.sB);
773 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC, a3.sC, a4.sC, a5.sC, a6.sC, a7.sC,
774 a8.sC, a9.sC, aA.sC, aB.sC, aC.sC, aD.sC, aE.sC, aF.sC);
775 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD, a3.sD, a4.sD, a5.sD, a6.sD, a7.sD,
776 a8.sD, a9.sD, aA.sD, aB.sD, aC.sD, aD.sD, aE.sD, aF.sD);
777 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE, a3.sE, a4.sE, a5.sE, a6.sE, a7.sE,
778 a8.sE, a9.sE, aA.sE, aB.sE, aC.sE, aD.sE, aE.sE, aF.sE);
779 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF, a3.sF, a4.sF, a5.sF, a6.sF, a7.sF,
780 a8.sF, a9.sF, aA.sF, aB.sF, aC.sF, aD.sF, aE.sF, aF.sF);
781#endif // N0 > 8
782
783#else // N0 == 16
784#error "Not supported N0 value"
785#endif // N0 > 2
786
787 // ---------------------------Store the output values ------------------------------
788
789 VSTORE(K0)
790 (res0, 0, (__global DATA_TYPE *)(output_ptr + 0 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
791 VSTORE(K0)
792 (res1, 0, (__global DATA_TYPE *)(output_ptr + 1 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
793#if N0 > 2
794 VSTORE(K0)
795 (res2, 0, (__global DATA_TYPE *)(output_ptr + 2 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
796 VSTORE(K0)
797 (res3, 0, (__global DATA_TYPE *)(output_ptr + 3 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
798#endif // N0 > 2
799#if N0 > 4
800 VSTORE(K0)
801 (res4, 0, (__global DATA_TYPE *)(output_ptr + 4 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
802 VSTORE(K0)
803 (res5, 0, (__global DATA_TYPE *)(output_ptr + 5 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
804 VSTORE(K0)
805 (res6, 0, (__global DATA_TYPE *)(output_ptr + 6 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
806 VSTORE(K0)
807 (res7, 0, (__global DATA_TYPE *)(output_ptr + 7 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
808#endif // N0 > 4
809#if N0 > 8
810 VSTORE(K0)
811 (res8, 0, (__global DATA_TYPE *)(output_ptr + 8 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
812 VSTORE(K0)
813 (res9, 0, (__global DATA_TYPE *)(output_ptr + 9 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
814 VSTORE(K0)
815 (resA, 0, (__global DATA_TYPE *)(output_ptr + 10 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
816 VSTORE(K0)
817 (resB, 0, (__global DATA_TYPE *)(output_ptr + 11 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
818 VSTORE(K0)
819 (resC, 0, (__global DATA_TYPE *)(output_ptr + 12 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
820 VSTORE(K0)
821 (resD, 0, (__global DATA_TYPE *)(output_ptr + 13 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
822 VSTORE(K0)
823 (resE, 0, (__global DATA_TYPE *)(output_ptr + 14 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
824 VSTORE(K0)
825 (resF, 0, (__global DATA_TYPE *)(output_ptr + 15 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
826#endif // N0 > 8
827
828#undef BLOCK_SIZE
829#undef OUTPUT_OFFSET_X
830#undef OUTPUT_STEP_X
831}
832#endif // defined(TRANSPOSE)
833#endif // defined(K0) && defined(N0) && defined(H0) && defined(DATA_TYPE) && defined(SRC_HEIGHT)
834
Gian Marco36a0a462018-01-12 10:21:40 +0000835#if defined(TRANSPOSE_W) && defined(MULT_TRANSPOSE1XW_WIDTH)
836
Gian Marco19835e52018-01-30 13:35:54 +0000837#if ELEMENT_SIZE == 1
Gian Marco36a0a462018-01-12 10:21:40 +0000838#define DATA_TYPE uchar
Gian Marco19835e52018-01-30 13:35:54 +0000839#elif ELEMENT_SIZE == 2
840#define DATA_TYPE ushort
841#elif ELEMENT_SIZE == 4
842#define DATA_TYPE uint
843#else // ELEMENT_SIZE == 1
844#error "Element size not supported"
845#endif // ELEMENT_SIZE
Gian Marco36a0a462018-01-12 10:21:40 +0000846
847/** This OpenCL kernel computes the "vector" 1xW transposition of input matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100848 *
Gian Marco19835e52018-01-30 13:35:54 +0000849 * @note The transposition width must be passed at compile time using -DTRANSPOSE_W (i.e. -DTRANSPOSE_W)
850 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
Gian Marco36a0a462018-01-12 10:21:40 +0000851 *
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100852 * @param[in] src_ptr Pointer to the source matrix. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100853 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
854 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
855 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
856 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marcoae2af742018-02-15 12:35:44 +0000857 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
858 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100859 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100860 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100861 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000862 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100863 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000864 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marcoae2af742018-02-15 12:35:44 +0000865 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
866 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100867 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
868 */
Gian Marcoae2af742018-02-15 12:35:44 +0000869__kernel void gemm_transpose1xW(TENSOR3D_DECLARATION(src),
870 TENSOR3D_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100871{
872 uint x = get_global_id(0);
873 uint y = get_global_id(1);
Gian Marcoae2af742018-02-15 12:35:44 +0000874 uint z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100875
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100876 // Compute address for Matrix B - source
Gian Marcoae2af742018-02-15 12:35:44 +0000877 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100878
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100879 // Compute address for Matrix B transposed - destination. X and Y are swapped
Gian Marco36a0a462018-01-12 10:21:40 +0000880 uint dst_addr_in_bytes = dst_offset_first_element_in_bytes + y * TRANSPOSE_W * sizeof(DATA_TYPE) * MULT_TRANSPOSE1XW_WIDTH + (x / MULT_TRANSPOSE1XW_WIDTH) * dst_stride_y +
881 (x % MULT_TRANSPOSE1XW_WIDTH) * TRANSPOSE_W * sizeof(DATA_TYPE);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100882
Gian Marcoae2af742018-02-15 12:35:44 +0000883 // Add offset for batched GEMM
884 dst_addr_in_bytes += z * dst_stride_z;
885
Gian Marco36a0a462018-01-12 10:21:40 +0000886 VEC_DATA_TYPE(DATA_TYPE, TRANSPOSE_W)
887 b0 = VLOAD(TRANSPOSE_W)(0, (__global DATA_TYPE *)src.ptr);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100888
Gian Marco36a0a462018-01-12 10:21:40 +0000889 VSTORE(TRANSPOSE_W)
890 (b0, 0, (__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100891}
Gian Marco36a0a462018-01-12 10:21:40 +0000892#endif // defined(TRANSPOSE_W) && defined(MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100893
Gian Marco36a0a462018-01-12 10:21:40 +0000894#if defined(MULT_INTERLEAVE4X4_HEIGHT) && defined(DATA_TYPE)
895
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100896/** This OpenCL kernel reshapes the input matrix transposing each 4x4 block. If -DUNROLL_BLOCK is passed at compile time, the 4x4 block
897 * will be simply unrolled.
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100898 *
Gian Marco19835e52018-01-30 13:35:54 +0000899 * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float)
900 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +0100901 * @note In case the input has to be reinterpreted as a 3D tensor (i.e. input of convolution layer 1x1), the following information must be passed at compile time:
902 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
903 * -# HEIGHT_GEMM3D: The height of the input in case it has to be reinterpreted as a 3D tensor.
904 * -# DEPTH_GEMM3D: The depth of the input in case it has to be reinterpreted as a 3D tensor
905 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
Gian Marco19835e52018-01-30 13:35:54 +0000906 *
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100907 * @param[in] src_ptr Pointer to the source matrix. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100908 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
909 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
910 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
911 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marcoae2af742018-02-15 12:35:44 +0000912 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
913 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100914 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100915 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100916 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
917 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
918 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
919 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marcoae2af742018-02-15 12:35:44 +0000920 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
921 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100922 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Gian Marco Iodice68a3f562018-07-26 11:44:03 +0100923 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100924 */
Gian Marcoae2af742018-02-15 12:35:44 +0000925__kernel void gemm_interleave4x4(TENSOR3D_DECLARATION(src),
Gian Marco Iodice68a3f562018-07-26 11:44:03 +0100926 TENSOR3D_DECLARATION(dst)
927#if defined(REINTERPRET_INPUT_AS_3D)
928 ,
929 uint cross_plane_pad
930#endif // REINTERPRET_INPUT_AS_3D
931 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100932{
Gian Marco36a0a462018-01-12 10:21:40 +0000933 // Compute source and destination addresses
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100934 uint x = get_global_id(0);
935 uint y = get_global_id(1);
Gian Marcoae2af742018-02-15 12:35:44 +0000936 uint z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100937
Gian Marcoae2af742018-02-15 12:35:44 +0000938 // Compute address for source tensor
939 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100940
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000941 // Compute address for Matrix B transposed - destination. X and Y are swapped
Gian Marco36a0a462018-01-12 10:21:40 +0000942 uint dst_addr_in_bytes = dst_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) * 16 * MULT_INTERLEAVE4X4_HEIGHT + (y / MULT_INTERLEAVE4X4_HEIGHT) * dst_stride_y +
943 (y % MULT_INTERLEAVE4X4_HEIGHT) * 4 * sizeof(DATA_TYPE);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100944
Gian Marcoae2af742018-02-15 12:35:44 +0000945 // Add offset for batched GEMM
946 dst_addr_in_bytes += z * dst_stride_z;
947
Gian Marco Iodice68a3f562018-07-26 11:44:03 +0100948#if defined(REINTERPRET_INPUT_AS_3D)
949 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * 4 * sizeof(DATA_TYPE) + y * 4 * src_stride_y;
950
951 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
952 // in order to take into account the presence of possible cross plane paddings
953 //
954 // | |
955 // | plane0 |
956 // | |
957 // |__________________|
958 // |******************|
959 // | cross_plane_pad |
960 // |******************|
961 // | |
962 // | plane1 |
963 // | |
964 // |__________________|
965
966 // The plane (zin) is calculated dividing M (y * 4) by HEIGHT_GEMM3D
967 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(y * 4)) / (uint4)HEIGHT_GEMM3D;
968 zin = min(DEPTH_GEMM3D - 1, zin);
969
970 // Add offset due to the cross plane paddings
971 zin *= (cross_plane_pad * src_stride_y);
972
973 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
974 // multiply src_stride_z by DEPTH_GEMM3D
975 input_ptr += z * src_stride_z * DEPTH_GEMM3D;
976
977 // Load values from Matrix A
978 VEC_DATA_TYPE(DATA_TYPE, 4)
979 a0 = vload4(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y + zin.s0));
980 VEC_DATA_TYPE(DATA_TYPE, 4)
981 a1 = vload4(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y + zin.s1));
982 VEC_DATA_TYPE(DATA_TYPE, 4)
983 a2 = vload4(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y + zin.s2));
984 VEC_DATA_TYPE(DATA_TYPE, 4)
985 a3 = vload4(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y + zin.s3));
986#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +0000987 __global uchar *input_ptr = src.ptr;
988
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000989 // Load values from Matrix A
Gian Marco36a0a462018-01-12 10:21:40 +0000990 VEC_DATA_TYPE(DATA_TYPE, 4)
Gian Marcoae2af742018-02-15 12:35:44 +0000991 a0 = vload4(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y));
Gian Marco36a0a462018-01-12 10:21:40 +0000992 VEC_DATA_TYPE(DATA_TYPE, 4)
Gian Marcoae2af742018-02-15 12:35:44 +0000993 a1 = vload4(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y));
Gian Marco36a0a462018-01-12 10:21:40 +0000994 VEC_DATA_TYPE(DATA_TYPE, 4)
Gian Marcoae2af742018-02-15 12:35:44 +0000995 a2 = vload4(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y));
Gian Marco36a0a462018-01-12 10:21:40 +0000996 VEC_DATA_TYPE(DATA_TYPE, 4)
Gian Marcoae2af742018-02-15 12:35:44 +0000997 a3 = vload4(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y));
Gian Marco Iodice68a3f562018-07-26 11:44:03 +0100998#endif // defined(REINTERPRET_INPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100999
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001000#if defined(UNROLL_BLOCK)
1001 vstore4(a0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 0 * MULT_INTERLEAVE4X4_HEIGHT));
1002 vstore4(a1, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 4 * MULT_INTERLEAVE4X4_HEIGHT));
1003 vstore4(a2, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 8 * MULT_INTERLEAVE4X4_HEIGHT));
1004 vstore4(a3, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 12 * MULT_INTERLEAVE4X4_HEIGHT));
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +00001005#else // defined(UNROLL_BLOCK)
Gian Marco36a0a462018-01-12 10:21:40 +00001006 VEC_DATA_TYPE(DATA_TYPE, 4)
1007 val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s0, a1.s0, a2.s0, a3.s0);
1008 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 0 * MULT_INTERLEAVE4X4_HEIGHT));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001009
Gian Marco36a0a462018-01-12 10:21:40 +00001010 val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s1, a1.s1, a2.s1, a3.s1);
1011 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 4 * MULT_INTERLEAVE4X4_HEIGHT));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001012
Gian Marco36a0a462018-01-12 10:21:40 +00001013 val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s2, a1.s2, a2.s2, a3.s2);
1014 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 8 * MULT_INTERLEAVE4X4_HEIGHT));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001015
Gian Marco36a0a462018-01-12 10:21:40 +00001016 val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s3, a1.s3, a2.s3, a3.s3);
1017 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 12 * MULT_INTERLEAVE4X4_HEIGHT));
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001018#endif // defined(UNROLL_BLOCK)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001019}
Gian Marco36a0a462018-01-12 10:21:40 +00001020#endif // defined(MULT_INTERLEAVE4X4_HEIGHT) && defined(DATA_TYPE)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001021
Gian Marco36a0a462018-01-12 10:21:40 +00001022#if defined(COLS_B) && defined(MULT_TRANSPOSE1XW_WIDTH) && defined(MULT_INTERLEAVE4X4_HEIGHT)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001023/** This OpenCL kernel is optimised for Midgard. It computes the matrix multiplication between matrix A (src0) and matrix B (src1)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001024 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_32bit and @ref gemm_transpose1x4 before running the matrix multiplication
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001025 *
Gian Marco19835e52018-01-30 13:35:54 +00001026 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
1027 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
1028 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001029 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
1030 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001031 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001032 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
1033 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1034 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1035 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1036 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
1037 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001038 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
1039 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1040 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1041 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1042 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1043 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001044 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001045 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1046 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1047 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1048 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1049 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001050 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001051 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00001052 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001053 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00001054 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001055 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001056 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
1057 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
1058 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001059 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001060 */
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01001061__kernel void gemm_mm_interleaved_transposed_f32(IMAGE_DECLARATION(src0),
1062 IMAGE_DECLARATION(src1),
1063 IMAGE_DECLARATION(dst),
1064 uint src0_stride_z,
1065 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001066 uint dst_stride_z
1067#if defined(REINTERPRET_OUTPUT_AS_3D)
1068 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001069 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001070#endif // REINTERPRET_OUTPUT_AS_3D
1071 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001072{
Gian Marco36a0a462018-01-12 10:21:40 +00001073 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
1074 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +00001075 int z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001076
Gian Marco36a0a462018-01-12 10:21:40 +00001077 // Offset
1078 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
1079 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 4;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001080
Gian Marco36a0a462018-01-12 10:21:40 +00001081 // src_addr_a = address of matrix A
1082 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001083 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
1084 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
1085
1086#if defined(MATRIX_B_DEPTH)
1087 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1088 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
1089#else // defined(MATRIX_B_DEPTH)
1090 src1_addr_in_bytes += z * src1_stride_z;
1091#endif // defined(MATRIX_B_DEPTH)
1092
1093 __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes);
1094 __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001095
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001096 // Compute end row address for matrix B
Gian Marco36a0a462018-01-12 10:21:40 +00001097 __global float *src_end_addr_b = src_addr_b + COLS_B;
1098
1099 src_addr_a += offset_row_a;
1100 src_addr_b += offset_row_b;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001101
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001102 // Reset accumulators
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001103 float4 c00 = 0.0f;
1104 float4 c10 = 0.0f;
1105 float4 c20 = 0.0f;
1106 float4 c30 = 0.0f;
1107
Gian Marco36a0a462018-01-12 10:21:40 +00001108 for(; src_addr_b <= (src_end_addr_b - (int)(8 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001109 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001110 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00001111 float4 a0 = vload4(0, src_addr_a);
1112 float4 b0 = vload4(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001113
1114 c00 += (float4)a0.s0 * b0;
1115 c10 += (float4)a0.s1 * b0;
1116 c20 += (float4)a0.s2 * b0;
1117 c30 += (float4)a0.s3 * b0;
1118
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001119 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00001120 a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
1121 b0 = vload4(0, src_addr_b + 4 * MULT_TRANSPOSE1XW_WIDTH);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001122
1123 c00 += (float4)a0.s0 * b0;
1124 c10 += (float4)a0.s1 * b0;
1125 c20 += (float4)a0.s2 * b0;
1126 c30 += (float4)a0.s3 * b0;
1127 }
1128
Gian Marco36a0a462018-01-12 10:21:40 +00001129 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001130 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001131 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00001132 float4 a0 = vload4(0, src_addr_a);
1133 float4 b0 = vload4(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001134
1135 c00 += (float4)a0.s0 * b0;
1136 c10 += (float4)a0.s1 * b0;
1137 c20 += (float4)a0.s2 * b0;
1138 c30 += (float4)a0.s3 * b0;
1139 }
1140
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001141 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001142 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1143
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001144#if defined(ALPHA)
1145 // Multiply by the weight of matrix product
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001146 c00 = c00 * (float4)ALPHA;
1147 c10 = c10 * (float4)ALPHA;
1148 c20 = c20 * (float4)ALPHA;
1149 c30 = c30 * (float4)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001150#endif // defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001151
Gian Marcoae2af742018-02-15 12:35:44 +00001152 // Compute dst address
1153 __global uchar *dst_addr = offset(&dst, 0, 0);
1154
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001155#if defined(REINTERPRET_OUTPUT_AS_3D)
1156 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001157 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001158 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001159 // | |
1160 // | plane0 |
1161 // | |
1162 // |__________________|
1163 // |******************|
1164 // | cross_plane_pad |
1165 // |******************|
1166 // | |
1167 // | plane1 |
1168 // | |
1169 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001170
1171 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
1172 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
1173 zout = min(DEPTH_GEMM3D - 1, zout);
1174
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001175 // Add offset due to the cross plane paddings
1176 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001177
1178 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1179 // multiply dst_stride_z by DEPTH_GEMM3D
1180 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1181
1182 // Store 4x4 block
1183 vstore4(c00, 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
1184 vstore4(c10, 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
1185 vstore4(c20, 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
1186 vstore4(c30, 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
1187
1188#else // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +00001189 // Add offset for batched GEMM
1190 dst_addr += z * dst_stride_z;
1191
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001192 // Store 4x4 block
Gian Marcoae2af742018-02-15 12:35:44 +00001193 vstore4(c00, 0, (__global float *)(dst_addr + 0 * dst_stride_y));
1194 vstore4(c10, 0, (__global float *)(dst_addr + 1 * dst_stride_y));
1195 vstore4(c20, 0, (__global float *)(dst_addr + 2 * dst_stride_y));
1196 vstore4(c30, 0, (__global float *)(dst_addr + 3 * dst_stride_y));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001197#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001198}
1199
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001200/** This OpenCL kernel is optimized for Bifrost. It computes the matrix multiplication between matrix A (src0) and matrix B (src1)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001201 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_32bit and @ref gemm_transpose1x4 before running the matrix multiplication
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001202 *
Gian Marco19835e52018-01-30 13:35:54 +00001203 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
1204 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
1205 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001206 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
1207 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
1208 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001209 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001210 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
1211 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1212 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1213 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1214 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
1215 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001216 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
1217 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1218 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1219 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1220 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1221 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001222 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001223 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1224 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1225 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1226 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1227 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001228 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001229 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00001230 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001231 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00001232 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001233 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001234 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
1235 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
1236 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001237 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001238 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001239__kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0),
1240 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +00001241 IMAGE_DECLARATION(dst),
1242 uint src0_stride_z,
1243 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001244 uint dst_stride_z
1245#if defined(REINTERPRET_OUTPUT_AS_3D)
1246 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001247 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001248#endif // REINTERPRET_OUTPUT_AS_3D
1249 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001250{
Gian Marco36a0a462018-01-12 10:21:40 +00001251 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
1252 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +00001253 int z = get_global_id(2);
Gian Marco36a0a462018-01-12 10:21:40 +00001254
1255 // Offset
1256 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
1257 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 4;
1258
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001259 // src_addr_a = address of matrix A
1260 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001261 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
1262 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
1263
1264#if defined(MATRIX_B_DEPTH)
1265 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1266 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
1267#else // defined(MATRIX_B_DEPTH)
1268 src1_addr_in_bytes += z * src1_stride_z;
1269#endif // defined(MATRIX_B_DEPTH)
1270
1271 __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes);
1272 __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001273
Gian Marco36a0a462018-01-12 10:21:40 +00001274 src_addr_a += offset_row_a;
1275 src_addr_b += offset_row_b;
1276
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001277 // Reset accumulators
1278 float c00 = 0.0f;
1279 float c01 = 0.0f;
1280 float c02 = 0.0f;
1281 float c03 = 0.0f;
1282 float c10 = 0.0f;
1283 float c11 = 0.0f;
1284 float c12 = 0.0f;
1285 float c13 = 0.0f;
1286 float c20 = 0.0f;
1287 float c21 = 0.0f;
1288 float c22 = 0.0f;
1289 float c23 = 0.0f;
1290 float c30 = 0.0f;
1291 float c31 = 0.0f;
1292 float c32 = 0.0f;
1293 float c33 = 0.0f;
1294
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001295#define COLS_MTX_B (COLS_B / (4 * MULT_TRANSPOSE1XW_WIDTH))
1296
1297 int i = 0;
1298 for(; i <= (int)(COLS_MTX_B - 4); i += 4)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001299 {
1300 // Load values from matrix A (interleaved) and matrix B (transposed)
1301 float4 a0 = vload4(0, src_addr_a);
1302 float4 b0 = vload4(0, src_addr_b);
1303
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001304 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
1305 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001306
1307 c00 = fma(a0.s0, b0.s0, c00);
1308 c01 = fma(a0.s0, b0.s1, c01);
1309 c02 = fma(a0.s0, b0.s2, c02);
1310 c03 = fma(a0.s0, b0.s3, c03);
1311
1312 c10 = fma(a0.s1, b0.s0, c10);
1313 c11 = fma(a0.s1, b0.s1, c11);
1314 c12 = fma(a0.s1, b0.s2, c12);
1315 c13 = fma(a0.s1, b0.s3, c13);
1316
1317 c20 = fma(a0.s2, b0.s0, c20);
1318 c21 = fma(a0.s2, b0.s1, c21);
1319 c22 = fma(a0.s2, b0.s2, c22);
1320 c23 = fma(a0.s2, b0.s3, c23);
1321
1322 c30 = fma(a0.s3, b0.s0, c30);
1323 c31 = fma(a0.s3, b0.s1, c31);
1324 c32 = fma(a0.s3, b0.s2, c32);
1325 c33 = fma(a0.s3, b0.s3, c33);
1326
1327 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001328 a0 = vload4(0, src_addr_a);
1329 b0 = vload4(0, src_addr_b);
1330
1331 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
1332 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001333
1334 c00 = fma(a0.s0, b0.s0, c00);
1335 c01 = fma(a0.s0, b0.s1, c01);
1336 c02 = fma(a0.s0, b0.s2, c02);
1337 c03 = fma(a0.s0, b0.s3, c03);
1338
1339 c10 = fma(a0.s1, b0.s0, c10);
1340 c11 = fma(a0.s1, b0.s1, c11);
1341 c12 = fma(a0.s1, b0.s2, c12);
1342 c13 = fma(a0.s1, b0.s3, c13);
1343
1344 c20 = fma(a0.s2, b0.s0, c20);
1345 c21 = fma(a0.s2, b0.s1, c21);
1346 c22 = fma(a0.s2, b0.s2, c22);
1347 c23 = fma(a0.s2, b0.s3, c23);
1348
1349 c30 = fma(a0.s3, b0.s0, c30);
1350 c31 = fma(a0.s3, b0.s1, c31);
1351 c32 = fma(a0.s3, b0.s2, c32);
1352 c33 = fma(a0.s3, b0.s3, c33);
1353
1354 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001355 a0 = vload4(0, src_addr_a);
1356 b0 = vload4(0, src_addr_b);
1357
1358 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
1359 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
1360
1361 c00 = fma(a0.s0, b0.s0, c00);
1362 c01 = fma(a0.s0, b0.s1, c01);
1363 c02 = fma(a0.s0, b0.s2, c02);
1364 c03 = fma(a0.s0, b0.s3, c03);
1365
1366 c10 = fma(a0.s1, b0.s0, c10);
1367 c11 = fma(a0.s1, b0.s1, c11);
1368 c12 = fma(a0.s1, b0.s2, c12);
1369 c13 = fma(a0.s1, b0.s3, c13);
1370
1371 c20 = fma(a0.s2, b0.s0, c20);
1372 c21 = fma(a0.s2, b0.s1, c21);
1373 c22 = fma(a0.s2, b0.s2, c22);
1374 c23 = fma(a0.s2, b0.s3, c23);
1375
1376 c30 = fma(a0.s3, b0.s0, c30);
1377 c31 = fma(a0.s3, b0.s1, c31);
1378 c32 = fma(a0.s3, b0.s2, c32);
1379 c33 = fma(a0.s3, b0.s3, c33);
1380
1381 // Load values from matrix A (interleaved) and matrix B (transposed)
1382 a0 = vload4(0, src_addr_a);
1383 b0 = vload4(0, src_addr_b);
1384
1385 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
1386 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001387
1388 c00 = fma(a0.s0, b0.s0, c00);
1389 c01 = fma(a0.s0, b0.s1, c01);
1390 c02 = fma(a0.s0, b0.s2, c02);
1391 c03 = fma(a0.s0, b0.s3, c03);
1392
1393 c10 = fma(a0.s1, b0.s0, c10);
1394 c11 = fma(a0.s1, b0.s1, c11);
1395 c12 = fma(a0.s1, b0.s2, c12);
1396 c13 = fma(a0.s1, b0.s3, c13);
1397
1398 c20 = fma(a0.s2, b0.s0, c20);
1399 c21 = fma(a0.s2, b0.s1, c21);
1400 c22 = fma(a0.s2, b0.s2, c22);
1401 c23 = fma(a0.s2, b0.s3, c23);
1402
1403 c30 = fma(a0.s3, b0.s0, c30);
1404 c31 = fma(a0.s3, b0.s1, c31);
1405 c32 = fma(a0.s3, b0.s2, c32);
1406 c33 = fma(a0.s3, b0.s3, c33);
1407 }
1408
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001409 for(; i < (int)(COLS_MTX_B); ++i)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001410 {
1411 // Load values from matrix A (interleaved) and matrix B (transposed)
1412 float4 a0 = vload4(0, src_addr_a);
1413 float4 b0 = vload4(0, src_addr_b);
1414
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001415 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
1416 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
1417
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001418 c00 = fma(a0.s0, b0.s0, c00);
1419 c01 = fma(a0.s0, b0.s1, c01);
1420 c02 = fma(a0.s0, b0.s2, c02);
1421 c03 = fma(a0.s0, b0.s3, c03);
1422
1423 c10 = fma(a0.s1, b0.s0, c10);
1424 c11 = fma(a0.s1, b0.s1, c11);
1425 c12 = fma(a0.s1, b0.s2, c12);
1426 c13 = fma(a0.s1, b0.s3, c13);
1427
1428 c20 = fma(a0.s2, b0.s0, c20);
1429 c21 = fma(a0.s2, b0.s1, c21);
1430 c22 = fma(a0.s2, b0.s2, c22);
1431 c23 = fma(a0.s2, b0.s3, c23);
1432
1433 c30 = fma(a0.s3, b0.s0, c30);
1434 c31 = fma(a0.s3, b0.s1, c31);
1435 c32 = fma(a0.s3, b0.s2, c32);
1436 c33 = fma(a0.s3, b0.s3, c33);
1437 }
1438
1439 // Compute destination address
1440 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1441
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001442#if defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001443 // Multiply by the weight of matrix product
1444 c00 = c00 * ALPHA;
1445 c01 = c01 * ALPHA;
1446 c02 = c02 * ALPHA;
1447 c03 = c03 * ALPHA;
1448 c10 = c10 * ALPHA;
1449 c11 = c11 * ALPHA;
1450 c12 = c12 * ALPHA;
1451 c13 = c13 * ALPHA;
1452 c20 = c20 * ALPHA;
1453 c21 = c21 * ALPHA;
1454 c22 = c22 * ALPHA;
1455 c23 = c23 * ALPHA;
1456 c30 = c30 * ALPHA;
1457 c31 = c31 * ALPHA;
1458 c32 = c32 * ALPHA;
1459 c33 = c33 * ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001460#endif // defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001461
Gian Marcoae2af742018-02-15 12:35:44 +00001462 // Compute dst address
1463 __global uchar *dst_addr = offset(&dst, 0, 0);
1464
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001465#if defined(REINTERPRET_OUTPUT_AS_3D)
1466 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001467 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001468 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001469 // | |
1470 // | plane0 |
1471 // | |
1472 // |__________________|
1473 // |******************|
1474 // | cross_plane_pad |
1475 // |******************|
1476 // | |
1477 // | plane1 |
1478 // | |
1479 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001480
1481 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
1482 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
1483 zout = min(DEPTH_GEMM3D - 1, zout);
1484
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001485 // Add offset due to the cross plane paddings
1486 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001487
1488 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1489 // multiply dst_stride_z by DEPTH_GEMM3D
1490 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1491
1492 // Store 4x4 block
1493 vstore4((float4)(c00, c01, c02, c03), 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
1494 vstore4((float4)(c10, c11, c12, c13), 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
1495 vstore4((float4)(c20, c21, c22, c23), 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
1496 vstore4((float4)(c30, c31, c32, c33), 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
1497
1498#else // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +00001499 // Add offset for batched GEMM
1500 dst_addr += z * dst_stride_z;
1501
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001502 // Store 4x4 block
Gian Marcoae2af742018-02-15 12:35:44 +00001503 vstore4((float4)(c00, c01, c02, c03), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
1504 vstore4((float4)(c10, c11, c12, c13), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
1505 vstore4((float4)(c20, c21, c22, c23), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
1506 vstore4((float4)(c30, c31, c32, c33), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001507#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001508}
1509
Georgios Pinitas84225582018-05-14 12:00:05 +01001510// Undefine local defines
1511#undef COLS_MTX_B
1512
Matthew Bentham6f31f8c2017-10-27 11:50:06 +01001513#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001514/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001515 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001516 *
Gian Marco19835e52018-01-30 13:35:54 +00001517 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
1518 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
1519 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001520 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
1521 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001522 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001523 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
1524 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1525 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1526 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1527 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
1528 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001529 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
1530 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1531 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1532 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1533 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1534 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001535 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001536 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1537 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1538 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1539 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1540 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001541 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001542 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00001543 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001544 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00001545 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001546 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001547 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
1548 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
1549 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001550 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001551 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001552__kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0),
1553 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +00001554 IMAGE_DECLARATION(dst),
1555 uint src0_stride_z,
1556 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001557 uint dst_stride_z
1558#if defined(REINTERPRET_OUTPUT_AS_3D)
1559 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001560 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001561#endif // REINTERPRET_OUTPUT_AS_3D
1562 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001563{
Gian Marco36a0a462018-01-12 10:21:40 +00001564 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
1565 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +00001566 int z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001567
Gian Marco36a0a462018-01-12 10:21:40 +00001568 // Offset
1569 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
1570 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001571
Gian Marco36a0a462018-01-12 10:21:40 +00001572 // src_addr_a = address of matrix A
1573 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001574 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
1575 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
1576
1577#if defined(MATRIX_B_DEPTH)
1578 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1579 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
1580#else // defined(MATRIX_B_DEPTH)
1581 src1_addr_in_bytes += z * src1_stride_z;
1582#endif // defined(MATRIX_B_DEPTH)
1583
1584 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
1585 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001586
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001587 // Compute end row address for matrix B
Gian Marco36a0a462018-01-12 10:21:40 +00001588 __global half *src_end_addr_b = src_addr_b + COLS_B;
1589
1590 src_addr_a += offset_row_a;
1591 src_addr_b += offset_row_b;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001592
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001593 // Reset accumulators
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001594 half8 c00 = 0.0f;
1595 half8 c10 = 0.0f;
1596 half8 c20 = 0.0f;
1597 half8 c30 = 0.0f;
1598
Gian Marco36a0a462018-01-12 10:21:40 +00001599 for(; src_addr_b <= (src_end_addr_b - (int)(16 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 16 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001600 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001601 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00001602 half4 a0 = vload4(0, src_addr_a);
1603 half8 b0 = vload8(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001604
1605 c00 += (half8)a0.s0 * b0;
1606 c10 += (half8)a0.s1 * b0;
1607 c20 += (half8)a0.s2 * b0;
1608 c30 += (half8)a0.s3 * b0;
1609
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001610 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00001611 a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
1612 b0 = vload8(0, src_addr_b + 8 * MULT_TRANSPOSE1XW_WIDTH);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001613
1614 c00 += (half8)a0.s0 * b0;
1615 c10 += (half8)a0.s1 * b0;
1616 c20 += (half8)a0.s2 * b0;
1617 c30 += (half8)a0.s3 * b0;
1618 }
1619
Gian Marco36a0a462018-01-12 10:21:40 +00001620 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001621 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001622 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00001623 half4 a0 = vload4(0, src_addr_a);
1624 half8 b0 = vload8(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001625
1626 c00 += (half8)a0.s0 * b0;
1627 c10 += (half8)a0.s1 * b0;
1628 c20 += (half8)a0.s2 * b0;
1629 c30 += (half8)a0.s3 * b0;
1630 }
1631
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001632 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001633 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1634
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001635#if defined(ALPHA)
1636 // Multiply by the weight of matrix product
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001637 c00 = c00 * (half8)ALPHA;
1638 c10 = c10 * (half8)ALPHA;
1639 c20 = c20 * (half8)ALPHA;
1640 c30 = c30 * (half8)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001641#endif // defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001642
Gian Marcoae2af742018-02-15 12:35:44 +00001643 // Compute dst address
1644 __global uchar *dst_addr = offset(&dst, 0, 0);
1645
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001646#if defined(REINTERPRET_OUTPUT_AS_3D)
1647 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001648 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001649 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001650 // | |
1651 // | plane0 |
1652 // | |
1653 // |__________________|
1654 // |******************|
1655 // | cross_plane_pad |
1656 // |******************|
1657 // | |
1658 // | plane1 |
1659 // | |
1660 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001661
1662 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
1663 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
1664 zout = min(DEPTH_GEMM3D - 1, zout);
1665
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001666 // Add offset due to the cross plane paddings
1667 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001668
1669 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1670 // multiply dst_stride_z by DEPTH_GEMM3D
1671 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1672
1673 // Store 4x8 block
1674 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
1675 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
1676 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
1677 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
1678
1679#else // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +00001680 // Add offset for batched GEMM
1681 dst_addr += z * dst_stride_z;
1682
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001683 // Store 4x8 block
Gian Marcoae2af742018-02-15 12:35:44 +00001684 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
1685 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
1686 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
1687 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001688#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001689}
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01001690
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00001691/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1) while accumulating the result in a 32 floating point variable.
1692 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
1693 *
1694 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
1695 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
1696 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
1697 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
1698 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
1699 *
1700 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
1701 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1702 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1703 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1704 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
1705 *
1706 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
1707 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1708 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1709 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1710 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1711 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
1712 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
1713 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1714 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1715 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1716 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1717 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
1718 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
1719 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1720 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1721 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1722 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1723 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1724 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
1725 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
1726 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1727 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
1728 */
1729__kernel void gemm_mm_interleaved_transposed_f16_acc32(IMAGE_DECLARATION(src0),
1730 IMAGE_DECLARATION(src1),
1731 IMAGE_DECLARATION(dst),
1732 uint src0_stride_z,
1733 uint src1_stride_z,
1734 uint dst_stride_z
1735#if defined(REINTERPRET_OUTPUT_AS_3D)
1736 ,
1737 uint cross_plane_pad
1738#endif // REINTERPRET_OUTPUT_AS_3D
1739 )
1740{
1741 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
1742 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
1743 int z = get_global_id(2);
1744
1745 // Offset
1746 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
1747 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
1748
1749 // src_addr_a = address of matrix A
1750 // src_addr_b = address of matrix B
1751 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
1752 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
1753
1754#if defined(MATRIX_B_DEPTH)
1755 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1756 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
1757#else // defined(MATRIX_B_DEPTH)
1758 src1_addr_in_bytes += z * src1_stride_z;
1759#endif // defined(MATRIX_B_DEPTH)
1760
1761 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
1762 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
1763
1764 // Compute end row address for matrix B
1765 __global half *src_end_addr_b = src_addr_b + COLS_B;
1766
1767 src_addr_a += offset_row_a;
1768 src_addr_b += offset_row_b;
1769
1770 // Reset accumulators
1771 float8 c00 = 0.0f;
1772 float8 c10 = 0.0f;
1773 float8 c20 = 0.0f;
1774 float8 c30 = 0.0f;
1775
1776 for(; src_addr_b <= (src_end_addr_b - (int)(16 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 16 * MULT_TRANSPOSE1XW_WIDTH)
1777 {
1778 // Load values from matrix A (interleaved) and matrix B (transposed)
1779 float4 a0 = convert_float4(vload4(0, src_addr_a));
1780 float8 b0 = convert_float8(vload8(0, src_addr_b));
1781
1782 c00 += (float8)a0.s0 * b0;
1783 c10 += (float8)a0.s1 * b0;
1784 c20 += (float8)a0.s2 * b0;
1785 c30 += (float8)a0.s3 * b0;
1786
1787 // Load values from matrix A (interleaved) and matrix B (transposed)
1788 a0 = convert_float4(vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT));
1789 b0 = convert_float8(vload8(0, src_addr_b + 8 * MULT_TRANSPOSE1XW_WIDTH));
1790
1791 c00 += (float8)a0.s0 * b0;
1792 c10 += (float8)a0.s1 * b0;
1793 c20 += (float8)a0.s2 * b0;
1794 c30 += (float8)a0.s3 * b0;
1795 }
1796
1797 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
1798 {
1799 // Load values from matrix A (interleaved) and matrix B (transposed)
1800 float4 a0 = convert_float4(vload4(0, src_addr_a));
1801 float8 b0 = convert_float8(vload8(0, src_addr_b));
1802
1803 c00 += (float8)a0.s0 * b0;
1804 c10 += (float8)a0.s1 * b0;
1805 c20 += (float8)a0.s2 * b0;
1806 c30 += (float8)a0.s3 * b0;
1807 }
1808
1809 // Compute destination address
1810 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1811
1812#if defined(ALPHA)
1813 // Multiply by the weight of matrix product
1814 c00 = c00 * (float8)ALPHA;
1815 c10 = c10 * (float8)ALPHA;
1816 c20 = c20 * (float8)ALPHA;
1817 c30 = c30 * (float8)ALPHA;
1818#endif // defined(ALPHA)
1819
1820 // Compute dst address
1821 __global uchar *dst_addr = offset(&dst, 0, 0);
1822
1823#if defined(REINTERPRET_OUTPUT_AS_3D)
1824 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
1825 // in order to take into account the presence of possible cross plane paddings
1826 //
1827 // | |
1828 // | plane0 |
1829 // | |
1830 // |__________________|
1831 // |******************|
1832 // | cross_plane_pad |
1833 // |******************|
1834 // | |
1835 // | plane1 |
1836 // | |
1837 // |__________________|
1838
1839 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
1840 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
1841 zout = min(DEPTH_GEMM3D - 1, zout);
1842
1843 // Add offset due to the cross plane paddings
1844 zout *= (cross_plane_pad * dst_stride_y);
1845
1846 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1847 // multiply dst_stride_z by DEPTH_GEMM3D
1848 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1849
1850 // Store 4x8 block
1851 vstore8(convert_half8(c00), 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
1852 vstore8(convert_half8(c10), 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
1853 vstore8(convert_half8(c20), 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
1854 vstore8(convert_half8(c30), 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
1855
1856#else // defined(REINTERPRET_OUTPUT_AS_3D)
1857 // Add offset for batched GEMM
1858 dst_addr += z * dst_stride_z;
1859
1860 // Store 4x8 block
1861 vstore8(convert_half8(c00), 0, (__global half *)(dst_addr + 0 * dst_stride_y));
1862 vstore8(convert_half8(c10), 0, (__global half *)(dst_addr + 1 * dst_stride_y));
1863 vstore8(convert_half8(c20), 0, (__global half *)(dst_addr + 2 * dst_stride_y));
1864 vstore8(convert_half8(c30), 0, (__global half *)(dst_addr + 3 * dst_stride_y));
1865#endif // defined(REINTERPRET_OUTPUT_AS_3D)
1866}
1867
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01001868/** This OpenCL kernel optimized for Bifrost architectures computes the matrix multiplication between matrix A (src0) and matrix B (src1)
1869 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
1870 *
1871 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
1872 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
1873 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
1874 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
1875 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
1876 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001877 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
1878 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1879 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1880 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1881 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
1882 *
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01001883 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
1884 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1885 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1886 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1887 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1888 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
1889 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
1890 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1891 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1892 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1893 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1894 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
1895 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
1896 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1897 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1898 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1899 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1900 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001901 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01001902 */
1903__kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0),
1904 IMAGE_DECLARATION(src1),
1905 IMAGE_DECLARATION(dst),
1906 uint src0_stride_z,
1907 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001908 uint dst_stride_z
1909#if defined(REINTERPRET_OUTPUT_AS_3D)
1910 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001911 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001912#endif // REINTERPRET_OUTPUT_AS_3D
1913 )
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01001914{
1915 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
1916 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
1917 int z = get_global_id(2);
1918
1919 // Offset
1920 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
1921 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
1922
1923 // src_addr_a = address of matrix A
1924 // src_addr_b = address of matrix B
1925 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
1926 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
1927
1928#if defined(MATRIX_B_DEPTH)
1929 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1930 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
1931#else // defined(MATRIX_B_DEPTH)
1932 src1_addr_in_bytes += z * src1_stride_z;
1933#endif // defined(MATRIX_B_DEPTH)
1934
1935 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
1936 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
1937
1938 // Compute end row address for matrix B
1939 __global half *src_end_addr_b = src_addr_b + COLS_B;
1940
1941 src_addr_a += offset_row_a;
1942 src_addr_b += offset_row_b;
1943
1944 // Reset accumulators
1945 half8 c00 = 0.0f;
1946 half8 c10 = 0.0f;
1947 half8 c20 = 0.0f;
1948 half8 c30 = 0.0f;
1949
1950#define COLS_MTX_B (COLS_B / (8 * MULT_TRANSPOSE1XW_WIDTH))
1951
1952 int i = 0;
1953 for(; i <= (int)(COLS_MTX_B - 4); i += 4)
1954 {
1955#if MULT_INTERLEAVE4X4_HEIGHT == 1
1956 // Load values from matrix A (interleaved) and matrix B (transposed)
1957 half8 a0 = vload8(0, src_addr_a);
1958 half8 b0 = vload8(0, src_addr_b);
1959
1960 src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT;
1961 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
1962
1963 c00 = fma((half8)a0.s0, b0, c00);
1964 c10 = fma((half8)a0.s1, b0, c10);
1965 c20 = fma((half8)a0.s2, b0, c20);
1966 c30 = fma((half8)a0.s3, b0, c30);
1967
1968 // Load values from matrix B (transposed)
1969 b0 = vload8(0, src_addr_b);
1970
1971 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
1972
1973 c00 = fma((half8)a0.s4, b0, c00);
1974 c10 = fma((half8)a0.s5, b0, c10);
1975 c20 = fma((half8)a0.s6, b0, c20);
1976 c30 = fma((half8)a0.s7, b0, c30);
1977
1978 // Load values from matrix A (interleaved) and matrix B (transposed)
1979 a0 = vload8(0, src_addr_a);
1980 b0 = vload8(0, src_addr_b);
1981
1982 src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT;
1983 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
1984
1985 c00 = fma((half8)a0.s0, b0, c00);
1986 c10 = fma((half8)a0.s1, b0, c10);
1987 c20 = fma((half8)a0.s2, b0, c20);
1988 c30 = fma((half8)a0.s3, b0, c30);
1989
1990 // Load values from matrix B (transposed)
1991 b0 = vload8(0, src_addr_b);
1992
1993 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
1994
1995 c00 = fma((half8)a0.s4, b0, c00);
1996 c10 = fma((half8)a0.s5, b0, c10);
1997 c20 = fma((half8)a0.s6, b0, c20);
1998 c30 = fma((half8)a0.s7, b0, c30);
1999#else // MULT_INTERLEAVE4X4_HEIGHT == 1
2000 // Load values from matrix A (interleaved) and matrix B (transposed)
2001 half4 a0 = vload4(0, src_addr_a);
2002 half8 b0 = vload8(0, src_addr_b);
2003
2004 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
2005 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
2006
2007 c00 = fma((half8)a0.s0, b0, c00);
2008 c10 = fma((half8)a0.s1, b0, c10);
2009 c20 = fma((half8)a0.s2, b0, c20);
2010 c30 = fma((half8)a0.s3, b0, c30);
2011
2012 // Load values from matrix A (interleaved) and matrix B (transposed)
2013 a0 = vload4(0, src_addr_a);
2014 b0 = vload8(0, src_addr_b);
2015
2016 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
2017 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
2018
2019 c00 = fma((half8)a0.s0, b0, c00);
2020 c10 = fma((half8)a0.s1, b0, c10);
2021 c20 = fma((half8)a0.s2, b0, c20);
2022 c30 = fma((half8)a0.s3, b0, c30);
2023
2024 // Load values from matrix A (interleaved) and matrix B (transposed)
2025 a0 = vload4(0, src_addr_a);
2026 b0 = vload8(0, src_addr_b);
2027
2028 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
2029 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
2030
2031 c00 = fma((half8)a0.s0, b0, c00);
2032 c10 = fma((half8)a0.s1, b0, c10);
2033 c20 = fma((half8)a0.s2, b0, c20);
2034 c30 = fma((half8)a0.s3, b0, c30);
2035
2036 // Load values from matrix A (interleaved) and matrix B (transposed)
2037 a0 = vload4(0, src_addr_a);
2038 b0 = vload8(0, src_addr_b);
2039
2040 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
2041 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
2042
2043 c00 = fma((half8)a0.s0, b0, c00);
2044 c10 = fma((half8)a0.s1, b0, c10);
2045 c20 = fma((half8)a0.s2, b0, c20);
2046 c30 = fma((half8)a0.s3, b0, c30);
2047#endif // MULT_INTERLEAVE4X4_HEIGHT == 1
2048 }
2049
2050 for(; i < (int)(COLS_MTX_B); ++i)
2051 {
2052 // Load values from matrix A (interleaved) and matrix B (transposed)
2053 half4 a0 = vload4(0, src_addr_a);
2054 half8 b0 = vload8(0, src_addr_b);
2055
2056 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
2057 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
2058
2059 c00 = fma((half8)a0.s0, b0, c00);
2060 c10 = fma((half8)a0.s1, b0, c10);
2061 c20 = fma((half8)a0.s2, b0, c20);
2062 c30 = fma((half8)a0.s3, b0, c30);
2063 }
2064
2065 // Compute destination address
2066 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2067
2068#if defined(ALPHA)
2069 // Multiply by the weight of matrix product
2070 c00 = c00 * (half8)ALPHA;
2071 c10 = c10 * (half8)ALPHA;
2072 c20 = c20 * (half8)ALPHA;
2073 c30 = c30 * (half8)ALPHA;
2074#endif // defined(ALPHA)
2075
2076 // Compute dst address
2077 __global uchar *dst_addr = offset(&dst, 0, 0);
2078
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002079#if defined(REINTERPRET_OUTPUT_AS_3D)
2080 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002081 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002082 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002083 // | |
2084 // | plane0 |
2085 // | |
2086 // |__________________|
2087 // |******************|
2088 // | cross_plane_pad |
2089 // |******************|
2090 // | |
2091 // | plane1 |
2092 // | |
2093 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002094
2095 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
2096 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
2097 zout = min(DEPTH_GEMM3D - 1, zout);
2098
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002099 // Add offset due to the cross plane paddings
2100 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002101
2102 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2103 // multiply dst_stride_z by DEPTH_GEMM3D
2104 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
2105
2106 // Store 4x8 block
2107 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
2108 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
2109 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
2110 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
2111
2112#else // defined(REINTERPRET_OUTPUT_AS_3D)
2113 // Add offset for batched GEMM
2114 dst_addr += z * dst_stride_z;
2115
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01002116 // Store 4x8 block
2117 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
2118 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
2119 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
2120 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002121#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01002122}
Georgios Pinitas84225582018-05-14 12:00:05 +01002123
2124// Undefine local defines
2125#undef COLS_MTX_B
2126
Matthew Bentham6f31f8c2017-10-27 11:50:06 +01002127#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002128
Gian Marco36a0a462018-01-12 10:21:40 +00002129#endif // defined(COLS_B) && defined(MULT_TRANSPOSE1XW_WIDTH) && defined(MULT_INTERLEAVE4X4_HEIGHT)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002130
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002131#if defined(COLS_A) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && (NUM_ELEMS_PROCESSED_PER_THREAD_Y)
2132#if defined(DATA_TYPE)
2133#define VECTOR_TYPE VEC_DATA_TYPE(DATA_TYPE, NUM_ELEMS_PROCESSED_PER_THREAD_X)
Michele Di Giorgiof6f08da2018-04-26 10:24:30 +01002134/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002135 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002136 * @note This OpenCL kernel works with floating point data types (F16/F32)
2137 * @note The floating point data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
2138 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002139 * @note The number of matrix A columns and the optional alpha's value need to be passed at compile time using -DCOLS_A and -DALPHA
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002140 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
2141 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002142 *
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002143 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
2144 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002145 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2146 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2147 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2148 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
2149 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002150 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002151 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
2152 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2153 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
2154 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2155 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002156 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002157 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
2158 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2159 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
2160 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2161 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002162 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002163 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2164 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2165 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2166 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
2167 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002168 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
2169 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
2170 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002171 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
2172 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements for the output tensor (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002173 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002174__kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
2175 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +00002176 IMAGE_DECLARATION(dst),
2177 uint src0_stride_z,
2178 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002179 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002180#if defined(REINTERPRET_INPUT_AS_3D)
2181 ,
2182 uint src_cross_plane_pad
2183#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002184#if defined(REINTERPRET_OUTPUT_AS_3D)
2185 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002186 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002187#endif // REINTERPRET_OUTPUT_AS_3D
2188 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002189{
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002190 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002191
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002192 // Compute starting address for matrix A and Matrix B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002193 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002194
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002195 // Update address for the matrix A
2196 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002197
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002198 // Update address for the matrix B
2199 src_addr.s1 += idx * sizeof(DATA_TYPE);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002200
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002201#if defined(REINTERPRET_INPUT_AS_3D)
2202 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
2203 // in order to take into account the presence of possible cross plane paddings
2204 //
2205 // | |
2206 // | plane0 |
2207 // | |
2208 // |__________________|
2209 // |******************|
2210 // | cross_plane_pad |
2211 // |******************|
2212 // | |
2213 // | plane1 |
2214 // | |
2215 // |__________________|
2216
2217 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
2218 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
2219 zin = min(DEPTH_GEMM3D - 1, zin);
2220
2221 // Add offset due to the cross plane paddings
2222 zin *= (src_cross_plane_pad * src0_stride_y);
2223
2224 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2225 // multiply src0_stride_z by DEPTH_GEMM3D
2226 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
2227
2228#else // defined(REINTERPRET_INPUT_AS_3D)
2229
Gian Marcoae2af742018-02-15 12:35:44 +00002230 // Add offset for batched GEMM
2231 src_addr.s0 += get_global_id(2) * src0_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002232
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002233#endif // defined(REINTERPRET_INPUT_AS_3D)
2234
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002235#if defined(MATRIX_B_DEPTH)
2236 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2237 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
2238#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00002239 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002240#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00002241
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002242 int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(DATA_TYPE));
2243
2244 VECTOR_TYPE acc0 = 0.0f;
2245#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2246 VECTOR_TYPE acc1 = 0.0f;
2247#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2248#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2249 VECTOR_TYPE acc2 = 0.0f;
2250#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2251#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2252 VECTOR_TYPE acc3 = 0.0f;
2253#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2254
Georgios Pinitas96880cf2017-10-20 18:52:20 +01002255 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(DATA_TYPE)); src_addr += (int2)(2 * sizeof(DATA_TYPE), 2 * src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002256 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002257#if defined(REINTERPRET_INPUT_AS_3D)
2258 // Load values from matrix A
2259 VEC_DATA_TYPE(DATA_TYPE, 2)
2260 a0 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
2261#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2262 VEC_DATA_TYPE(DATA_TYPE, 2)
2263 a1 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
2264#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2265#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2266 VEC_DATA_TYPE(DATA_TYPE, 2)
2267 a2 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
2268#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2269#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2270 VEC_DATA_TYPE(DATA_TYPE, 2)
2271 a3 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
2272#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2273#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002274 // Load values from matrix A
2275 VEC_DATA_TYPE(DATA_TYPE, 2)
2276 a0 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
2277#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2278 VEC_DATA_TYPE(DATA_TYPE, 2)
2279 a1 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2280#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2281#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2282 VEC_DATA_TYPE(DATA_TYPE, 2)
2283 a2 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2284#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2285#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2286 VEC_DATA_TYPE(DATA_TYPE, 2)
2287 a3 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2288#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002289#endif // defined(REINTERPRET_INPUT_AS_3D)
2290
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002291 // Load values from matrix B
2292 VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
2293 VECTOR_TYPE b1 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1 + src1_stride_y));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002294
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002295 // Accumulate
2296 acc0 += b0 * (VECTOR_TYPE)a0.s0;
2297 acc0 += b1 * (VECTOR_TYPE)a0.s1;
2298#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2299 acc1 += b0 * (VECTOR_TYPE)a1.s0;
2300 acc1 += b1 * (VECTOR_TYPE)a1.s1;
2301#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2302#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2303 acc2 += b0 * (VECTOR_TYPE)a2.s0;
2304 acc2 += b1 * (VECTOR_TYPE)a2.s1;
2305#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2306#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2307 acc3 += b0 * (VECTOR_TYPE)a3.s0;
2308 acc3 += b1 * (VECTOR_TYPE)a3.s1;
2309#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002310 }
2311
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002312 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(DATA_TYPE), src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002313 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002314#if defined(REINTERPRET_INPUT_AS_3D)
2315 // Load values from matrix A
2316 DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
2317#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2318 DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
2319#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2320#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2321 DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
2322#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2323#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2324 DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
2325#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2326#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002327 // Load values from matrix A
2328 DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
2329#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2330 DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2331#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2332#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2333 DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2334#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2335#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2336 DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2337#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002338#endif // defined(REINTERPRET_INPUT_AS_3D)
2339
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002340 // Load values from matrix B
2341 VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002342
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002343 // Accumulate
2344 acc0 += b0 * (VECTOR_TYPE)a0;
2345#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2346 acc1 += b0 * (VECTOR_TYPE)a1;
2347#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2348#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2349 acc2 += b0 * (VECTOR_TYPE)a2;
2350#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2351#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2352 acc3 += b0 * (VECTOR_TYPE)a3;
2353#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002354 }
2355
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002356 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002357 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2358
Gian Marcoae2af742018-02-15 12:35:44 +00002359 // Compute dst address
2360 __global uchar *dst_addr = offset(&dst, 0, 0);
2361
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002362 // Multiply by the weight of matrix-matrix product and store the result
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002363#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002364 acc0 = acc0 * (VECTOR_TYPE)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002365#endif // defined(ALPHA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002366#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
2367 acc1 = acc1 * (VECTOR_TYPE)ALPHA;
2368#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
2369#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
2370 acc2 = acc2 * (VECTOR_TYPE)ALPHA;
2371#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
2372#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
2373 acc3 = acc3 * (VECTOR_TYPE)ALPHA;
2374#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
2375
2376 int z = get_global_id(2);
2377
2378#if defined(REINTERPRET_OUTPUT_AS_3D)
2379 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002380 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002381 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002382 // | |
2383 // | plane0 |
2384 // | |
2385 // |__________________|
2386 // |******************|
2387 // | cross_plane_pad |
2388 // |******************|
2389 // | |
2390 // | plane1 |
2391 // | |
2392 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002393
2394 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
2395 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
2396 zout = min(DEPTH_GEMM3D - 1, zout);
2397
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002398 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002399 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002400
2401 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2402 // multiply dst_stride_z by DEPTH_GEMM3D
2403 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
2404
2405 // Store output block
2406 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
2407 (acc0, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y + zout.s0));
2408#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2409 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
2410 (acc1, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y + zout.s1));
2411#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2412#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2413 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
2414 (acc2, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y + zout.s2));
2415#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2416#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2417 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
2418 (acc3, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y + zout.s3));
2419#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2420
2421#else // defined(REINTERPRET_OUTPUT_AS_3D)
2422 // Add offset for batched GEMM
2423 dst_addr += z * dst_stride_z;
2424
2425 // Store output block
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002426 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00002427 (acc0, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002428#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002429 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00002430 (acc1, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002431#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2432#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002433 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00002434 (acc2, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002435#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2436#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002437 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00002438 (acc3, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002439#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002440#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002441}
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002442#endif // defined(DATA_TYPE)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002443
Michele Di Giorgiof6f08da2018-04-26 10:24:30 +01002444/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002445 *
2446 * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
2447 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
2448 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
2449 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
2450 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002451 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
2452 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002453 *
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002454 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
2455 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002456 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2457 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2458 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2459 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
2460 *
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002461 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
2462 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
2463 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2464 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
2465 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2466 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
2467 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
2468 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
2469 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2470 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
2471 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2472 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
2473 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
2474 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2475 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2476 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2477 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
2478 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002479 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
2480 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
2481 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002482 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
2483 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002484 */
2485__kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
2486 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +00002487 IMAGE_DECLARATION(dst),
2488 uint src0_stride_z,
2489 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002490 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002491#if defined(REINTERPRET_INPUT_AS_3D)
2492 ,
2493 uint src_cross_plane_pad
2494#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002495#if defined(REINTERPRET_OUTPUT_AS_3D)
2496 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002497 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002498#endif // REINTERPRET_OUTPUT_AS_3D
2499 )
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002500{
2501 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
2502
2503 // Compute starting address for matrix A and matrix B
2504 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
2505
2506 // Update address for matrix A
2507 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
2508
2509 // Update address for matrix B
2510 src_addr.s1 += idx * sizeof(float);
2511
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002512#if defined(REINTERPRET_INPUT_AS_3D)
2513 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
2514 // in order to take into account the presence of possible cross plane paddings
2515 //
2516 // | |
2517 // | plane0 |
2518 // | |
2519 // |__________________|
2520 // |******************|
2521 // | cross_plane_pad |
2522 // |******************|
2523 // | |
2524 // | plane1 |
2525 // | |
2526 // |__________________|
2527
2528 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
2529 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
2530 zin = min(DEPTH_GEMM3D - 1, zin);
2531
2532 // Add offset due to the cross plane paddings
2533 zin *= (src_cross_plane_pad * src0_stride_y);
2534
2535 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2536 // multiply src0_stride_z by DEPTH_GEMM3D
2537 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
2538
2539#else // defined(REINTERPRET_INPUT_AS_3D)
2540
Gian Marcoae2af742018-02-15 12:35:44 +00002541 // Add offset for batched GEMM
2542 src_addr.s0 += get_global_id(2) * src0_stride_z;
2543
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002544#endif // defined(REINTERPRET_INPUT_AS_3D)
2545
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002546#if defined(MATRIX_B_DEPTH)
2547 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2548 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
2549#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00002550 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002551#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00002552
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002553 // Initialize accumulators
2554 float acc00 = 0.0f;
2555 float acc01 = 0.0f;
2556 float acc02 = 0.0f;
2557 float acc03 = 0.0f;
2558
2559#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2560 float acc10 = 0.0f;
2561 float acc11 = 0.0f;
2562 float acc12 = 0.0f;
2563 float acc13 = 0.0f;
2564#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2565
2566#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2567 float acc20 = 0.0f;
2568 float acc21 = 0.0f;
2569 float acc22 = 0.0f;
2570 float acc23 = 0.0f;
2571#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2572
2573#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2574 float acc30 = 0.0f;
2575 float acc31 = 0.0f;
2576 float acc32 = 0.0f;
2577 float acc33 = 0.0f;
2578#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2579
2580 // A and B src indices get incremented at the same time.
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002581 int i = 0;
2582 for(; i <= ((int)COLS_A - 4); i += 4)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002583 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002584#if defined(REINTERPRET_INPUT_AS_3D)
2585 // Load values from matrix A and matrix B
2586 float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
2587#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2588 float4 a1 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
2589#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2590#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2591 float4 a2 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
2592#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2593#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2594 float4 a3 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
2595#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2596#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002597 // Load values from matrix A and matrix B
2598 float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002599#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002600 float4 a1 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002601#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2602#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002603 float4 a2 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002604#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2605#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002606 float4 a3 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002607#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002608#endif // defined(REINTERPRET_INPUT_AS_3D)
2609
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002610 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
2611 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002612
2613 // Multiply and accumulate
2614 acc00 = fma(a0.s0, b0.s0, acc00);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002615 acc01 = fma(a0.s0, b0.s1, acc01);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002616 acc02 = fma(a0.s0, b0.s2, acc02);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002617 acc03 = fma(a0.s0, b0.s3, acc03);
2618
2619#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002620
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002621 acc10 = fma(a1.s0, b0.s0, acc10);
2622 acc11 = fma(a1.s0, b0.s1, acc11);
2623 acc12 = fma(a1.s0, b0.s2, acc12);
2624 acc13 = fma(a1.s0, b0.s3, acc13);
2625
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002626#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2627#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002628
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002629 acc20 = fma(a2.s0, b0.s0, acc20);
2630 acc21 = fma(a2.s0, b0.s1, acc21);
2631 acc22 = fma(a2.s0, b0.s2, acc22);
2632 acc23 = fma(a2.s0, b0.s3, acc23);
2633
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002634#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2635#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002636
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002637 acc30 = fma(a3.s0, b0.s0, acc30);
2638 acc31 = fma(a3.s0, b0.s1, acc31);
2639 acc32 = fma(a3.s0, b0.s2, acc32);
2640 acc33 = fma(a3.s0, b0.s3, acc33);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002641#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002642
2643 // Load values from matrix A and matrix B
2644 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
2645 src_addr.s1 += src1_stride_y;
2646
2647 // Multiply and accumulate
2648 acc00 = fma(a0.s1, b0.s0, acc00);
2649 acc01 = fma(a0.s1, b0.s1, acc01);
2650 acc02 = fma(a0.s1, b0.s2, acc02);
2651 acc03 = fma(a0.s1, b0.s3, acc03);
2652
2653#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2654
2655 acc10 = fma(a1.s1, b0.s0, acc10);
2656 acc11 = fma(a1.s1, b0.s1, acc11);
2657 acc12 = fma(a1.s1, b0.s2, acc12);
2658 acc13 = fma(a1.s1, b0.s3, acc13);
2659
2660#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2661#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2662
2663 acc20 = fma(a2.s1, b0.s0, acc20);
2664 acc21 = fma(a2.s1, b0.s1, acc21);
2665 acc22 = fma(a2.s1, b0.s2, acc22);
2666 acc23 = fma(a2.s1, b0.s3, acc23);
2667
2668#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2669#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2670
2671 acc30 = fma(a3.s1, b0.s0, acc30);
2672 acc31 = fma(a3.s1, b0.s1, acc31);
2673 acc32 = fma(a3.s1, b0.s2, acc32);
2674 acc33 = fma(a3.s1, b0.s3, acc33);
2675#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2676
2677 // Load values from matrix A and matrix B
2678 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
2679 src_addr.s1 += src1_stride_y;
2680
2681 // Multiply and accumulate
2682 acc00 = fma(a0.s2, b0.s0, acc00);
2683 acc01 = fma(a0.s2, b0.s1, acc01);
2684 acc02 = fma(a0.s2, b0.s2, acc02);
2685 acc03 = fma(a0.s2, b0.s3, acc03);
2686
2687#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2688
2689 acc10 = fma(a1.s2, b0.s0, acc10);
2690 acc11 = fma(a1.s2, b0.s1, acc11);
2691 acc12 = fma(a1.s2, b0.s2, acc12);
2692 acc13 = fma(a1.s2, b0.s3, acc13);
2693
2694#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2695#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2696
2697 acc20 = fma(a2.s2, b0.s0, acc20);
2698 acc21 = fma(a2.s2, b0.s1, acc21);
2699 acc22 = fma(a2.s2, b0.s2, acc22);
2700 acc23 = fma(a2.s2, b0.s3, acc23);
2701
2702#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2703#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2704
2705 acc30 = fma(a3.s2, b0.s0, acc30);
2706 acc31 = fma(a3.s2, b0.s1, acc31);
2707 acc32 = fma(a3.s2, b0.s2, acc32);
2708 acc33 = fma(a3.s2, b0.s3, acc33);
2709#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2710
2711 // Load values from matrix A and matrix B
2712 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
2713 src_addr.s1 += src1_stride_y;
2714
2715 // Multiply and accumulate
2716 acc00 = fma(a0.s3, b0.s0, acc00);
2717 acc01 = fma(a0.s3, b0.s1, acc01);
2718 acc02 = fma(a0.s3, b0.s2, acc02);
2719 acc03 = fma(a0.s3, b0.s3, acc03);
2720
2721#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2722
2723 acc10 = fma(a1.s3, b0.s0, acc10);
2724 acc11 = fma(a1.s3, b0.s1, acc11);
2725 acc12 = fma(a1.s3, b0.s2, acc12);
2726 acc13 = fma(a1.s3, b0.s3, acc13);
2727
2728#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2729#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2730
2731 acc20 = fma(a2.s3, b0.s0, acc20);
2732 acc21 = fma(a2.s3, b0.s1, acc21);
2733 acc22 = fma(a2.s3, b0.s2, acc22);
2734 acc23 = fma(a2.s3, b0.s3, acc23);
2735
2736#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2737#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2738
2739 acc30 = fma(a3.s3, b0.s0, acc30);
2740 acc31 = fma(a3.s3, b0.s1, acc31);
2741 acc32 = fma(a3.s3, b0.s2, acc32);
2742 acc33 = fma(a3.s3, b0.s3, acc33);
2743#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2744
2745 src_addr.s0 += 4 * sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002746 }
2747
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002748 for(; i < (int)COLS_A; ++i)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002749 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002750#if defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002751 // Load values from matrix A
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002752 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
2753#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2754 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
2755#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2756#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2757 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
2758#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2759#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2760 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
2761#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2762#else // defined(REINTERPRET_INPUT_AS_3D)
2763 // Load values from matrix A
2764 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002765#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2766 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2767#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2768#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2769 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2770#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2771#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2772 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2773#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002774#endif // defined(REINTERPRET_INPUT_AS_3D)
2775
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002776 // Load values from matrix B
2777 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002778 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002779
2780 // Multiply and accumulate
2781 acc00 = fma(a0, b0.s0, acc00);
2782 acc01 = fma(a0, b0.s1, acc01);
2783 acc02 = fma(a0, b0.s2, acc02);
2784 acc03 = fma(a0, b0.s3, acc03);
2785#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2786 acc10 = fma(a1, b0.s0, acc10);
2787 acc11 = fma(a1, b0.s1, acc11);
2788 acc12 = fma(a1, b0.s2, acc12);
2789 acc13 = fma(a1, b0.s3, acc13);
2790#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2791#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2792 acc20 = fma(a2, b0.s0, acc20);
2793 acc21 = fma(a2, b0.s1, acc21);
2794 acc22 = fma(a2, b0.s2, acc22);
2795 acc23 = fma(a2, b0.s3, acc23);
2796#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2797#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2798 acc30 = fma(a3, b0.s0, acc30);
2799 acc31 = fma(a3, b0.s1, acc31);
2800 acc32 = fma(a3, b0.s2, acc32);
2801 acc33 = fma(a3, b0.s3, acc33);
2802#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002803
2804 src_addr.s0 += sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002805 }
2806
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002807 int z = get_global_id(2);
2808
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002809 // Compute destination address
2810 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2811
2812 // Multiply by the weight of matrix-matrix product and store the result
2813#if defined(ALPHA)
2814 acc00 = acc00 * ALPHA;
2815 acc01 = acc01 * ALPHA;
2816 acc02 = acc02 * ALPHA;
2817 acc03 = acc03 * ALPHA;
2818#endif // defined(ALPHA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002819#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002820 acc10 = acc10 * ALPHA;
2821 acc11 = acc11 * ALPHA;
2822 acc12 = acc12 * ALPHA;
2823 acc13 = acc13 * ALPHA;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002824#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
2825#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002826 acc20 = acc20 * ALPHA;
2827 acc21 = acc21 * ALPHA;
2828 acc22 = acc22 * ALPHA;
2829 acc23 = acc23 * ALPHA;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002830#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
2831#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002832 acc30 = acc30 * ALPHA;
2833 acc31 = acc31 * ALPHA;
2834 acc32 = acc32 * ALPHA;
2835 acc33 = acc33 * ALPHA;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002836#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
2837
2838 // Compute dst address
2839 __global uchar *dst_addr = offset(&dst, 0, 0);
2840
2841#if defined(REINTERPRET_OUTPUT_AS_3D)
2842 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002843 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002844 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002845 // | |
2846 // | plane0 |
2847 // | |
2848 // |__________________|
2849 // |******************|
2850 // | cross_plane_pad |
2851 // |******************|
2852 // | |
2853 // | plane1 |
2854 // | |
2855 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002856
2857 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
2858 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
2859 zout = min(DEPTH_GEMM3D - 1, zout);
2860
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002861 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002862 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002863
2864 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2865 // multiply dst_stride_z by DEPTH_GEMM3D
2866 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
2867
2868 // Store the output block
2869 vstore4((float4)(acc00, acc01, acc02, acc03), 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
2870#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2871 vstore4((float4)(acc10, acc11, acc12, acc13), 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
2872#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2873#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2874 vstore4((float4)(acc20, acc21, acc22, acc23), 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
2875#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2876#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2877 vstore4((float4)(acc30, acc31, acc32, acc33), 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002878#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002879
2880#else // defined(REINTERPRET_OUTPUT_AS_3D)
2881 // Add offset for batched GEMM
2882 dst_addr += z * dst_stride_z;
2883
2884 // Store the output block
2885 vstore4((float4)(acc00, acc01, acc02, acc03), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
2886#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2887 vstore4((float4)(acc10, acc11, acc12, acc13), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
2888#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2889#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2890 vstore4((float4)(acc20, acc21, acc22, acc23), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
2891#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2892#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2893 vstore4((float4)(acc30, acc31, acc32, acc33), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
2894#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2895#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002896}
2897
2898/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
2899 *
2900 * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
2901 * This OpenCL kernel is optimized for Bifrost when the number of matrix B columns is less or equal to 1000.
2902 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
2903 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=2.
2904 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
2905 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha if alpha!=1.0f.
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002906 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
2907 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002908 *
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002909 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
2910 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002911 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2912 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2913 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2914 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
2915 *
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002916 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
2917 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
2918 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2919 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
2920 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2921 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
2922 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
2923 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
2924 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2925 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
2926 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2927 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
2928 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
2929 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2930 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2931 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2932 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
2933 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002934 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
2935 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
2936 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002937 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
2938 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002939 */
2940__kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
2941 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +00002942 IMAGE_DECLARATION(dst),
2943 uint src0_stride_z,
2944 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002945 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002946#if defined(REINTERPRET_INPUT_AS_3D)
2947 ,
2948 uint src_cross_plane_pad
2949#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002950#if defined(REINTERPRET_OUTPUT_AS_3D)
2951 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002952 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002953#endif // REINTERPRET_OUTPUT_AS_3D
2954 )
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002955{
2956 // Requires 2 NUM_ELEMS_PROCESSED_PER_THREAD_X, C vect2, A vect4, B (2 vload2) // to fix for NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2957 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
2958
2959 // Compute starting address for matrix A and Matrix B
2960 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
2961
2962 // Update address for the matrix A
2963 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
2964
2965 // Update address for the matrix B
2966 src_addr.s1 += idx * sizeof(float);
2967
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002968#if defined(REINTERPRET_INPUT_AS_3D)
2969 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
2970 // in order to take into account the presence of possible cross plane paddings
2971 //
2972 // | |
2973 // | plane0 |
2974 // | |
2975 // |__________________|
2976 // |******************|
2977 // | cross_plane_pad |
2978 // |******************|
2979 // | |
2980 // | plane1 |
2981 // | |
2982 // |__________________|
2983
2984 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
2985 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
2986 zin = min(DEPTH_GEMM3D - 1, zin);
2987
2988 // Add offset due to the cross plane paddings
2989 zin *= (src_cross_plane_pad * src0_stride_y);
2990
2991 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2992 // multiply src0_stride_z by DEPTH_GEMM3D
2993 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
2994
2995#else // defined(REINTERPRET_INPUT_AS_3D)
2996
Gian Marcoae2af742018-02-15 12:35:44 +00002997 // Add offset for batched GEMM
2998 src_addr.s0 += get_global_id(2) * src0_stride_z;
2999
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003000#endif // defined(REINTERPRET_INPUT_AS_3D)
3001
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003002#if defined(MATRIX_B_DEPTH)
3003 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
3004 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
3005#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00003006 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003007#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00003008
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003009 // Initialize accumulators
3010 float acc00 = 0.0f;
3011 float acc01 = 0.0f;
3012
3013#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3014 float acc10 = 0.0f;
3015 float acc11 = 0.0f;
3016#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3017#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3018 float acc20 = 0.0f;
3019 float acc21 = 0.0f;
3020#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3021#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3022 float acc30 = 0.0f;
3023 float acc31 = 0.0f;
3024#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3025
3026 // A and B src indices get incremented at the same time.
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003027 int i = 0;
3028 for(; i <= ((int)COLS_A - 8); i += 8)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003029 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003030#if defined(REINTERPRET_INPUT_AS_3D)
3031 // Load values from matrix A
3032 float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + zin.s0));
3033#else // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003034 // Load values from matrix A
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003035 float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0));
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003036#endif // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003037
3038 // Load values from matrix B
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003039 float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
3040 src_addr.s1 += src1_stride_y;
3041 float2 b1 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
3042 src_addr.s1 += src1_stride_y;
3043 float2 b2 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
3044 src_addr.s1 += src1_stride_y;
3045 float2 b3 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
3046 src_addr.s1 += src1_stride_y;
3047 float2 b4 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
3048 src_addr.s1 += src1_stride_y;
3049 float2 b5 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
3050 src_addr.s1 += src1_stride_y;
3051 float2 b6 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
3052 src_addr.s1 += src1_stride_y;
3053 float2 b7 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
3054 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003055
3056 // Multiply and accumulate
3057 acc00 = fma(a0.s0, b0.s0, acc00);
3058 acc00 = fma(a0.s1, b1.s0, acc00);
3059 acc00 = fma(a0.s2, b2.s0, acc00);
3060 acc00 = fma(a0.s3, b3.s0, acc00);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003061 acc00 = fma(a0.s4, b4.s0, acc00);
3062 acc00 = fma(a0.s5, b5.s0, acc00);
3063 acc00 = fma(a0.s6, b6.s0, acc00);
3064 acc00 = fma(a0.s7, b7.s0, acc00);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003065
3066 acc01 = fma(a0.s0, b0.s1, acc01);
3067 acc01 = fma(a0.s1, b1.s1, acc01);
3068 acc01 = fma(a0.s2, b2.s1, acc01);
3069 acc01 = fma(a0.s3, b3.s1, acc01);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003070 acc01 = fma(a0.s4, b4.s1, acc01);
3071 acc01 = fma(a0.s5, b5.s1, acc01);
3072 acc01 = fma(a0.s6, b6.s1, acc01);
3073 acc01 = fma(a0.s7, b7.s1, acc01);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003074
3075#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003076#if defined(REINTERPRET_INPUT_AS_3D)
3077 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
3078#else // defined(REINTERPRET_INPUT_AS_3D)
3079 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
3080#endif // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003081 acc10 = fma(a0.s0, b0.s0, acc10);
3082 acc10 = fma(a0.s1, b1.s0, acc10);
3083 acc10 = fma(a0.s2, b2.s0, acc10);
3084 acc10 = fma(a0.s3, b3.s0, acc10);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003085 acc10 = fma(a0.s4, b4.s0, acc10);
3086 acc10 = fma(a0.s5, b5.s0, acc10);
3087 acc10 = fma(a0.s6, b6.s0, acc10);
3088 acc10 = fma(a0.s7, b7.s0, acc10);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003089
3090 acc11 = fma(a0.s0, b0.s1, acc11);
3091 acc11 = fma(a0.s1, b1.s1, acc11);
3092 acc11 = fma(a0.s2, b2.s1, acc11);
3093 acc11 = fma(a0.s3, b3.s1, acc11);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003094 acc11 = fma(a0.s4, b4.s1, acc11);
3095 acc11 = fma(a0.s5, b5.s1, acc11);
3096 acc11 = fma(a0.s6, b6.s1, acc11);
3097 acc11 = fma(a0.s7, b7.s1, acc11);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003098#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3099#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003100#if defined(REINTERPRET_INPUT_AS_3D)
3101 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
3102#else // defined(REINTERPRET_INPUT_AS_3D)
3103 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
3104#endif // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003105 acc20 = fma(a0.s0, b0.s0, acc20);
3106 acc20 = fma(a0.s1, b1.s0, acc20);
3107 acc20 = fma(a0.s2, b2.s0, acc20);
3108 acc20 = fma(a0.s3, b3.s0, acc20);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003109 acc20 = fma(a0.s4, b4.s0, acc20);
3110 acc20 = fma(a0.s5, b5.s0, acc20);
3111 acc20 = fma(a0.s6, b6.s0, acc20);
3112 acc20 = fma(a0.s7, b7.s0, acc20);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003113
3114 acc21 = fma(a0.s0, b0.s1, acc21);
3115 acc21 = fma(a0.s1, b1.s1, acc21);
3116 acc21 = fma(a0.s2, b2.s1, acc21);
3117 acc21 = fma(a0.s3, b3.s1, acc21);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003118 acc21 = fma(a0.s4, b4.s1, acc21);
3119 acc21 = fma(a0.s5, b5.s1, acc21);
3120 acc21 = fma(a0.s6, b6.s1, acc21);
3121 acc21 = fma(a0.s7, b7.s1, acc21);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003122#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3123#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003124#if defined(REINTERPRET_INPUT_AS_3D)
3125 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
3126#else // defined(REINTERPRET_INPUT_AS_3D)
3127 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
3128#endif // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003129 acc30 = fma(a0.s0, b0.s0, acc30);
3130 acc30 = fma(a0.s1, b1.s0, acc30);
3131 acc30 = fma(a0.s2, b2.s0, acc30);
3132 acc30 = fma(a0.s3, b3.s0, acc30);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003133 acc30 = fma(a0.s4, b4.s0, acc30);
3134 acc30 = fma(a0.s5, b5.s0, acc30);
3135 acc30 = fma(a0.s6, b6.s0, acc30);
3136 acc30 = fma(a0.s7, b7.s0, acc30);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003137
3138 acc31 = fma(a0.s0, b0.s1, acc31);
3139 acc31 = fma(a0.s1, b1.s1, acc31);
3140 acc31 = fma(a0.s2, b2.s1, acc31);
3141 acc31 = fma(a0.s3, b3.s1, acc31);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003142 acc31 = fma(a0.s4, b4.s1, acc31);
3143 acc31 = fma(a0.s5, b5.s1, acc31);
3144 acc31 = fma(a0.s6, b6.s1, acc31);
3145 acc31 = fma(a0.s7, b7.s1, acc31);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003146#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003147
3148 src_addr.s0 += sizeof(float) * 8;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003149 }
3150 // float size increment
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003151 for(; i < (int)COLS_A; ++i)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003152 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003153#if defined(REINTERPRET_INPUT_AS_3D)
3154 // Load values from matrix A
3155 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
3156#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3157 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
3158#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3159#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3160 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
3161#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3162#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3163 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
3164#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3165#else // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003166 // Load values from matrix A
3167 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
3168#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3169 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
3170#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3171#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3172 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
3173#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3174#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3175 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
3176#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003177#endif // defined(REINTERPRET_INPUT_AS_3D)
3178
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003179 // Load values from matrix B
3180 float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003181 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003182
3183 // Multiply and accumulate
3184 acc00 = fma(a0, b0.s0, acc00);
3185 acc01 = fma(a0, b0.s1, acc01);
3186#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3187 acc10 = fma(a1, b0.s0, acc10);
3188 acc11 = fma(a1, b0.s1, acc11);
3189#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3190#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3191 acc20 = fma(a2, b0.s0, acc20);
3192 acc21 = fma(a2, b0.s1, acc21);
3193#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3194#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3195 acc30 = fma(a3, b0.s0, acc30);
3196 acc31 = fma(a3, b0.s1, acc31);
3197#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003198
3199 src_addr.s0 += sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003200 }
3201
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003202 // Multiply by the weight of matrix-matrix product and store the result
3203#if defined(ALPHA)
3204 acc00 = acc00 * ALPHA;
3205 acc01 = acc01 * ALPHA;
3206#endif // defined(ALPHA)
3207#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
3208 acc10 = acc10 * ALPHA;
3209 acc11 = acc11 * ALPHA;
3210#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
3211#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
3212 acc20 = acc20 * ALPHA;
3213 acc21 = acc21 * ALPHA;
3214#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
3215#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
3216 acc30 = acc30 * ALPHA;
3217 acc31 = acc31 * ALPHA;
3218#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
3219
3220 int z = get_global_id(2);
3221
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003222 // Compute destination address
3223 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
3224
Gian Marcoae2af742018-02-15 12:35:44 +00003225 // Compute dst address
3226 __global uchar *dst_addr = offset(&dst, 0, 0);
3227
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003228#if defined(REINTERPRET_OUTPUT_AS_3D)
3229 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003230 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003231 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003232 // | |
3233 // | plane0 |
3234 // | |
3235 // |__________________|
3236 // |******************|
3237 // | cross_plane_pad |
3238 // |******************|
3239 // | |
3240 // | plane1 |
3241 // | |
3242 // |__________________|
Gian Marcoae2af742018-02-15 12:35:44 +00003243
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003244 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
3245 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
3246 zout = min(DEPTH_GEMM3D - 1, zout);
3247
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003248 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003249 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003250
3251 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3252 // multiply dst_stride_z by DEPTH_GEMM3D
3253 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
3254
3255 // Store the output block
3256 vstore2((float2)(acc00, acc01), 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003257#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003258 vstore2((float2)(acc10, acc11), 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003259#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3260#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003261 vstore2((float2)(acc20, acc21), 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003262#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3263#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003264 vstore2((float2)(acc30, acc31), 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003265#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003266
3267#else // defined(REINTERPRET_OUTPUT_AS_3D)
3268 // Add offset for batched GEMM
3269 dst_addr += z * dst_stride_z;
3270
3271 // Store the output block
3272 vstore2((float2)(acc00, acc01), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
3273#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3274 vstore2((float2)(acc10, acc11), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
3275#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3276#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3277 vstore2((float2)(acc20, acc21), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
3278#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3279#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3280 vstore2((float2)(acc30, acc31), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
3281#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3282#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003283}
3284
Vidhya Sudhan Loganathanbdff4912018-05-22 15:03:09 +01003285#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01003286/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
3287 *
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00003288 * @note This OpenCL kernel works with the 16-bit floating point data type (half) and accumulating the result in a 32 floating point variable.
3289 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
3290 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
3291 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
3292 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
3293 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
3294 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
3295 *
3296 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
3297 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
3298 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
3299 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
3300 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
3301 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
3302 *
3303 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
3304 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
3305 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3306 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
3307 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3308 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
3309 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
3310 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
3311 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3312 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
3313 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3314 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
3315 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
3316 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
3317 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
3318 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
3319 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
3320 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
3321 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
3322 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
3323 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
3324 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
3325 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
3326 */
3327__kernel void gemm_mm_floating_point_f16_bifrost_acc32(IMAGE_DECLARATION(src0),
3328 IMAGE_DECLARATION(src1),
3329 IMAGE_DECLARATION(dst),
3330 uint src0_stride_z,
3331 uint src1_stride_z,
3332 uint dst_stride_z
3333#if defined(REINTERPRET_INPUT_AS_3D)
3334 ,
3335 uint src_cross_plane_pad
3336#endif // REINTERPRET_INPUT_AS_3D
3337#if defined(REINTERPRET_OUTPUT_AS_3D)
3338 ,
3339 uint dst_cross_plane_pad
3340#endif // REINTERPRET_OUTPUT_AS_3D
3341 )
3342{
3343 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
3344
3345 // Compute starting address for matrix A and Matrix B
3346 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
3347
3348 // Update address for the matrix A
3349 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
3350
3351 // Update address for the matrix B
3352 src_addr.s1 += idx * sizeof(half);
3353
3354#if defined(REINTERPRET_INPUT_AS_3D)
3355 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
3356 // in order to take into account the presence of possible cross plane paddings
3357 //
3358 // | |
3359 // | plane0 |
3360 // | |
3361 // |__________________|
3362 // |******************|
3363 // | cross_plane_pad |
3364 // |******************|
3365 // | |
3366 // | plane1 |
3367 // | |
3368 // |__________________|
3369
3370 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
3371 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
3372 zin = min(DEPTH_GEMM3D - 1, zin);
3373
3374 // Add offset due to the cross plane paddings
3375 zin *= (src_cross_plane_pad * src0_stride_y);
3376
3377 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3378 // multiply src0_stride_z by DEPTH_GEMM3D
3379 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
3380
3381#else // defined(REINTERPRET_INPUT_AS_3D)
3382
3383 // Add offset for batched GEMM
3384 src_addr.s0 += get_global_id(2) * src0_stride_z;
3385
3386#endif // defined(REINTERPRET_INPUT_AS_3D)
3387
3388#if defined(MATRIX_B_DEPTH)
3389 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
3390 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
3391#else // defined(MATRIX_B_DEPTH)
3392 src_addr.s1 += get_global_id(2) * src1_stride_z;
3393#endif // defined(MATRIX_B_DEPTH)
3394
3395 float8 acc0 = 0.0h;
3396#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3397 float8 acc1 = 0.0h;
3398#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3399#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3400 float8 acc2 = 0.0h;
3401#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3402#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3403 float8 acc3 = 0.0h;
3404#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3405
3406 int i = 0;
3407 for(; i <= ((int)COLS_A - 4); i += 4)
3408 {
3409#if defined(REINTERPRET_INPUT_AS_3D)
3410 // Load values from matrix A
3411 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
3412#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3413 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
3414#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3415#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3416 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
3417#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3418#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3419 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
3420#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3421#else // defined(REINTERPRET_INPUT_AS_3D)
3422 // Load values from matrix A
3423 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
3424#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3425 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
3426#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3427#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3428 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
3429#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3430#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3431 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
3432#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3433#endif // defined(REINTERPRET_INPUT_AS_3D)
3434
3435 // Load values from matrix B
3436 float8 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
3437 src_addr.s1 += src1_stride_y;
3438
3439 // Accumulate
3440 acc0 = fma(b0, (float8)a0.s0, acc0);
3441#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3442 acc1 = fma(b0, (float8)a1.s0, acc1);
3443#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3444#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3445 acc2 = fma(b0, (float8)a2.s0, acc2);
3446#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3447#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3448 acc3 = fma(b0, (float8)a3.s0, acc3);
3449#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3450
3451 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
3452 src_addr.s1 += src1_stride_y;
3453 acc0 = fma(b0, (float8)a0.s1, acc0);
3454#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3455 acc1 = fma(b0, (float8)a1.s1, acc1);
3456#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3457#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3458 acc2 = fma(b0, (float8)a2.s1, acc2);
3459#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3460#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3461 acc3 = fma(b0, (float8)a3.s1, acc3);
3462#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3463
3464 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
3465 src_addr.s1 += src1_stride_y;
3466 acc0 = fma(b0, (float8)a0.s2, acc0);
3467#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3468 acc1 = fma(b0, (float8)a1.s2, acc1);
3469#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3470#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3471 acc2 = fma(b0, (float8)a2.s2, acc2);
3472#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3473#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3474 acc3 = fma(b0, (float8)a3.s2, acc3);
3475#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3476
3477 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
3478 src_addr.s1 += src1_stride_y;
3479 acc0 = fma(b0, (float8)a0.s3, acc0);
3480#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3481 acc1 = fma(b0, (float8)a1.s3, acc1);
3482#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3483#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3484 acc2 = fma(b0, (float8)a2.s3, acc2);
3485#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3486#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3487 acc3 = fma(b0, (float8)a3.s3, acc3);
3488#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3489
3490 src_addr.s0 += 4 * sizeof(half);
3491 }
3492
3493 for(; i < (int)COLS_A; ++i)
3494 {
3495#if defined(REINTERPRET_INPUT_AS_3D)
3496 // Load values from matrix A
3497 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
3498#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3499 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
3500#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3501#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3502 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
3503#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3504#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3505 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
3506#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3507#else // defined(REINTERPRET_INPUT_AS_3D)
3508 // Load values from matrix A
3509 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
3510#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3511 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
3512#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3513#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3514 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
3515#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3516#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3517 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
3518#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3519#endif // defined(REINTERPRET_INPUT_AS_3D)
3520
3521 // Load values from matrix B
3522 float8 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
3523
3524 src_addr += (int2)(sizeof(half), src1_stride_y);
3525
3526 // Accumulate
3527 acc0 = fma(b0, (float8)a0, acc0); // b0 * (half8)a0;
3528#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3529 acc1 = fma(b0, (float8)a1, acc1); // b0 * (half8)a1;
3530#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3531#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3532 acc2 = fma(b0, (float8)a2, acc2); // b0 * (half8)a2;
3533#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3534#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3535 acc3 = fma(b0, (float8)a3, acc3); // b0 * (half8)a3;
3536#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3537 }
3538
3539 // Multiply by the weight of matrix-matrix product and store the result
3540#if defined(ALPHA)
3541 half8 hacc0 = convert_half8(acc0) * (half8)ALPHA;
3542#else //defined(ALPHA)
3543 half8 hacc0 = convert_half8(acc0);
3544#endif // defined(ALPHA)
3545#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3546#if defined(ALPHA)
3547 half8 hacc1 = convert_half8(acc1) * (half8)ALPHA;
3548#else //defined(ALPHA)
3549 half8 hacc1 = convert_half8(acc1);
3550#endif //defined(ALPHA)
3551#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y
3552
3553#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3554#if defined(ALPHA)
3555 half8 hacc2 = convert_half8(acc2) * (half8)ALPHA;
3556#else //defined(ALPHA)
3557 half8 hacc2 = convert_half8(acc2);
3558#endif //defined(ALPHA)
3559#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3560
3561#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3562#if defined(ALPHA)
3563 half8 hacc3 = convert_half8(acc3) * (half8)ALPHA;
3564#else //defined(ALPHA)
3565 half8 hacc3 = convert_half8(acc3);
3566#endif // defined(ALPHA)
3567#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3568
3569 int z = get_global_id(2);
3570
3571 // Compute destination address
3572 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
3573
3574 // Compute dst address
3575 __global uchar *dst_addr = offset(&dst, 0, 0);
3576
3577#if defined(REINTERPRET_OUTPUT_AS_3D)
3578 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
3579 // in order to take into account the presence of possible cross plane paddings
3580 //
3581 // | |
3582 // | plane0 |
3583 // | |
3584 // |__________________|
3585 // |******************|
3586 // | cross_plane_pad |
3587 // |******************|
3588 // | |
3589 // | plane1 |
3590 // | |
3591 // |__________________|
3592
3593 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
3594 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
3595 zout = min(DEPTH_GEMM3D - 1, zout);
3596
3597 // Add offset due to the cross plane paddings
3598 zout *= (dst_cross_plane_pad * dst_stride_y);
3599
3600 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3601 // multiply dst_stride_z by DEPTH_GEMM3D
3602 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
3603
3604 // Store the output block
3605 vstore8(hacc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
3606#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3607 vstore8(hacc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
3608#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3609#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3610 vstore8(hacc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
3611#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3612#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3613 vstore8(hacc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
3614#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3615
3616#else // defined(REINTERPRET_OUTPUT_AS_3D)
3617 // Add offset for batched GEMM
3618 dst_addr += z * dst_stride_z;
3619
3620 // Store the output block
3621 vstore8(hacc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
3622#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3623 vstore8(hacc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
3624#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3625#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3626 vstore8(hacc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
3627#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3628#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3629 vstore8(hacc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
3630#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3631#endif // REINTERPRET_OUTPUT_AS_3D
3632}
3633
3634/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
3635 *
Gian Marco Iodicefd683112018-04-17 09:52:44 +01003636 * @note This OpenCL kernel works with the 16-bit floating point data type (half) and uses the fma units.
3637 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
3638 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
3639 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
3640 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
3641 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
3642 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
3643 *
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003644 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
3645 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003646 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
3647 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
3648 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
3649 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
3650 *
Gian Marco Iodicefd683112018-04-17 09:52:44 +01003651 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
3652 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
3653 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3654 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
3655 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3656 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
3657 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
3658 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
3659 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3660 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
3661 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3662 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
3663 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
3664 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
3665 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
3666 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
3667 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
3668 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003669 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
3670 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
3671 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003672 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
3673 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01003674 */
3675__kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0),
3676 IMAGE_DECLARATION(src1),
3677 IMAGE_DECLARATION(dst),
3678 uint src0_stride_z,
3679 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003680 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003681#if defined(REINTERPRET_INPUT_AS_3D)
3682 ,
3683 uint src_cross_plane_pad
3684#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003685#if defined(REINTERPRET_OUTPUT_AS_3D)
3686 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003687 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003688#endif // REINTERPRET_OUTPUT_AS_3D
3689 )
Gian Marco Iodicefd683112018-04-17 09:52:44 +01003690{
3691 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
3692
3693 // Compute starting address for matrix A and Matrix B
3694 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
3695
3696 // Update address for the matrix A
3697 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
3698
3699 // Update address for the matrix B
3700 src_addr.s1 += idx * sizeof(half);
3701
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003702#if defined(REINTERPRET_INPUT_AS_3D)
3703 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
3704 // in order to take into account the presence of possible cross plane paddings
3705 //
3706 // | |
3707 // | plane0 |
3708 // | |
3709 // |__________________|
3710 // |******************|
3711 // | cross_plane_pad |
3712 // |******************|
3713 // | |
3714 // | plane1 |
3715 // | |
3716 // |__________________|
3717
3718 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
3719 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
3720 zin = min(DEPTH_GEMM3D - 1, zin);
3721
3722 // Add offset due to the cross plane paddings
3723 zin *= (src_cross_plane_pad * src0_stride_y);
3724
3725 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3726 // multiply src0_stride_z by DEPTH_GEMM3D
3727 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
3728
3729#else // defined(REINTERPRET_INPUT_AS_3D)
3730
Gian Marco Iodicefd683112018-04-17 09:52:44 +01003731 // Add offset for batched GEMM
3732 src_addr.s0 += get_global_id(2) * src0_stride_z;
3733
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003734#endif // defined(REINTERPRET_INPUT_AS_3D)
3735
Gian Marco Iodicefd683112018-04-17 09:52:44 +01003736#if defined(MATRIX_B_DEPTH)
3737 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
3738 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
3739#else // defined(MATRIX_B_DEPTH)
3740 src_addr.s1 += get_global_id(2) * src1_stride_z;
3741#endif // defined(MATRIX_B_DEPTH)
3742
3743 half8 acc0 = 0.0h;
3744#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3745 half8 acc1 = 0.0h;
3746#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3747#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3748 half8 acc2 = 0.0h;
3749#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3750#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3751 half8 acc3 = 0.0h;
3752#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3753
3754 int i = 0;
3755 for(; i <= ((int)COLS_A - 4); i += 4)
3756 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003757#if defined(REINTERPRET_INPUT_AS_3D)
3758 // Load values from matrix A
3759 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
3760#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3761 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
3762#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3763#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3764 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
3765#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3766#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3767 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
3768#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3769#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01003770 // Load values from matrix A
3771 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
3772#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3773 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
3774#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3775#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3776 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
3777#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3778#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3779 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
3780#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003781#endif // defined(REINTERPRET_INPUT_AS_3D)
3782
Gian Marco Iodicefd683112018-04-17 09:52:44 +01003783 // Load values from matrix B
3784 half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
3785 src_addr.s1 += src1_stride_y;
3786
3787 // Accumulate
3788 acc0 = fma(b0, (half8)a0.s0, acc0);
3789#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3790 acc1 = fma(b0, (half8)a1.s0, acc1);
3791#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3792#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3793 acc2 = fma(b0, (half8)a2.s0, acc2);
3794#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3795#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3796 acc3 = fma(b0, (half8)a3.s0, acc3);
3797#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3798
3799 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
3800 src_addr.s1 += src1_stride_y;
3801 acc0 = fma(b0, (half8)a0.s1, acc0);
3802#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3803 acc1 = fma(b0, (half8)a1.s1, acc1);
3804#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3805#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3806 acc2 = fma(b0, (half8)a2.s1, acc2);
3807#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3808#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3809 acc3 = fma(b0, (half8)a3.s1, acc3);
3810#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3811
3812 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
3813 src_addr.s1 += src1_stride_y;
3814 acc0 = fma(b0, (half8)a0.s2, acc0);
3815#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3816 acc1 = fma(b0, (half8)a1.s2, acc1);
3817#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3818#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3819 acc2 = fma(b0, (half8)a2.s2, acc2);
3820#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3821#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3822 acc3 = fma(b0, (half8)a3.s2, acc3);
3823#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3824
3825 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
3826 src_addr.s1 += src1_stride_y;
3827 acc0 = fma(b0, (half8)a0.s3, acc0);
3828#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3829 acc1 = fma(b0, (half8)a1.s3, acc1);
3830#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3831#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3832 acc2 = fma(b0, (half8)a2.s3, acc2);
3833#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3834#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3835 acc3 = fma(b0, (half8)a3.s3, acc3);
3836#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3837
3838 src_addr.s0 += 4 * sizeof(half);
3839 }
3840
3841 for(; i < (int)COLS_A; ++i)
3842 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003843#if defined(REINTERPRET_INPUT_AS_3D)
3844 // Load values from matrix A
3845 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
3846#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3847 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
3848#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3849#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3850 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
3851#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3852#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3853 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
3854#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3855#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01003856 // Load values from matrix A
3857 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
3858#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3859 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
3860#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3861#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3862 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
3863#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3864#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3865 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
3866#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003867#endif // defined(REINTERPRET_INPUT_AS_3D)
3868
Gian Marco Iodicefd683112018-04-17 09:52:44 +01003869 // Load values from matrix B
3870 half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
3871
3872 src_addr += (int2)(sizeof(half), src1_stride_y);
3873
3874 // Accumulate
3875 acc0 = fma(b0, (half8)a0, acc0); // b0 * (half8)a0;
3876#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3877 acc1 = fma(b0, (half8)a1, acc1); // b0 * (half8)a1;
3878#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3879#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3880 acc2 = fma(b0, (half8)a2, acc2); // b0 * (half8)a2;
3881#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3882#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3883 acc3 = fma(b0, (half8)a3, acc3); // b0 * (half8)a3;
3884#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3885 }
3886
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003887 // Multiply by the weight of matrix-matrix product and store the result
3888#if defined(ALPHA)
3889 acc0 = acc0 * (half8)ALPHA;
3890#endif // defined(ALPHA)
3891#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
3892 acc1 = acc1 * (half8)ALPHA;
3893#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
3894#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
3895 acc2 = acc2 * (half8)ALPHA;
3896#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
3897#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
3898 acc3 = acc3 * (half8)ALPHA;
3899#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
3900
3901 int z = get_global_id(2);
3902
Gian Marco Iodicefd683112018-04-17 09:52:44 +01003903 // Compute destination address
3904 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
3905
3906 // Compute dst address
3907 __global uchar *dst_addr = offset(&dst, 0, 0);
3908
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003909#if defined(REINTERPRET_OUTPUT_AS_3D)
3910 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003911 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003912 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003913 // | |
3914 // | plane0 |
3915 // | |
3916 // |__________________|
3917 // |******************|
3918 // | cross_plane_pad |
3919 // |******************|
3920 // | |
3921 // | plane1 |
3922 // | |
3923 // |__________________|
Gian Marco Iodicefd683112018-04-17 09:52:44 +01003924
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003925 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
3926 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
3927 zout = min(DEPTH_GEMM3D - 1, zout);
3928
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003929 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003930 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003931
3932 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3933 // multiply dst_stride_z by DEPTH_GEMM3D
3934 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
3935
3936 // Store the output block
3937 vstore8(acc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
3938#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3939 vstore8(acc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
3940#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3941#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3942 vstore8(acc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
3943#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3944#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3945 vstore8(acc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
3946#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3947
3948#else // defined(REINTERPRET_OUTPUT_AS_3D)
3949 // Add offset for batched GEMM
3950 dst_addr += z * dst_stride_z;
3951
3952 // Store the output block
Gian Marco Iodicefd683112018-04-17 09:52:44 +01003953 vstore8(acc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
3954#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicefd683112018-04-17 09:52:44 +01003955 vstore8(acc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
3956#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3957#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicefd683112018-04-17 09:52:44 +01003958 vstore8(acc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
3959#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3960#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicefd683112018-04-17 09:52:44 +01003961 vstore8(acc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
3962#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003963#endif // REINTERPRET_OUTPUT_AS_3D
Gian Marco Iodicefd683112018-04-17 09:52:44 +01003964}
Vidhya Sudhan Loganathanbdff4912018-05-22 15:03:09 +01003965#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01003966
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003967#endif // defined(COLS_A) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && (NUM_ELEMS_PROCESSED_PER_THREAD_Y)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003968
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003969#if defined(BETA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003970/** This OpenCL kernel performs the in-place matrix addition between 2 matrices taking into account that the second matrix might be weighted by a scalar value beta:
3971 *
Gian Marco19835e52018-01-30 13:35:54 +00003972 * @note The beta's value need to be passed at compile time using -DBETA
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003973 *
3974 * @param[in] src_ptr Pointer to the source matrix. Supported data types: F32
3975 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
3976 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3977 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
3978 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003979 * @param[in] src_stride_z Stride of the destination tensor in Z dimension (in bytes)
3980 * @param[in] src_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003981 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003982 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003983 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
3984 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
3985 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
3986 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003987 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
3988 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003989 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
3990 */
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003991__kernel void gemm_ma_f32(TENSOR3D_DECLARATION(src),
3992 TENSOR3D_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003993{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003994 // Compute source and destination addresses
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003995 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
3996 Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003997
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003998 // Load values from A x B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003999 float4 alpha_ab = vload4(0, (__global float *)dst.ptr);
4000
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004001 // Load values from Matrix C
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004002 float4 c = vload4(0, (__global float *)src.ptr);
4003
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004004 // Computes alpha * axb + beta * c
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004005 float4 out = alpha_ab + (float4)BETA * c;
4006
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004007 // Store final result in axb matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004008 vstore4(out, 0, (__global float *)dst.ptr);
4009}
4010
Vidhya Sudhan Loganathan76c85642018-05-25 13:53:02 +01004011#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004012/** This OpenCL kernel performs the in-place matrix addition between 2 matrices taking into account that the second matrix might be weighted by a scalar value beta:
4013 *
Gian Marco19835e52018-01-30 13:35:54 +00004014 * @note The beta's value need to be passed at compile time using -DBETA
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01004015 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004016 * @param[in] src_ptr Pointer to the source matrix. Supported data types: F16
4017 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
4018 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4019 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
4020 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004021 * @param[in] src_stride_z Stride of the destination tensor in Z dimension (in bytes)
4022 * @param[in] src_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004023 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01004024 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004025 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
4026 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
4027 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
4028 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004029 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
4030 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004031 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
4032 */
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004033__kernel void gemm_ma_f16(TENSOR3D_DECLARATION(src),
4034 TENSOR3D_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004035{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004036 // Compute source and destination addresses
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004037 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
4038 Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004039
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004040 // Load values from A x B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004041 half8 alpha_ab = vload8(0, (__global half *)dst.ptr);
4042
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004043 // Load values from Matrix C
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004044 half8 c = vload8(0, (__global half *)src.ptr);
4045
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004046 // Computes alpha * axb + beta * c
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004047 half8 out = alpha_ab + (half8)BETA * c;
4048
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004049 // Store final result in axb matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004050 vstore8(out, 0, (__global half *)dst.ptr);
4051}
Vidhya Sudhan Loganathan76c85642018-05-25 13:53:02 +01004052#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004053#endif // defined(BETA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004054
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004055#if defined(WIDTH_VECTOR_A)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004056/** This OpenCL kernel computes the vector by matrix multiplication between each row of A (src0) and matrix B (src1) used for locally connected layer
4057 *
Gian Marco19835e52018-01-30 13:35:54 +00004058 * @note The width of A need to be passed at compile time using -DWIDTH_VECTOR_A
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004059 *
Gian Marco19835e52018-01-30 13:35:54 +00004060 * @note The input A and matrix B must not be reshaped
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004061 *
4062 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
4063 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
4064 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4065 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
4066 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4067 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01004068 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004069 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
4070 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4071 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
4072 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4073 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
4074 * @param[in] src1_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
4075 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01004076 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004077 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
4078 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
4079 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
4080 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
4081 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
4082 */
4083__kernel void gemm_lc_vm_f32(IMAGE_DECLARATION(src0),
4084 TENSOR3D_DECLARATION(src1),
4085 IMAGE_DECLARATION(dst))
4086{
4087 int idx = get_global_id(0) * 4;
4088 int idy = get_global_id(1);
4089
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004090 // Compute the address for the vector A and matrix B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004091 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes + src0_stride_y * idy, src1_offset_first_element_in_bytes + src1_stride_z * idy));
4092 src_addr.s1 += idx * sizeof(float);
4093
4094 int end_row_vec_a = src_addr.s0 + (WIDTH_VECTOR_A * sizeof(float));
4095
4096 float4 acc = 0.0f;
4097
Georgios Pinitas96880cf2017-10-20 18:52:20 +01004098 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(float)); src_addr += (int2)(2 * sizeof(float), 2 * src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004099 {
4100 float2 a0 = vload2(0, (__global float *)(src0_ptr + src_addr.s0));
4101 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
4102 float4 b1 = vload4(0, (__global float *)(src1_ptr + src_addr.s1 + src1_stride_y));
4103
4104 acc += b0 * (float4)a0.s0;
4105 acc += b1 * (float4)a0.s1;
4106 }
4107
4108 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(float), src1_stride_y))
4109 {
4110 float a0 = *((__global float *)(src0_ptr + src_addr.s0));
4111 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
4112
4113 acc += b0 * (float4)a0;
4114 }
4115
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004116 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004117 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
4118
4119 vstore4(acc, 0, (__global float *)(offset(&dst, 0, 0)));
4120}
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004121#endif // defined(WIDTH_VECTOR_A)
4122
4123/** This kernel accumulates each row with the biases vector.
4124 *
4125 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=short.
4126 * @note The vector size must be passed at compile time using -DVECTOR_SIZE e.g. -DVECTOR_SIZE=16.
4127 *
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01004128 * @param[in, out] accum_ptr Pointer to the accumulate tensor. Supported data type: U8/S8/U16/S16/F16/U32/S32/F32
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004129 * @param[in] accum_stride_x Stride of the accmulate tensor in X dimension (in bytes)
4130 * @param[in] accum_step_x accum_stride_x * number of elements along X processed per workitem(in bytes)
4131 * @param[in] accum_stride_y Stride of the accumlulate tensor in Y dimension (in bytes)
4132 * @param[in] accum_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4133 * @param[in] accum_offset_first_element_in_bytes The offset of the first element in the accumulate tensor
4134 * @param[in] biases_ptr Pointer to the biases vector. Same as @p accum_ptr
4135 * @param[in] biases_stride_x Stride of the destination tensor in X dimension (in bytes)
4136 * @param[in] biases_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
4137 * @param[in] biases_offset_first_element_in_bytes The offset of the first element in the destination tensor
4138 */
4139#if defined(DATA_TYPE) && defined(VECTOR_SIZE)
4140__kernel void gemm_accumulate_biases(
4141 IMAGE_DECLARATION(accum),
4142 VECTOR_DECLARATION(biases))
4143{
4144 Image accum = CONVERT_TO_IMAGE_STRUCT(accum);
4145 Vector biases = CONVERT_TO_VECTOR_STRUCT(biases);
4146
4147 // Vector size, i.e. number of vector elements.
4148 VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
4149 accum_value = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)accum.ptr);
4150 VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
4151 biases_value = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)biases.ptr);
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01004152 accum_value = biases_value + accum_value;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004153 // Store result in the accumulate buffer
4154 VSTORE(VECTOR_SIZE)
4155 (accum_value, 0, (__global DATA_TYPE *)accum.ptr);
4156}
4157#endif // defined(DATA_TYPE) && defined(VECTOR_SIZE)