blob: 9dd072bd6ea293c43d3b1323d57ea6c16be608f5 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +00002 * Copyright (c) 2017-2019 ARM Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "helpers.h"
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +000025#include "repeat.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010026
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000027#if defined(M0) && defined(K0) && defined(V0) && defined(DATA_TYPE)
28
29/** This OpenCL kernel reshapes the lhs input matrix. The kernel splits the input matrix in blocks of size M0xK0 and stores each one (not transposed) in
30 * the output matrix unrolling the values.
31 *
32 * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float)
33 * @note The block's dimensions (M0 and K0) must be passed at compile time using -DM0 and -DK0 (i.e. -DM0=2, -DK0=2).
34 * @note The number of M0xK0 vertical blocks to store on the same output row must be passed at compile time using -DV0 (i.e. -DV0=2)
35 * @note Only the following values for M0, K0 and V0 are supported:
36 * M0: 2,3,4,5,6,7,8
37 * K0: 2,4,8,16
38 * V0: greater than 0
39 * @note In case the input has to be reinterpreted as a 3D tensor (i.e. input of convolution layer 1x1), the following information must be passed at compile time:
40 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
41 * -# HEIGHT_GEMM3D: The height of the input in case it has to be reinterpreted as a 3D tensor.
42 * -# DEPTH_GEMM3D: The depth of the input in case it has to be reinterpreted as a 3D tensor
43 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
44 * @note If the M0xK0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
45 *
46 * @param[in] src_ptr Pointer to the source LHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
47 * @param[in] src_stride_x Stride of the source LHS tensor in X dimension (in bytes)
48 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
49 * @param[in] src_stride_y Stride of the source LHS tensor in Y dimension (in bytes)
50 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
51 * @param[in] src_stride_z Stride of the source LHS tensor in Z dimension (in bytes)
52 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
53 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source LHS tensor
54 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
55 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
56 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
57 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
58 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
59 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
60 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
61 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
62 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
63 */
64__kernel void gemm_reshape_lhs_matrix_nt(TENSOR3D_DECLARATION(src),
65 TENSOR3D_DECLARATION(dst)
66#if defined(REINTERPRET_INPUT_AS_3D)
67 ,
68 uint cross_plane_pad
69#endif // REINTERPRET_INPUT_AS_3D
70 )
71{
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +000072 // Block size
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000073#define BLOCK_SIZE ((M0) * (K0))
74
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +000075 // Output offset X
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000076#if defined(INTERLEAVE)
77#define OUTPUT_OFFSET_X (K0)
78#else // defined(INTERLEAVE)
79#define OUTPUT_OFFSET_X (BLOCK_SIZE)
80#endif // defined(INTERLEAVE)
81
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +000082 // Output step X
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000083#if defined(INTERLEAVE)
84#define OUTPUT_STEP_X (K0) * (V0)
85#else // Do not interleave
86#define OUTPUT_STEP_X (K0)
87#endif // defined(INTERLEAVE)
88
89 // Compute source and destination addresses
90 uint x = get_global_id(0);
91 uint y = get_global_id(1);
92 uint z = get_global_id(2);
93
94 // ------------------ Compute input/output addresses ---------------------------
95
96 // Compute the input address
97 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)K0 * sizeof(DATA_TYPE) + y * (uint)M0 * src_stride_y;
98
99 // Compute the output address
100 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)BLOCK_SIZE * (uint)V0 * sizeof(DATA_TYPE)) + ((y / (uint)V0) * (uint)dst_stride_y) + ((y % V0) *
101 (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE));
102
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +0000103 REPEAT_VAR_INIT_TO_CONST(M0, uint, zin, 0); //uint zin0=0, zin1=0, zin2=0...zin(M0-1)=0;
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000104
105#if defined(REINTERPRET_INPUT_AS_3D)
106 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
107 // multiply src_stride_z by DEPTH_GEMM3D
108
109 // Note for the REINTERPRET_INPUT_AS_3D case
110 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
111 // in order to take into account the presence of possible cross plane paddings
112 //
113 // | |
114 // | plane0 |
115 // | |
116 // |__________________|
117 // |******************|
118 // | cross_plane_pad |
119 // |******************|
120 // | |
121 // | plane1 |
122 // | |
123 // |__________________|
124
125 input_ptr += z * (uint)src_stride_z * DEPTH_GEMM3D;
126
127 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
128 zin0 = (0 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
129 zin0 = min((uint)(DEPTH_GEMM3D - 1), zin0);
130 zin0 *= (cross_plane_pad * src_stride_y);
131#if M0 > 1
132 zin1 = (1 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
133 zin1 = min((uint)(DEPTH_GEMM3D - 1), zin1);
134 zin1 *= (cross_plane_pad * src_stride_y);
135#endif // M0 > 1
136#if M0 > 2
137 zin2 = (2 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
138 zin2 = min((uint)(DEPTH_GEMM3D - 1), zin2);
139 zin2 *= (cross_plane_pad * src_stride_y);
140#endif // M0 > 2
141#if M0 > 3
142 zin3 = (3 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
143 zin3 = min((uint)(DEPTH_GEMM3D - 1), zin3);
144 zin3 *= (cross_plane_pad * src_stride_y);
145#endif // M0 > 3
146#if M0 > 4
147 zin4 = (4 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
148 zin4 = min((uint)(DEPTH_GEMM3D - 1), zin4);
149 zin4 *= (cross_plane_pad * src_stride_y);
150#endif // M0 > 4
151#if M0 > 5
152 zin5 = (5 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
153 zin5 = min((uint)(DEPTH_GEMM3D - 1), zin5);
154 zin5 *= (cross_plane_pad * src_stride_y);
155#endif // M0 > 5
156#if M0 > 6
157 zin6 = (6 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
158 zin6 = min((uint)(DEPTH_GEMM3D - 1), zin6);
159 zin6 *= (cross_plane_pad * src_stride_y);
160#endif // M0 > 6
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +0000161#if M0 > 7
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000162 zin7 = (7 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
163 zin7 = min((uint)(DEPTH_GEMM3D - 1), zin7);
164 zin7 *= (cross_plane_pad * src_stride_y);
165#endif // M0 > 7
166
167#else // defined(REINTERPRET_INPUT_AS_3D)
168
169 input_ptr += z * (uint)src_stride_z;
170
171#endif // defined(REINTERPRET_INPUT_AS_3D)
172
173 // Add offset for batched GEMM
174 output_ptr += z * (uint)dst_stride_z;
175
176 // ---------------------------Load input values --------------------------------
177
178 // Load values from the LHS matrix
179 VEC_DATA_TYPE(DATA_TYPE, K0)
180 a0 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y + zin0));
181#if M0 > 1
182 VEC_DATA_TYPE(DATA_TYPE, K0)
183 a1 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y + zin1));
184#endif // M0 > 1
185#if M0 > 2
186 VEC_DATA_TYPE(DATA_TYPE, K0)
187 a2 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y + zin2));
188#endif // M0 > 2
189#if M0 > 3
190 VEC_DATA_TYPE(DATA_TYPE, K0)
191 a3 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y + zin3));
192#endif // M0 > 3
193#if M0 > 4
194 VEC_DATA_TYPE(DATA_TYPE, K0)
195 a4 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 4 * src_stride_y + zin4));
196#endif // M0 > 4
197#if M0 > 5
198 VEC_DATA_TYPE(DATA_TYPE, K0)
199 a5 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 5 * src_stride_y + zin5));
200#endif // M0 > 5
201#if M0 > 6
202 VEC_DATA_TYPE(DATA_TYPE, K0)
203 a6 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 6 * src_stride_y + zin6));
204#endif // M0 > 6
205#if M0 > 7
206 VEC_DATA_TYPE(DATA_TYPE, K0)
207 a7 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 7 * src_stride_y + zin7));
208#endif // M0 > 7
209
210 // ---------------------------Store output values ------------------------------
211
212 VSTORE(K0)
213 (a0, 0, (__global DATA_TYPE *)(output_ptr + 0 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
214#if M0 > 1
215 VSTORE(K0)
216 (a1, 0, (__global DATA_TYPE *)(output_ptr + 1 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
217#endif // M0 > 1
218#if M0 > 2
219 VSTORE(K0)
220 (a2, 0, (__global DATA_TYPE *)(output_ptr + 2 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
221#endif // M0 > 2
222#if M0 > 3
223 VSTORE(K0)
224 (a3, 0, (__global DATA_TYPE *)(output_ptr + 3 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
225#endif // M0 > 3
226#if M0 > 4
227 VSTORE(K0)
228 (a4, 0, (__global DATA_TYPE *)(output_ptr + 4 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
229#endif // M0 > 4
230#if M0 > 5
231 VSTORE(K0)
232 (a5, 0, (__global DATA_TYPE *)(output_ptr + 5 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
233#endif // M0 > 5
234#if M0 > 6
235 VSTORE(K0)
236 (a6, 0, (__global DATA_TYPE *)(output_ptr + 6 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
237#endif // M0 > 6
238#if M0 > 7
239 VSTORE(K0)
240 (a7, 0, (__global DATA_TYPE *)(output_ptr + 7 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
241#endif // M0 > 7
242
243#undef BLOCK_SIZE
244#undef OUTPUT_OFFSET_X
245#undef OUTPUT_STEP_X
246}
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000247
248#if M0 == 2
249#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
250 ({ \
251 VEC_DATA_TYPE(DATA_TYPE, M0) \
252 res = (VEC_DATA_TYPE(DATA_TYPE, M0))(a0.s##i, a1.s##i); \
253 VSTORE(M0) \
254 (res, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
255 })
256#elif M0 == 3 // M0 == 3
257#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
258 ({ \
259 VEC_DATA_TYPE(DATA_TYPE, M0) \
260 res = (VEC_DATA_TYPE(DATA_TYPE, M0))(a0.s##i, a1.s##i, a2.s##i); \
261 VSTORE(M0) \
262 (res, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
263 })
264#elif M0 == 4 // M0 == 4
265#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
266 ({ \
267 VEC_DATA_TYPE(DATA_TYPE, M0) \
268 res = (VEC_DATA_TYPE(DATA_TYPE, M0))(a0.s##i, a1.s##i, a2.s##i, a3.s##i); \
269 VSTORE(M0) \
270 (res, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
271 })
272#elif M0 == 5 // M0 == 5
273#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
274 ({ \
275 VEC_DATA_TYPE(DATA_TYPE, 4) \
276 res0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s##i, a1.s##i, a2.s##i, a3.s##i); \
277 DATA_TYPE res1 = a4.s##i; \
278 VSTORE(4) \
279 (res0, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
280 *((__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE)) + 4) = res1; \
281 })
282#elif M0 == 6 // M0 == 6
283#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
284 ({ \
285 VEC_DATA_TYPE(DATA_TYPE, 4) \
286 res0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s##i, a1.s##i, a2.s##i, a3.s##i); \
287 VEC_DATA_TYPE(DATA_TYPE, 2) \
288 res1 = (VEC_DATA_TYPE(DATA_TYPE, 2))(a4.s##i, a5.s##i); \
289 VSTORE(4) \
290 (res0, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
291 VSTORE(2) \
292 (res1, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE)) + 4); \
293 })
294#elif M0 == 7 // M0 == 7
295#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
296 ({ \
297 VEC_DATA_TYPE(DATA_TYPE, 4) \
298 res0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s##i, a1.s##i, a2.s##i, a3.s##i); \
299 VEC_DATA_TYPE(DATA_TYPE, 3) \
300 res1 = (VEC_DATA_TYPE(DATA_TYPE, 3))(a4.s##i, a5.s##i, a6.s##i); \
301 VSTORE(4) \
302 (res0, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
303 VSTORE(3) \
304 (res1, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE)) + 4); \
305 })
306#elif M0 == 8 // M0 == 8
307#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
308 ({ \
309 VEC_DATA_TYPE(DATA_TYPE, M0) \
310 res = (VEC_DATA_TYPE(DATA_TYPE, M0))(a0.s##i, a1.s##i, a2.s##i, a3.s##i, a4.s##i, a5.s##i, a6.s##i, a7.s##i); \
311 VSTORE(M0) \
312 (res, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
313 })
314#else // M0 not supported
315#error "M0 value not supported"
316#endif // N0 conditions
317
318/** This OpenCL kernel reshapes the lhs input matrix. The kernel splits the input matrix in blocks of size M0xK0 and stores each one (transposed) in
319 * the output matrix unrolling the values.
320 *
321 * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float)
322 * @note The block's dimensions (M0 and K0) must be passed at compile time using -DM0 and -DK0 (i.e. -DM0=2, -DK0=2).
323 * @note The number of M0xK0 vertical blocks to store on the same output row must be passed at compile time using -DV0 (i.e. -DV0=2)
324 * @note Only the following values for M0, K0 and V0 are supported:
325 * M0: 2,3,4,5,6,7,8
326 * K0: 2,4,8,16
327 * V0: greater than 0
328 * @note In case the input has to be reinterpreted as a 3D tensor (i.e. input of convolution layer 1x1), the following information must be passed at compile time:
329 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
330 * -# HEIGHT_GEMM3D: The height of the input in case it has to be reinterpreted as a 3D tensor.
331 * -# DEPTH_GEMM3D: The depth of the input in case it has to be reinterpreted as a 3D tensor
332 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
333 * @note If the M0xK0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
334 *
335 * @param[in] src_ptr Pointer to the source LHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
336 * @param[in] src_stride_x Stride of the source LHS tensor in X dimension (in bytes)
337 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
338 * @param[in] src_stride_y Stride of the source LHS tensor in Y dimension (in bytes)
339 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
340 * @param[in] src_stride_z Stride of the source LHS tensor in Z dimension (in bytes)
341 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
342 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source LHS tensor
343 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
344 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
345 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
346 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
347 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
348 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
349 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
350 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
351 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
352 */
353__kernel void gemm_reshape_lhs_matrix_t(TENSOR3D_DECLARATION(src),
354 TENSOR3D_DECLARATION(dst)
355#if defined(REINTERPRET_INPUT_AS_3D)
356 ,
357 uint cross_plane_pad
358#endif // REINTERPRET_INPUT_AS_3D
359 )
360{
361 // Block size
362#define BLOCK_SIZE ((M0) * (K0))
363
364 // Output offset X
365#if defined(INTERLEAVE)
366#define OUTPUT_OFFSET_X (M0)
367#else // defined(INTERLEAVE)
368#define OUTPUT_OFFSET_X (BLOCK_SIZE)
369#endif // defined(INTERLEAVE)
370
371 // Output step X
372#if defined(INTERLEAVE)
373#define OUTPUT_STEP_X (M0) * (V0)
374#else // Do not interleave
375#define OUTPUT_STEP_X (M0)
376#endif // defined(INTERLEAVE)
377
378 // Compute source and destination addresses
379 uint x = get_global_id(0);
380 uint y = get_global_id(1);
381 uint z = get_global_id(2);
382
383 // ------------------ Compute input/output addresses ---------------------------
384
385 // Compute the input address
386 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)K0 * sizeof(DATA_TYPE) + y * (uint)M0 * src_stride_y;
387
388 // Compute the output address
389 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)BLOCK_SIZE * (uint)V0 * sizeof(DATA_TYPE)) + ((y / (uint)V0) * (uint)dst_stride_y) + ((y % V0) *
390 (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE));
391
392 uint zin0 = 0;
393 uint zin1 = 0;
394 uint zin2 = 0;
395 uint zin3 = 0;
396 uint zin4 = 0;
397 uint zin5 = 0;
398 uint zin6 = 0;
399 uint zin7 = 0;
400
401#if defined(REINTERPRET_INPUT_AS_3D)
402 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
403 // multiply src_stride_z by DEPTH_GEMM3D
404
405 // Note for the REINTERPRET_INPUT_AS_3D case
406 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
407 // in order to take into account the presence of possible cross plane paddings
408 //
409 // | |
410 // | plane0 |
411 // | |
412 // |__________________|
413 // |******************|
414 // | cross_plane_pad |
415 // |******************|
416 // | |
417 // | plane1 |
418 // | |
419 // |__________________|
420
421 input_ptr += z * (uint)src_stride_z * DEPTH_GEMM3D;
422
423 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
424 zin0 = (0 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
425 zin0 = min((uint)(DEPTH_GEMM3D - 1), zin0);
426 zin0 *= (cross_plane_pad * src_stride_y);
427#if M0 > 1
428 zin1 = (1 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
429 zin1 = min((uint)(DEPTH_GEMM3D - 1), zin1);
430 zin1 *= (cross_plane_pad * src_stride_y);
431#endif // M0 > 1
432#if M0 > 2
433 zin2 = (2 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
434 zin2 = min((uint)(DEPTH_GEMM3D - 1), zin2);
435 zin2 *= (cross_plane_pad * src_stride_y);
436#endif // M0 > 2
437#if M0 > 3
438 zin3 = (3 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
439 zin3 = min((uint)(DEPTH_GEMM3D - 1), zin3);
440 zin3 *= (cross_plane_pad * src_stride_y);
441#endif // M0 > 3
442#if M0 > 4
443 zin4 = (4 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
444 zin4 = min((uint)(DEPTH_GEMM3D - 1), zin4);
445 zin4 *= (cross_plane_pad * src_stride_y);
446#endif // M0 > 4
447#if M0 > 5
448 zin5 = (5 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
449 zin5 = min((uint)(DEPTH_GEMM3D - 1), zin5);
450 zin5 *= (cross_plane_pad * src_stride_y);
451#endif // M0 > 5
452#if M0 > 6
453 zin6 = (6 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
454 zin6 = min((uint)(DEPTH_GEMM3D - 1), zin6);
455 zin6 *= (cross_plane_pad * src_stride_y);
456#endif // M0 > 6
457#if M0 > 6
458 zin7 = (7 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
459 zin7 = min((uint)(DEPTH_GEMM3D - 1), zin7);
460 zin7 *= (cross_plane_pad * src_stride_y);
461#endif // M0 > 7
462
463#else // defined(REINTERPRET_INPUT_AS_3D)
464
465 input_ptr += z * (uint)src_stride_z;
466
467#endif // defined(REINTERPRET_INPUT_AS_3D)
468
469 // Add offset for batched GEMM
470 output_ptr += z * (uint)dst_stride_z;
471
472 // ---------------------------Load input values --------------------------------
473
474 // Load values from the LHS matrix
475 VEC_DATA_TYPE(DATA_TYPE, K0)
476 a0 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y + zin0));
477#if M0 > 1
478 VEC_DATA_TYPE(DATA_TYPE, K0)
479 a1 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y + zin1));
480#endif // M0 > 1
481#if M0 > 2
482 VEC_DATA_TYPE(DATA_TYPE, K0)
483 a2 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y + zin2));
484#endif // M0 > 2
485#if M0 > 3
486 VEC_DATA_TYPE(DATA_TYPE, K0)
487 a3 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y + zin3));
488#endif // M0 > 3
489#if M0 > 4
490 VEC_DATA_TYPE(DATA_TYPE, K0)
491 a4 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 4 * src_stride_y + zin4));
492#endif // M0 > 4
493#if M0 > 5
494 VEC_DATA_TYPE(DATA_TYPE, K0)
495 a5 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 5 * src_stride_y + zin5));
496#endif // M0 > 5
497#if M0 > 6
498 VEC_DATA_TYPE(DATA_TYPE, K0)
499 a6 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 6 * src_stride_y + zin6));
500#endif // M0 > 6
501#if M0 > 7
502 VEC_DATA_TYPE(DATA_TYPE, K0)
503 a7 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 7 * src_stride_y + zin7));
504#endif // M0 > 7
505
506 // ---------------------------Transpose and store block -----------------------
507
508 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 0);
509 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 1);
510#if K0 > 2
511 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 2);
512 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 3);
513#endif // K0 > 2
514#if K0 > 4
515 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 4);
516 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 5);
517 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 6);
518 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 7);
519#endif // K0 > 4
520#if K0 > 8
521 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 8);
522 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 9);
523 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, A);
524 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, B);
525 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, C);
526 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, D);
527 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, E);
528 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, F);
529#endif // K0 > 8
530
531#undef BLOCK_SIZE
532#undef OUTPUT_OFFSET_X
533#undef OUTPUT_STEP_X
534}
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000535#endif // defined(M0) && defined(K0) && defined(V0) && defined(DATA_TYPE)
536
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000537#if defined(K0) && defined(N0) && defined(H0) && defined(DATA_TYPE) && defined(SRC_HEIGHT)
538/** This OpenCL kernel reshapes the rhs input matrix. The kernel splits the input matrix in blocks of size K0xN0 and stores each one (not transposed) in
539 * the output matrix unrolling the values.
540 *
541 * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float)
542 * @note The height of the input tensor must be passed at compile time using -DSRC_HEIGHT (i.e. -DSRC_HEIGHT=16)
543 * @note The block's dimensions (K0 and N0) must be passed at compile time using -DK0 and -DN0 (i.e. -DK0=2, -DN0=2).
544 * @note The number of K0xN0 vertical blocks to store on the same output row must be passed at compile time using -DH0 (i.e. -DH0=2)
545 * @note If the K0xN0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
546 * @note Only the following values for K0, N0 and H0 are supported:
547 * N0: 2,4,8,16
548 * K0: 1,2,4,8,16
549 * H0: greater than 0
550 *
551 * @param[in] src_ptr Pointer to the source RHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
552 * @param[in] src_stride_x Stride of the source RHS tensor in X dimension (in bytes)
553 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
554 * @param[in] src_stride_y Stride of the source RHS tensor in Y dimension (in bytes)
555 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
556 * @param[in] src_stride_z Stride of the source RHS tensor in Z dimension (in bytes)
557 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
558 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source RHS tensor
559 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
560 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
561 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
562 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
563 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
564 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
565 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
566 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
567 */
568__kernel void gemm_reshape_rhs_matrix_nt(TENSOR3D_DECLARATION(src),
569 TENSOR3D_DECLARATION(dst))
570{
571 // Block size
572#define BLOCK_SIZE ((K0) * (N0))
573
574 // Output offset X
575#if defined(INTERLEAVE)
576#define OUTPUT_OFFSET_X (N0)
577#else // defined(INTERLEAVE)
578#define OUTPUT_OFFSET_X (BLOCK_SIZE)
579#endif // defined(INTERLEAVE)
580
581 // Output step X
582#if defined(INTERLEAVE)
583#define OUTPUT_STEP_X (N0) * (H0)
584#else // Do not interleave
585#define OUTPUT_STEP_X (N0)
586#endif // defined(INTERLEAVE)
587
588 // Compute source and destination addresses
589 uint x = get_global_id(0);
590 uint y = get_global_id(1);
591 uint z = get_global_id(2);
592
593 // ------------------ Compute input/output addresses ---------------------------
594
595 // Compute the input address
596 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)N0 * sizeof(DATA_TYPE) + y * (uint)K0 * src_stride_y + z * (uint)src_stride_z;
597
598 // Compute the output address
599 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (y * (uint)BLOCK_SIZE * (uint)H0 * sizeof(DATA_TYPE)) + ((x % (uint)H0) * (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE)) + ((
600 x / (uint)H0)
601 * (uint)dst_stride_y)
602 + z * (uint)dst_stride_z;
603
604 // ---------------------------Load input values --------------------------------
605
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +0000606 REPEAT_VAR_INIT_TO_CONST(K0, VEC_DATA_TYPE(DATA_TYPE, N0), a, 0); ////uint a0=0, a1=0, a2=0...a(M0-1)=0;
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000607
608 // Load values from the RHS matrix
609 a0 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y));
610#if K0 > 1
611 if(y * (uint)K0 + 1 < SRC_HEIGHT)
612 {
613 a1 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y));
614 }
615#endif // K0 > 1
616#if K0 > 2
617 if(y * (uint)K0 + 2 < SRC_HEIGHT)
618 {
619 a2 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y));
620 }
621 if(y * (uint)K0 + 3 < SRC_HEIGHT)
622 {
623 a3 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y));
624 }
625#endif // K0 > 2
626#if K0 > 4
627 if(y * (uint)K0 + 4 < SRC_HEIGHT)
628 {
629 a4 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 4 * src_stride_y));
630 }
631 if(y * (uint)K0 + 5 < SRC_HEIGHT)
632 {
633 a5 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 5 * src_stride_y));
634 }
635 if(y * (uint)K0 + 6 < SRC_HEIGHT)
636 {
637 a6 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 6 * src_stride_y));
638 }
639 if(y * (uint)K0 + 7 < SRC_HEIGHT)
640 {
641 a7 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 7 * src_stride_y));
642 }
643#endif // K0 > 4
644#if K0 > 8
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000645 if(y * (uint)K0 + 8 < SRC_HEIGHT)
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000646 {
647 a8 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 8 * src_stride_y));
648 }
649 if(y * (uint)K0 + 9 < SRC_HEIGHT)
650 {
651 a9 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 9 * src_stride_y));
652 }
653 if(y * (uint)K0 + 10 < SRC_HEIGHT)
654 {
655 aA = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 10 * src_stride_y));
656 }
657 if(y * (uint)K0 + 11 < SRC_HEIGHT)
658 {
659 aB = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 11 * src_stride_y));
660 }
661 if(y * (uint)K0 + 12 < SRC_HEIGHT)
662 {
663 aC = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 12 * src_stride_y));
664 }
665 if(y * (uint)K0 + 13 < SRC_HEIGHT)
666 {
667 aD = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 13 * src_stride_y));
668 }
669 if(y * (uint)K0 + 14 < SRC_HEIGHT)
670 {
671 aE = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 14 * src_stride_y));
672 }
673 if(y * (uint)K0 + 15 < SRC_HEIGHT)
674 {
675 aF = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 15 * src_stride_y));
676 }
677#endif // K0 > 8
678
679 // ---------------------------Store output values ------------------------------
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000680 VSTORE(N0)
681 (a0, 0, (__global DATA_TYPE *)(output_ptr + 0 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
682#if K0 > 1
683 VSTORE(N0)
684 (a1, 0, (__global DATA_TYPE *)(output_ptr + 1 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
685#endif // K0 > 1
686#if K0 > 2
687 VSTORE(N0)
688 (a2, 0, (__global DATA_TYPE *)(output_ptr + 2 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
689 VSTORE(N0)
690 (a3, 0, (__global DATA_TYPE *)(output_ptr + 3 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
691#endif // K0 > 2
692#if K0 > 4
693 VSTORE(N0)
694 (a4, 0, (__global DATA_TYPE *)(output_ptr + 4 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
695 VSTORE(N0)
696 (a5, 0, (__global DATA_TYPE *)(output_ptr + 5 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
697 VSTORE(N0)
698 (a6, 0, (__global DATA_TYPE *)(output_ptr + 6 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
699 VSTORE(N0)
700 (a7, 0, (__global DATA_TYPE *)(output_ptr + 7 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
701#endif // N0 > 4
702#if K0 > 8
703 VSTORE(N0)
704 (a8, 0, (__global DATA_TYPE *)(output_ptr + 8 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
705 VSTORE(N0)
706 (a9, 0, (__global DATA_TYPE *)(output_ptr + 9 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
707 VSTORE(N0)
708 (aA, 0, (__global DATA_TYPE *)(output_ptr + 10 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
709 VSTORE(N0)
710 (aB, 0, (__global DATA_TYPE *)(output_ptr + 11 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
711 VSTORE(N0)
712 (aC, 0, (__global DATA_TYPE *)(output_ptr + 12 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
713 VSTORE(N0)
714 (aD, 0, (__global DATA_TYPE *)(output_ptr + 13 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
715 VSTORE(N0)
716 (aE, 0, (__global DATA_TYPE *)(output_ptr + 14 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
717 VSTORE(N0)
718 (aF, 0, (__global DATA_TYPE *)(output_ptr + 15 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
719#endif // N0 > 8
720
721#undef BLOCK_SIZE
722#undef OUTPUT_OFFSET_X
723#undef OUTPUT_STEP_X
724}
725
726#if defined(TRANSPOSE)
727/** This OpenCL kernel reshapes the rhs input matrix. The kernel splits the input matrix in blocks of size K0xN0 and stores each one (transposed) in
728 * the output matrix unrolling the values.
729 *
730 * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float)
731 * @note The height of the input tensor must be passed at compile time using -DSRC_HEIGHT (i.e. -DSRC_HEIGHT=16)
732 * @note The block's dimensions (K0 and N0) must be passed at compile time using -DK0 and -DN0 (i.e. -DK0=2, -DN0=2).
733 * @note The number of K0xN0 vertical blocks to store on the same output row must be passed at compile time using -DH0 (i.e. -DH0=2)
734 * @note If the K0xN0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
735 * @note The option -DTRANSPOSE must passed at compile time.
736 * @note Only the following values for K0, N0 and H0 are supported:
737 * N0: 2,4,8,16
738 * K0: 4,8,16
739 * H0: greater than 0
740 *
741 * @param[in] src_ptr Pointer to the source RHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
742 * @param[in] src_stride_x Stride of the source RHS tensor in X dimension (in bytes)
743 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
744 * @param[in] src_stride_y Stride of the source RHS tensor in Y dimension (in bytes)
745 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
746 * @param[in] src_stride_z Stride of the source RHS tensor in Z dimension (in bytes)
747 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
748 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source RHS tensor
749 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
750 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
751 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
752 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
753 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
754 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
755 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
756 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
757 */
758__kernel void gemm_reshape_rhs_matrix_t(TENSOR3D_DECLARATION(src),
759 TENSOR3D_DECLARATION(dst))
760{
761 // Block size
762#define BLOCK_SIZE ((K0) * (N0))
763
764 // Output offset X
765#if defined(INTERLEAVE)
766#define OUTPUT_OFFSET_X (K0)
767#else // defined(INTERLEAVE)
768#define OUTPUT_OFFSET_X (BLOCK_SIZE)
769#endif // defined(INTERLEAVE)
770
771 // Output step X
772#if defined(INTERLEAVE)
773#define OUTPUT_STEP_X (K0) * (H0)
774#else // Do not interleave
775#define OUTPUT_STEP_X (K0)
776#endif // defined(INTERLEAVE)
777
778 // Compute source and destination addresses
779 uint x = get_global_id(0);
780 uint y = get_global_id(1);
781 uint z = get_global_id(2);
782
783 // ------------------ Compute input/output addresses ---------------------------
784
785 // Compute the input address
786 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)N0 * sizeof(DATA_TYPE) + y * (uint)K0 * src_stride_y + z * (uint)src_stride_z;
787
788 // Compute the output address
789 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (y * (uint)BLOCK_SIZE * (uint)H0 * sizeof(DATA_TYPE)) + ((x % H0) * (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE)) + ((x /
790 (uint)H0) * (uint)dst_stride_y) + z * (uint)dst_stride_z;
791
792 // ---------------------------Load input values --------------------------------
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +0000793 REPEAT_VAR_INIT_TO_CONST(K0, VEC_DATA_TYPE(DATA_TYPE, N0), a, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) a0=0, a1=0, ... a(K0-1)=0;
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000794
795 // Load values from the RHS matrix
796 a0 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y));
797 if(y * (uint)K0 + 1 < SRC_HEIGHT)
798 {
799 a1 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y));
800 }
801 if(y * (uint)K0 + 2 < SRC_HEIGHT)
802 {
803 a2 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y));
804 }
805 if(y * (uint)K0 + 3 < SRC_HEIGHT)
806 {
807 a3 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y));
808 }
809#if K0 > 4
810 if(y * (uint)K0 + 4 < SRC_HEIGHT)
811 {
812 a4 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 4 * src_stride_y));
813 }
814 if(y * (uint)K0 + 5 < SRC_HEIGHT)
815 {
816 a5 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 5 * src_stride_y));
817 }
818 if(y * (uint)K0 + 6 < SRC_HEIGHT)
819 {
820 a6 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 6 * src_stride_y));
821 }
822 if(y * (uint)K0 + 7 < SRC_HEIGHT)
823 {
824 a7 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 7 * src_stride_y));
825 }
826#endif // K0 > 4
827#if K0 > 8
Gian Marco Iodice89124342018-12-19 14:17:22 +0000828 if(y * (uint)K0 + 8 < SRC_HEIGHT)
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000829 {
830 a8 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 8 * src_stride_y));
831 }
832 if(y * (uint)K0 + 9 < SRC_HEIGHT)
833 {
834 a9 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 9 * src_stride_y));
835 }
836 if(y * (uint)K0 + 10 < SRC_HEIGHT)
837 {
838 aA = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 10 * src_stride_y));
839 }
840 if(y * (uint)K0 + 11 < SRC_HEIGHT)
841 {
842 aB = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 11 * src_stride_y));
843 }
844 if(y * (uint)K0 + 12 < SRC_HEIGHT)
845 {
846 aC = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 12 * src_stride_y));
847 }
848 if(y * (uint)K0 + 13 < SRC_HEIGHT)
849 {
850 aD = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 13 * src_stride_y));
851 }
852 if(y * (uint)K0 + 14 < SRC_HEIGHT)
853 {
854 aE = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 14 * src_stride_y));
855 }
856 if(y * (uint)K0 + 15 < SRC_HEIGHT)
857 {
858 aF = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 15 * src_stride_y));
859 }
860#endif // K0 > 8
861
862 // ---------------------------Transpose the block ------------------------------
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +0000863 REPEAT_VAR_INIT_TO_CONST(N0, VEC_DATA_TYPE(DATA_TYPE, K0), res, 0); //VEC_DATA_TYPE(DATA_TYPE, K0) res0=0, res1=0, res2=0,... res(N0-1)=0;
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000864
865#if K0 == 4
866 // This part computes the following transpositions:
867 // 4x2 -> 2x4
868 // 4x4 -> 4x4
869 // 4x8 -> 8x4
870 // 4x16 -> 16x4
871 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0, a3.s0);
872 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1, a3.s1);
873#if N0 > 2
874 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2, a3.s2);
875 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3, a3.s3);
876#endif // N0 > 2
877#if N0 > 4
878 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4, a3.s4);
879 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5, a3.s5);
880 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6, a3.s6);
881 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7, a3.s7);
882#endif // N0 > 4
883#if N0 > 8
884 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8, a3.s8);
885 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9, a3.s9);
886 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA, a3.sA);
887 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB, a3.sB);
888 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC, a3.sC);
889 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD, a3.sD);
890 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE, a3.sE);
891 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF, a3.sF);
892#endif // N0 > 8
893
894#elif K0 == 8 // N0 == 3
895 // This part computes the following transpositions:
896 // 8x2 -> 2x8
897 // 8x4 -> 4x8
898 // 8x8 -> 8x8
899 // 8x16 -> 16x8
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000900 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0, a3.s0, a4.s0, a5.s0, a6.s0, a7.s0);
901 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1, a3.s1, a4.s1, a5.s1, a6.s1, a7.s1);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000902#if N0 > 2
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000903 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2, a3.s2, a4.s2, a5.s2, a6.s2, a7.s2);
904 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3, a3.s3, a4.s3, a5.s3, a6.s3, a7.s3);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000905#endif // N0 > 2
906#if N0 > 4
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000907 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4, a3.s4, a4.s4, a5.s4, a6.s4, a7.s4);
908 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5, a3.s5, a4.s5, a5.s5, a6.s5, a7.s5);
909 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6, a3.s6, a4.s6, a5.s6, a6.s6, a7.s6);
910 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7, a3.s7, a4.s7, a5.s7, a6.s7, a7.s7);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000911#endif // N0 > 4
912#if N0 > 8
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000913 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8, a3.s8, a4.s8, a5.s8, a6.s8, a7.s8);
914 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9, a3.s9, a4.s9, a5.s9, a6.s9, a7.s9);
915 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA, a3.sA, a4.sA, a5.sA, a6.sA, a7.sA);
916 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB, a3.sB, a4.sB, a5.sB, a6.sB, a7.sB);
917 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC, a3.sC, a4.sC, a5.sC, a6.sC, a7.sC);
918 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD, a3.sD, a4.sD, a5.sD, a6.sD, a7.sD);
919 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE, a3.sE, a4.sE, a5.sE, a6.sE, a7.sE);
920 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF, a3.sF, a4.sF, a5.sF, a6.sF, a7.sF);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000921#endif // N0 > 8
922
923#elif K0 == 16 // N0 == 16
924
925 // This part computes the following transpositions:
926 // 16x2 -> 2x16
927 // 16x4 -> 4x16
928 // 16x8 -> 8x16
929 // 16x16 -> 16x16
930 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0, a3.s0, a4.s0, a5.s0, a6.s0, a7.s0,
931 a8.s0, a9.s0, aA.s0, aB.s0, aC.s0, aD.s0, aE.s0, aF.s0);
932 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1, a3.s1, a4.s1, a5.s1, a6.s1, a7.s1,
933 a8.s1, a9.s1, aA.s1, aB.s1, aC.s1, aD.s1, aE.s1, aF.s1);
934#if N0 > 2
935 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2, a3.s2, a4.s2, a5.s2, a6.s2, a7.s2,
936 a8.s2, a9.s2, aA.s2, aB.s2, aC.s2, aD.s2, aE.s2, aF.s2);
937 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3, a3.s3, a4.s3, a5.s3, a6.s3, a7.s3,
938 a8.s3, a9.s3, aA.s3, aB.s3, aC.s3, aD.s3, aE.s3, aF.s3);
939#endif // N0 > 2
940#if N0 > 4
941 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4, a3.s4, a4.s4, a5.s4, a6.s4, a7.s4,
942 a8.s4, a9.s4, aA.s4, aB.s4, aC.s4, aD.s4, aE.s4, aF.s4);
943 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5, a3.s5, a4.s5, a5.s5, a6.s5, a7.s5,
944 a8.s5, a9.s5, aA.s5, aB.s5, aC.s5, aD.s5, aE.s5, aF.s5);
945 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6, a3.s6, a4.s6, a5.s6, a6.s6, a7.s6,
946 a8.s6, a9.s6, aA.s6, aB.s6, aC.s6, aD.s6, aE.s6, aF.s6);
947 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7, a3.s7, a4.s7, a5.s7, a6.s7, a7.s7,
948 a8.s7, a9.s7, aA.s7, aB.s7, aC.s7, aD.s7, aE.s7, aF.s7);
949#endif // N0 > 4
950#if N0 > 8
951 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8, a3.s8, a4.s8, a5.s8, a6.s8, a7.s8,
952 a8.s8, a9.s8, aA.s8, aB.s8, aC.s8, aD.s8, aE.s8, aF.s8);
953 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9, a3.s9, a4.s9, a5.s9, a6.s9, a7.s9,
954 a8.s9, a9.s9, aA.s9, aB.s9, aC.s9, aD.s9, aE.s9, aF.s9);
955 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA, a3.sA, a4.sA, a5.sA, a6.sA, a7.sA,
956 a8.sA, a9.sA, aA.sA, aB.sA, aC.sA, aD.sA, aE.sA, aF.sA);
957 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB, a3.sB, a4.sB, a5.sB, a6.sB, a7.sB,
958 a8.sB, a9.sB, aA.sB, aB.sB, aC.sB, aD.sB, aE.sB, aF.sB);
959 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC, a3.sC, a4.sC, a5.sC, a6.sC, a7.sC,
960 a8.sC, a9.sC, aA.sC, aB.sC, aC.sC, aD.sC, aE.sC, aF.sC);
961 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD, a3.sD, a4.sD, a5.sD, a6.sD, a7.sD,
962 a8.sD, a9.sD, aA.sD, aB.sD, aC.sD, aD.sD, aE.sD, aF.sD);
963 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE, a3.sE, a4.sE, a5.sE, a6.sE, a7.sE,
964 a8.sE, a9.sE, aA.sE, aB.sE, aC.sE, aD.sE, aE.sE, aF.sE);
965 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF, a3.sF, a4.sF, a5.sF, a6.sF, a7.sF,
966 a8.sF, a9.sF, aA.sF, aB.sF, aC.sF, aD.sF, aE.sF, aF.sF);
967#endif // N0 > 8
968
969#else // N0 == 16
970#error "Not supported N0 value"
971#endif // N0 > 2
972
973 // ---------------------------Store the output values ------------------------------
974
975 VSTORE(K0)
976 (res0, 0, (__global DATA_TYPE *)(output_ptr + 0 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
977 VSTORE(K0)
978 (res1, 0, (__global DATA_TYPE *)(output_ptr + 1 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
979#if N0 > 2
980 VSTORE(K0)
981 (res2, 0, (__global DATA_TYPE *)(output_ptr + 2 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
982 VSTORE(K0)
983 (res3, 0, (__global DATA_TYPE *)(output_ptr + 3 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
984#endif // N0 > 2
985#if N0 > 4
986 VSTORE(K0)
987 (res4, 0, (__global DATA_TYPE *)(output_ptr + 4 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
988 VSTORE(K0)
989 (res5, 0, (__global DATA_TYPE *)(output_ptr + 5 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
990 VSTORE(K0)
991 (res6, 0, (__global DATA_TYPE *)(output_ptr + 6 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
992 VSTORE(K0)
993 (res7, 0, (__global DATA_TYPE *)(output_ptr + 7 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
994#endif // N0 > 4
995#if N0 > 8
996 VSTORE(K0)
997 (res8, 0, (__global DATA_TYPE *)(output_ptr + 8 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
998 VSTORE(K0)
999 (res9, 0, (__global DATA_TYPE *)(output_ptr + 9 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
1000 VSTORE(K0)
1001 (resA, 0, (__global DATA_TYPE *)(output_ptr + 10 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
1002 VSTORE(K0)
1003 (resB, 0, (__global DATA_TYPE *)(output_ptr + 11 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
1004 VSTORE(K0)
1005 (resC, 0, (__global DATA_TYPE *)(output_ptr + 12 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
1006 VSTORE(K0)
1007 (resD, 0, (__global DATA_TYPE *)(output_ptr + 13 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
1008 VSTORE(K0)
1009 (resE, 0, (__global DATA_TYPE *)(output_ptr + 14 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
1010 VSTORE(K0)
1011 (resF, 0, (__global DATA_TYPE *)(output_ptr + 15 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
1012#endif // N0 > 8
1013
1014#undef BLOCK_SIZE
1015#undef OUTPUT_OFFSET_X
1016#undef OUTPUT_STEP_X
1017}
1018#endif // defined(TRANSPOSE)
1019#endif // defined(K0) && defined(N0) && defined(H0) && defined(DATA_TYPE) && defined(SRC_HEIGHT)
1020
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001021#if defined(M0) && defined(N0) && defined(K0) && defined(V0) && defined(H0) && defined(K) && defined(DATA_TYPE)
1022
1023#define ARM_DOT(x, y, val) \
1024 ({ \
1025 val = fma(x.s0, y.s0, val); \
1026 val = fma(x.s1, y.s1, val); \
1027 val = fma(x.s2, y.s2, val); \
1028 val = fma(x.s3, y.s3, val); \
1029 })
1030
1031#if K0 == 4
1032#define ARM_DOT_K0(a, b, c) \
1033 ({ \
1034 ARM_DOT(a, b, c); \
1035 })
1036#elif K0 == 8 // K0 == 8
1037#define ARM_DOT_K0(a, b, c) \
1038 ({ \
1039 ARM_DOT((a).s0123, (b).s0123, c); \
1040 ARM_DOT((a).s4567, (b).s4567, c); \
1041 })
1042#elif K0 == 16 // K0 == 16
1043#define ARM_DOT_K0(a, b, c) \
1044 ({ \
1045 ARM_DOT((a).s0123, (b).s0123, c); \
1046 ARM_DOT((a).s4567, (b).s4567, c); \
1047 ARM_DOT((a).s89AB, (b).s89AB, c); \
1048 ARM_DOT((a).sCDEF, (b).sCDEF, c); \
1049 })
1050#else // K0 not supported
1051#error "K0 value not supported"
1052#endif // K0 conditions
1053
1054#if N0 == 2
1055#define ARM_DOT_K0XN0(a, b, c) \
1056 ({ \
1057 ARM_DOT_K0((a), (b##0), (c.s0)); \
1058 ARM_DOT_K0((a), (b##1), (c.s1)); \
1059 })
1060#elif N0 == 4 // N0 == 4
1061#define ARM_DOT_K0XN0(a, b, c) \
1062 ({ \
1063 ARM_DOT_K0((a), (b##0), (c.s0)); \
1064 ARM_DOT_K0((a), (b##1), (c.s1)); \
1065 ARM_DOT_K0((a), (b##2), (c.s2)); \
1066 ARM_DOT_K0((a), (b##3), (c.s3)); \
1067 })
1068#elif N0 == 8 // N0 == 8
1069#define ARM_DOT_K0XN0(a, b, c) \
1070 ({ \
1071 ARM_DOT_K0((a), (b##0), (c.s0)); \
1072 ARM_DOT_K0((a), (b##1), (c.s1)); \
1073 ARM_DOT_K0((a), (b##2), (c.s2)); \
1074 ARM_DOT_K0((a), (b##3), (c.s3)); \
1075 ARM_DOT_K0((a), (b##4), (c.s4)); \
1076 ARM_DOT_K0((a), (b##5), (c.s5)); \
1077 ARM_DOT_K0((a), (b##6), (c.s6)); \
1078 ARM_DOT_K0((a), (b##7), (c.s7)); \
1079 })
1080#elif N0 == 16 // N0 == 16
1081#define ARM_DOT_K0XN0(a, b, c) \
1082 ({ \
1083 ARM_DOT_K0((a), (b##0), (c.s0)); \
1084 ARM_DOT_K0((a), (b##1), (c.s1)); \
1085 ARM_DOT_K0((a), (b##2), (c.s2)); \
1086 ARM_DOT_K0((a), (b##3), (c.s3)); \
1087 ARM_DOT_K0((a), (b##4), (c.s4)); \
1088 ARM_DOT_K0((a), (b##5), (c.s5)); \
1089 ARM_DOT_K0((a), (b##6), (c.s6)); \
1090 ARM_DOT_K0((a), (b##7), (c.s7)); \
1091 ARM_DOT_K0((a), (b##8), (c.s8)); \
1092 ARM_DOT_K0((a), (b##9), (c.s9)); \
1093 ARM_DOT_K0((a), (b##A), (c.sA)); \
1094 ARM_DOT_K0((a), (b##B), (c.sB)); \
1095 ARM_DOT_K0((a), (b##C), (c.sC)); \
1096 ARM_DOT_K0((a), (b##D), (c.sD)); \
1097 ARM_DOT_K0((a), (b##E), (c.sE)); \
1098 ARM_DOT_K0((a), (b##F), (c.sF)); \
1099 })
1100#else // N0 not supported
1101#error "N0 value not supported"
1102#endif // N0 conditions
1103
1104/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
1105 * The LHS matrix must be reshaped with @ref CLGEMMReshapeLHSMatrixKernel and the M0xK0 must be NOT transposed
1106 * The RHS matrix must be reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the K0xN0 must be transposed
1107 *
1108 * @note The number of columns in the RHS matrix NOT reshaped needs to be passed at compile time using -DK (i.e. -Dk=128).
1109 * @note The block's dimensions used for reshaping the LHS matrix and the RHS matrix (M0, N0 and K0) must be passed at compile time using -DM0, -DN0 and -DK0 (i.e. -DM0=4, -DN0=8, -DK0=4).
1110 * @note The number of M0xK0 vertical blocks stored on the same output row of the reshaped LHS matrix must be passed at compile time using -DV0 (i.e. -DV0=2)
1111 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (i.e. -DH0=2)
1112 * @note If the M0xK0 blocks in the reshaped LHS matrix have been interleaved, the option -DLHS_INTERLEAVE must passed at compile time.
1113 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
1114 * @note Only the following configurations of M0, N0 and K0 are currently supported:
1115 * - M0 = 2, 3, 4, 5, 6, 7, 8
1116 * - N0 = 2, 4, 8, 16
1117 * - K0 = 4, 8, 16
1118 *
1119 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
1120 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1121 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1122 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1123 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix NOT reshaped
1124 *
1125 * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: F16/F32
1126 * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes)
1127 * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1128 * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes)
1129 * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1130 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
Gian Marco Iodice49b10152018-12-14 17:13:34 +00001131 * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001132 * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes)
1133 * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1134 * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
1135 * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1136 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
Gian Marco Iodice49b10152018-12-14 17:13:34 +00001137 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001138 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1139 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1140 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1141 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1142 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1143 * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes)
1144 * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
1145 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1146 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
1147 */
1148__kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs),
1149 IMAGE_DECLARATION(rhs),
1150 IMAGE_DECLARATION(dst),
1151 uint lhs_stride_z,
1152 uint rhs_stride_z,
1153 uint dst_stride_z
1154#if defined(REINTERPRET_OUTPUT_AS_3D)
1155 ,
1156 uint dst_cross_plane_pad
1157#endif // REINTERPRET_OUTPUT_AS_3D
1158 )
1159{
1160 // Block size
1161#define LHS_BLOCK_SIZE ((K0) * (M0))
1162
1163#if defined(LHS_INTERLEAVE)
1164#define LHS_OFFSET_X (K0)
1165#define LHS_STEP_X ((K0) * (V0))
1166#define LHS_STEP_LOOP (1)
1167#else // defined(INTERLEAVE)
1168#define LHS_OFFSET_X (LHS_BLOCK_SIZE)
1169#define LHS_STEP_X (K0)
1170#define LHS_STEP_LOOP (V0)
1171#endif // defined(INTERLEAVE)
1172
1173 // Block size
1174#define RHS_BLOCK_SIZE ((K0) * (N0))
1175
1176 // RHS offset and step X
1177#if defined(RHS_INTERLEAVE)
1178#define RHS_OFFSET_X (K0)
1179#define RHS_STEP_X ((K0) * (H0))
1180#define RHS_STEP_LOOP (1)
1181#else // defined(RHS_INTERLEAVE)
1182#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
1183#define RHS_STEP_X (K0)
1184#define RHS_STEP_LOOP (H0)
1185#endif // defined(RHS_INTERLEAVE)
1186
1187 // Compute LHS matrix address
1188 __global uchar *lhs_addr = lhs_ptr + lhs_offset_first_element_in_bytes + (get_global_id(1) % V0) * (uint)LHS_OFFSET_X * sizeof(DATA_TYPE) + (get_global_id(1) / V0) * (uint)lhs_stride_y +
1189 (get_global_id(2) * lhs_stride_z);
1190
1191 // Compute RHS matrix address
1192 __global uchar *rhs_addr = rhs_ptr + rhs_offset_first_element_in_bytes + (get_global_id(0) % H0) * (uint)RHS_OFFSET_X * sizeof(DATA_TYPE) + (get_global_id(0) / (uint)H0) * rhs_stride_y;
1193
1194#if defined(MATRIX_B_DEPTH)
1195 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1196 rhs_addr += (get_global_id(2) % MATRIX_B_DEPTH) * rhs_stride_z;
1197#else // defined(MATRIX_B_DEPTH)
1198 rhs_addr += get_global_id(2) * rhs_stride_z;
1199#endif // defined(MATRIX_B_DEPTH)
1200
1201 // Initialize the accumulators
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +00001202 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(M0-1)=0;
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001203
1204 for(int i = 0; i < K; i += K0)
1205 {
1206 // Supported cases (M0, K0):
1207 // 2,4 - 2,8 - 2,16
1208 // 3,4 - 3,8 - 3,16
1209 // 4,4 - 4,8 - 4,16
1210 // 5,4 - 5,8 - 5,16
1211 // 6,4 - 6,8 - 6,16
1212 // Load values from LHS matrix
1213 VEC_DATA_TYPE(DATA_TYPE, K0)
1214 a0 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_addr + 0 * LHS_STEP_X * sizeof(DATA_TYPE)));
1215#if M0 > 1
1216 VEC_DATA_TYPE(DATA_TYPE, K0)
1217 a1 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_addr + 1 * LHS_STEP_X * sizeof(DATA_TYPE)));
1218#endif // M0 > 1
1219#if M0 > 2
1220 VEC_DATA_TYPE(DATA_TYPE, K0)
1221 a2 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_addr + 2 * LHS_STEP_X * sizeof(DATA_TYPE)));
1222#endif // M0 > 2
1223#if M0 > 3
1224 VEC_DATA_TYPE(DATA_TYPE, K0)
1225 a3 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_addr + 3 * LHS_STEP_X * sizeof(DATA_TYPE)));
1226#endif // M0 > 3
1227#if M0 > 4
1228 VEC_DATA_TYPE(DATA_TYPE, K0)
1229 a4 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_addr + 4 * LHS_STEP_X * sizeof(DATA_TYPE)));
1230#endif // M0 > 4
1231#if M0 > 5
1232 VEC_DATA_TYPE(DATA_TYPE, K0)
1233 a5 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_addr + 5 * LHS_STEP_X * sizeof(DATA_TYPE)));
1234#endif // M0 > 5
1235#if M0 > 6
1236 VEC_DATA_TYPE(DATA_TYPE, K0)
1237 a6 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_addr + 6 * LHS_STEP_X * sizeof(DATA_TYPE)));
1238#endif // M0 > 6
1239#if M0 > 7
1240 VEC_DATA_TYPE(DATA_TYPE, K0)
1241 a7 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_addr + 7 * LHS_STEP_X * sizeof(DATA_TYPE)));
1242#endif // M0 > 7
1243
1244 // Load values from RHS matrix
1245 VEC_DATA_TYPE(DATA_TYPE, K0)
1246 b0 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 0 * RHS_STEP_X * sizeof(DATA_TYPE)));
1247 VEC_DATA_TYPE(DATA_TYPE, K0)
1248 b1 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 1 * RHS_STEP_X * sizeof(DATA_TYPE)));
1249#if N0 > 2
1250 VEC_DATA_TYPE(DATA_TYPE, K0)
1251 b2 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 2 * RHS_STEP_X * sizeof(DATA_TYPE)));
1252 VEC_DATA_TYPE(DATA_TYPE, K0)
1253 b3 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 3 * RHS_STEP_X * sizeof(DATA_TYPE)));
1254#endif // N0 > 2
1255#if N0 > 4
1256 VEC_DATA_TYPE(DATA_TYPE, K0)
1257 b4 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 4 * RHS_STEP_X * sizeof(DATA_TYPE)));
1258 VEC_DATA_TYPE(DATA_TYPE, K0)
1259 b5 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 5 * RHS_STEP_X * sizeof(DATA_TYPE)));
1260 VEC_DATA_TYPE(DATA_TYPE, K0)
1261 b6 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 6 * RHS_STEP_X * sizeof(DATA_TYPE)));
1262 VEC_DATA_TYPE(DATA_TYPE, K0)
1263 b7 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 7 * RHS_STEP_X * sizeof(DATA_TYPE)));
1264#endif // N0 > 4
1265#if N0 > 8
1266 VEC_DATA_TYPE(DATA_TYPE, K0)
1267 b8 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 8 * RHS_STEP_X * sizeof(DATA_TYPE)));
1268 VEC_DATA_TYPE(DATA_TYPE, K0)
1269 b9 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 9 * RHS_STEP_X * sizeof(DATA_TYPE)));
1270 VEC_DATA_TYPE(DATA_TYPE, K0)
1271 bA = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 10 * RHS_STEP_X * sizeof(DATA_TYPE)));
1272 VEC_DATA_TYPE(DATA_TYPE, K0)
1273 bB = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 11 * RHS_STEP_X * sizeof(DATA_TYPE)));
1274 VEC_DATA_TYPE(DATA_TYPE, K0)
1275 bC = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 12 * RHS_STEP_X * sizeof(DATA_TYPE)));
1276 VEC_DATA_TYPE(DATA_TYPE, K0)
1277 bD = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 13 * RHS_STEP_X * sizeof(DATA_TYPE)));
1278 VEC_DATA_TYPE(DATA_TYPE, K0)
1279 bE = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 14 * RHS_STEP_X * sizeof(DATA_TYPE)));
1280 VEC_DATA_TYPE(DATA_TYPE, K0)
1281 bF = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 15 * RHS_STEP_X * sizeof(DATA_TYPE)));
1282#endif // N0 > 8
1283
1284 // Accumulate
1285 ARM_DOT_K0XN0(a0, b, c0);
1286#if M0 > 1
1287 ARM_DOT_K0XN0(a1, b, c1);
1288#endif // M0 > 1
1289#if M0 > 2
1290 ARM_DOT_K0XN0(a2, b, c2);
1291#endif // M0 > 2
1292#if M0 > 3
1293 ARM_DOT_K0XN0(a3, b, c3);
1294#endif // M0 > 3
1295#if M0 > 4
1296 ARM_DOT_K0XN0(a4, b, c4);
1297#endif // M0 > 4
1298#if M0 > 5
1299 ARM_DOT_K0XN0(a5, b, c5);
1300#endif // M0 > 5
1301#if M0 > 6
1302 ARM_DOT_K0XN0(a6, b, c6);
1303#endif // M0 > 6
1304#if M0 > 7
1305 ARM_DOT_K0XN0(a7, b, c7);
1306#endif // M0 > 7
1307
1308 lhs_addr += (M0 * LHS_STEP_X * LHS_STEP_LOOP) * sizeof(DATA_TYPE);
1309 rhs_addr += (N0 * RHS_STEP_X * RHS_STEP_LOOP) * sizeof(DATA_TYPE);
1310 }
1311
1312 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * dst_stride_y);
1313
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +00001314 REPEAT_VAR_INIT_TO_CONST(8, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001315
1316#if defined(REINTERPRET_OUTPUT_AS_3D)
1317 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
1318 // in order to take into account the presence of possible cross plane paddings
1319 //
1320 // | |
1321 // | plane0 |
1322 // | |
1323 // |__________________|
1324 // |******************|
1325 // | cross_plane_pad |
1326 // |******************|
1327 // | |
1328 // | plane1 |
1329 // | |
1330 // |__________________|
1331
1332 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
1333 zout0 = (0 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1334 zout0 = min((uint)(DEPTH_GEMM3D - 1), zout0);
Gian Marco Iodice49b10152018-12-14 17:13:34 +00001335 zout0 *= (dst_cross_plane_pad * dst_stride_y);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001336#if M0 > 1
1337 zout1 = (1 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1338 zout1 = min((uint)(DEPTH_GEMM3D - 1), zout1);
Gian Marco Iodice49b10152018-12-14 17:13:34 +00001339 zout1 *= (dst_cross_plane_pad * dst_stride_y);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001340#endif // M0 > 1
1341#if M0 > 2
1342 zout2 = (2 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1343 zout2 = min((uint)(DEPTH_GEMM3D - 1), zout2);
Gian Marco Iodice49b10152018-12-14 17:13:34 +00001344 zout2 *= (dst_cross_plane_pad * dst_stride_y);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001345#endif // M0 > 2
1346#if M0 > 3
1347 zout3 = (3 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1348 zout3 = min((uint)(DEPTH_GEMM3D - 1), zout3);
Gian Marco Iodice49b10152018-12-14 17:13:34 +00001349 zout3 *= (dst_cross_plane_pad * dst_stride_y);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001350#endif // M0 > 3
1351#if M0 > 4
1352 zout4 = (4 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1353 zout4 = min((uint)(DEPTH_GEMM3D - 1), zout4);
Gian Marco Iodice49b10152018-12-14 17:13:34 +00001354 zout4 *= (dst_cross_plane_pad * dst_stride_y);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001355#endif // M0 > 4
1356#if M0 > 5
1357 zout5 = (5 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1358 zout5 = min((uint)(DEPTH_GEMM3D - 1), zout5);
Gian Marco Iodice49b10152018-12-14 17:13:34 +00001359 zout5 *= (dst_cross_plane_pad * dst_stride_y);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001360#endif // M0 > 5
1361#if M0 > 6
1362 zout6 = (6 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1363 zout6 = min((uint)(DEPTH_GEMM3D - 1), zout6);
Gian Marco Iodice49b10152018-12-14 17:13:34 +00001364 zout6 *= (dst_cross_plane_pad * dst_stride_y);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001365#endif // M0 > 6
1366#if M0 > 6
1367 zout7 = (7 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1368 zout7 = min((uint)(DEPTH_GEMM3D - 1), zout7);
Gian Marco Iodice49b10152018-12-14 17:13:34 +00001369 zout7 *= (dst_cross_plane_pad * dst_stride_y);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001370#endif // M0 > 7
1371
1372 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1373 // multiply dst_stride_z by DEPTH_GEMM3D
1374 dst_addr += get_global_id(2) * dst_stride_z * DEPTH_GEMM3D;
1375
1376#else // defined(REINTERPRET_OUTPUT_AS_3D)
1377
1378 // Add offset for batched GEMM
1379 dst_addr += get_global_id(2) * dst_stride_z;
1380
1381#endif // defined(REINTERPRET_OUTPUT_AS_3D)
1382
1383 // Multiply by the weight of matrix-matrix product and store the result
1384#if defined(ALPHA)
1385 c0 = c0 * (DATA_TYPE)ALPHA;
1386#if M0 > 1
1387 c1 = c1 * (DATA_TYPE)ALPHA;
1388#endif // M0 > 1
1389#if M0 > 2
1390 c2 = c2 * (DATA_TYPE)ALPHA;
1391#endif // M0 > 2
1392#if M0 > 3
1393 c3 = c3 * (DATA_TYPE)ALPHA;
1394#endif // M0 > 3
1395#if M0 > 4
1396 c4 = c4 * (DATA_TYPE)ALPHA;
1397#endif // M0 > 4
1398#if M0 > 5
1399 c5 = c5 * (DATA_TYPE)ALPHA;
1400#endif // M0 > 5
1401#if M0 > 6
1402 c6 = c6 * (DATA_TYPE)ALPHA;
1403#endif // M0 > 5
1404#if M0 > 7
1405 c7 = c7 * (DATA_TYPE)ALPHA;
1406#endif // M0 > 7
1407#endif // defined(ALPHA)
1408
1409 // Store output block
1410 VSTORE(N0)
1411 (c0, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y + zout0));
1412#if M0 > 1
1413 VSTORE(N0)
1414 (c1, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y + zout1));
1415#endif // M0 > 1
1416#if M0 > 2
1417 VSTORE(N0)
1418 (c2, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y + zout2));
1419#endif // M0 > 2
1420#if M0 > 3
1421 VSTORE(N0)
1422 (c3, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y + zout3));
1423#endif // M0 > 3
1424#if M0 > 4
1425 VSTORE(N0)
1426 (c4, 0, (__global DATA_TYPE *)(dst_addr + 4 * dst_stride_y + zout4));
1427#endif // M0 > 4
1428#if M0 > 5
1429 VSTORE(N0)
1430 (c5, 0, (__global DATA_TYPE *)(dst_addr + 5 * dst_stride_y + zout5));
1431#endif // M0 > 5
1432#if M0 > 6
1433 VSTORE(N0)
1434 (c6, 0, (__global DATA_TYPE *)(dst_addr + 6 * dst_stride_y + zout6));
1435#endif // M0 > 6
1436#if M0 > 7
1437 VSTORE(N0)
1438 (c7, 0, (__global DATA_TYPE *)(dst_addr + 7 * dst_stride_y + zout7));
1439#endif // M0 > 7
1440
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +00001441
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001442#undef LHS_BLOCK_SIZE
1443#undef LHS_OFFSET_X
1444#undef LHS_STEP_X
1445#undef RHS_BLOCK_SIZE
1446#undef RHS_OFFSET_X
1447#undef RHS_STEP_X
1448}
1449#endif // defined(M0) && defined(N0) && defined(K0) && defined(V0) && defined(H0) && defined(K) && defined(DATA_TYPE)
1450
Gian Marco36a0a462018-01-12 10:21:40 +00001451#if defined(TRANSPOSE_W) && defined(MULT_TRANSPOSE1XW_WIDTH)
1452
Gian Marco19835e52018-01-30 13:35:54 +00001453#if ELEMENT_SIZE == 1
Gian Marco36a0a462018-01-12 10:21:40 +00001454#define DATA_TYPE uchar
Gian Marco19835e52018-01-30 13:35:54 +00001455#elif ELEMENT_SIZE == 2
1456#define DATA_TYPE ushort
1457#elif ELEMENT_SIZE == 4
1458#define DATA_TYPE uint
1459#else // ELEMENT_SIZE == 1
1460#error "Element size not supported"
1461#endif // ELEMENT_SIZE
Gian Marco36a0a462018-01-12 10:21:40 +00001462
1463/** This OpenCL kernel computes the "vector" 1xW transposition of input matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001464 *
Gian Marco19835e52018-01-30 13:35:54 +00001465 * @note The transposition width must be passed at compile time using -DTRANSPOSE_W (i.e. -DTRANSPOSE_W)
1466 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
Gian Marco36a0a462018-01-12 10:21:40 +00001467 *
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01001468 * @param[in] src_ptr Pointer to the source matrix. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001469 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
1470 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1471 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
1472 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marcoae2af742018-02-15 12:35:44 +00001473 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1474 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001475 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001476 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001477 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00001478 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001479 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00001480 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marcoae2af742018-02-15 12:35:44 +00001481 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1482 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001483 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1484 */
Gian Marcoae2af742018-02-15 12:35:44 +00001485__kernel void gemm_transpose1xW(TENSOR3D_DECLARATION(src),
1486 TENSOR3D_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001487{
1488 uint x = get_global_id(0);
1489 uint y = get_global_id(1);
Gian Marcoae2af742018-02-15 12:35:44 +00001490 uint z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001491
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001492 // Compute address for Matrix B - source
Gian Marcoae2af742018-02-15 12:35:44 +00001493 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001494
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001495 // Compute address for Matrix B transposed - destination. X and Y are swapped
Gian Marco36a0a462018-01-12 10:21:40 +00001496 uint dst_addr_in_bytes = dst_offset_first_element_in_bytes + y * TRANSPOSE_W * sizeof(DATA_TYPE) * MULT_TRANSPOSE1XW_WIDTH + (x / MULT_TRANSPOSE1XW_WIDTH) * dst_stride_y +
1497 (x % MULT_TRANSPOSE1XW_WIDTH) * TRANSPOSE_W * sizeof(DATA_TYPE);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001498
Gian Marcoae2af742018-02-15 12:35:44 +00001499 // Add offset for batched GEMM
1500 dst_addr_in_bytes += z * dst_stride_z;
1501
Gian Marco36a0a462018-01-12 10:21:40 +00001502 VEC_DATA_TYPE(DATA_TYPE, TRANSPOSE_W)
1503 b0 = VLOAD(TRANSPOSE_W)(0, (__global DATA_TYPE *)src.ptr);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001504
Gian Marco36a0a462018-01-12 10:21:40 +00001505 VSTORE(TRANSPOSE_W)
1506 (b0, 0, (__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001507}
Gian Marco36a0a462018-01-12 10:21:40 +00001508#endif // defined(TRANSPOSE_W) && defined(MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001509
Gian Marco36a0a462018-01-12 10:21:40 +00001510#if defined(MULT_INTERLEAVE4X4_HEIGHT) && defined(DATA_TYPE)
1511
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001512/** This OpenCL kernel reshapes the input matrix transposing each 4x4 block. If -DUNROLL_BLOCK is passed at compile time, the 4x4 block
1513 * will be simply unrolled.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001514 *
Gian Marco19835e52018-01-30 13:35:54 +00001515 * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float)
1516 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001517 * @note In case the input has to be reinterpreted as a 3D tensor (i.e. input of convolution layer 1x1), the following information must be passed at compile time:
1518 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
1519 * -# HEIGHT_GEMM3D: The height of the input in case it has to be reinterpreted as a 3D tensor.
1520 * -# DEPTH_GEMM3D: The depth of the input in case it has to be reinterpreted as a 3D tensor
1521 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
Gian Marco19835e52018-01-30 13:35:54 +00001522 *
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01001523 * @param[in] src_ptr Pointer to the source matrix. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001524 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
1525 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1526 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
1527 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marcoae2af742018-02-15 12:35:44 +00001528 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1529 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001530 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001531 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001532 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1533 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1534 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1535 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marcoae2af742018-02-15 12:35:44 +00001536 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1537 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001538 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001539 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001540 */
Gian Marcoae2af742018-02-15 12:35:44 +00001541__kernel void gemm_interleave4x4(TENSOR3D_DECLARATION(src),
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001542 TENSOR3D_DECLARATION(dst)
1543#if defined(REINTERPRET_INPUT_AS_3D)
1544 ,
1545 uint cross_plane_pad
1546#endif // REINTERPRET_INPUT_AS_3D
1547 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001548{
Gian Marco36a0a462018-01-12 10:21:40 +00001549 // Compute source and destination addresses
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001550 uint x = get_global_id(0);
1551 uint y = get_global_id(1);
Gian Marcoae2af742018-02-15 12:35:44 +00001552 uint z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001553
Gian Marcoae2af742018-02-15 12:35:44 +00001554 // Compute address for source tensor
1555 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001556
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001557 // Compute address for Matrix B transposed - destination. X and Y are swapped
Gian Marco36a0a462018-01-12 10:21:40 +00001558 uint dst_addr_in_bytes = dst_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) * 16 * MULT_INTERLEAVE4X4_HEIGHT + (y / MULT_INTERLEAVE4X4_HEIGHT) * dst_stride_y +
1559 (y % MULT_INTERLEAVE4X4_HEIGHT) * 4 * sizeof(DATA_TYPE);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001560
Gian Marcoae2af742018-02-15 12:35:44 +00001561 // Add offset for batched GEMM
1562 dst_addr_in_bytes += z * dst_stride_z;
1563
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001564#if defined(REINTERPRET_INPUT_AS_3D)
1565 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * 4 * sizeof(DATA_TYPE) + y * 4 * src_stride_y;
1566
1567 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
1568 // in order to take into account the presence of possible cross plane paddings
1569 //
1570 // | |
1571 // | plane0 |
1572 // | |
1573 // |__________________|
1574 // |******************|
1575 // | cross_plane_pad |
1576 // |******************|
1577 // | |
1578 // | plane1 |
1579 // | |
1580 // |__________________|
1581
1582 // The plane (zin) is calculated dividing M (y * 4) by HEIGHT_GEMM3D
1583 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(y * 4)) / (uint4)HEIGHT_GEMM3D;
1584 zin = min(DEPTH_GEMM3D - 1, zin);
1585
1586 // Add offset due to the cross plane paddings
1587 zin *= (cross_plane_pad * src_stride_y);
1588
1589 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1590 // multiply src_stride_z by DEPTH_GEMM3D
1591 input_ptr += z * src_stride_z * DEPTH_GEMM3D;
1592
1593 // Load values from Matrix A
1594 VEC_DATA_TYPE(DATA_TYPE, 4)
1595 a0 = vload4(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y + zin.s0));
1596 VEC_DATA_TYPE(DATA_TYPE, 4)
1597 a1 = vload4(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y + zin.s1));
1598 VEC_DATA_TYPE(DATA_TYPE, 4)
1599 a2 = vload4(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y + zin.s2));
1600 VEC_DATA_TYPE(DATA_TYPE, 4)
1601 a3 = vload4(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y + zin.s3));
1602#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +00001603 __global uchar *input_ptr = src.ptr;
1604
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001605 // Load values from Matrix A
Gian Marco36a0a462018-01-12 10:21:40 +00001606 VEC_DATA_TYPE(DATA_TYPE, 4)
Gian Marcoae2af742018-02-15 12:35:44 +00001607 a0 = vload4(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y));
Gian Marco36a0a462018-01-12 10:21:40 +00001608 VEC_DATA_TYPE(DATA_TYPE, 4)
Gian Marcoae2af742018-02-15 12:35:44 +00001609 a1 = vload4(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y));
Gian Marco36a0a462018-01-12 10:21:40 +00001610 VEC_DATA_TYPE(DATA_TYPE, 4)
Gian Marcoae2af742018-02-15 12:35:44 +00001611 a2 = vload4(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y));
Gian Marco36a0a462018-01-12 10:21:40 +00001612 VEC_DATA_TYPE(DATA_TYPE, 4)
Gian Marcoae2af742018-02-15 12:35:44 +00001613 a3 = vload4(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y));
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001614#endif // defined(REINTERPRET_INPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001615
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001616#if defined(UNROLL_BLOCK)
1617 vstore4(a0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 0 * MULT_INTERLEAVE4X4_HEIGHT));
1618 vstore4(a1, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 4 * MULT_INTERLEAVE4X4_HEIGHT));
1619 vstore4(a2, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 8 * MULT_INTERLEAVE4X4_HEIGHT));
1620 vstore4(a3, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 12 * MULT_INTERLEAVE4X4_HEIGHT));
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +00001621#else // defined(UNROLL_BLOCK)
Gian Marco36a0a462018-01-12 10:21:40 +00001622 VEC_DATA_TYPE(DATA_TYPE, 4)
1623 val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s0, a1.s0, a2.s0, a3.s0);
1624 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 0 * MULT_INTERLEAVE4X4_HEIGHT));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001625
Gian Marco36a0a462018-01-12 10:21:40 +00001626 val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s1, a1.s1, a2.s1, a3.s1);
1627 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 4 * MULT_INTERLEAVE4X4_HEIGHT));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001628
Gian Marco36a0a462018-01-12 10:21:40 +00001629 val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s2, a1.s2, a2.s2, a3.s2);
1630 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 8 * MULT_INTERLEAVE4X4_HEIGHT));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001631
Gian Marco36a0a462018-01-12 10:21:40 +00001632 val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s3, a1.s3, a2.s3, a3.s3);
1633 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 12 * MULT_INTERLEAVE4X4_HEIGHT));
Gian Marco Iodice4b908652018-10-18 10:21:02 +01001634#endif // defined(UNROLL_BLOCK)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001635}
Gian Marco36a0a462018-01-12 10:21:40 +00001636#endif // defined(MULT_INTERLEAVE4X4_HEIGHT) && defined(DATA_TYPE)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001637
Gian Marco36a0a462018-01-12 10:21:40 +00001638#if defined(COLS_B) && defined(MULT_TRANSPOSE1XW_WIDTH) && defined(MULT_INTERLEAVE4X4_HEIGHT)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001639/** This OpenCL kernel is optimised for Midgard. It computes the matrix multiplication between matrix A (src0) and matrix B (src1)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001640 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_32bit and @ref gemm_transpose1x4 before running the matrix multiplication
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001641 *
Gian Marco19835e52018-01-30 13:35:54 +00001642 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
1643 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
1644 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001645 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
1646 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001647 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001648 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
1649 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1650 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1651 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1652 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
1653 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001654 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
1655 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1656 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1657 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1658 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1659 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001660 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001661 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1662 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1663 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1664 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1665 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001666 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001667 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00001668 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001669 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00001670 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001671 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001672 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
1673 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
1674 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001675 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001676 */
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01001677__kernel void gemm_mm_interleaved_transposed_f32(IMAGE_DECLARATION(src0),
1678 IMAGE_DECLARATION(src1),
1679 IMAGE_DECLARATION(dst),
1680 uint src0_stride_z,
1681 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001682 uint dst_stride_z
1683#if defined(REINTERPRET_OUTPUT_AS_3D)
1684 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001685 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001686#endif // REINTERPRET_OUTPUT_AS_3D
1687 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001688{
Gian Marco36a0a462018-01-12 10:21:40 +00001689 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
1690 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +00001691 int z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001692
Gian Marco36a0a462018-01-12 10:21:40 +00001693 // Offset
1694 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
1695 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 4;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001696
Gian Marco36a0a462018-01-12 10:21:40 +00001697 // src_addr_a = address of matrix A
1698 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001699 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
1700 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
1701
1702#if defined(MATRIX_B_DEPTH)
1703 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1704 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
1705#else // defined(MATRIX_B_DEPTH)
1706 src1_addr_in_bytes += z * src1_stride_z;
1707#endif // defined(MATRIX_B_DEPTH)
1708
1709 __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes);
1710 __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001711
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001712 // Compute end row address for matrix B
Gian Marco36a0a462018-01-12 10:21:40 +00001713 __global float *src_end_addr_b = src_addr_b + COLS_B;
1714
1715 src_addr_a += offset_row_a;
1716 src_addr_b += offset_row_b;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001717
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001718 // Reset accumulators
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001719 float4 c00 = 0.0f;
1720 float4 c10 = 0.0f;
1721 float4 c20 = 0.0f;
1722 float4 c30 = 0.0f;
1723
Gian Marco36a0a462018-01-12 10:21:40 +00001724 for(; src_addr_b <= (src_end_addr_b - (int)(8 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001725 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001726 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00001727 float4 a0 = vload4(0, src_addr_a);
1728 float4 b0 = vload4(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001729
1730 c00 += (float4)a0.s0 * b0;
1731 c10 += (float4)a0.s1 * b0;
1732 c20 += (float4)a0.s2 * b0;
1733 c30 += (float4)a0.s3 * b0;
1734
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001735 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00001736 a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
1737 b0 = vload4(0, src_addr_b + 4 * MULT_TRANSPOSE1XW_WIDTH);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001738
1739 c00 += (float4)a0.s0 * b0;
1740 c10 += (float4)a0.s1 * b0;
1741 c20 += (float4)a0.s2 * b0;
1742 c30 += (float4)a0.s3 * b0;
1743 }
1744
Gian Marco36a0a462018-01-12 10:21:40 +00001745 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001746 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001747 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00001748 float4 a0 = vload4(0, src_addr_a);
1749 float4 b0 = vload4(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001750
1751 c00 += (float4)a0.s0 * b0;
1752 c10 += (float4)a0.s1 * b0;
1753 c20 += (float4)a0.s2 * b0;
1754 c30 += (float4)a0.s3 * b0;
1755 }
1756
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001757 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001758 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1759
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001760#if defined(ALPHA)
1761 // Multiply by the weight of matrix product
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001762 c00 = c00 * (float4)ALPHA;
1763 c10 = c10 * (float4)ALPHA;
1764 c20 = c20 * (float4)ALPHA;
1765 c30 = c30 * (float4)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001766#endif // defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001767
Gian Marcoae2af742018-02-15 12:35:44 +00001768 // Compute dst address
1769 __global uchar *dst_addr = offset(&dst, 0, 0);
1770
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001771#if defined(REINTERPRET_OUTPUT_AS_3D)
1772 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001773 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001774 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001775 // | |
1776 // | plane0 |
1777 // | |
1778 // |__________________|
1779 // |******************|
1780 // | cross_plane_pad |
1781 // |******************|
1782 // | |
1783 // | plane1 |
1784 // | |
1785 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001786
1787 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
1788 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
1789 zout = min(DEPTH_GEMM3D - 1, zout);
1790
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001791 // Add offset due to the cross plane paddings
1792 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001793
1794 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1795 // multiply dst_stride_z by DEPTH_GEMM3D
1796 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1797
1798 // Store 4x4 block
1799 vstore4(c00, 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
1800 vstore4(c10, 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
1801 vstore4(c20, 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
1802 vstore4(c30, 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
1803
1804#else // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +00001805 // Add offset for batched GEMM
1806 dst_addr += z * dst_stride_z;
1807
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001808 // Store 4x4 block
Gian Marcoae2af742018-02-15 12:35:44 +00001809 vstore4(c00, 0, (__global float *)(dst_addr + 0 * dst_stride_y));
1810 vstore4(c10, 0, (__global float *)(dst_addr + 1 * dst_stride_y));
1811 vstore4(c20, 0, (__global float *)(dst_addr + 2 * dst_stride_y));
1812 vstore4(c30, 0, (__global float *)(dst_addr + 3 * dst_stride_y));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001813#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001814}
1815
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001816/** This OpenCL kernel is optimized for Bifrost. It computes the matrix multiplication between matrix A (src0) and matrix B (src1)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001817 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_32bit and @ref gemm_transpose1x4 before running the matrix multiplication
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001818 *
Gian Marco19835e52018-01-30 13:35:54 +00001819 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
1820 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
1821 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001822 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
1823 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
1824 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001825 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001826 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
1827 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1828 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1829 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1830 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
1831 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001832 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
1833 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1834 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1835 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1836 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1837 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001838 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001839 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1840 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1841 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1842 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1843 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001844 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001845 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00001846 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001847 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00001848 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001849 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001850 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
1851 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
1852 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001853 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001854 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001855__kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0),
1856 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +00001857 IMAGE_DECLARATION(dst),
1858 uint src0_stride_z,
1859 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001860 uint dst_stride_z
1861#if defined(REINTERPRET_OUTPUT_AS_3D)
1862 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001863 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001864#endif // REINTERPRET_OUTPUT_AS_3D
1865 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001866{
Gian Marco36a0a462018-01-12 10:21:40 +00001867 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
1868 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +00001869 int z = get_global_id(2);
Gian Marco36a0a462018-01-12 10:21:40 +00001870
1871 // Offset
1872 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
1873 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 4;
1874
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001875 // src_addr_a = address of matrix A
1876 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001877 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
1878 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
1879
1880#if defined(MATRIX_B_DEPTH)
1881 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1882 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
1883#else // defined(MATRIX_B_DEPTH)
1884 src1_addr_in_bytes += z * src1_stride_z;
1885#endif // defined(MATRIX_B_DEPTH)
1886
1887 __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes);
1888 __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001889
Gian Marco36a0a462018-01-12 10:21:40 +00001890 src_addr_a += offset_row_a;
1891 src_addr_b += offset_row_b;
1892
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001893 // Reset accumulators
1894 float c00 = 0.0f;
1895 float c01 = 0.0f;
1896 float c02 = 0.0f;
1897 float c03 = 0.0f;
1898 float c10 = 0.0f;
1899 float c11 = 0.0f;
1900 float c12 = 0.0f;
1901 float c13 = 0.0f;
1902 float c20 = 0.0f;
1903 float c21 = 0.0f;
1904 float c22 = 0.0f;
1905 float c23 = 0.0f;
1906 float c30 = 0.0f;
1907 float c31 = 0.0f;
1908 float c32 = 0.0f;
1909 float c33 = 0.0f;
1910
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001911#define COLS_MTX_B (COLS_B / (4 * MULT_TRANSPOSE1XW_WIDTH))
1912
1913 int i = 0;
1914 for(; i <= (int)(COLS_MTX_B - 4); i += 4)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001915 {
1916 // Load values from matrix A (interleaved) and matrix B (transposed)
1917 float4 a0 = vload4(0, src_addr_a);
1918 float4 b0 = vload4(0, src_addr_b);
1919
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001920 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
1921 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001922
1923 c00 = fma(a0.s0, b0.s0, c00);
1924 c01 = fma(a0.s0, b0.s1, c01);
1925 c02 = fma(a0.s0, b0.s2, c02);
1926 c03 = fma(a0.s0, b0.s3, c03);
1927
1928 c10 = fma(a0.s1, b0.s0, c10);
1929 c11 = fma(a0.s1, b0.s1, c11);
1930 c12 = fma(a0.s1, b0.s2, c12);
1931 c13 = fma(a0.s1, b0.s3, c13);
1932
1933 c20 = fma(a0.s2, b0.s0, c20);
1934 c21 = fma(a0.s2, b0.s1, c21);
1935 c22 = fma(a0.s2, b0.s2, c22);
1936 c23 = fma(a0.s2, b0.s3, c23);
1937
1938 c30 = fma(a0.s3, b0.s0, c30);
1939 c31 = fma(a0.s3, b0.s1, c31);
1940 c32 = fma(a0.s3, b0.s2, c32);
1941 c33 = fma(a0.s3, b0.s3, c33);
1942
1943 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001944 a0 = vload4(0, src_addr_a);
1945 b0 = vload4(0, src_addr_b);
1946
1947 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
1948 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001949
1950 c00 = fma(a0.s0, b0.s0, c00);
1951 c01 = fma(a0.s0, b0.s1, c01);
1952 c02 = fma(a0.s0, b0.s2, c02);
1953 c03 = fma(a0.s0, b0.s3, c03);
1954
1955 c10 = fma(a0.s1, b0.s0, c10);
1956 c11 = fma(a0.s1, b0.s1, c11);
1957 c12 = fma(a0.s1, b0.s2, c12);
1958 c13 = fma(a0.s1, b0.s3, c13);
1959
1960 c20 = fma(a0.s2, b0.s0, c20);
1961 c21 = fma(a0.s2, b0.s1, c21);
1962 c22 = fma(a0.s2, b0.s2, c22);
1963 c23 = fma(a0.s2, b0.s3, c23);
1964
1965 c30 = fma(a0.s3, b0.s0, c30);
1966 c31 = fma(a0.s3, b0.s1, c31);
1967 c32 = fma(a0.s3, b0.s2, c32);
1968 c33 = fma(a0.s3, b0.s3, c33);
1969
1970 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01001971 a0 = vload4(0, src_addr_a);
1972 b0 = vload4(0, src_addr_b);
1973
1974 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
1975 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
1976
1977 c00 = fma(a0.s0, b0.s0, c00);
1978 c01 = fma(a0.s0, b0.s1, c01);
1979 c02 = fma(a0.s0, b0.s2, c02);
1980 c03 = fma(a0.s0, b0.s3, c03);
1981
1982 c10 = fma(a0.s1, b0.s0, c10);
1983 c11 = fma(a0.s1, b0.s1, c11);
1984 c12 = fma(a0.s1, b0.s2, c12);
1985 c13 = fma(a0.s1, b0.s3, c13);
1986
1987 c20 = fma(a0.s2, b0.s0, c20);
1988 c21 = fma(a0.s2, b0.s1, c21);
1989 c22 = fma(a0.s2, b0.s2, c22);
1990 c23 = fma(a0.s2, b0.s3, c23);
1991
1992 c30 = fma(a0.s3, b0.s0, c30);
1993 c31 = fma(a0.s3, b0.s1, c31);
1994 c32 = fma(a0.s3, b0.s2, c32);
1995 c33 = fma(a0.s3, b0.s3, c33);
1996
1997 // Load values from matrix A (interleaved) and matrix B (transposed)
1998 a0 = vload4(0, src_addr_a);
1999 b0 = vload4(0, src_addr_b);
2000
2001 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
2002 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002003
2004 c00 = fma(a0.s0, b0.s0, c00);
2005 c01 = fma(a0.s0, b0.s1, c01);
2006 c02 = fma(a0.s0, b0.s2, c02);
2007 c03 = fma(a0.s0, b0.s3, c03);
2008
2009 c10 = fma(a0.s1, b0.s0, c10);
2010 c11 = fma(a0.s1, b0.s1, c11);
2011 c12 = fma(a0.s1, b0.s2, c12);
2012 c13 = fma(a0.s1, b0.s3, c13);
2013
2014 c20 = fma(a0.s2, b0.s0, c20);
2015 c21 = fma(a0.s2, b0.s1, c21);
2016 c22 = fma(a0.s2, b0.s2, c22);
2017 c23 = fma(a0.s2, b0.s3, c23);
2018
2019 c30 = fma(a0.s3, b0.s0, c30);
2020 c31 = fma(a0.s3, b0.s1, c31);
2021 c32 = fma(a0.s3, b0.s2, c32);
2022 c33 = fma(a0.s3, b0.s3, c33);
2023 }
2024
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002025 for(; i < (int)(COLS_MTX_B); ++i)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002026 {
2027 // Load values from matrix A (interleaved) and matrix B (transposed)
2028 float4 a0 = vload4(0, src_addr_a);
2029 float4 b0 = vload4(0, src_addr_b);
2030
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002031 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
2032 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
2033
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002034 c00 = fma(a0.s0, b0.s0, c00);
2035 c01 = fma(a0.s0, b0.s1, c01);
2036 c02 = fma(a0.s0, b0.s2, c02);
2037 c03 = fma(a0.s0, b0.s3, c03);
2038
2039 c10 = fma(a0.s1, b0.s0, c10);
2040 c11 = fma(a0.s1, b0.s1, c11);
2041 c12 = fma(a0.s1, b0.s2, c12);
2042 c13 = fma(a0.s1, b0.s3, c13);
2043
2044 c20 = fma(a0.s2, b0.s0, c20);
2045 c21 = fma(a0.s2, b0.s1, c21);
2046 c22 = fma(a0.s2, b0.s2, c22);
2047 c23 = fma(a0.s2, b0.s3, c23);
2048
2049 c30 = fma(a0.s3, b0.s0, c30);
2050 c31 = fma(a0.s3, b0.s1, c31);
2051 c32 = fma(a0.s3, b0.s2, c32);
2052 c33 = fma(a0.s3, b0.s3, c33);
2053 }
2054
2055 // Compute destination address
2056 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2057
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002058#if defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002059 // Multiply by the weight of matrix product
2060 c00 = c00 * ALPHA;
2061 c01 = c01 * ALPHA;
2062 c02 = c02 * ALPHA;
2063 c03 = c03 * ALPHA;
2064 c10 = c10 * ALPHA;
2065 c11 = c11 * ALPHA;
2066 c12 = c12 * ALPHA;
2067 c13 = c13 * ALPHA;
2068 c20 = c20 * ALPHA;
2069 c21 = c21 * ALPHA;
2070 c22 = c22 * ALPHA;
2071 c23 = c23 * ALPHA;
2072 c30 = c30 * ALPHA;
2073 c31 = c31 * ALPHA;
2074 c32 = c32 * ALPHA;
2075 c33 = c33 * ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002076#endif // defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002077
Gian Marcoae2af742018-02-15 12:35:44 +00002078 // Compute dst address
2079 __global uchar *dst_addr = offset(&dst, 0, 0);
2080
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002081#if defined(REINTERPRET_OUTPUT_AS_3D)
2082 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002083 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002084 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002085 // | |
2086 // | plane0 |
2087 // | |
2088 // |__________________|
2089 // |******************|
2090 // | cross_plane_pad |
2091 // |******************|
2092 // | |
2093 // | plane1 |
2094 // | |
2095 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002096
2097 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
2098 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
2099 zout = min(DEPTH_GEMM3D - 1, zout);
2100
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002101 // Add offset due to the cross plane paddings
2102 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002103
2104 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2105 // multiply dst_stride_z by DEPTH_GEMM3D
2106 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
2107
2108 // Store 4x4 block
2109 vstore4((float4)(c00, c01, c02, c03), 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
2110 vstore4((float4)(c10, c11, c12, c13), 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
2111 vstore4((float4)(c20, c21, c22, c23), 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
2112 vstore4((float4)(c30, c31, c32, c33), 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
2113
2114#else // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +00002115 // Add offset for batched GEMM
2116 dst_addr += z * dst_stride_z;
2117
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002118 // Store 4x4 block
Gian Marcoae2af742018-02-15 12:35:44 +00002119 vstore4((float4)(c00, c01, c02, c03), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
2120 vstore4((float4)(c10, c11, c12, c13), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
2121 vstore4((float4)(c20, c21, c22, c23), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
2122 vstore4((float4)(c30, c31, c32, c33), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002123#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002124}
2125
Georgios Pinitas84225582018-05-14 12:00:05 +01002126// Undefine local defines
2127#undef COLS_MTX_B
2128
Matthew Bentham6f31f8c2017-10-27 11:50:06 +01002129#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002130/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002131 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002132 *
Gian Marco19835e52018-01-30 13:35:54 +00002133 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
2134 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
2135 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002136 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
2137 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002138 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002139 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
2140 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2141 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2142 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2143 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
2144 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002145 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
2146 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
2147 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2148 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
2149 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2150 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002151 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002152 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
2153 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2154 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
2155 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2156 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002157 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002158 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00002159 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002160 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00002161 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002162 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002163 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
2164 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
2165 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002166 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002167 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002168__kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0),
2169 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +00002170 IMAGE_DECLARATION(dst),
2171 uint src0_stride_z,
2172 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002173 uint dst_stride_z
2174#if defined(REINTERPRET_OUTPUT_AS_3D)
2175 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002176 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002177#endif // REINTERPRET_OUTPUT_AS_3D
2178 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002179{
Gian Marco36a0a462018-01-12 10:21:40 +00002180 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
2181 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +00002182 int z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002183
Gian Marco36a0a462018-01-12 10:21:40 +00002184 // Offset
2185 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
2186 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002187
Gian Marco36a0a462018-01-12 10:21:40 +00002188 // src_addr_a = address of matrix A
2189 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002190 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
2191 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
2192
2193#if defined(MATRIX_B_DEPTH)
2194 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2195 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
2196#else // defined(MATRIX_B_DEPTH)
2197 src1_addr_in_bytes += z * src1_stride_z;
2198#endif // defined(MATRIX_B_DEPTH)
2199
2200 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
2201 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002202
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002203 // Compute end row address for matrix B
Gian Marco36a0a462018-01-12 10:21:40 +00002204 __global half *src_end_addr_b = src_addr_b + COLS_B;
2205
2206 src_addr_a += offset_row_a;
2207 src_addr_b += offset_row_b;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002208
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002209 // Reset accumulators
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002210 half8 c00 = 0.0f;
2211 half8 c10 = 0.0f;
2212 half8 c20 = 0.0f;
2213 half8 c30 = 0.0f;
2214
Gian Marco36a0a462018-01-12 10:21:40 +00002215 for(; src_addr_b <= (src_end_addr_b - (int)(16 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 16 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002216 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002217 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00002218 half4 a0 = vload4(0, src_addr_a);
2219 half8 b0 = vload8(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002220
2221 c00 += (half8)a0.s0 * b0;
2222 c10 += (half8)a0.s1 * b0;
2223 c20 += (half8)a0.s2 * b0;
2224 c30 += (half8)a0.s3 * b0;
2225
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002226 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00002227 a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
2228 b0 = vload8(0, src_addr_b + 8 * MULT_TRANSPOSE1XW_WIDTH);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002229
2230 c00 += (half8)a0.s0 * b0;
2231 c10 += (half8)a0.s1 * b0;
2232 c20 += (half8)a0.s2 * b0;
2233 c30 += (half8)a0.s3 * b0;
2234 }
2235
Gian Marco36a0a462018-01-12 10:21:40 +00002236 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002237 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002238 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00002239 half4 a0 = vload4(0, src_addr_a);
2240 half8 b0 = vload8(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002241
2242 c00 += (half8)a0.s0 * b0;
2243 c10 += (half8)a0.s1 * b0;
2244 c20 += (half8)a0.s2 * b0;
2245 c30 += (half8)a0.s3 * b0;
2246 }
2247
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002248 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002249 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2250
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002251#if defined(ALPHA)
2252 // Multiply by the weight of matrix product
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002253 c00 = c00 * (half8)ALPHA;
2254 c10 = c10 * (half8)ALPHA;
2255 c20 = c20 * (half8)ALPHA;
2256 c30 = c30 * (half8)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002257#endif // defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002258
Gian Marcoae2af742018-02-15 12:35:44 +00002259 // Compute dst address
2260 __global uchar *dst_addr = offset(&dst, 0, 0);
2261
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002262#if defined(REINTERPRET_OUTPUT_AS_3D)
2263 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002264 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002265 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002266 // | |
2267 // | plane0 |
2268 // | |
2269 // |__________________|
2270 // |******************|
2271 // | cross_plane_pad |
2272 // |******************|
2273 // | |
2274 // | plane1 |
2275 // | |
2276 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002277
2278 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
2279 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
2280 zout = min(DEPTH_GEMM3D - 1, zout);
2281
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002282 // Add offset due to the cross plane paddings
2283 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002284
2285 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2286 // multiply dst_stride_z by DEPTH_GEMM3D
2287 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
2288
2289 // Store 4x8 block
2290 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
2291 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
2292 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
2293 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
2294
2295#else // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +00002296 // Add offset for batched GEMM
2297 dst_addr += z * dst_stride_z;
2298
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002299 // Store 4x8 block
Gian Marcoae2af742018-02-15 12:35:44 +00002300 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
2301 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
2302 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
2303 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002304#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002305}
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01002306
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00002307/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1) while accumulating the result in a 32 floating point variable.
2308 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
2309 *
2310 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
2311 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
2312 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
2313 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
2314 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
2315 *
2316 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
2317 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2318 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2319 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2320 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
2321 *
2322 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
2323 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
2324 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2325 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
2326 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2327 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
2328 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
2329 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
2330 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2331 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
2332 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2333 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
2334 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
2335 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2336 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2337 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2338 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2339 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2340 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
2341 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
2342 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2343 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
2344 */
2345__kernel void gemm_mm_interleaved_transposed_f16_acc32(IMAGE_DECLARATION(src0),
2346 IMAGE_DECLARATION(src1),
2347 IMAGE_DECLARATION(dst),
2348 uint src0_stride_z,
2349 uint src1_stride_z,
2350 uint dst_stride_z
2351#if defined(REINTERPRET_OUTPUT_AS_3D)
2352 ,
2353 uint cross_plane_pad
2354#endif // REINTERPRET_OUTPUT_AS_3D
2355 )
2356{
2357 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
2358 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
2359 int z = get_global_id(2);
2360
2361 // Offset
2362 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
2363 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
2364
2365 // src_addr_a = address of matrix A
2366 // src_addr_b = address of matrix B
2367 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
2368 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
2369
2370#if defined(MATRIX_B_DEPTH)
2371 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2372 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
2373#else // defined(MATRIX_B_DEPTH)
2374 src1_addr_in_bytes += z * src1_stride_z;
2375#endif // defined(MATRIX_B_DEPTH)
2376
2377 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
2378 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
2379
2380 // Compute end row address for matrix B
2381 __global half *src_end_addr_b = src_addr_b + COLS_B;
2382
2383 src_addr_a += offset_row_a;
2384 src_addr_b += offset_row_b;
2385
2386 // Reset accumulators
2387 float8 c00 = 0.0f;
2388 float8 c10 = 0.0f;
2389 float8 c20 = 0.0f;
2390 float8 c30 = 0.0f;
2391
2392 for(; src_addr_b <= (src_end_addr_b - (int)(16 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 16 * MULT_TRANSPOSE1XW_WIDTH)
2393 {
2394 // Load values from matrix A (interleaved) and matrix B (transposed)
2395 float4 a0 = convert_float4(vload4(0, src_addr_a));
2396 float8 b0 = convert_float8(vload8(0, src_addr_b));
2397
2398 c00 += (float8)a0.s0 * b0;
2399 c10 += (float8)a0.s1 * b0;
2400 c20 += (float8)a0.s2 * b0;
2401 c30 += (float8)a0.s3 * b0;
2402
2403 // Load values from matrix A (interleaved) and matrix B (transposed)
2404 a0 = convert_float4(vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT));
2405 b0 = convert_float8(vload8(0, src_addr_b + 8 * MULT_TRANSPOSE1XW_WIDTH));
2406
2407 c00 += (float8)a0.s0 * b0;
2408 c10 += (float8)a0.s1 * b0;
2409 c20 += (float8)a0.s2 * b0;
2410 c30 += (float8)a0.s3 * b0;
2411 }
2412
2413 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
2414 {
2415 // Load values from matrix A (interleaved) and matrix B (transposed)
2416 float4 a0 = convert_float4(vload4(0, src_addr_a));
2417 float8 b0 = convert_float8(vload8(0, src_addr_b));
2418
2419 c00 += (float8)a0.s0 * b0;
2420 c10 += (float8)a0.s1 * b0;
2421 c20 += (float8)a0.s2 * b0;
2422 c30 += (float8)a0.s3 * b0;
2423 }
2424
2425 // Compute destination address
2426 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2427
2428#if defined(ALPHA)
2429 // Multiply by the weight of matrix product
2430 c00 = c00 * (float8)ALPHA;
2431 c10 = c10 * (float8)ALPHA;
2432 c20 = c20 * (float8)ALPHA;
2433 c30 = c30 * (float8)ALPHA;
2434#endif // defined(ALPHA)
2435
2436 // Compute dst address
2437 __global uchar *dst_addr = offset(&dst, 0, 0);
2438
2439#if defined(REINTERPRET_OUTPUT_AS_3D)
2440 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
2441 // in order to take into account the presence of possible cross plane paddings
2442 //
2443 // | |
2444 // | plane0 |
2445 // | |
2446 // |__________________|
2447 // |******************|
2448 // | cross_plane_pad |
2449 // |******************|
2450 // | |
2451 // | plane1 |
2452 // | |
2453 // |__________________|
2454
2455 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
2456 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
2457 zout = min(DEPTH_GEMM3D - 1, zout);
2458
2459 // Add offset due to the cross plane paddings
2460 zout *= (cross_plane_pad * dst_stride_y);
2461
2462 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2463 // multiply dst_stride_z by DEPTH_GEMM3D
2464 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
2465
2466 // Store 4x8 block
2467 vstore8(convert_half8(c00), 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
2468 vstore8(convert_half8(c10), 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
2469 vstore8(convert_half8(c20), 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
2470 vstore8(convert_half8(c30), 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
2471
2472#else // defined(REINTERPRET_OUTPUT_AS_3D)
2473 // Add offset for batched GEMM
2474 dst_addr += z * dst_stride_z;
2475
2476 // Store 4x8 block
2477 vstore8(convert_half8(c00), 0, (__global half *)(dst_addr + 0 * dst_stride_y));
2478 vstore8(convert_half8(c10), 0, (__global half *)(dst_addr + 1 * dst_stride_y));
2479 vstore8(convert_half8(c20), 0, (__global half *)(dst_addr + 2 * dst_stride_y));
2480 vstore8(convert_half8(c30), 0, (__global half *)(dst_addr + 3 * dst_stride_y));
2481#endif // defined(REINTERPRET_OUTPUT_AS_3D)
2482}
2483
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01002484/** This OpenCL kernel optimized for Bifrost architectures computes the matrix multiplication between matrix A (src0) and matrix B (src1)
2485 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
2486 *
2487 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
2488 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
2489 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
2490 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
2491 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
2492 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002493 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
2494 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2495 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2496 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2497 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
2498 *
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01002499 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
2500 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
2501 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2502 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
2503 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2504 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
2505 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
2506 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
2507 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2508 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
2509 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2510 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
2511 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
2512 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2513 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2514 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2515 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2516 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002517 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01002518 */
2519__kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0),
2520 IMAGE_DECLARATION(src1),
2521 IMAGE_DECLARATION(dst),
2522 uint src0_stride_z,
2523 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002524 uint dst_stride_z
2525#if defined(REINTERPRET_OUTPUT_AS_3D)
2526 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002527 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002528#endif // REINTERPRET_OUTPUT_AS_3D
2529 )
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01002530{
2531 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
2532 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
2533 int z = get_global_id(2);
2534
2535 // Offset
2536 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
2537 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
2538
2539 // src_addr_a = address of matrix A
2540 // src_addr_b = address of matrix B
2541 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
2542 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
2543
2544#if defined(MATRIX_B_DEPTH)
2545 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2546 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
2547#else // defined(MATRIX_B_DEPTH)
2548 src1_addr_in_bytes += z * src1_stride_z;
2549#endif // defined(MATRIX_B_DEPTH)
2550
2551 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
2552 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
2553
2554 // Compute end row address for matrix B
2555 __global half *src_end_addr_b = src_addr_b + COLS_B;
2556
2557 src_addr_a += offset_row_a;
2558 src_addr_b += offset_row_b;
2559
2560 // Reset accumulators
2561 half8 c00 = 0.0f;
2562 half8 c10 = 0.0f;
2563 half8 c20 = 0.0f;
2564 half8 c30 = 0.0f;
2565
2566#define COLS_MTX_B (COLS_B / (8 * MULT_TRANSPOSE1XW_WIDTH))
2567
2568 int i = 0;
2569 for(; i <= (int)(COLS_MTX_B - 4); i += 4)
2570 {
2571#if MULT_INTERLEAVE4X4_HEIGHT == 1
2572 // Load values from matrix A (interleaved) and matrix B (transposed)
2573 half8 a0 = vload8(0, src_addr_a);
2574 half8 b0 = vload8(0, src_addr_b);
2575
2576 src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT;
2577 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
2578
2579 c00 = fma((half8)a0.s0, b0, c00);
2580 c10 = fma((half8)a0.s1, b0, c10);
2581 c20 = fma((half8)a0.s2, b0, c20);
2582 c30 = fma((half8)a0.s3, b0, c30);
2583
2584 // Load values from matrix B (transposed)
2585 b0 = vload8(0, src_addr_b);
2586
2587 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
2588
2589 c00 = fma((half8)a0.s4, b0, c00);
2590 c10 = fma((half8)a0.s5, b0, c10);
2591 c20 = fma((half8)a0.s6, b0, c20);
2592 c30 = fma((half8)a0.s7, b0, c30);
2593
2594 // Load values from matrix A (interleaved) and matrix B (transposed)
2595 a0 = vload8(0, src_addr_a);
2596 b0 = vload8(0, src_addr_b);
2597
2598 src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT;
2599 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
2600
2601 c00 = fma((half8)a0.s0, b0, c00);
2602 c10 = fma((half8)a0.s1, b0, c10);
2603 c20 = fma((half8)a0.s2, b0, c20);
2604 c30 = fma((half8)a0.s3, b0, c30);
2605
2606 // Load values from matrix B (transposed)
2607 b0 = vload8(0, src_addr_b);
2608
2609 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
2610
2611 c00 = fma((half8)a0.s4, b0, c00);
2612 c10 = fma((half8)a0.s5, b0, c10);
2613 c20 = fma((half8)a0.s6, b0, c20);
2614 c30 = fma((half8)a0.s7, b0, c30);
2615#else // MULT_INTERLEAVE4X4_HEIGHT == 1
2616 // Load values from matrix A (interleaved) and matrix B (transposed)
2617 half4 a0 = vload4(0, src_addr_a);
2618 half8 b0 = vload8(0, src_addr_b);
2619
2620 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
2621 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
2622
2623 c00 = fma((half8)a0.s0, b0, c00);
2624 c10 = fma((half8)a0.s1, b0, c10);
2625 c20 = fma((half8)a0.s2, b0, c20);
2626 c30 = fma((half8)a0.s3, b0, c30);
2627
2628 // Load values from matrix A (interleaved) and matrix B (transposed)
2629 a0 = vload4(0, src_addr_a);
2630 b0 = vload8(0, src_addr_b);
2631
2632 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
2633 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
2634
2635 c00 = fma((half8)a0.s0, b0, c00);
2636 c10 = fma((half8)a0.s1, b0, c10);
2637 c20 = fma((half8)a0.s2, b0, c20);
2638 c30 = fma((half8)a0.s3, b0, c30);
2639
2640 // Load values from matrix A (interleaved) and matrix B (transposed)
2641 a0 = vload4(0, src_addr_a);
2642 b0 = vload8(0, src_addr_b);
2643
2644 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
2645 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
2646
2647 c00 = fma((half8)a0.s0, b0, c00);
2648 c10 = fma((half8)a0.s1, b0, c10);
2649 c20 = fma((half8)a0.s2, b0, c20);
2650 c30 = fma((half8)a0.s3, b0, c30);
2651
2652 // Load values from matrix A (interleaved) and matrix B (transposed)
2653 a0 = vload4(0, src_addr_a);
2654 b0 = vload8(0, src_addr_b);
2655
2656 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
2657 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
2658
2659 c00 = fma((half8)a0.s0, b0, c00);
2660 c10 = fma((half8)a0.s1, b0, c10);
2661 c20 = fma((half8)a0.s2, b0, c20);
2662 c30 = fma((half8)a0.s3, b0, c30);
2663#endif // MULT_INTERLEAVE4X4_HEIGHT == 1
2664 }
2665
2666 for(; i < (int)(COLS_MTX_B); ++i)
2667 {
2668 // Load values from matrix A (interleaved) and matrix B (transposed)
2669 half4 a0 = vload4(0, src_addr_a);
2670 half8 b0 = vload8(0, src_addr_b);
2671
2672 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
2673 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
2674
2675 c00 = fma((half8)a0.s0, b0, c00);
2676 c10 = fma((half8)a0.s1, b0, c10);
2677 c20 = fma((half8)a0.s2, b0, c20);
2678 c30 = fma((half8)a0.s3, b0, c30);
2679 }
2680
2681 // Compute destination address
2682 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2683
2684#if defined(ALPHA)
2685 // Multiply by the weight of matrix product
2686 c00 = c00 * (half8)ALPHA;
2687 c10 = c10 * (half8)ALPHA;
2688 c20 = c20 * (half8)ALPHA;
2689 c30 = c30 * (half8)ALPHA;
2690#endif // defined(ALPHA)
2691
2692 // Compute dst address
2693 __global uchar *dst_addr = offset(&dst, 0, 0);
2694
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002695#if defined(REINTERPRET_OUTPUT_AS_3D)
2696 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002697 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002698 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002699 // | |
2700 // | plane0 |
2701 // | |
2702 // |__________________|
2703 // |******************|
2704 // | cross_plane_pad |
2705 // |******************|
2706 // | |
2707 // | plane1 |
2708 // | |
2709 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002710
2711 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
2712 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
2713 zout = min(DEPTH_GEMM3D - 1, zout);
2714
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002715 // Add offset due to the cross plane paddings
2716 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002717
2718 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2719 // multiply dst_stride_z by DEPTH_GEMM3D
2720 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
2721
2722 // Store 4x8 block
2723 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
2724 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
2725 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
2726 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
2727
2728#else // defined(REINTERPRET_OUTPUT_AS_3D)
2729 // Add offset for batched GEMM
2730 dst_addr += z * dst_stride_z;
2731
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01002732 // Store 4x8 block
2733 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
2734 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
2735 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
2736 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002737#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01002738}
Georgios Pinitas84225582018-05-14 12:00:05 +01002739
2740// Undefine local defines
2741#undef COLS_MTX_B
2742
Matthew Bentham6f31f8c2017-10-27 11:50:06 +01002743#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002744
Gian Marco36a0a462018-01-12 10:21:40 +00002745#endif // defined(COLS_B) && defined(MULT_TRANSPOSE1XW_WIDTH) && defined(MULT_INTERLEAVE4X4_HEIGHT)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002746
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002747#if defined(COLS_A) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && (NUM_ELEMS_PROCESSED_PER_THREAD_Y)
2748#if defined(DATA_TYPE)
2749#define VECTOR_TYPE VEC_DATA_TYPE(DATA_TYPE, NUM_ELEMS_PROCESSED_PER_THREAD_X)
Michele Di Giorgiof6f08da2018-04-26 10:24:30 +01002750/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002751 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002752 * @note This OpenCL kernel works with floating point data types (F16/F32)
2753 * @note The floating point data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
2754 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002755 * @note The number of matrix A columns and the optional alpha's value need to be passed at compile time using -DCOLS_A and -DALPHA
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002756 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
2757 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002758 *
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002759 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
2760 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002761 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2762 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2763 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2764 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
2765 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002766 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002767 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
2768 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2769 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
2770 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2771 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002772 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002773 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
2774 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2775 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
2776 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2777 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002778 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002779 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2780 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2781 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2782 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
2783 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002784 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
2785 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
2786 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002787 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
2788 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements for the output tensor (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002789 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002790__kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
2791 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +00002792 IMAGE_DECLARATION(dst),
2793 uint src0_stride_z,
2794 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002795 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002796#if defined(REINTERPRET_INPUT_AS_3D)
2797 ,
2798 uint src_cross_plane_pad
2799#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002800#if defined(REINTERPRET_OUTPUT_AS_3D)
2801 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002802 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002803#endif // REINTERPRET_OUTPUT_AS_3D
2804 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002805{
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002806 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002807
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002808 // Compute starting address for matrix A and Matrix B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002809 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002810
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002811 // Update address for the matrix A
2812 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002813
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002814 // Update address for the matrix B
2815 src_addr.s1 += idx * sizeof(DATA_TYPE);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002816
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002817#if defined(REINTERPRET_INPUT_AS_3D)
2818 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
2819 // in order to take into account the presence of possible cross plane paddings
2820 //
2821 // | |
2822 // | plane0 |
2823 // | |
2824 // |__________________|
2825 // |******************|
2826 // | cross_plane_pad |
2827 // |******************|
2828 // | |
2829 // | plane1 |
2830 // | |
2831 // |__________________|
2832
2833 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
2834 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
2835 zin = min(DEPTH_GEMM3D - 1, zin);
2836
2837 // Add offset due to the cross plane paddings
2838 zin *= (src_cross_plane_pad * src0_stride_y);
2839
2840 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2841 // multiply src0_stride_z by DEPTH_GEMM3D
2842 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
2843
2844#else // defined(REINTERPRET_INPUT_AS_3D)
2845
Gian Marcoae2af742018-02-15 12:35:44 +00002846 // Add offset for batched GEMM
2847 src_addr.s0 += get_global_id(2) * src0_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002848
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002849#endif // defined(REINTERPRET_INPUT_AS_3D)
2850
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002851#if defined(MATRIX_B_DEPTH)
2852 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2853 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
2854#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00002855 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002856#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00002857
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002858 int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(DATA_TYPE));
2859
2860 VECTOR_TYPE acc0 = 0.0f;
2861#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2862 VECTOR_TYPE acc1 = 0.0f;
2863#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2864#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2865 VECTOR_TYPE acc2 = 0.0f;
2866#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2867#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2868 VECTOR_TYPE acc3 = 0.0f;
2869#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2870
Georgios Pinitas96880cf2017-10-20 18:52:20 +01002871 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(DATA_TYPE)); src_addr += (int2)(2 * sizeof(DATA_TYPE), 2 * src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002872 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002873#if defined(REINTERPRET_INPUT_AS_3D)
2874 // Load values from matrix A
2875 VEC_DATA_TYPE(DATA_TYPE, 2)
2876 a0 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
2877#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2878 VEC_DATA_TYPE(DATA_TYPE, 2)
2879 a1 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
2880#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2881#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2882 VEC_DATA_TYPE(DATA_TYPE, 2)
2883 a2 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
2884#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2885#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2886 VEC_DATA_TYPE(DATA_TYPE, 2)
2887 a3 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
2888#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2889#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002890 // Load values from matrix A
2891 VEC_DATA_TYPE(DATA_TYPE, 2)
2892 a0 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
2893#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2894 VEC_DATA_TYPE(DATA_TYPE, 2)
2895 a1 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2896#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2897#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2898 VEC_DATA_TYPE(DATA_TYPE, 2)
2899 a2 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2900#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2901#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2902 VEC_DATA_TYPE(DATA_TYPE, 2)
2903 a3 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2904#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002905#endif // defined(REINTERPRET_INPUT_AS_3D)
2906
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002907 // Load values from matrix B
2908 VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
2909 VECTOR_TYPE b1 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1 + src1_stride_y));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002910
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002911 // Accumulate
2912 acc0 += b0 * (VECTOR_TYPE)a0.s0;
2913 acc0 += b1 * (VECTOR_TYPE)a0.s1;
2914#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2915 acc1 += b0 * (VECTOR_TYPE)a1.s0;
2916 acc1 += b1 * (VECTOR_TYPE)a1.s1;
2917#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2918#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2919 acc2 += b0 * (VECTOR_TYPE)a2.s0;
2920 acc2 += b1 * (VECTOR_TYPE)a2.s1;
2921#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2922#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2923 acc3 += b0 * (VECTOR_TYPE)a3.s0;
2924 acc3 += b1 * (VECTOR_TYPE)a3.s1;
2925#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002926 }
2927
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002928 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(DATA_TYPE), src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002929 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002930#if defined(REINTERPRET_INPUT_AS_3D)
2931 // Load values from matrix A
2932 DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
2933#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2934 DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
2935#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2936#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2937 DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
2938#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2939#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2940 DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
2941#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2942#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002943 // Load values from matrix A
2944 DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
2945#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2946 DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2947#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2948#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2949 DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2950#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2951#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2952 DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2953#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002954#endif // defined(REINTERPRET_INPUT_AS_3D)
2955
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002956 // Load values from matrix B
2957 VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002958
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002959 // Accumulate
2960 acc0 += b0 * (VECTOR_TYPE)a0;
2961#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2962 acc1 += b0 * (VECTOR_TYPE)a1;
2963#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2964#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2965 acc2 += b0 * (VECTOR_TYPE)a2;
2966#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2967#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2968 acc3 += b0 * (VECTOR_TYPE)a3;
2969#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002970 }
2971
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002972 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002973 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2974
Gian Marcoae2af742018-02-15 12:35:44 +00002975 // Compute dst address
2976 __global uchar *dst_addr = offset(&dst, 0, 0);
2977
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002978 // Multiply by the weight of matrix-matrix product and store the result
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002979#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002980 acc0 = acc0 * (VECTOR_TYPE)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002981#endif // defined(ALPHA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002982#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
2983 acc1 = acc1 * (VECTOR_TYPE)ALPHA;
2984#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
2985#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
2986 acc2 = acc2 * (VECTOR_TYPE)ALPHA;
2987#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
2988#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
2989 acc3 = acc3 * (VECTOR_TYPE)ALPHA;
2990#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
2991
2992 int z = get_global_id(2);
2993
2994#if defined(REINTERPRET_OUTPUT_AS_3D)
2995 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002996 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002997 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002998 // | |
2999 // | plane0 |
3000 // | |
3001 // |__________________|
3002 // |******************|
3003 // | cross_plane_pad |
3004 // |******************|
3005 // | |
3006 // | plane1 |
3007 // | |
3008 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003009
3010 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
3011 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
3012 zout = min(DEPTH_GEMM3D - 1, zout);
3013
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003014 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003015 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003016
3017 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3018 // multiply dst_stride_z by DEPTH_GEMM3D
3019 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
3020
3021 // Store output block
3022 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
3023 (acc0, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y + zout.s0));
3024#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3025 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
3026 (acc1, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y + zout.s1));
3027#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3028#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3029 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
3030 (acc2, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y + zout.s2));
3031#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3032#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3033 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
3034 (acc3, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y + zout.s3));
3035#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3036
3037#else // defined(REINTERPRET_OUTPUT_AS_3D)
3038 // Add offset for batched GEMM
3039 dst_addr += z * dst_stride_z;
3040
3041 // Store output block
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003042 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00003043 (acc0, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003044#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003045 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00003046 (acc1, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003047#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3048#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003049 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00003050 (acc2, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003051#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3052#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003053 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00003054 (acc3, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003055#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003056#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003057}
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003058#endif // defined(DATA_TYPE)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003059
Michele Di Giorgiof6f08da2018-04-26 10:24:30 +01003060/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003061 *
3062 * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
3063 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
3064 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
3065 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
3066 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003067 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
3068 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003069 *
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003070 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
3071 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003072 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
3073 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
3074 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
3075 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
3076 *
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003077 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
3078 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
3079 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3080 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
3081 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3082 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
3083 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
3084 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
3085 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3086 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
3087 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3088 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
3089 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
3090 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
3091 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
3092 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
3093 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
3094 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003095 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
3096 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
3097 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003098 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
3099 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003100 */
3101__kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
3102 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +00003103 IMAGE_DECLARATION(dst),
3104 uint src0_stride_z,
3105 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003106 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003107#if defined(REINTERPRET_INPUT_AS_3D)
3108 ,
3109 uint src_cross_plane_pad
3110#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003111#if defined(REINTERPRET_OUTPUT_AS_3D)
3112 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003113 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003114#endif // REINTERPRET_OUTPUT_AS_3D
3115 )
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003116{
3117 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
3118
3119 // Compute starting address for matrix A and matrix B
3120 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
3121
3122 // Update address for matrix A
3123 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
3124
3125 // Update address for matrix B
3126 src_addr.s1 += idx * sizeof(float);
3127
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003128#if defined(REINTERPRET_INPUT_AS_3D)
3129 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
3130 // in order to take into account the presence of possible cross plane paddings
3131 //
3132 // | |
3133 // | plane0 |
3134 // | |
3135 // |__________________|
3136 // |******************|
3137 // | cross_plane_pad |
3138 // |******************|
3139 // | |
3140 // | plane1 |
3141 // | |
3142 // |__________________|
3143
3144 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
3145 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
3146 zin = min(DEPTH_GEMM3D - 1, zin);
3147
3148 // Add offset due to the cross plane paddings
3149 zin *= (src_cross_plane_pad * src0_stride_y);
3150
3151 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3152 // multiply src0_stride_z by DEPTH_GEMM3D
3153 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
3154
3155#else // defined(REINTERPRET_INPUT_AS_3D)
3156
Gian Marcoae2af742018-02-15 12:35:44 +00003157 // Add offset for batched GEMM
3158 src_addr.s0 += get_global_id(2) * src0_stride_z;
3159
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003160#endif // defined(REINTERPRET_INPUT_AS_3D)
3161
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003162#if defined(MATRIX_B_DEPTH)
3163 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
3164 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
3165#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00003166 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003167#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00003168
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003169 // Initialize accumulators
3170 float acc00 = 0.0f;
3171 float acc01 = 0.0f;
3172 float acc02 = 0.0f;
3173 float acc03 = 0.0f;
3174
3175#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3176 float acc10 = 0.0f;
3177 float acc11 = 0.0f;
3178 float acc12 = 0.0f;
3179 float acc13 = 0.0f;
3180#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3181
3182#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3183 float acc20 = 0.0f;
3184 float acc21 = 0.0f;
3185 float acc22 = 0.0f;
3186 float acc23 = 0.0f;
3187#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3188
3189#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3190 float acc30 = 0.0f;
3191 float acc31 = 0.0f;
3192 float acc32 = 0.0f;
3193 float acc33 = 0.0f;
3194#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3195
3196 // A and B src indices get incremented at the same time.
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003197 int i = 0;
3198 for(; i <= ((int)COLS_A - 4); i += 4)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003199 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003200#if defined(REINTERPRET_INPUT_AS_3D)
3201 // Load values from matrix A and matrix B
3202 float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
3203#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3204 float4 a1 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
3205#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3206#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3207 float4 a2 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
3208#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3209#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3210 float4 a3 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
3211#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3212#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003213 // Load values from matrix A and matrix B
3214 float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003215#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003216 float4 a1 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003217#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3218#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003219 float4 a2 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003220#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3221#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003222 float4 a3 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003223#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003224#endif // defined(REINTERPRET_INPUT_AS_3D)
3225
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003226 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
3227 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003228
3229 // Multiply and accumulate
3230 acc00 = fma(a0.s0, b0.s0, acc00);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003231 acc01 = fma(a0.s0, b0.s1, acc01);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003232 acc02 = fma(a0.s0, b0.s2, acc02);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003233 acc03 = fma(a0.s0, b0.s3, acc03);
3234
3235#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003236
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003237 acc10 = fma(a1.s0, b0.s0, acc10);
3238 acc11 = fma(a1.s0, b0.s1, acc11);
3239 acc12 = fma(a1.s0, b0.s2, acc12);
3240 acc13 = fma(a1.s0, b0.s3, acc13);
3241
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003242#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3243#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003244
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003245 acc20 = fma(a2.s0, b0.s0, acc20);
3246 acc21 = fma(a2.s0, b0.s1, acc21);
3247 acc22 = fma(a2.s0, b0.s2, acc22);
3248 acc23 = fma(a2.s0, b0.s3, acc23);
3249
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003250#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3251#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003252
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003253 acc30 = fma(a3.s0, b0.s0, acc30);
3254 acc31 = fma(a3.s0, b0.s1, acc31);
3255 acc32 = fma(a3.s0, b0.s2, acc32);
3256 acc33 = fma(a3.s0, b0.s3, acc33);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003257#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003258
3259 // Load values from matrix A and matrix B
3260 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
3261 src_addr.s1 += src1_stride_y;
3262
3263 // Multiply and accumulate
3264 acc00 = fma(a0.s1, b0.s0, acc00);
3265 acc01 = fma(a0.s1, b0.s1, acc01);
3266 acc02 = fma(a0.s1, b0.s2, acc02);
3267 acc03 = fma(a0.s1, b0.s3, acc03);
3268
3269#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3270
3271 acc10 = fma(a1.s1, b0.s0, acc10);
3272 acc11 = fma(a1.s1, b0.s1, acc11);
3273 acc12 = fma(a1.s1, b0.s2, acc12);
3274 acc13 = fma(a1.s1, b0.s3, acc13);
3275
3276#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3277#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3278
3279 acc20 = fma(a2.s1, b0.s0, acc20);
3280 acc21 = fma(a2.s1, b0.s1, acc21);
3281 acc22 = fma(a2.s1, b0.s2, acc22);
3282 acc23 = fma(a2.s1, b0.s3, acc23);
3283
3284#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3285#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3286
3287 acc30 = fma(a3.s1, b0.s0, acc30);
3288 acc31 = fma(a3.s1, b0.s1, acc31);
3289 acc32 = fma(a3.s1, b0.s2, acc32);
3290 acc33 = fma(a3.s1, b0.s3, acc33);
3291#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3292
3293 // Load values from matrix A and matrix B
3294 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
3295 src_addr.s1 += src1_stride_y;
3296
3297 // Multiply and accumulate
3298 acc00 = fma(a0.s2, b0.s0, acc00);
3299 acc01 = fma(a0.s2, b0.s1, acc01);
3300 acc02 = fma(a0.s2, b0.s2, acc02);
3301 acc03 = fma(a0.s2, b0.s3, acc03);
3302
3303#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3304
3305 acc10 = fma(a1.s2, b0.s0, acc10);
3306 acc11 = fma(a1.s2, b0.s1, acc11);
3307 acc12 = fma(a1.s2, b0.s2, acc12);
3308 acc13 = fma(a1.s2, b0.s3, acc13);
3309
3310#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3311#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3312
3313 acc20 = fma(a2.s2, b0.s0, acc20);
3314 acc21 = fma(a2.s2, b0.s1, acc21);
3315 acc22 = fma(a2.s2, b0.s2, acc22);
3316 acc23 = fma(a2.s2, b0.s3, acc23);
3317
3318#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3319#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3320
3321 acc30 = fma(a3.s2, b0.s0, acc30);
3322 acc31 = fma(a3.s2, b0.s1, acc31);
3323 acc32 = fma(a3.s2, b0.s2, acc32);
3324 acc33 = fma(a3.s2, b0.s3, acc33);
3325#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3326
3327 // Load values from matrix A and matrix B
3328 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
3329 src_addr.s1 += src1_stride_y;
3330
3331 // Multiply and accumulate
3332 acc00 = fma(a0.s3, b0.s0, acc00);
3333 acc01 = fma(a0.s3, b0.s1, acc01);
3334 acc02 = fma(a0.s3, b0.s2, acc02);
3335 acc03 = fma(a0.s3, b0.s3, acc03);
3336
3337#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3338
3339 acc10 = fma(a1.s3, b0.s0, acc10);
3340 acc11 = fma(a1.s3, b0.s1, acc11);
3341 acc12 = fma(a1.s3, b0.s2, acc12);
3342 acc13 = fma(a1.s3, b0.s3, acc13);
3343
3344#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3345#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3346
3347 acc20 = fma(a2.s3, b0.s0, acc20);
3348 acc21 = fma(a2.s3, b0.s1, acc21);
3349 acc22 = fma(a2.s3, b0.s2, acc22);
3350 acc23 = fma(a2.s3, b0.s3, acc23);
3351
3352#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3353#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3354
3355 acc30 = fma(a3.s3, b0.s0, acc30);
3356 acc31 = fma(a3.s3, b0.s1, acc31);
3357 acc32 = fma(a3.s3, b0.s2, acc32);
3358 acc33 = fma(a3.s3, b0.s3, acc33);
3359#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3360
3361 src_addr.s0 += 4 * sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003362 }
3363
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003364 for(; i < (int)COLS_A; ++i)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003365 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003366#if defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003367 // Load values from matrix A
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003368 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
3369#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3370 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
3371#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3372#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3373 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
3374#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3375#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3376 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
3377#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3378#else // defined(REINTERPRET_INPUT_AS_3D)
3379 // Load values from matrix A
3380 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003381#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3382 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
3383#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3384#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3385 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
3386#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3387#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3388 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
3389#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003390#endif // defined(REINTERPRET_INPUT_AS_3D)
3391
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003392 // Load values from matrix B
3393 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003394 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003395
3396 // Multiply and accumulate
3397 acc00 = fma(a0, b0.s0, acc00);
3398 acc01 = fma(a0, b0.s1, acc01);
3399 acc02 = fma(a0, b0.s2, acc02);
3400 acc03 = fma(a0, b0.s3, acc03);
3401#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3402 acc10 = fma(a1, b0.s0, acc10);
3403 acc11 = fma(a1, b0.s1, acc11);
3404 acc12 = fma(a1, b0.s2, acc12);
3405 acc13 = fma(a1, b0.s3, acc13);
3406#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3407#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3408 acc20 = fma(a2, b0.s0, acc20);
3409 acc21 = fma(a2, b0.s1, acc21);
3410 acc22 = fma(a2, b0.s2, acc22);
3411 acc23 = fma(a2, b0.s3, acc23);
3412#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3413#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3414 acc30 = fma(a3, b0.s0, acc30);
3415 acc31 = fma(a3, b0.s1, acc31);
3416 acc32 = fma(a3, b0.s2, acc32);
3417 acc33 = fma(a3, b0.s3, acc33);
3418#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003419
3420 src_addr.s0 += sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003421 }
3422
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003423 int z = get_global_id(2);
3424
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003425 // Compute destination address
3426 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
3427
3428 // Multiply by the weight of matrix-matrix product and store the result
3429#if defined(ALPHA)
3430 acc00 = acc00 * ALPHA;
3431 acc01 = acc01 * ALPHA;
3432 acc02 = acc02 * ALPHA;
3433 acc03 = acc03 * ALPHA;
3434#endif // defined(ALPHA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003435#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003436 acc10 = acc10 * ALPHA;
3437 acc11 = acc11 * ALPHA;
3438 acc12 = acc12 * ALPHA;
3439 acc13 = acc13 * ALPHA;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003440#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
3441#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003442 acc20 = acc20 * ALPHA;
3443 acc21 = acc21 * ALPHA;
3444 acc22 = acc22 * ALPHA;
3445 acc23 = acc23 * ALPHA;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003446#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
3447#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003448 acc30 = acc30 * ALPHA;
3449 acc31 = acc31 * ALPHA;
3450 acc32 = acc32 * ALPHA;
3451 acc33 = acc33 * ALPHA;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003452#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
3453
3454 // Compute dst address
3455 __global uchar *dst_addr = offset(&dst, 0, 0);
3456
3457#if defined(REINTERPRET_OUTPUT_AS_3D)
3458 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003459 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003460 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003461 // | |
3462 // | plane0 |
3463 // | |
3464 // |__________________|
3465 // |******************|
3466 // | cross_plane_pad |
3467 // |******************|
3468 // | |
3469 // | plane1 |
3470 // | |
3471 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003472
3473 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
3474 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
3475 zout = min(DEPTH_GEMM3D - 1, zout);
3476
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003477 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003478 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003479
3480 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3481 // multiply dst_stride_z by DEPTH_GEMM3D
3482 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
3483
3484 // Store the output block
3485 vstore4((float4)(acc00, acc01, acc02, acc03), 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
3486#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3487 vstore4((float4)(acc10, acc11, acc12, acc13), 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
3488#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3489#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3490 vstore4((float4)(acc20, acc21, acc22, acc23), 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
3491#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3492#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3493 vstore4((float4)(acc30, acc31, acc32, acc33), 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003494#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003495
3496#else // defined(REINTERPRET_OUTPUT_AS_3D)
3497 // Add offset for batched GEMM
3498 dst_addr += z * dst_stride_z;
3499
3500 // Store the output block
3501 vstore4((float4)(acc00, acc01, acc02, acc03), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
3502#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3503 vstore4((float4)(acc10, acc11, acc12, acc13), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
3504#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3505#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3506 vstore4((float4)(acc20, acc21, acc22, acc23), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
3507#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3508#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3509 vstore4((float4)(acc30, acc31, acc32, acc33), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
3510#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3511#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003512}
3513
3514/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
3515 *
3516 * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
3517 * This OpenCL kernel is optimized for Bifrost when the number of matrix B columns is less or equal to 1000.
3518 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
3519 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=2.
3520 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
3521 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha if alpha!=1.0f.
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003522 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
3523 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003524 *
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003525 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
3526 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003527 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
3528 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
3529 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
3530 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
3531 *
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003532 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
3533 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
3534 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3535 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
3536 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3537 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
3538 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
3539 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
3540 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3541 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
3542 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3543 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
3544 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
3545 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
3546 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
3547 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
3548 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
3549 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003550 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
3551 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
3552 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003553 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
3554 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003555 */
3556__kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
3557 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +00003558 IMAGE_DECLARATION(dst),
3559 uint src0_stride_z,
3560 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003561 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003562#if defined(REINTERPRET_INPUT_AS_3D)
3563 ,
3564 uint src_cross_plane_pad
3565#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003566#if defined(REINTERPRET_OUTPUT_AS_3D)
3567 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003568 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003569#endif // REINTERPRET_OUTPUT_AS_3D
3570 )
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003571{
3572 // Requires 2 NUM_ELEMS_PROCESSED_PER_THREAD_X, C vect2, A vect4, B (2 vload2) // to fix for NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3573 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
3574
3575 // Compute starting address for matrix A and Matrix B
3576 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
3577
3578 // Update address for the matrix A
3579 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
3580
3581 // Update address for the matrix B
3582 src_addr.s1 += idx * sizeof(float);
3583
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003584#if defined(REINTERPRET_INPUT_AS_3D)
3585 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
3586 // in order to take into account the presence of possible cross plane paddings
3587 //
3588 // | |
3589 // | plane0 |
3590 // | |
3591 // |__________________|
3592 // |******************|
3593 // | cross_plane_pad |
3594 // |******************|
3595 // | |
3596 // | plane1 |
3597 // | |
3598 // |__________________|
3599
3600 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
3601 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
3602 zin = min(DEPTH_GEMM3D - 1, zin);
3603
3604 // Add offset due to the cross plane paddings
3605 zin *= (src_cross_plane_pad * src0_stride_y);
3606
3607 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3608 // multiply src0_stride_z by DEPTH_GEMM3D
3609 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
3610
3611#else // defined(REINTERPRET_INPUT_AS_3D)
3612
Gian Marcoae2af742018-02-15 12:35:44 +00003613 // Add offset for batched GEMM
3614 src_addr.s0 += get_global_id(2) * src0_stride_z;
3615
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003616#endif // defined(REINTERPRET_INPUT_AS_3D)
3617
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003618#if defined(MATRIX_B_DEPTH)
3619 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
3620 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
3621#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00003622 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003623#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00003624
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003625 // Initialize accumulators
3626 float acc00 = 0.0f;
3627 float acc01 = 0.0f;
3628
3629#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3630 float acc10 = 0.0f;
3631 float acc11 = 0.0f;
3632#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3633#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3634 float acc20 = 0.0f;
3635 float acc21 = 0.0f;
3636#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3637#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3638 float acc30 = 0.0f;
3639 float acc31 = 0.0f;
3640#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3641
3642 // A and B src indices get incremented at the same time.
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003643 int i = 0;
3644 for(; i <= ((int)COLS_A - 8); i += 8)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003645 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003646#if defined(REINTERPRET_INPUT_AS_3D)
3647 // Load values from matrix A
3648 float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + zin.s0));
3649#else // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003650 // Load values from matrix A
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003651 float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0));
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003652#endif // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003653
3654 // Load values from matrix B
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003655 float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
3656 src_addr.s1 += src1_stride_y;
3657 float2 b1 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
3658 src_addr.s1 += src1_stride_y;
3659 float2 b2 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
3660 src_addr.s1 += src1_stride_y;
3661 float2 b3 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
3662 src_addr.s1 += src1_stride_y;
3663 float2 b4 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
3664 src_addr.s1 += src1_stride_y;
3665 float2 b5 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
3666 src_addr.s1 += src1_stride_y;
3667 float2 b6 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
3668 src_addr.s1 += src1_stride_y;
3669 float2 b7 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
3670 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003671
3672 // Multiply and accumulate
3673 acc00 = fma(a0.s0, b0.s0, acc00);
3674 acc00 = fma(a0.s1, b1.s0, acc00);
3675 acc00 = fma(a0.s2, b2.s0, acc00);
3676 acc00 = fma(a0.s3, b3.s0, acc00);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003677 acc00 = fma(a0.s4, b4.s0, acc00);
3678 acc00 = fma(a0.s5, b5.s0, acc00);
3679 acc00 = fma(a0.s6, b6.s0, acc00);
3680 acc00 = fma(a0.s7, b7.s0, acc00);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003681
3682 acc01 = fma(a0.s0, b0.s1, acc01);
3683 acc01 = fma(a0.s1, b1.s1, acc01);
3684 acc01 = fma(a0.s2, b2.s1, acc01);
3685 acc01 = fma(a0.s3, b3.s1, acc01);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003686 acc01 = fma(a0.s4, b4.s1, acc01);
3687 acc01 = fma(a0.s5, b5.s1, acc01);
3688 acc01 = fma(a0.s6, b6.s1, acc01);
3689 acc01 = fma(a0.s7, b7.s1, acc01);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003690
3691#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003692#if defined(REINTERPRET_INPUT_AS_3D)
3693 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
3694#else // defined(REINTERPRET_INPUT_AS_3D)
3695 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
3696#endif // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003697 acc10 = fma(a0.s0, b0.s0, acc10);
3698 acc10 = fma(a0.s1, b1.s0, acc10);
3699 acc10 = fma(a0.s2, b2.s0, acc10);
3700 acc10 = fma(a0.s3, b3.s0, acc10);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003701 acc10 = fma(a0.s4, b4.s0, acc10);
3702 acc10 = fma(a0.s5, b5.s0, acc10);
3703 acc10 = fma(a0.s6, b6.s0, acc10);
3704 acc10 = fma(a0.s7, b7.s0, acc10);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003705
3706 acc11 = fma(a0.s0, b0.s1, acc11);
3707 acc11 = fma(a0.s1, b1.s1, acc11);
3708 acc11 = fma(a0.s2, b2.s1, acc11);
3709 acc11 = fma(a0.s3, b3.s1, acc11);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003710 acc11 = fma(a0.s4, b4.s1, acc11);
3711 acc11 = fma(a0.s5, b5.s1, acc11);
3712 acc11 = fma(a0.s6, b6.s1, acc11);
3713 acc11 = fma(a0.s7, b7.s1, acc11);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003714#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3715#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003716#if defined(REINTERPRET_INPUT_AS_3D)
3717 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
3718#else // defined(REINTERPRET_INPUT_AS_3D)
3719 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
3720#endif // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003721 acc20 = fma(a0.s0, b0.s0, acc20);
3722 acc20 = fma(a0.s1, b1.s0, acc20);
3723 acc20 = fma(a0.s2, b2.s0, acc20);
3724 acc20 = fma(a0.s3, b3.s0, acc20);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003725 acc20 = fma(a0.s4, b4.s0, acc20);
3726 acc20 = fma(a0.s5, b5.s0, acc20);
3727 acc20 = fma(a0.s6, b6.s0, acc20);
3728 acc20 = fma(a0.s7, b7.s0, acc20);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003729
3730 acc21 = fma(a0.s0, b0.s1, acc21);
3731 acc21 = fma(a0.s1, b1.s1, acc21);
3732 acc21 = fma(a0.s2, b2.s1, acc21);
3733 acc21 = fma(a0.s3, b3.s1, acc21);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003734 acc21 = fma(a0.s4, b4.s1, acc21);
3735 acc21 = fma(a0.s5, b5.s1, acc21);
3736 acc21 = fma(a0.s6, b6.s1, acc21);
3737 acc21 = fma(a0.s7, b7.s1, acc21);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003738#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3739#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003740#if defined(REINTERPRET_INPUT_AS_3D)
3741 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
3742#else // defined(REINTERPRET_INPUT_AS_3D)
3743 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
3744#endif // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003745 acc30 = fma(a0.s0, b0.s0, acc30);
3746 acc30 = fma(a0.s1, b1.s0, acc30);
3747 acc30 = fma(a0.s2, b2.s0, acc30);
3748 acc30 = fma(a0.s3, b3.s0, acc30);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003749 acc30 = fma(a0.s4, b4.s0, acc30);
3750 acc30 = fma(a0.s5, b5.s0, acc30);
3751 acc30 = fma(a0.s6, b6.s0, acc30);
3752 acc30 = fma(a0.s7, b7.s0, acc30);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003753
3754 acc31 = fma(a0.s0, b0.s1, acc31);
3755 acc31 = fma(a0.s1, b1.s1, acc31);
3756 acc31 = fma(a0.s2, b2.s1, acc31);
3757 acc31 = fma(a0.s3, b3.s1, acc31);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003758 acc31 = fma(a0.s4, b4.s1, acc31);
3759 acc31 = fma(a0.s5, b5.s1, acc31);
3760 acc31 = fma(a0.s6, b6.s1, acc31);
3761 acc31 = fma(a0.s7, b7.s1, acc31);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003762#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003763
3764 src_addr.s0 += sizeof(float) * 8;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003765 }
3766 // float size increment
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003767 for(; i < (int)COLS_A; ++i)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003768 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003769#if defined(REINTERPRET_INPUT_AS_3D)
3770 // Load values from matrix A
3771 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
3772#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3773 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
3774#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3775#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3776 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
3777#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3778#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3779 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
3780#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3781#else // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003782 // Load values from matrix A
3783 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
3784#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3785 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
3786#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3787#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3788 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
3789#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3790#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3791 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
3792#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003793#endif // defined(REINTERPRET_INPUT_AS_3D)
3794
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003795 // Load values from matrix B
3796 float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003797 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003798
3799 // Multiply and accumulate
3800 acc00 = fma(a0, b0.s0, acc00);
3801 acc01 = fma(a0, b0.s1, acc01);
3802#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3803 acc10 = fma(a1, b0.s0, acc10);
3804 acc11 = fma(a1, b0.s1, acc11);
3805#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3806#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3807 acc20 = fma(a2, b0.s0, acc20);
3808 acc21 = fma(a2, b0.s1, acc21);
3809#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3810#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3811 acc30 = fma(a3, b0.s0, acc30);
3812 acc31 = fma(a3, b0.s1, acc31);
3813#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003814
3815 src_addr.s0 += sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003816 }
3817
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003818 // Multiply by the weight of matrix-matrix product and store the result
3819#if defined(ALPHA)
3820 acc00 = acc00 * ALPHA;
3821 acc01 = acc01 * ALPHA;
3822#endif // defined(ALPHA)
3823#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
3824 acc10 = acc10 * ALPHA;
3825 acc11 = acc11 * ALPHA;
3826#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
3827#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
3828 acc20 = acc20 * ALPHA;
3829 acc21 = acc21 * ALPHA;
3830#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
3831#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
3832 acc30 = acc30 * ALPHA;
3833 acc31 = acc31 * ALPHA;
3834#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
3835
3836 int z = get_global_id(2);
3837
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003838 // Compute destination address
3839 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
3840
Gian Marcoae2af742018-02-15 12:35:44 +00003841 // Compute dst address
3842 __global uchar *dst_addr = offset(&dst, 0, 0);
3843
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003844#if defined(REINTERPRET_OUTPUT_AS_3D)
3845 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003846 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003847 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003848 // | |
3849 // | plane0 |
3850 // | |
3851 // |__________________|
3852 // |******************|
3853 // | cross_plane_pad |
3854 // |******************|
3855 // | |
3856 // | plane1 |
3857 // | |
3858 // |__________________|
Gian Marcoae2af742018-02-15 12:35:44 +00003859
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003860 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
3861 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
3862 zout = min(DEPTH_GEMM3D - 1, zout);
3863
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003864 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003865 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003866
3867 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3868 // multiply dst_stride_z by DEPTH_GEMM3D
3869 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
3870
3871 // Store the output block
3872 vstore2((float2)(acc00, acc01), 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003873#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003874 vstore2((float2)(acc10, acc11), 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003875#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3876#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003877 vstore2((float2)(acc20, acc21), 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003878#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3879#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003880 vstore2((float2)(acc30, acc31), 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003881#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003882
3883#else // defined(REINTERPRET_OUTPUT_AS_3D)
3884 // Add offset for batched GEMM
3885 dst_addr += z * dst_stride_z;
3886
3887 // Store the output block
3888 vstore2((float2)(acc00, acc01), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
3889#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3890 vstore2((float2)(acc10, acc11), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
3891#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3892#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3893 vstore2((float2)(acc20, acc21), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
3894#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3895#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3896 vstore2((float2)(acc30, acc31), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
3897#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3898#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003899}
3900
Vidhya Sudhan Loganathanbdff4912018-05-22 15:03:09 +01003901#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01003902/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
3903 *
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00003904 * @note This OpenCL kernel works with the 16-bit floating point data type (half) and accumulating the result in a 32 floating point variable.
3905 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
3906 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
3907 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
3908 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
3909 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
3910 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
3911 *
3912 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
3913 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
3914 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
3915 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
3916 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
3917 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
3918 *
3919 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
3920 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
3921 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3922 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
3923 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3924 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
3925 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
3926 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
3927 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3928 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
3929 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3930 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
3931 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
3932 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
3933 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
3934 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
3935 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
3936 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
3937 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
3938 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
3939 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
3940 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
3941 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
3942 */
3943__kernel void gemm_mm_floating_point_f16_bifrost_acc32(IMAGE_DECLARATION(src0),
3944 IMAGE_DECLARATION(src1),
3945 IMAGE_DECLARATION(dst),
3946 uint src0_stride_z,
3947 uint src1_stride_z,
3948 uint dst_stride_z
3949#if defined(REINTERPRET_INPUT_AS_3D)
3950 ,
3951 uint src_cross_plane_pad
3952#endif // REINTERPRET_INPUT_AS_3D
3953#if defined(REINTERPRET_OUTPUT_AS_3D)
3954 ,
3955 uint dst_cross_plane_pad
3956#endif // REINTERPRET_OUTPUT_AS_3D
3957 )
3958{
3959 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
3960
3961 // Compute starting address for matrix A and Matrix B
3962 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
3963
3964 // Update address for the matrix A
3965 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
3966
3967 // Update address for the matrix B
3968 src_addr.s1 += idx * sizeof(half);
3969
3970#if defined(REINTERPRET_INPUT_AS_3D)
3971 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
3972 // in order to take into account the presence of possible cross plane paddings
3973 //
3974 // | |
3975 // | plane0 |
3976 // | |
3977 // |__________________|
3978 // |******************|
3979 // | cross_plane_pad |
3980 // |******************|
3981 // | |
3982 // | plane1 |
3983 // | |
3984 // |__________________|
3985
3986 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
3987 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
3988 zin = min(DEPTH_GEMM3D - 1, zin);
3989
3990 // Add offset due to the cross plane paddings
3991 zin *= (src_cross_plane_pad * src0_stride_y);
3992
3993 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3994 // multiply src0_stride_z by DEPTH_GEMM3D
3995 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
3996
3997#else // defined(REINTERPRET_INPUT_AS_3D)
3998
3999 // Add offset for batched GEMM
4000 src_addr.s0 += get_global_id(2) * src0_stride_z;
4001
4002#endif // defined(REINTERPRET_INPUT_AS_3D)
4003
4004#if defined(MATRIX_B_DEPTH)
4005 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
4006 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
4007#else // defined(MATRIX_B_DEPTH)
4008 src_addr.s1 += get_global_id(2) * src1_stride_z;
4009#endif // defined(MATRIX_B_DEPTH)
4010
4011 float8 acc0 = 0.0h;
4012#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4013 float8 acc1 = 0.0h;
4014#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4015#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4016 float8 acc2 = 0.0h;
4017#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4018#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4019 float8 acc3 = 0.0h;
4020#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4021
4022 int i = 0;
4023 for(; i <= ((int)COLS_A - 4); i += 4)
4024 {
4025#if defined(REINTERPRET_INPUT_AS_3D)
4026 // Load values from matrix A
4027 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
4028#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4029 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
4030#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4031#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4032 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
4033#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4034#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4035 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
4036#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4037#else // defined(REINTERPRET_INPUT_AS_3D)
4038 // Load values from matrix A
4039 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
4040#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4041 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
4042#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4043#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4044 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
4045#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4046#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4047 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
4048#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4049#endif // defined(REINTERPRET_INPUT_AS_3D)
4050
4051 // Load values from matrix B
4052 float8 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
4053 src_addr.s1 += src1_stride_y;
4054
4055 // Accumulate
4056 acc0 = fma(b0, (float8)a0.s0, acc0);
4057#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4058 acc1 = fma(b0, (float8)a1.s0, acc1);
4059#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4060#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4061 acc2 = fma(b0, (float8)a2.s0, acc2);
4062#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4063#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4064 acc3 = fma(b0, (float8)a3.s0, acc3);
4065#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4066
4067 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
4068 src_addr.s1 += src1_stride_y;
4069 acc0 = fma(b0, (float8)a0.s1, acc0);
4070#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4071 acc1 = fma(b0, (float8)a1.s1, acc1);
4072#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4073#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4074 acc2 = fma(b0, (float8)a2.s1, acc2);
4075#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4076#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4077 acc3 = fma(b0, (float8)a3.s1, acc3);
4078#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4079
4080 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
4081 src_addr.s1 += src1_stride_y;
4082 acc0 = fma(b0, (float8)a0.s2, acc0);
4083#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4084 acc1 = fma(b0, (float8)a1.s2, acc1);
4085#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4086#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4087 acc2 = fma(b0, (float8)a2.s2, acc2);
4088#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4089#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4090 acc3 = fma(b0, (float8)a3.s2, acc3);
4091#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4092
4093 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
4094 src_addr.s1 += src1_stride_y;
4095 acc0 = fma(b0, (float8)a0.s3, acc0);
4096#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4097 acc1 = fma(b0, (float8)a1.s3, acc1);
4098#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4099#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4100 acc2 = fma(b0, (float8)a2.s3, acc2);
4101#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4102#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4103 acc3 = fma(b0, (float8)a3.s3, acc3);
4104#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4105
4106 src_addr.s0 += 4 * sizeof(half);
4107 }
4108
4109 for(; i < (int)COLS_A; ++i)
4110 {
4111#if defined(REINTERPRET_INPUT_AS_3D)
4112 // Load values from matrix A
4113 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
4114#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4115 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
4116#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4117#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4118 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
4119#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4120#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4121 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
4122#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4123#else // defined(REINTERPRET_INPUT_AS_3D)
4124 // Load values from matrix A
4125 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
4126#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4127 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
4128#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4129#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4130 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
4131#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4132#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4133 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
4134#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4135#endif // defined(REINTERPRET_INPUT_AS_3D)
4136
4137 // Load values from matrix B
4138 float8 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
4139
4140 src_addr += (int2)(sizeof(half), src1_stride_y);
4141
4142 // Accumulate
4143 acc0 = fma(b0, (float8)a0, acc0); // b0 * (half8)a0;
4144#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4145 acc1 = fma(b0, (float8)a1, acc1); // b0 * (half8)a1;
4146#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4147#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4148 acc2 = fma(b0, (float8)a2, acc2); // b0 * (half8)a2;
4149#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4150#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4151 acc3 = fma(b0, (float8)a3, acc3); // b0 * (half8)a3;
4152#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4153 }
4154
4155 // Multiply by the weight of matrix-matrix product and store the result
4156#if defined(ALPHA)
4157 half8 hacc0 = convert_half8(acc0) * (half8)ALPHA;
4158#else //defined(ALPHA)
4159 half8 hacc0 = convert_half8(acc0);
4160#endif // defined(ALPHA)
4161#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4162#if defined(ALPHA)
4163 half8 hacc1 = convert_half8(acc1) * (half8)ALPHA;
4164#else //defined(ALPHA)
4165 half8 hacc1 = convert_half8(acc1);
4166#endif //defined(ALPHA)
4167#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y
4168
4169#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4170#if defined(ALPHA)
4171 half8 hacc2 = convert_half8(acc2) * (half8)ALPHA;
4172#else //defined(ALPHA)
4173 half8 hacc2 = convert_half8(acc2);
4174#endif //defined(ALPHA)
4175#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4176
4177#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4178#if defined(ALPHA)
4179 half8 hacc3 = convert_half8(acc3) * (half8)ALPHA;
4180#else //defined(ALPHA)
4181 half8 hacc3 = convert_half8(acc3);
4182#endif // defined(ALPHA)
4183#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4184
4185 int z = get_global_id(2);
4186
4187 // Compute destination address
4188 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
4189
4190 // Compute dst address
4191 __global uchar *dst_addr = offset(&dst, 0, 0);
4192
4193#if defined(REINTERPRET_OUTPUT_AS_3D)
4194 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
4195 // in order to take into account the presence of possible cross plane paddings
4196 //
4197 // | |
4198 // | plane0 |
4199 // | |
4200 // |__________________|
4201 // |******************|
4202 // | cross_plane_pad |
4203 // |******************|
4204 // | |
4205 // | plane1 |
4206 // | |
4207 // |__________________|
4208
4209 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
4210 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
4211 zout = min(DEPTH_GEMM3D - 1, zout);
4212
4213 // Add offset due to the cross plane paddings
4214 zout *= (dst_cross_plane_pad * dst_stride_y);
4215
4216 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4217 // multiply dst_stride_z by DEPTH_GEMM3D
4218 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
4219
4220 // Store the output block
4221 vstore8(hacc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
4222#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4223 vstore8(hacc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
4224#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4225#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4226 vstore8(hacc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
4227#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4228#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4229 vstore8(hacc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
4230#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4231
4232#else // defined(REINTERPRET_OUTPUT_AS_3D)
4233 // Add offset for batched GEMM
4234 dst_addr += z * dst_stride_z;
4235
4236 // Store the output block
4237 vstore8(hacc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
4238#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4239 vstore8(hacc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
4240#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4241#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4242 vstore8(hacc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
4243#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4244#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4245 vstore8(hacc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
4246#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4247#endif // REINTERPRET_OUTPUT_AS_3D
4248}
4249
4250/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
4251 *
Gian Marco Iodicefd683112018-04-17 09:52:44 +01004252 * @note This OpenCL kernel works with the 16-bit floating point data type (half) and uses the fma units.
4253 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
4254 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
4255 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
4256 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
4257 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
4258 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
4259 *
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004260 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
4261 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004262 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
4263 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
4264 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
4265 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
4266 *
Gian Marco Iodicefd683112018-04-17 09:52:44 +01004267 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
4268 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
4269 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4270 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
4271 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4272 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
4273 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
4274 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
4275 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4276 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
4277 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4278 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
4279 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
4280 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
4281 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
4282 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
4283 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
4284 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004285 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
4286 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
4287 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004288 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
4289 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01004290 */
4291__kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0),
4292 IMAGE_DECLARATION(src1),
4293 IMAGE_DECLARATION(dst),
4294 uint src0_stride_z,
4295 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004296 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004297#if defined(REINTERPRET_INPUT_AS_3D)
4298 ,
4299 uint src_cross_plane_pad
4300#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004301#if defined(REINTERPRET_OUTPUT_AS_3D)
4302 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004303 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004304#endif // REINTERPRET_OUTPUT_AS_3D
4305 )
Gian Marco Iodicefd683112018-04-17 09:52:44 +01004306{
4307 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
4308
4309 // Compute starting address for matrix A and Matrix B
4310 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
4311
4312 // Update address for the matrix A
4313 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
4314
4315 // Update address for the matrix B
4316 src_addr.s1 += idx * sizeof(half);
4317
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004318#if defined(REINTERPRET_INPUT_AS_3D)
4319 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
4320 // in order to take into account the presence of possible cross plane paddings
4321 //
4322 // | |
4323 // | plane0 |
4324 // | |
4325 // |__________________|
4326 // |******************|
4327 // | cross_plane_pad |
4328 // |******************|
4329 // | |
4330 // | plane1 |
4331 // | |
4332 // |__________________|
4333
4334 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
4335 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
4336 zin = min(DEPTH_GEMM3D - 1, zin);
4337
4338 // Add offset due to the cross plane paddings
4339 zin *= (src_cross_plane_pad * src0_stride_y);
4340
4341 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4342 // multiply src0_stride_z by DEPTH_GEMM3D
4343 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
4344
4345#else // defined(REINTERPRET_INPUT_AS_3D)
4346
Gian Marco Iodicefd683112018-04-17 09:52:44 +01004347 // Add offset for batched GEMM
4348 src_addr.s0 += get_global_id(2) * src0_stride_z;
4349
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004350#endif // defined(REINTERPRET_INPUT_AS_3D)
4351
Gian Marco Iodicefd683112018-04-17 09:52:44 +01004352#if defined(MATRIX_B_DEPTH)
4353 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
4354 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
4355#else // defined(MATRIX_B_DEPTH)
4356 src_addr.s1 += get_global_id(2) * src1_stride_z;
4357#endif // defined(MATRIX_B_DEPTH)
4358
4359 half8 acc0 = 0.0h;
4360#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4361 half8 acc1 = 0.0h;
4362#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4363#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4364 half8 acc2 = 0.0h;
4365#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4366#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4367 half8 acc3 = 0.0h;
4368#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4369
4370 int i = 0;
4371 for(; i <= ((int)COLS_A - 4); i += 4)
4372 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004373#if defined(REINTERPRET_INPUT_AS_3D)
4374 // Load values from matrix A
4375 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
4376#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4377 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
4378#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4379#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4380 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
4381#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4382#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4383 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
4384#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4385#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01004386 // Load values from matrix A
4387 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
4388#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4389 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
4390#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4391#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4392 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
4393#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4394#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4395 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
4396#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004397#endif // defined(REINTERPRET_INPUT_AS_3D)
4398
Gian Marco Iodicefd683112018-04-17 09:52:44 +01004399 // Load values from matrix B
4400 half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
4401 src_addr.s1 += src1_stride_y;
4402
4403 // Accumulate
4404 acc0 = fma(b0, (half8)a0.s0, acc0);
4405#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4406 acc1 = fma(b0, (half8)a1.s0, acc1);
4407#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4408#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4409 acc2 = fma(b0, (half8)a2.s0, acc2);
4410#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4411#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4412 acc3 = fma(b0, (half8)a3.s0, acc3);
4413#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4414
4415 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
4416 src_addr.s1 += src1_stride_y;
4417 acc0 = fma(b0, (half8)a0.s1, acc0);
4418#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4419 acc1 = fma(b0, (half8)a1.s1, acc1);
4420#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4421#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4422 acc2 = fma(b0, (half8)a2.s1, acc2);
4423#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4424#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4425 acc3 = fma(b0, (half8)a3.s1, acc3);
4426#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4427
4428 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
4429 src_addr.s1 += src1_stride_y;
4430 acc0 = fma(b0, (half8)a0.s2, acc0);
4431#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4432 acc1 = fma(b0, (half8)a1.s2, acc1);
4433#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4434#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4435 acc2 = fma(b0, (half8)a2.s2, acc2);
4436#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4437#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4438 acc3 = fma(b0, (half8)a3.s2, acc3);
4439#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4440
4441 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
4442 src_addr.s1 += src1_stride_y;
4443 acc0 = fma(b0, (half8)a0.s3, acc0);
4444#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4445 acc1 = fma(b0, (half8)a1.s3, acc1);
4446#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4447#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4448 acc2 = fma(b0, (half8)a2.s3, acc2);
4449#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4450#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4451 acc3 = fma(b0, (half8)a3.s3, acc3);
4452#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4453
4454 src_addr.s0 += 4 * sizeof(half);
4455 }
4456
4457 for(; i < (int)COLS_A; ++i)
4458 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004459#if defined(REINTERPRET_INPUT_AS_3D)
4460 // Load values from matrix A
4461 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
4462#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4463 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
4464#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4465#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4466 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
4467#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4468#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4469 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
4470#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4471#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01004472 // Load values from matrix A
4473 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
4474#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4475 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
4476#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4477#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4478 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
4479#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4480#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4481 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
4482#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004483#endif // defined(REINTERPRET_INPUT_AS_3D)
4484
Gian Marco Iodicefd683112018-04-17 09:52:44 +01004485 // Load values from matrix B
4486 half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
4487
4488 src_addr += (int2)(sizeof(half), src1_stride_y);
4489
4490 // Accumulate
4491 acc0 = fma(b0, (half8)a0, acc0); // b0 * (half8)a0;
4492#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4493 acc1 = fma(b0, (half8)a1, acc1); // b0 * (half8)a1;
4494#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4495#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4496 acc2 = fma(b0, (half8)a2, acc2); // b0 * (half8)a2;
4497#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4498#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4499 acc3 = fma(b0, (half8)a3, acc3); // b0 * (half8)a3;
4500#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4501 }
4502
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004503 // Multiply by the weight of matrix-matrix product and store the result
4504#if defined(ALPHA)
4505 acc0 = acc0 * (half8)ALPHA;
4506#endif // defined(ALPHA)
4507#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
4508 acc1 = acc1 * (half8)ALPHA;
4509#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
4510#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
4511 acc2 = acc2 * (half8)ALPHA;
4512#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
4513#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
4514 acc3 = acc3 * (half8)ALPHA;
4515#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
4516
4517 int z = get_global_id(2);
4518
Gian Marco Iodicefd683112018-04-17 09:52:44 +01004519 // Compute destination address
4520 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
4521
4522 // Compute dst address
4523 __global uchar *dst_addr = offset(&dst, 0, 0);
4524
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004525#if defined(REINTERPRET_OUTPUT_AS_3D)
4526 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004527 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004528 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004529 // | |
4530 // | plane0 |
4531 // | |
4532 // |__________________|
4533 // |******************|
4534 // | cross_plane_pad |
4535 // |******************|
4536 // | |
4537 // | plane1 |
4538 // | |
4539 // |__________________|
Gian Marco Iodicefd683112018-04-17 09:52:44 +01004540
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004541 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
4542 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
4543 zout = min(DEPTH_GEMM3D - 1, zout);
4544
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004545 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004546 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004547
4548 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4549 // multiply dst_stride_z by DEPTH_GEMM3D
4550 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
4551
4552 // Store the output block
4553 vstore8(acc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
4554#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4555 vstore8(acc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
4556#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4557#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4558 vstore8(acc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
4559#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4560#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4561 vstore8(acc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
4562#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4563
4564#else // defined(REINTERPRET_OUTPUT_AS_3D)
4565 // Add offset for batched GEMM
4566 dst_addr += z * dst_stride_z;
4567
4568 // Store the output block
Gian Marco Iodicefd683112018-04-17 09:52:44 +01004569 vstore8(acc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
4570#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicefd683112018-04-17 09:52:44 +01004571 vstore8(acc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
4572#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4573#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicefd683112018-04-17 09:52:44 +01004574 vstore8(acc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
4575#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4576#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicefd683112018-04-17 09:52:44 +01004577 vstore8(acc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
4578#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004579#endif // REINTERPRET_OUTPUT_AS_3D
Gian Marco Iodicefd683112018-04-17 09:52:44 +01004580}
Vidhya Sudhan Loganathanbdff4912018-05-22 15:03:09 +01004581#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01004582
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004583#endif // defined(COLS_A) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && (NUM_ELEMS_PROCESSED_PER_THREAD_Y)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004584
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004585#if defined(BETA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004586/** This OpenCL kernel performs the in-place matrix addition between 2 matrices taking into account that the second matrix might be weighted by a scalar value beta:
4587 *
Gian Marco19835e52018-01-30 13:35:54 +00004588 * @note The beta's value need to be passed at compile time using -DBETA
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004589 *
4590 * @param[in] src_ptr Pointer to the source matrix. Supported data types: F32
4591 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
4592 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4593 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
4594 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004595 * @param[in] src_stride_z Stride of the destination tensor in Z dimension (in bytes)
4596 * @param[in] src_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004597 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01004598 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004599 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
4600 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
4601 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
4602 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004603 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
4604 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004605 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
4606 */
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004607__kernel void gemm_ma_f32(TENSOR3D_DECLARATION(src),
4608 TENSOR3D_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004609{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004610 // Compute source and destination addresses
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004611 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
4612 Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004613
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004614 // Load values from A x B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004615 float4 alpha_ab = vload4(0, (__global float *)dst.ptr);
4616
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004617 // Load values from Matrix C
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004618 float4 c = vload4(0, (__global float *)src.ptr);
4619
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004620 // Computes alpha * axb + beta * c
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004621 float4 out = alpha_ab + (float4)BETA * c;
4622
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004623 // Store final result in axb matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004624 vstore4(out, 0, (__global float *)dst.ptr);
4625}
4626
Vidhya Sudhan Loganathan76c85642018-05-25 13:53:02 +01004627#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004628/** This OpenCL kernel performs the in-place matrix addition between 2 matrices taking into account that the second matrix might be weighted by a scalar value beta:
4629 *
Gian Marco19835e52018-01-30 13:35:54 +00004630 * @note The beta's value need to be passed at compile time using -DBETA
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01004631 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004632 * @param[in] src_ptr Pointer to the source matrix. Supported data types: F16
4633 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
4634 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4635 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
4636 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004637 * @param[in] src_stride_z Stride of the destination tensor in Z dimension (in bytes)
4638 * @param[in] src_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004639 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01004640 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004641 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
4642 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
4643 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
4644 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004645 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
4646 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004647 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
4648 */
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004649__kernel void gemm_ma_f16(TENSOR3D_DECLARATION(src),
4650 TENSOR3D_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004651{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004652 // Compute source and destination addresses
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004653 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
4654 Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004655
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004656 // Load values from A x B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004657 half8 alpha_ab = vload8(0, (__global half *)dst.ptr);
4658
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004659 // Load values from Matrix C
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004660 half8 c = vload8(0, (__global half *)src.ptr);
4661
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004662 // Computes alpha * axb + beta * c
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004663 half8 out = alpha_ab + (half8)BETA * c;
4664
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004665 // Store final result in axb matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004666 vstore8(out, 0, (__global half *)dst.ptr);
4667}
Vidhya Sudhan Loganathan76c85642018-05-25 13:53:02 +01004668#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004669#endif // defined(BETA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004670
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004671#if defined(WIDTH_VECTOR_A)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004672/** This OpenCL kernel computes the vector by matrix multiplication between each row of A (src0) and matrix B (src1) used for locally connected layer
4673 *
Gian Marco19835e52018-01-30 13:35:54 +00004674 * @note The width of A need to be passed at compile time using -DWIDTH_VECTOR_A
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004675 *
Gian Marco19835e52018-01-30 13:35:54 +00004676 * @note The input A and matrix B must not be reshaped
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004677 *
4678 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
4679 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
4680 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4681 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
4682 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4683 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01004684 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004685 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
4686 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4687 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
4688 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4689 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
4690 * @param[in] src1_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
4691 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01004692 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004693 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
4694 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
4695 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
4696 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
4697 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
4698 */
4699__kernel void gemm_lc_vm_f32(IMAGE_DECLARATION(src0),
4700 TENSOR3D_DECLARATION(src1),
4701 IMAGE_DECLARATION(dst))
4702{
4703 int idx = get_global_id(0) * 4;
4704 int idy = get_global_id(1);
4705
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004706 // Compute the address for the vector A and matrix B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004707 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes + src0_stride_y * idy, src1_offset_first_element_in_bytes + src1_stride_z * idy));
4708 src_addr.s1 += idx * sizeof(float);
4709
4710 int end_row_vec_a = src_addr.s0 + (WIDTH_VECTOR_A * sizeof(float));
4711
4712 float4 acc = 0.0f;
4713
Georgios Pinitas96880cf2017-10-20 18:52:20 +01004714 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(float)); src_addr += (int2)(2 * sizeof(float), 2 * src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004715 {
4716 float2 a0 = vload2(0, (__global float *)(src0_ptr + src_addr.s0));
4717 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
4718 float4 b1 = vload4(0, (__global float *)(src1_ptr + src_addr.s1 + src1_stride_y));
4719
4720 acc += b0 * (float4)a0.s0;
4721 acc += b1 * (float4)a0.s1;
4722 }
4723
4724 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(float), src1_stride_y))
4725 {
4726 float a0 = *((__global float *)(src0_ptr + src_addr.s0));
4727 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
4728
4729 acc += b0 * (float4)a0;
4730 }
4731
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004732 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004733 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
4734
4735 vstore4(acc, 0, (__global float *)(offset(&dst, 0, 0)));
4736}
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004737#endif // defined(WIDTH_VECTOR_A)
4738
4739/** This kernel accumulates each row with the biases vector.
4740 *
4741 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=short.
4742 * @note The vector size must be passed at compile time using -DVECTOR_SIZE e.g. -DVECTOR_SIZE=16.
4743 *
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01004744 * @param[in, out] accum_ptr Pointer to the accumulate tensor. Supported data type: U8/S8/U16/S16/F16/U32/S32/F32
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004745 * @param[in] accum_stride_x Stride of the accmulate tensor in X dimension (in bytes)
4746 * @param[in] accum_step_x accum_stride_x * number of elements along X processed per workitem(in bytes)
4747 * @param[in] accum_stride_y Stride of the accumlulate tensor in Y dimension (in bytes)
4748 * @param[in] accum_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4749 * @param[in] accum_offset_first_element_in_bytes The offset of the first element in the accumulate tensor
4750 * @param[in] biases_ptr Pointer to the biases vector. Same as @p accum_ptr
4751 * @param[in] biases_stride_x Stride of the destination tensor in X dimension (in bytes)
4752 * @param[in] biases_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
4753 * @param[in] biases_offset_first_element_in_bytes The offset of the first element in the destination tensor
4754 */
4755#if defined(DATA_TYPE) && defined(VECTOR_SIZE)
4756__kernel void gemm_accumulate_biases(
4757 IMAGE_DECLARATION(accum),
4758 VECTOR_DECLARATION(biases))
4759{
4760 Image accum = CONVERT_TO_IMAGE_STRUCT(accum);
4761 Vector biases = CONVERT_TO_VECTOR_STRUCT(biases);
4762
4763 // Vector size, i.e. number of vector elements.
4764 VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
4765 accum_value = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)accum.ptr);
4766 VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
4767 biases_value = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)biases.ptr);
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01004768 accum_value = biases_value + accum_value;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004769 // Store result in the accumulate buffer
4770 VSTORE(VECTOR_SIZE)
4771 (accum_value, 0, (__global DATA_TYPE *)accum.ptr);
4772}
4773#endif // defined(DATA_TYPE) && defined(VECTOR_SIZE)