blob: 45c600cd373833ea9abad5c14fd8e77fae6f4f80 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +00002 * Copyright (c) 2017-2019 ARM Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "helpers.h"
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +000025#include "repeat.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010026
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +000027#if defined(M0) && defined(K0) && defined(V0) && defined(DATA_TYPE) && defined(SRC_WIDTH)
28#define INC2 (VEC_DATA_TYPE(uint, 2))(0, 1)
29#define INC3 (VEC_DATA_TYPE(uint, 3))(0, 1, 2)
30#define INC4 (VEC_DATA_TYPE(uint, 4))(0, 1, 2, 3)
31#define INC8 (VEC_DATA_TYPE(uint, 8))(0, 1, 2, 3, 4, 5, 6, 7)
32#define INC16 (VEC_DATA_TYPE(uint, 16))(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)
33#define CONCAT_INC(K0) INC##K0
34#define INC(K0) CONCAT_INC(K0)
35
36#if(SRC_WIDTH % K0)
37#define BOUNDARY_CONDITION_X(x, a) \
38 ({ \
39 a = select(0, a, CONVERT(((x * (VEC_DATA_TYPE(uint, K0))K0 + INC(K0)) < (VEC_DATA_TYPE(uint, K0))SRC_WIDTH), VEC_DATA_TYPE(DATA_TYPE, K0))); \
40 })
41#else // (SRC_WIDTH % K0)
42#define BOUNDARY_CONDITION_X(x, a) \
43 ({})
44#endif // (SRC_WIDTH % K0)
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000045
46/** This OpenCL kernel reshapes the lhs input matrix. The kernel splits the input matrix in blocks of size M0xK0 and stores each one (not transposed) in
47 * the output matrix unrolling the values.
48 *
49 * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float)
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +000050 * @note The width of the input tensor must be passed at compile time using -DSRC_WIDTH (i.e. -DSRC_WIDTH=16)
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000051 * @note The block's dimensions (M0 and K0) must be passed at compile time using -DM0 and -DK0 (i.e. -DM0=2, -DK0=2).
52 * @note The number of M0xK0 vertical blocks to store on the same output row must be passed at compile time using -DV0 (i.e. -DV0=2)
53 * @note Only the following values for M0, K0 and V0 are supported:
54 * M0: 2,3,4,5,6,7,8
Gian Marco Iodicebacfec52019-01-11 11:30:55 +000055 * K0: 2,3,4,8,16
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000056 * V0: greater than 0
57 * @note In case the input has to be reinterpreted as a 3D tensor (i.e. input of convolution layer 1x1), the following information must be passed at compile time:
58 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
59 * -# HEIGHT_GEMM3D: The height of the input in case it has to be reinterpreted as a 3D tensor.
60 * -# DEPTH_GEMM3D: The depth of the input in case it has to be reinterpreted as a 3D tensor
61 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
62 * @note If the M0xK0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
63 *
64 * @param[in] src_ptr Pointer to the source LHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
65 * @param[in] src_stride_x Stride of the source LHS tensor in X dimension (in bytes)
66 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
67 * @param[in] src_stride_y Stride of the source LHS tensor in Y dimension (in bytes)
68 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
69 * @param[in] src_stride_z Stride of the source LHS tensor in Z dimension (in bytes)
70 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
71 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source LHS tensor
72 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
73 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
74 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
75 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
76 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
77 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
78 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
79 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
80 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
81 */
82__kernel void gemm_reshape_lhs_matrix_nt(TENSOR3D_DECLARATION(src),
83 TENSOR3D_DECLARATION(dst)
84#if defined(REINTERPRET_INPUT_AS_3D)
85 ,
86 uint cross_plane_pad
87#endif // REINTERPRET_INPUT_AS_3D
88 )
89{
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +000090 // Block size
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000091#define BLOCK_SIZE ((M0) * (K0))
92
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +000093 // Output offset X
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000094#if defined(INTERLEAVE)
95#define OUTPUT_OFFSET_X (K0)
96#else // defined(INTERLEAVE)
97#define OUTPUT_OFFSET_X (BLOCK_SIZE)
98#endif // defined(INTERLEAVE)
99
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000100 // Output step X
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000101#if defined(INTERLEAVE)
102#define OUTPUT_STEP_X (K0) * (V0)
103#else // Do not interleave
104#define OUTPUT_STEP_X (K0)
105#endif // defined(INTERLEAVE)
106
107 // Compute source and destination addresses
108 uint x = get_global_id(0);
109 uint y = get_global_id(1);
110 uint z = get_global_id(2);
111
112 // ------------------ Compute input/output addresses ---------------------------
113
114 // Compute the input address
115 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)K0 * sizeof(DATA_TYPE) + y * (uint)M0 * src_stride_y;
116
117 // Compute the output address
118 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)BLOCK_SIZE * (uint)V0 * sizeof(DATA_TYPE)) + ((y / (uint)V0) * (uint)dst_stride_y) + ((y % V0) *
119 (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE));
120
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000121 // Create variables: uint zin0=0, zin1=0, zin2=0...zin(M0-1)=0;
122 REPEAT_VAR_INIT_TO_CONST(M0, uint, zin, 0);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000123
124#if defined(REINTERPRET_INPUT_AS_3D)
125 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
126 // multiply src_stride_z by DEPTH_GEMM3D
127
128 // Note for the REINTERPRET_INPUT_AS_3D case
129 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
130 // in order to take into account the presence of possible cross plane paddings
131 //
132 // | |
133 // | plane0 |
134 // | |
135 // |__________________|
136 // |******************|
137 // | cross_plane_pad |
138 // |******************|
139 // | |
140 // | plane1 |
141 // | |
142 // |__________________|
143
144 input_ptr += z * (uint)src_stride_z * DEPTH_GEMM3D;
145
146 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
147 zin0 = (0 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
148 zin0 = min((uint)(DEPTH_GEMM3D - 1), zin0);
149 zin0 *= (cross_plane_pad * src_stride_y);
150#if M0 > 1
151 zin1 = (1 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
152 zin1 = min((uint)(DEPTH_GEMM3D - 1), zin1);
153 zin1 *= (cross_plane_pad * src_stride_y);
154#endif // M0 > 1
155#if M0 > 2
156 zin2 = (2 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
157 zin2 = min((uint)(DEPTH_GEMM3D - 1), zin2);
158 zin2 *= (cross_plane_pad * src_stride_y);
159#endif // M0 > 2
160#if M0 > 3
161 zin3 = (3 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
162 zin3 = min((uint)(DEPTH_GEMM3D - 1), zin3);
163 zin3 *= (cross_plane_pad * src_stride_y);
164#endif // M0 > 3
165#if M0 > 4
166 zin4 = (4 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
167 zin4 = min((uint)(DEPTH_GEMM3D - 1), zin4);
168 zin4 *= (cross_plane_pad * src_stride_y);
169#endif // M0 > 4
170#if M0 > 5
171 zin5 = (5 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
172 zin5 = min((uint)(DEPTH_GEMM3D - 1), zin5);
173 zin5 *= (cross_plane_pad * src_stride_y);
174#endif // M0 > 5
175#if M0 > 6
176 zin6 = (6 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
177 zin6 = min((uint)(DEPTH_GEMM3D - 1), zin6);
178 zin6 *= (cross_plane_pad * src_stride_y);
179#endif // M0 > 6
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +0000180#if M0 > 7
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000181 zin7 = (7 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
182 zin7 = min((uint)(DEPTH_GEMM3D - 1), zin7);
183 zin7 *= (cross_plane_pad * src_stride_y);
184#endif // M0 > 7
185
186#else // defined(REINTERPRET_INPUT_AS_3D)
187
188 input_ptr += z * (uint)src_stride_z;
189
190#endif // defined(REINTERPRET_INPUT_AS_3D)
191
192 // Add offset for batched GEMM
193 output_ptr += z * (uint)dst_stride_z;
194
195 // ---------------------------Load input values --------------------------------
196
197 // Load values from the LHS matrix
198 VEC_DATA_TYPE(DATA_TYPE, K0)
199 a0 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y + zin0));
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000200 BOUNDARY_CONDITION_X(x, a0);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000201#if M0 > 1
202 VEC_DATA_TYPE(DATA_TYPE, K0)
203 a1 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y + zin1));
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000204 BOUNDARY_CONDITION_X(x, a1);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000205#endif // M0 > 1
206#if M0 > 2
207 VEC_DATA_TYPE(DATA_TYPE, K0)
208 a2 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y + zin2));
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000209 BOUNDARY_CONDITION_X(x, a2);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000210#endif // M0 > 2
211#if M0 > 3
212 VEC_DATA_TYPE(DATA_TYPE, K0)
213 a3 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y + zin3));
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000214 BOUNDARY_CONDITION_X(x, a3);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000215#endif // M0 > 3
216#if M0 > 4
217 VEC_DATA_TYPE(DATA_TYPE, K0)
218 a4 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 4 * src_stride_y + zin4));
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000219 BOUNDARY_CONDITION_X(x, a4);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000220#endif // M0 > 4
221#if M0 > 5
222 VEC_DATA_TYPE(DATA_TYPE, K0)
223 a5 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 5 * src_stride_y + zin5));
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000224 BOUNDARY_CONDITION_X(x, a5);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000225#endif // M0 > 5
226#if M0 > 6
227 VEC_DATA_TYPE(DATA_TYPE, K0)
228 a6 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 6 * src_stride_y + zin6));
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000229 BOUNDARY_CONDITION_X(x, a6);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000230#endif // M0 > 6
231#if M0 > 7
232 VEC_DATA_TYPE(DATA_TYPE, K0)
233 a7 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 7 * src_stride_y + zin7));
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000234 BOUNDARY_CONDITION_X(x, a7);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000235#endif // M0 > 7
236
237 // ---------------------------Store output values ------------------------------
238
239 VSTORE(K0)
240 (a0, 0, (__global DATA_TYPE *)(output_ptr + 0 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
241#if M0 > 1
242 VSTORE(K0)
243 (a1, 0, (__global DATA_TYPE *)(output_ptr + 1 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
244#endif // M0 > 1
245#if M0 > 2
246 VSTORE(K0)
247 (a2, 0, (__global DATA_TYPE *)(output_ptr + 2 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
248#endif // M0 > 2
249#if M0 > 3
250 VSTORE(K0)
251 (a3, 0, (__global DATA_TYPE *)(output_ptr + 3 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
252#endif // M0 > 3
253#if M0 > 4
254 VSTORE(K0)
255 (a4, 0, (__global DATA_TYPE *)(output_ptr + 4 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
256#endif // M0 > 4
257#if M0 > 5
258 VSTORE(K0)
259 (a5, 0, (__global DATA_TYPE *)(output_ptr + 5 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
260#endif // M0 > 5
261#if M0 > 6
262 VSTORE(K0)
263 (a6, 0, (__global DATA_TYPE *)(output_ptr + 6 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
264#endif // M0 > 6
265#if M0 > 7
266 VSTORE(K0)
267 (a7, 0, (__global DATA_TYPE *)(output_ptr + 7 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
268#endif // M0 > 7
269
270#undef BLOCK_SIZE
271#undef OUTPUT_OFFSET_X
272#undef OUTPUT_STEP_X
273}
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000274
275#if M0 == 2
276#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
277 ({ \
278 VEC_DATA_TYPE(DATA_TYPE, M0) \
279 res = (VEC_DATA_TYPE(DATA_TYPE, M0))(a0.s##i, a1.s##i); \
280 VSTORE(M0) \
281 (res, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
282 })
283#elif M0 == 3 // M0 == 3
284#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
285 ({ \
286 VEC_DATA_TYPE(DATA_TYPE, M0) \
287 res = (VEC_DATA_TYPE(DATA_TYPE, M0))(a0.s##i, a1.s##i, a2.s##i); \
288 VSTORE(M0) \
289 (res, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
290 })
291#elif M0 == 4 // M0 == 4
292#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
293 ({ \
294 VEC_DATA_TYPE(DATA_TYPE, M0) \
295 res = (VEC_DATA_TYPE(DATA_TYPE, M0))(a0.s##i, a1.s##i, a2.s##i, a3.s##i); \
296 VSTORE(M0) \
297 (res, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
298 })
299#elif M0 == 5 // M0 == 5
300#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
301 ({ \
302 VEC_DATA_TYPE(DATA_TYPE, 4) \
303 res0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s##i, a1.s##i, a2.s##i, a3.s##i); \
304 DATA_TYPE res1 = a4.s##i; \
305 VSTORE(4) \
306 (res0, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
307 *((__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE)) + 4) = res1; \
308 })
309#elif M0 == 6 // M0 == 6
310#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
311 ({ \
312 VEC_DATA_TYPE(DATA_TYPE, 4) \
313 res0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s##i, a1.s##i, a2.s##i, a3.s##i); \
314 VEC_DATA_TYPE(DATA_TYPE, 2) \
315 res1 = (VEC_DATA_TYPE(DATA_TYPE, 2))(a4.s##i, a5.s##i); \
316 VSTORE(4) \
317 (res0, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
318 VSTORE(2) \
319 (res1, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE)) + 4); \
320 })
321#elif M0 == 7 // M0 == 7
322#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
323 ({ \
324 VEC_DATA_TYPE(DATA_TYPE, 4) \
325 res0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s##i, a1.s##i, a2.s##i, a3.s##i); \
326 VEC_DATA_TYPE(DATA_TYPE, 3) \
327 res1 = (VEC_DATA_TYPE(DATA_TYPE, 3))(a4.s##i, a5.s##i, a6.s##i); \
328 VSTORE(4) \
329 (res0, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
330 VSTORE(3) \
331 (res1, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE)) + 4); \
332 })
333#elif M0 == 8 // M0 == 8
334#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
335 ({ \
336 VEC_DATA_TYPE(DATA_TYPE, M0) \
337 res = (VEC_DATA_TYPE(DATA_TYPE, M0))(a0.s##i, a1.s##i, a2.s##i, a3.s##i, a4.s##i, a5.s##i, a6.s##i, a7.s##i); \
338 VSTORE(M0) \
339 (res, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
340 })
341#else // M0 not supported
342#error "M0 value not supported"
343#endif // N0 conditions
344
345/** This OpenCL kernel reshapes the lhs input matrix. The kernel splits the input matrix in blocks of size M0xK0 and stores each one (transposed) in
346 * the output matrix unrolling the values.
347 *
348 * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float)
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000349 * @note The width of the input tensor must be passed at compile time using -DSRC_WIDTH (i.e. -DSRC_WIDTH=16)
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000350 * @note The block's dimensions (M0 and K0) must be passed at compile time using -DM0 and -DK0 (i.e. -DM0=2, -DK0=2).
351 * @note The number of M0xK0 vertical blocks to store on the same output row must be passed at compile time using -DV0 (i.e. -DV0=2)
352 * @note Only the following values for M0, K0 and V0 are supported:
353 * M0: 2,3,4,5,6,7,8
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000354 * K0: 2,3,4,8,16
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000355 * V0: greater than 0
356 * @note In case the input has to be reinterpreted as a 3D tensor (i.e. input of convolution layer 1x1), the following information must be passed at compile time:
357 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
358 * -# HEIGHT_GEMM3D: The height of the input in case it has to be reinterpreted as a 3D tensor.
359 * -# DEPTH_GEMM3D: The depth of the input in case it has to be reinterpreted as a 3D tensor
360 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
361 * @note If the M0xK0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
362 *
363 * @param[in] src_ptr Pointer to the source LHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
364 * @param[in] src_stride_x Stride of the source LHS tensor in X dimension (in bytes)
365 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
366 * @param[in] src_stride_y Stride of the source LHS tensor in Y dimension (in bytes)
367 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
368 * @param[in] src_stride_z Stride of the source LHS tensor in Z dimension (in bytes)
369 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
370 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source LHS tensor
371 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
372 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
373 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
374 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
375 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
376 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
377 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
378 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
379 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
380 */
381__kernel void gemm_reshape_lhs_matrix_t(TENSOR3D_DECLARATION(src),
382 TENSOR3D_DECLARATION(dst)
383#if defined(REINTERPRET_INPUT_AS_3D)
384 ,
385 uint cross_plane_pad
386#endif // REINTERPRET_INPUT_AS_3D
387 )
388{
389 // Block size
390#define BLOCK_SIZE ((M0) * (K0))
391
392 // Output offset X
393#if defined(INTERLEAVE)
394#define OUTPUT_OFFSET_X (M0)
395#else // defined(INTERLEAVE)
396#define OUTPUT_OFFSET_X (BLOCK_SIZE)
397#endif // defined(INTERLEAVE)
398
399 // Output step X
400#if defined(INTERLEAVE)
401#define OUTPUT_STEP_X (M0) * (V0)
402#else // Do not interleave
403#define OUTPUT_STEP_X (M0)
404#endif // defined(INTERLEAVE)
405
406 // Compute source and destination addresses
407 uint x = get_global_id(0);
408 uint y = get_global_id(1);
409 uint z = get_global_id(2);
410
411 // ------------------ Compute input/output addresses ---------------------------
412
413 // Compute the input address
414 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)K0 * sizeof(DATA_TYPE) + y * (uint)M0 * src_stride_y;
415
416 // Compute the output address
417 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)BLOCK_SIZE * (uint)V0 * sizeof(DATA_TYPE)) + ((y / (uint)V0) * (uint)dst_stride_y) + ((y % V0) *
418 (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE));
419
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000420 // Create variables: uint zin0=0, zin1=0, zin2=0...zin(M0-1)=0;
421 REPEAT_VAR_INIT_TO_CONST(M0, uint, zin, 0);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000422
423#if defined(REINTERPRET_INPUT_AS_3D)
424 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
425 // multiply src_stride_z by DEPTH_GEMM3D
426
427 // Note for the REINTERPRET_INPUT_AS_3D case
428 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
429 // in order to take into account the presence of possible cross plane paddings
430 //
431 // | |
432 // | plane0 |
433 // | |
434 // |__________________|
435 // |******************|
436 // | cross_plane_pad |
437 // |******************|
438 // | |
439 // | plane1 |
440 // | |
441 // |__________________|
442
443 input_ptr += z * (uint)src_stride_z * DEPTH_GEMM3D;
444
445 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
446 zin0 = (0 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
447 zin0 = min((uint)(DEPTH_GEMM3D - 1), zin0);
448 zin0 *= (cross_plane_pad * src_stride_y);
449#if M0 > 1
450 zin1 = (1 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
451 zin1 = min((uint)(DEPTH_GEMM3D - 1), zin1);
452 zin1 *= (cross_plane_pad * src_stride_y);
453#endif // M0 > 1
454#if M0 > 2
455 zin2 = (2 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
456 zin2 = min((uint)(DEPTH_GEMM3D - 1), zin2);
457 zin2 *= (cross_plane_pad * src_stride_y);
458#endif // M0 > 2
459#if M0 > 3
460 zin3 = (3 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
461 zin3 = min((uint)(DEPTH_GEMM3D - 1), zin3);
462 zin3 *= (cross_plane_pad * src_stride_y);
463#endif // M0 > 3
464#if M0 > 4
465 zin4 = (4 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
466 zin4 = min((uint)(DEPTH_GEMM3D - 1), zin4);
467 zin4 *= (cross_plane_pad * src_stride_y);
468#endif // M0 > 4
469#if M0 > 5
470 zin5 = (5 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
471 zin5 = min((uint)(DEPTH_GEMM3D - 1), zin5);
472 zin5 *= (cross_plane_pad * src_stride_y);
473#endif // M0 > 5
474#if M0 > 6
475 zin6 = (6 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
476 zin6 = min((uint)(DEPTH_GEMM3D - 1), zin6);
477 zin6 *= (cross_plane_pad * src_stride_y);
478#endif // M0 > 6
Gian Marco Iodice20b527a2019-01-23 14:05:42 +0000479#if M0 > 7
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000480 zin7 = (7 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
481 zin7 = min((uint)(DEPTH_GEMM3D - 1), zin7);
482 zin7 *= (cross_plane_pad * src_stride_y);
483#endif // M0 > 7
484
485#else // defined(REINTERPRET_INPUT_AS_3D)
486
487 input_ptr += z * (uint)src_stride_z;
488
489#endif // defined(REINTERPRET_INPUT_AS_3D)
490
491 // Add offset for batched GEMM
492 output_ptr += z * (uint)dst_stride_z;
493
494 // ---------------------------Load input values --------------------------------
495
496 // Load values from the LHS matrix
497 VEC_DATA_TYPE(DATA_TYPE, K0)
498 a0 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y + zin0));
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000499 BOUNDARY_CONDITION_X(x, a0);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000500#if M0 > 1
501 VEC_DATA_TYPE(DATA_TYPE, K0)
502 a1 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y + zin1));
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000503 BOUNDARY_CONDITION_X(x, a1);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000504#endif // M0 > 1
505#if M0 > 2
506 VEC_DATA_TYPE(DATA_TYPE, K0)
507 a2 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y + zin2));
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000508 BOUNDARY_CONDITION_X(x, a2);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000509#endif // M0 > 2
510#if M0 > 3
511 VEC_DATA_TYPE(DATA_TYPE, K0)
512 a3 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y + zin3));
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000513 BOUNDARY_CONDITION_X(x, a3);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000514#endif // M0 > 3
515#if M0 > 4
516 VEC_DATA_TYPE(DATA_TYPE, K0)
517 a4 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 4 * src_stride_y + zin4));
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000518 BOUNDARY_CONDITION_X(x, a4);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000519#endif // M0 > 4
520#if M0 > 5
521 VEC_DATA_TYPE(DATA_TYPE, K0)
522 a5 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 5 * src_stride_y + zin5));
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000523 BOUNDARY_CONDITION_X(x, a5);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000524#endif // M0 > 5
525#if M0 > 6
526 VEC_DATA_TYPE(DATA_TYPE, K0)
527 a6 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 6 * src_stride_y + zin6));
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000528 BOUNDARY_CONDITION_X(x, a6);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000529#endif // M0 > 6
530#if M0 > 7
531 VEC_DATA_TYPE(DATA_TYPE, K0)
532 a7 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 7 * src_stride_y + zin7));
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000533 BOUNDARY_CONDITION_X(x, a7);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000534#endif // M0 > 7
535
536 // ---------------------------Transpose and store block -----------------------
537
538 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 0);
539 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 1);
540#if K0 > 2
541 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 2);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000542#endif // K0 > 2
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000543#if K0 > 3
544 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 3);
545#endif // K0 > 3
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000546#if K0 > 4
547 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 4);
548 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 5);
549 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 6);
550 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 7);
551#endif // K0 > 4
552#if K0 > 8
553 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 8);
554 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 9);
555 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, A);
556 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, B);
557 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, C);
558 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, D);
559 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, E);
560 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, F);
561#endif // K0 > 8
562
563#undef BLOCK_SIZE
564#undef OUTPUT_OFFSET_X
565#undef OUTPUT_STEP_X
566}
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000567#endif // defined(M0) && defined(K0) && defined(V0) && defined(DATA_TYPE) && defined(SRC_WIDTH)
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000568
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000569#if defined(K0) && defined(N0) && defined(H0) && defined(DATA_TYPE) && defined(SRC_HEIGHT)
570/** This OpenCL kernel reshapes the rhs input matrix. The kernel splits the input matrix in blocks of size K0xN0 and stores each one (not transposed) in
571 * the output matrix unrolling the values.
572 *
573 * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float)
574 * @note The height of the input tensor must be passed at compile time using -DSRC_HEIGHT (i.e. -DSRC_HEIGHT=16)
575 * @note The block's dimensions (K0 and N0) must be passed at compile time using -DK0 and -DN0 (i.e. -DK0=2, -DN0=2).
576 * @note The number of K0xN0 vertical blocks to store on the same output row must be passed at compile time using -DH0 (i.e. -DH0=2)
577 * @note If the K0xN0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
578 * @note Only the following values for K0, N0 and H0 are supported:
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000579 * N0: 2,3,4,8,16
580 * K0: 1,2,3,4,8,16
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000581 * H0: greater than 0
582 *
583 * @param[in] src_ptr Pointer to the source RHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
584 * @param[in] src_stride_x Stride of the source RHS tensor in X dimension (in bytes)
585 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
586 * @param[in] src_stride_y Stride of the source RHS tensor in Y dimension (in bytes)
587 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
588 * @param[in] src_stride_z Stride of the source RHS tensor in Z dimension (in bytes)
589 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
590 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source RHS tensor
591 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
592 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
593 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
594 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
595 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
596 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
597 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
598 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
599 */
600__kernel void gemm_reshape_rhs_matrix_nt(TENSOR3D_DECLARATION(src),
601 TENSOR3D_DECLARATION(dst))
602{
603 // Block size
604#define BLOCK_SIZE ((K0) * (N0))
605
606 // Output offset X
607#if defined(INTERLEAVE)
608#define OUTPUT_OFFSET_X (N0)
609#else // defined(INTERLEAVE)
610#define OUTPUT_OFFSET_X (BLOCK_SIZE)
611#endif // defined(INTERLEAVE)
612
613 // Output step X
614#if defined(INTERLEAVE)
615#define OUTPUT_STEP_X (N0) * (H0)
616#else // Do not interleave
617#define OUTPUT_STEP_X (N0)
618#endif // defined(INTERLEAVE)
619
620 // Compute source and destination addresses
621 uint x = get_global_id(0);
622 uint y = get_global_id(1);
623 uint z = get_global_id(2);
624
625 // ------------------ Compute input/output addresses ---------------------------
626
627 // Compute the input address
628 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)N0 * sizeof(DATA_TYPE) + y * (uint)K0 * src_stride_y + z * (uint)src_stride_z;
629
630 // Compute the output address
631 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (y * (uint)BLOCK_SIZE * (uint)H0 * sizeof(DATA_TYPE)) + ((x % (uint)H0) * (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE)) + ((
632 x / (uint)H0)
633 * (uint)dst_stride_y)
634 + z * (uint)dst_stride_z;
635
636 // ---------------------------Load input values --------------------------------
637
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +0000638 REPEAT_VAR_INIT_TO_CONST(K0, VEC_DATA_TYPE(DATA_TYPE, N0), a, 0); ////uint a0=0, a1=0, a2=0...a(M0-1)=0;
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000639
640 // Load values from the RHS matrix
641 a0 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y));
642#if K0 > 1
643 if(y * (uint)K0 + 1 < SRC_HEIGHT)
644 {
645 a1 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y));
646 }
647#endif // K0 > 1
648#if K0 > 2
649 if(y * (uint)K0 + 2 < SRC_HEIGHT)
650 {
651 a2 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y));
652 }
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000653#endif // K0 > 2
654#if K0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000655 if(y * (uint)K0 + 3 < SRC_HEIGHT)
656 {
657 a3 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y));
658 }
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000659#endif // K0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000660#if K0 > 4
661 if(y * (uint)K0 + 4 < SRC_HEIGHT)
662 {
663 a4 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 4 * src_stride_y));
664 }
665 if(y * (uint)K0 + 5 < SRC_HEIGHT)
666 {
667 a5 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 5 * src_stride_y));
668 }
669 if(y * (uint)K0 + 6 < SRC_HEIGHT)
670 {
671 a6 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 6 * src_stride_y));
672 }
673 if(y * (uint)K0 + 7 < SRC_HEIGHT)
674 {
675 a7 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 7 * src_stride_y));
676 }
677#endif // K0 > 4
678#if K0 > 8
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000679 if(y * (uint)K0 + 8 < SRC_HEIGHT)
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000680 {
681 a8 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 8 * src_stride_y));
682 }
683 if(y * (uint)K0 + 9 < SRC_HEIGHT)
684 {
685 a9 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 9 * src_stride_y));
686 }
687 if(y * (uint)K0 + 10 < SRC_HEIGHT)
688 {
689 aA = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 10 * src_stride_y));
690 }
691 if(y * (uint)K0 + 11 < SRC_HEIGHT)
692 {
693 aB = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 11 * src_stride_y));
694 }
695 if(y * (uint)K0 + 12 < SRC_HEIGHT)
696 {
697 aC = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 12 * src_stride_y));
698 }
699 if(y * (uint)K0 + 13 < SRC_HEIGHT)
700 {
701 aD = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 13 * src_stride_y));
702 }
703 if(y * (uint)K0 + 14 < SRC_HEIGHT)
704 {
705 aE = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 14 * src_stride_y));
706 }
707 if(y * (uint)K0 + 15 < SRC_HEIGHT)
708 {
709 aF = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 15 * src_stride_y));
710 }
711#endif // K0 > 8
712
713 // ---------------------------Store output values ------------------------------
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000714 VSTORE(N0)
715 (a0, 0, (__global DATA_TYPE *)(output_ptr + 0 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
716#if K0 > 1
717 VSTORE(N0)
718 (a1, 0, (__global DATA_TYPE *)(output_ptr + 1 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
719#endif // K0 > 1
720#if K0 > 2
721 VSTORE(N0)
722 (a2, 0, (__global DATA_TYPE *)(output_ptr + 2 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000723#endif // K0 > 2
724#if K0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000725 VSTORE(N0)
726 (a3, 0, (__global DATA_TYPE *)(output_ptr + 3 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000727#endif // K0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000728#if K0 > 4
729 VSTORE(N0)
730 (a4, 0, (__global DATA_TYPE *)(output_ptr + 4 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
731 VSTORE(N0)
732 (a5, 0, (__global DATA_TYPE *)(output_ptr + 5 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
733 VSTORE(N0)
734 (a6, 0, (__global DATA_TYPE *)(output_ptr + 6 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
735 VSTORE(N0)
736 (a7, 0, (__global DATA_TYPE *)(output_ptr + 7 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
737#endif // N0 > 4
738#if K0 > 8
739 VSTORE(N0)
740 (a8, 0, (__global DATA_TYPE *)(output_ptr + 8 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
741 VSTORE(N0)
742 (a9, 0, (__global DATA_TYPE *)(output_ptr + 9 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
743 VSTORE(N0)
744 (aA, 0, (__global DATA_TYPE *)(output_ptr + 10 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
745 VSTORE(N0)
746 (aB, 0, (__global DATA_TYPE *)(output_ptr + 11 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
747 VSTORE(N0)
748 (aC, 0, (__global DATA_TYPE *)(output_ptr + 12 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
749 VSTORE(N0)
750 (aD, 0, (__global DATA_TYPE *)(output_ptr + 13 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
751 VSTORE(N0)
752 (aE, 0, (__global DATA_TYPE *)(output_ptr + 14 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
753 VSTORE(N0)
754 (aF, 0, (__global DATA_TYPE *)(output_ptr + 15 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
755#endif // N0 > 8
756
757#undef BLOCK_SIZE
758#undef OUTPUT_OFFSET_X
759#undef OUTPUT_STEP_X
760}
761
762#if defined(TRANSPOSE)
763/** This OpenCL kernel reshapes the rhs input matrix. The kernel splits the input matrix in blocks of size K0xN0 and stores each one (transposed) in
764 * the output matrix unrolling the values.
765 *
766 * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float)
767 * @note The height of the input tensor must be passed at compile time using -DSRC_HEIGHT (i.e. -DSRC_HEIGHT=16)
768 * @note The block's dimensions (K0 and N0) must be passed at compile time using -DK0 and -DN0 (i.e. -DK0=2, -DN0=2).
769 * @note The number of K0xN0 vertical blocks to store on the same output row must be passed at compile time using -DH0 (i.e. -DH0=2)
770 * @note If the K0xN0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
771 * @note The option -DTRANSPOSE must passed at compile time.
772 * @note Only the following values for K0, N0 and H0 are supported:
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000773 * N0: 2,3,4,8,16
774 * K0: 2,3,4,8,16
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000775 * H0: greater than 0
776 *
777 * @param[in] src_ptr Pointer to the source RHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
778 * @param[in] src_stride_x Stride of the source RHS tensor in X dimension (in bytes)
779 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
780 * @param[in] src_stride_y Stride of the source RHS tensor in Y dimension (in bytes)
781 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
782 * @param[in] src_stride_z Stride of the source RHS tensor in Z dimension (in bytes)
783 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
784 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source RHS tensor
785 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
786 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
787 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
788 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
789 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
790 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
791 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
792 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
793 */
794__kernel void gemm_reshape_rhs_matrix_t(TENSOR3D_DECLARATION(src),
795 TENSOR3D_DECLARATION(dst))
796{
797 // Block size
798#define BLOCK_SIZE ((K0) * (N0))
799
800 // Output offset X
801#if defined(INTERLEAVE)
802#define OUTPUT_OFFSET_X (K0)
803#else // defined(INTERLEAVE)
804#define OUTPUT_OFFSET_X (BLOCK_SIZE)
805#endif // defined(INTERLEAVE)
806
807 // Output step X
808#if defined(INTERLEAVE)
809#define OUTPUT_STEP_X (K0) * (H0)
810#else // Do not interleave
811#define OUTPUT_STEP_X (K0)
812#endif // defined(INTERLEAVE)
813
814 // Compute source and destination addresses
815 uint x = get_global_id(0);
816 uint y = get_global_id(1);
817 uint z = get_global_id(2);
818
819 // ------------------ Compute input/output addresses ---------------------------
820
821 // Compute the input address
822 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)N0 * sizeof(DATA_TYPE) + y * (uint)K0 * src_stride_y + z * (uint)src_stride_z;
823
824 // Compute the output address
825 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (y * (uint)BLOCK_SIZE * (uint)H0 * sizeof(DATA_TYPE)) + ((x % H0) * (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE)) + ((x /
826 (uint)H0) * (uint)dst_stride_y) + z * (uint)dst_stride_z;
827
828 // ---------------------------Load input values --------------------------------
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +0000829 REPEAT_VAR_INIT_TO_CONST(K0, VEC_DATA_TYPE(DATA_TYPE, N0), a, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) a0=0, a1=0, ... a(K0-1)=0;
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000830
831 // Load values from the RHS matrix
832 a0 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y));
833 if(y * (uint)K0 + 1 < SRC_HEIGHT)
834 {
835 a1 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y));
836 }
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000837#if K0 > 2
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000838 if(y * (uint)K0 + 2 < SRC_HEIGHT)
839 {
840 a2 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y));
841 }
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000842#endif // K0 > 2
843#if K0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000844 if(y * (uint)K0 + 3 < SRC_HEIGHT)
845 {
846 a3 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y));
847 }
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000848#endif // K0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000849#if K0 > 4
850 if(y * (uint)K0 + 4 < SRC_HEIGHT)
851 {
852 a4 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 4 * src_stride_y));
853 }
854 if(y * (uint)K0 + 5 < SRC_HEIGHT)
855 {
856 a5 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 5 * src_stride_y));
857 }
858 if(y * (uint)K0 + 6 < SRC_HEIGHT)
859 {
860 a6 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 6 * src_stride_y));
861 }
862 if(y * (uint)K0 + 7 < SRC_HEIGHT)
863 {
864 a7 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 7 * src_stride_y));
865 }
866#endif // K0 > 4
867#if K0 > 8
Gian Marco Iodice89124342018-12-19 14:17:22 +0000868 if(y * (uint)K0 + 8 < SRC_HEIGHT)
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000869 {
870 a8 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 8 * src_stride_y));
871 }
872 if(y * (uint)K0 + 9 < SRC_HEIGHT)
873 {
874 a9 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 9 * src_stride_y));
875 }
876 if(y * (uint)K0 + 10 < SRC_HEIGHT)
877 {
878 aA = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 10 * src_stride_y));
879 }
880 if(y * (uint)K0 + 11 < SRC_HEIGHT)
881 {
882 aB = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 11 * src_stride_y));
883 }
884 if(y * (uint)K0 + 12 < SRC_HEIGHT)
885 {
886 aC = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 12 * src_stride_y));
887 }
888 if(y * (uint)K0 + 13 < SRC_HEIGHT)
889 {
890 aD = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 13 * src_stride_y));
891 }
892 if(y * (uint)K0 + 14 < SRC_HEIGHT)
893 {
894 aE = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 14 * src_stride_y));
895 }
896 if(y * (uint)K0 + 15 < SRC_HEIGHT)
897 {
898 aF = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 15 * src_stride_y));
899 }
900#endif // K0 > 8
901
902 // ---------------------------Transpose the block ------------------------------
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +0000903 REPEAT_VAR_INIT_TO_CONST(N0, VEC_DATA_TYPE(DATA_TYPE, K0), res, 0); //VEC_DATA_TYPE(DATA_TYPE, K0) res0=0, res1=0, res2=0,... res(N0-1)=0;
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000904
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000905#if K0 == 2
906 // This part computes the following transpositions:
907 // 2x2 -> 2x2
908 // 2x4 -> 4x2
909 // 2x8 -> 8x2
910 // 2x16 -> 16x2
911 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0);
912 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1);
913#if N0 > 2
914 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2);
915#endif // N0 > 2
916#if N0 > 3
917 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3);
918#endif // N0 > 3
919#if N0 > 4
920 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4);
921 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5);
922 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6);
923 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7);
924#endif // N0 > 4
925#if N0 > 8
926 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8);
927 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9);
928 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA);
929 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB);
930 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC);
931 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD);
932 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE);
933 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF);
934#endif // N0 > 8
935
936#elif K0 == 3 // K0 == 2
937 // This part computes the following transpositions:
938 // 3x2 -> 2x3
939 // 3x4 -> 4x3
940 // 3x8 -> 8x3
941 // 3x16 -> 16x3
942 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0);
943 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1);
944#if N0 > 2
945 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2);
946#endif // N0 > 2
947#if N0 > 3
948 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3);
949#endif // N0 > 3
950#if N0 > 4
951 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4);
952 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5);
953 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6);
954 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7);
955#endif // N0 > 4
956#if N0 > 8
957 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8);
958 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9);
959 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA);
960 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB);
961 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC);
962 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD);
963 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE);
964 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF);
965#endif // N0 > 8
966
967#elif K0 == 4 // K0 == 4
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000968 // This part computes the following transpositions:
969 // 4x2 -> 2x4
970 // 4x4 -> 4x4
971 // 4x8 -> 8x4
972 // 4x16 -> 16x4
973 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0, a3.s0);
974 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1, a3.s1);
975#if N0 > 2
976 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2, a3.s2);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000977#endif // N0 > 2
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000978#if N0 > 3
979 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3, a3.s3);
980#endif // N0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000981#if N0 > 4
982 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4, a3.s4);
983 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5, a3.s5);
984 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6, a3.s6);
985 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7, a3.s7);
986#endif // N0 > 4
987#if N0 > 8
988 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8, a3.s8);
989 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9, a3.s9);
990 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA, a3.sA);
991 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB, a3.sB);
992 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC, a3.sC);
993 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD, a3.sD);
994 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE, a3.sE);
995 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF, a3.sF);
996#endif // N0 > 8
997
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000998#elif K0 == 8 // K0 == 8
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000999 // This part computes the following transpositions:
1000 // 8x2 -> 2x8
1001 // 8x4 -> 4x8
1002 // 8x8 -> 8x8
1003 // 8x16 -> 16x8
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001004 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0, a3.s0, a4.s0, a5.s0, a6.s0, a7.s0);
1005 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1, a3.s1, a4.s1, a5.s1, a6.s1, a7.s1);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +00001006#if N0 > 2
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001007 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2, a3.s2, a4.s2, a5.s2, a6.s2, a7.s2);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +00001008#endif // N0 > 2
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001009#if N0 > 3
1010 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3, a3.s3, a4.s3, a5.s3, a6.s3, a7.s3);
1011#endif // N0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +00001012#if N0 > 4
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001013 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4, a3.s4, a4.s4, a5.s4, a6.s4, a7.s4);
1014 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5, a3.s5, a4.s5, a5.s5, a6.s5, a7.s5);
1015 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6, a3.s6, a4.s6, a5.s6, a6.s6, a7.s6);
1016 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7, a3.s7, a4.s7, a5.s7, a6.s7, a7.s7);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +00001017#endif // N0 > 4
1018#if N0 > 8
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001019 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8, a3.s8, a4.s8, a5.s8, a6.s8, a7.s8);
1020 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9, a3.s9, a4.s9, a5.s9, a6.s9, a7.s9);
1021 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA, a3.sA, a4.sA, a5.sA, a6.sA, a7.sA);
1022 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB, a3.sB, a4.sB, a5.sB, a6.sB, a7.sB);
1023 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC, a3.sC, a4.sC, a5.sC, a6.sC, a7.sC);
1024 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD, a3.sD, a4.sD, a5.sD, a6.sD, a7.sD);
1025 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE, a3.sE, a4.sE, a5.sE, a6.sE, a7.sE);
1026 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF, a3.sF, a4.sF, a5.sF, a6.sF, a7.sF);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +00001027#endif // N0 > 8
1028
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001029#elif K0 == 16 // K0 == 16
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +00001030
1031 // This part computes the following transpositions:
1032 // 16x2 -> 2x16
1033 // 16x4 -> 4x16
1034 // 16x8 -> 8x16
1035 // 16x16 -> 16x16
1036 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0, a3.s0, a4.s0, a5.s0, a6.s0, a7.s0,
1037 a8.s0, a9.s0, aA.s0, aB.s0, aC.s0, aD.s0, aE.s0, aF.s0);
1038 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1, a3.s1, a4.s1, a5.s1, a6.s1, a7.s1,
1039 a8.s1, a9.s1, aA.s1, aB.s1, aC.s1, aD.s1, aE.s1, aF.s1);
1040#if N0 > 2
1041 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2, a3.s2, a4.s2, a5.s2, a6.s2, a7.s2,
1042 a8.s2, a9.s2, aA.s2, aB.s2, aC.s2, aD.s2, aE.s2, aF.s2);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001043#endif // N0 > 2
1044#if N0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +00001045 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3, a3.s3, a4.s3, a5.s3, a6.s3, a7.s3,
1046 a8.s3, a9.s3, aA.s3, aB.s3, aC.s3, aD.s3, aE.s3, aF.s3);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001047#endif // N0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +00001048#if N0 > 4
1049 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4, a3.s4, a4.s4, a5.s4, a6.s4, a7.s4,
1050 a8.s4, a9.s4, aA.s4, aB.s4, aC.s4, aD.s4, aE.s4, aF.s4);
1051 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5, a3.s5, a4.s5, a5.s5, a6.s5, a7.s5,
1052 a8.s5, a9.s5, aA.s5, aB.s5, aC.s5, aD.s5, aE.s5, aF.s5);
1053 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6, a3.s6, a4.s6, a5.s6, a6.s6, a7.s6,
1054 a8.s6, a9.s6, aA.s6, aB.s6, aC.s6, aD.s6, aE.s6, aF.s6);
1055 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7, a3.s7, a4.s7, a5.s7, a6.s7, a7.s7,
1056 a8.s7, a9.s7, aA.s7, aB.s7, aC.s7, aD.s7, aE.s7, aF.s7);
1057#endif // N0 > 4
1058#if N0 > 8
1059 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8, a3.s8, a4.s8, a5.s8, a6.s8, a7.s8,
1060 a8.s8, a9.s8, aA.s8, aB.s8, aC.s8, aD.s8, aE.s8, aF.s8);
1061 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9, a3.s9, a4.s9, a5.s9, a6.s9, a7.s9,
1062 a8.s9, a9.s9, aA.s9, aB.s9, aC.s9, aD.s9, aE.s9, aF.s9);
1063 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA, a3.sA, a4.sA, a5.sA, a6.sA, a7.sA,
1064 a8.sA, a9.sA, aA.sA, aB.sA, aC.sA, aD.sA, aE.sA, aF.sA);
1065 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB, a3.sB, a4.sB, a5.sB, a6.sB, a7.sB,
1066 a8.sB, a9.sB, aA.sB, aB.sB, aC.sB, aD.sB, aE.sB, aF.sB);
1067 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC, a3.sC, a4.sC, a5.sC, a6.sC, a7.sC,
1068 a8.sC, a9.sC, aA.sC, aB.sC, aC.sC, aD.sC, aE.sC, aF.sC);
1069 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD, a3.sD, a4.sD, a5.sD, a6.sD, a7.sD,
1070 a8.sD, a9.sD, aA.sD, aB.sD, aC.sD, aD.sD, aE.sD, aF.sD);
1071 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE, a3.sE, a4.sE, a5.sE, a6.sE, a7.sE,
1072 a8.sE, a9.sE, aA.sE, aB.sE, aC.sE, aD.sE, aE.sE, aF.sE);
1073 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF, a3.sF, a4.sF, a5.sF, a6.sF, a7.sF,
1074 a8.sF, a9.sF, aA.sF, aB.sF, aC.sF, aD.sF, aE.sF, aF.sF);
1075#endif // N0 > 8
1076
1077#else // N0 == 16
1078#error "Not supported N0 value"
1079#endif // N0 > 2
1080
1081 // ---------------------------Store the output values ------------------------------
1082
1083 VSTORE(K0)
1084 (res0, 0, (__global DATA_TYPE *)(output_ptr + 0 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
1085 VSTORE(K0)
1086 (res1, 0, (__global DATA_TYPE *)(output_ptr + 1 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
1087#if N0 > 2
1088 VSTORE(K0)
1089 (res2, 0, (__global DATA_TYPE *)(output_ptr + 2 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001090#endif // N0 > 2
1091#if N0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +00001092 VSTORE(K0)
1093 (res3, 0, (__global DATA_TYPE *)(output_ptr + 3 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001094#endif // N0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +00001095#if N0 > 4
1096 VSTORE(K0)
1097 (res4, 0, (__global DATA_TYPE *)(output_ptr + 4 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
1098 VSTORE(K0)
1099 (res5, 0, (__global DATA_TYPE *)(output_ptr + 5 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
1100 VSTORE(K0)
1101 (res6, 0, (__global DATA_TYPE *)(output_ptr + 6 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
1102 VSTORE(K0)
1103 (res7, 0, (__global DATA_TYPE *)(output_ptr + 7 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
1104#endif // N0 > 4
1105#if N0 > 8
1106 VSTORE(K0)
1107 (res8, 0, (__global DATA_TYPE *)(output_ptr + 8 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
1108 VSTORE(K0)
1109 (res9, 0, (__global DATA_TYPE *)(output_ptr + 9 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
1110 VSTORE(K0)
1111 (resA, 0, (__global DATA_TYPE *)(output_ptr + 10 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
1112 VSTORE(K0)
1113 (resB, 0, (__global DATA_TYPE *)(output_ptr + 11 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
1114 VSTORE(K0)
1115 (resC, 0, (__global DATA_TYPE *)(output_ptr + 12 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
1116 VSTORE(K0)
1117 (resD, 0, (__global DATA_TYPE *)(output_ptr + 13 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
1118 VSTORE(K0)
1119 (resE, 0, (__global DATA_TYPE *)(output_ptr + 14 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
1120 VSTORE(K0)
1121 (resF, 0, (__global DATA_TYPE *)(output_ptr + 15 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
1122#endif // N0 > 8
1123
1124#undef BLOCK_SIZE
1125#undef OUTPUT_OFFSET_X
1126#undef OUTPUT_STEP_X
1127}
1128#endif // defined(TRANSPOSE)
1129#endif // defined(K0) && defined(N0) && defined(H0) && defined(DATA_TYPE) && defined(SRC_HEIGHT)
1130
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001131#if defined(M0) && defined(N0) && defined(K0) && defined(H0) && defined(DATA_TYPE) && defined(M) && defined(N) && defined(K)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001132
1133#define CONCAT(a, b) a##b
1134
1135#define ARM_DOT1(a, b, c) \
1136 ({ \
1137 c = fma(a, b, c); \
1138 })
1139#define ARM_DOT2(a, b, c) \
1140 ({ \
1141 c = fma(a.s0, b.s0, c); \
1142 c = fma(a.s1, b.s1, c); \
1143 })
1144#define ARM_DOT3(a, b, c) \
1145 ({ \
1146 ARM_DOT2(a, b, c); \
1147 c = fma((a.s2), (b.s2), c); \
1148 })
1149#define ARM_DOT4(a, b, c) \
1150 ({ \
1151 ARM_DOT3(a, b, c); \
1152 c = fma((a.s3), (b.s3), c); \
1153 })
1154#define ARM_DOT8(a, b, c) \
1155 ({ \
1156 ARM_DOT4((a.lo), (b.lo), c); \
1157 ARM_DOT4((a.hi), (b.hi), c); \
1158 })
1159#define ARM_DOT16(a, b, c) \
1160 ({ \
1161 ARM_DOT8((a.lo), (b.lo), c); \
1162 ARM_DOT8((a.hi), (b.hi), c); \
1163 })
1164
1165#if N0 == 2
1166#define ARM_DOT_K0XN0(k0, a, b, c) \
1167 ({ \
1168 CONCAT(ARM_DOT, k0) \
1169 ((a), (b##0), (c.s0)); \
1170 CONCAT(ARM_DOT, k0) \
1171 ((a), (b##1), (c.s1)); \
1172 })
1173#elif N0 == 3 // N0 == 3
1174#define ARM_DOT_K0XN0(k0, a, b, c) \
1175 ({ \
1176 CONCAT(ARM_DOT, k0) \
1177 ((a), (b##0), (c.s0)); \
1178 CONCAT(ARM_DOT, k0) \
1179 ((a), (b##1), (c.s1)); \
1180 CONCAT(ARM_DOT, k0) \
1181 ((a), (b##2), (c.s2)); \
1182 })
1183#elif N0 == 4 // N0 == 4
1184#define ARM_DOT_K0XN0(k0, a, b, c) \
1185 ({ \
1186 CONCAT(ARM_DOT, k0) \
1187 ((a), (b##0), (c.s0)); \
1188 CONCAT(ARM_DOT, k0) \
1189 ((a), (b##1), (c.s1)); \
1190 CONCAT(ARM_DOT, k0) \
1191 ((a), (b##2), (c.s2)); \
1192 CONCAT(ARM_DOT, k0) \
1193 ((a), (b##3), (c.s3)); \
1194 })
1195#elif N0 == 8 // N0 == 8
1196#define ARM_DOT_K0XN0(k0, a, b, c) \
1197 ({ \
1198 CONCAT(ARM_DOT, k0) \
1199 ((a), (b##0), (c.s0)); \
1200 CONCAT(ARM_DOT, k0) \
1201 ((a), (b##1), (c.s1)); \
1202 CONCAT(ARM_DOT, k0) \
1203 ((a), (b##2), (c.s2)); \
1204 CONCAT(ARM_DOT, k0) \
1205 ((a), (b##3), (c.s3)); \
1206 CONCAT(ARM_DOT, k0) \
1207 ((a), (b##4), (c.s4)); \
1208 CONCAT(ARM_DOT, k0) \
1209 ((a), (b##5), (c.s5)); \
1210 CONCAT(ARM_DOT, k0) \
1211 ((a), (b##6), (c.s6)); \
1212 CONCAT(ARM_DOT, k0) \
1213 ((a), (b##7), (c.s7)); \
1214 })
1215#elif N0 == 16 // N0 == 16
1216#define ARM_DOT_K0XN0(k0, a, b, c) \
1217 ({ \
1218 CONCAT(ARM_DOT, k0) \
1219 ((a), (b##0), (c.s0)); \
1220 CONCAT(ARM_DOT, k0) \
1221 ((a), (b##1), (c.s1)); \
1222 CONCAT(ARM_DOT, k0) \
1223 ((a), (b##2), (c.s2)); \
1224 CONCAT(ARM_DOT, k0) \
1225 ((a), (b##3), (c.s3)); \
1226 CONCAT(ARM_DOT, k0) \
1227 ((a), (b##4), (c.s4)); \
1228 CONCAT(ARM_DOT, k0) \
1229 ((a), (b##5), (c.s5)); \
1230 CONCAT(ARM_DOT, k0) \
1231 ((a), (b##6), (c.s6)); \
1232 CONCAT(ARM_DOT, k0) \
1233 ((a), (b##7), (c.s7)); \
1234 CONCAT(ARM_DOT, k0) \
1235 ((a), (b##8), (c.s8)); \
1236 CONCAT(ARM_DOT, k0) \
1237 ((a), (b##9), (c.s9)); \
1238 CONCAT(ARM_DOT, k0) \
1239 ((a), (b##A), (c.sA)); \
1240 CONCAT(ARM_DOT, k0) \
1241 ((a), (b##B), (c.sB)); \
1242 CONCAT(ARM_DOT, k0) \
1243 ((a), (b##C), (c.sC)); \
1244 CONCAT(ARM_DOT, k0) \
1245 ((a), (b##D), (c.sD)); \
1246 CONCAT(ARM_DOT, k0) \
1247 ((a), (b##E), (c.sE)); \
1248 CONCAT(ARM_DOT, k0) \
1249 ((a), (b##F), (c.sF)); \
1250 })
1251#else // N0 not supported
1252#error "N0 value not supported"
1253#endif // N0 conditions
1254
1255/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
1256 * The LHS matrix is NOT reshaped
1257 * The RHS is reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the block K0xN0 is transposed
1258 *
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001259 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
1260 * @note The GEMM's dimensions (M,N and K) must be passed at compile time using -DM, -DN and and -DK (i.e. -DM=52, -DN=30 and -DK=90)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001261 * @note The number of columns of LHS matrix must be passed at compile time using -DK (i.e. -DK=64)
1262 * @note The block's dimensions used for reshaping the RHS matrix (N0 and K0) must be passed at compile time using -DN0 and -DK0 (i.e. -DN0=8, -DK0=4).
1263 * @note The number of M0 rows to process must be passed at compile time using -DM0 (i.e. -DM0=2)
1264 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (i.e. -DH0=2)
1265 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
1266 * @note Only the following configurations of M0, N0 and K0 are currently supported:
1267 * - M0 = 1, 2, 3, 4, 5, 6, 7, 8
1268 * - N0 = 2, 3, 4, 8, 16
1269 * - K0 = 2, 3, 4, 8, 16
1270 * - H0 > 1
1271 *
1272 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
1273 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
1274 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1275 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1276 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1277 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix
1278 *
1279 * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: F16/F32
1280 * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes)
1281 * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1282 * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes)
1283 * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1284 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
1285 * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
1286 * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes)
1287 * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1288 * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
1289 * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1290 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
1291 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
1292 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1293 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1294 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1295 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1296 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1297 * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes)
1298 * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
1299 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1300 * @param[in] lhs_cross_plane_pad (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
1301 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
1302 */
1303__kernel void gemm_mm_reshaped_only_rhs_t(IMAGE_DECLARATION(lhs),
1304 IMAGE_DECLARATION(rhs),
1305 IMAGE_DECLARATION(dst),
1306 uint lhs_stride_z,
1307 uint rhs_stride_z,
1308 uint dst_stride_z
1309#if defined(REINTERPRET_INPUT_AS_3D)
1310 ,
1311 uint lhs_cross_plane_pad
1312#endif // REINTERPRET_INPUT_AS_3D
1313#if defined(REINTERPRET_OUTPUT_AS_3D)
1314 ,
1315 uint dst_cross_plane_pad
1316#endif // REINTERPRET_OUTPUT_AS_3D
1317 )
1318{
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001319 // Block size
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001320#define RHS_BLOCK_SIZE ((K0) * (N0))
1321
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001322 // RHS offset and step X
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001323#if defined(RHS_INTERLEAVE)
1324#define RHS_OFFSET_X (K0)
1325#define RHS_STEP_X ((K0) * (H0))
1326#define RHS_STEP_LOOP (1)
1327#else // defined(RHS_INTERLEAVE)
1328#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
1329#define RHS_STEP_X (K0)
1330#define RHS_STEP_LOOP (H0)
1331#endif // defined(RHS_INTERLEAVE)
1332
1333 uint x = get_global_id(0);
1334 uint y = get_global_id(1);
1335 uint z = get_global_id(2);
1336
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001337#if defined(DUMMY_WORK_ITEMS)
1338 if((x * N0 >= N) || (y * M0 >= M))
1339 {
1340 return;
1341 }
1342#endif // defined(DUMMY_WORK_ITEMS)
1343
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001344 // Compute LHS matrix address
1345 uint lhs_offset = lhs_offset_first_element_in_bytes + y * M0 * (uint)lhs_stride_y;
1346
1347 // Compute RHS matrix address
1348 uint rhs_offset = rhs_offset_first_element_in_bytes + (x % H0) * (uint)RHS_OFFSET_X * sizeof(DATA_TYPE) + (x / (uint)H0) * rhs_stride_y;
1349
1350#if defined(MATRIX_B_DEPTH)
1351 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1352 rhs_offset += (z % MATRIX_B_DEPTH) * rhs_stride_z;
1353#else // defined(MATRIX_B_DEPTH)
1354 rhs_offset += z * rhs_stride_z;
1355#endif // defined(MATRIX_B_DEPTH)
1356
1357 REPEAT_VAR_INIT_TO_CONST(8, uint, zin, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
1358
1359#if defined(REINTERPRET_INPUT_AS_3D)
1360 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
1361 // in order to take into account the presence of possible cross plane paddings
1362 //
1363 // | |
1364 // | plane0 |
1365 // | |
1366 // |__________________|
1367 // |******************|
1368 // | cross_plane_pad |
1369 // |******************|
1370 // | |
1371 // | plane1 |
1372 // | |
1373 // |__________________|
1374
1375 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
1376 zin0 = (0 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1377 zin0 = min((uint)(DEPTH_GEMM3D - 1), zin0);
1378 zin0 *= (lhs_cross_plane_pad * lhs_stride_y);
1379#if M0 > 1
1380 zin1 = (1 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1381 zin1 = min((uint)(DEPTH_GEMM3D - 1), zin1);
1382 zin1 *= (lhs_cross_plane_pad * lhs_stride_y);
1383#endif // M0 > 1
1384#if M0 > 2
1385 zin2 = (2 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1386 zin2 = min((uint)(DEPTH_GEMM3D - 1), zin2);
1387 zin2 *= (lhs_cross_plane_pad * lhs_stride_y);
1388#endif // M0 > 2
1389#if M0 > 3
1390 zin3 = (3 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1391 zin3 = min((uint)(DEPTH_GEMM3D - 1), zin3);
1392 zin3 *= (lhs_cross_plane_pad * lhs_stride_y);
1393#endif // M0 > 3
1394#if M0 > 4
1395 zin4 = (4 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1396 zin4 = min((uint)(DEPTH_GEMM3D - 1), zin4);
1397 zin4 *= (lhs_cross_plane_pad * lhs_stride_y);
1398#endif // M0 > 4
1399#if M0 > 5
1400 zin5 = (5 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1401 zin5 = min((uint)(DEPTH_GEMM3D - 1), zin5);
1402 zin5 *= (lhs_cross_plane_pad * lhs_stride_y);
1403#endif // M0 > 5
1404#if M0 > 6
1405 zin6 = (6 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1406 zin6 = min((uint)(DEPTH_GEMM3D - 1), zin6);
1407 zin6 *= (lhs_cross_plane_pad * lhs_stride_y);
1408#endif // M0 > 6
1409#if M0 > 7
1410 zin7 = (7 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1411 zin7 = min((uint)(DEPTH_GEMM3D - 1), zout7);
1412 zin7 *= (lhs_cross_plane_pad * lhs_stride_y);
1413#endif // M0 > 7
1414
1415 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1416 // multiply lhs_stride_z by DEPTH_GEMM3D
1417 lhs_offset += z * lhs_stride_z * DEPTH_GEMM3D;
1418
1419#else // defined(REINTERPRET_INPUT_AS_3D)
1420
1421 // Add offset for batched GEMM
1422 lhs_offset += z * lhs_stride_z;
1423
1424#endif // defined(REINTERPRET_INPUT_AS_3D)
1425
1426 // Initialize the accumulators
1427 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(M0-1)=0;
1428
1429 int i = 0;
1430 for(; i <= (K - K0); i += K0)
1431 {
1432 // Supported cases (M0, K0):
1433 // 1,2 - 1,3 - 1,4 - 1,8 - 1,16
1434 // 2,2 - 2,3 - 2,4 - 2,8 - 2,16
1435 // 3,2 - 3,3 - 3,4 - 3,8 - 3,16
1436 // 4,2 - 4,3 - 4,4 - 4,8 - 4,16
1437 // 5,2 - 5,3 - 5,4 - 5,8 - 5,16
1438 // 6,2 - 6,3 - 6,4 - 6,8 - 6,16
1439 // 7,2 - 7,3 - 7,4 - 7,8 - 7,16
1440 // 8,2 - 8,3 - 8,4 - 8,8 - 8,16
1441 // Load values from LHS matrix
1442 VEC_DATA_TYPE(DATA_TYPE, K0)
1443 a0 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 0 * lhs_stride_y + zin0));
1444#if M0 > 1
1445 VEC_DATA_TYPE(DATA_TYPE, K0)
1446 a1 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 1 * lhs_stride_y + zin1));
1447#endif // M0 > 1
1448#if M0 > 2
1449 VEC_DATA_TYPE(DATA_TYPE, K0)
1450 a2 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 2 * lhs_stride_y + zin2));
1451#endif // M0 > 2
1452#if M0 > 3
1453 VEC_DATA_TYPE(DATA_TYPE, K0)
1454 a3 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 3 * lhs_stride_y + zin3));
1455#endif // M0 > 3
1456#if M0 > 4
1457 VEC_DATA_TYPE(DATA_TYPE, K0)
1458 a4 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 4 * lhs_stride_y + zin4));
1459#endif // M0 > 4
1460#if M0 > 5
1461 VEC_DATA_TYPE(DATA_TYPE, K0)
1462 a5 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 5 * lhs_stride_y + zin5));
1463#endif // M0 > 5
1464#if M0 > 6
1465 VEC_DATA_TYPE(DATA_TYPE, K0)
1466 a6 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 6 * lhs_stride_y + zin6));
1467#endif // M0 > 6
1468#if M0 > 7
1469 VEC_DATA_TYPE(DATA_TYPE, K0)
1470 a7 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 7 * lhs_stride_y + zin7));
1471#endif // M0 > 7
1472
1473 // Load values from RHS matrix
1474 VEC_DATA_TYPE(DATA_TYPE, K0)
1475 b0 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0 * RHS_STEP_X * sizeof(DATA_TYPE)));
1476 VEC_DATA_TYPE(DATA_TYPE, K0)
1477 b1 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 1 * RHS_STEP_X * sizeof(DATA_TYPE)));
1478#if N0 > 2
1479 VEC_DATA_TYPE(DATA_TYPE, K0)
1480 b2 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 2 * RHS_STEP_X * sizeof(DATA_TYPE)));
1481#endif // N0 > 2
1482#if N0 > 3
1483 VEC_DATA_TYPE(DATA_TYPE, K0)
1484 b3 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 3 * RHS_STEP_X * sizeof(DATA_TYPE)));
1485#endif // N0 > 3
1486#if N0 > 4
1487 VEC_DATA_TYPE(DATA_TYPE, K0)
1488 b4 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 4 * RHS_STEP_X * sizeof(DATA_TYPE)));
1489 VEC_DATA_TYPE(DATA_TYPE, K0)
1490 b5 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 5 * RHS_STEP_X * sizeof(DATA_TYPE)));
1491 VEC_DATA_TYPE(DATA_TYPE, K0)
1492 b6 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 6 * RHS_STEP_X * sizeof(DATA_TYPE)));
1493 VEC_DATA_TYPE(DATA_TYPE, K0)
1494 b7 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 7 * RHS_STEP_X * sizeof(DATA_TYPE)));
1495#endif // N0 > 4
1496#if N0 > 8
1497 VEC_DATA_TYPE(DATA_TYPE, K0)
1498 b8 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 8 * RHS_STEP_X * sizeof(DATA_TYPE)));
1499 VEC_DATA_TYPE(DATA_TYPE, K0)
1500 b9 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 9 * RHS_STEP_X * sizeof(DATA_TYPE)));
1501 VEC_DATA_TYPE(DATA_TYPE, K0)
1502 bA = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 10 * RHS_STEP_X * sizeof(DATA_TYPE)));
1503 VEC_DATA_TYPE(DATA_TYPE, K0)
1504 bB = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 11 * RHS_STEP_X * sizeof(DATA_TYPE)));
1505 VEC_DATA_TYPE(DATA_TYPE, K0)
1506 bC = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 12 * RHS_STEP_X * sizeof(DATA_TYPE)));
1507 VEC_DATA_TYPE(DATA_TYPE, K0)
1508 bD = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 13 * RHS_STEP_X * sizeof(DATA_TYPE)));
1509 VEC_DATA_TYPE(DATA_TYPE, K0)
1510 bE = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 14 * RHS_STEP_X * sizeof(DATA_TYPE)));
1511 VEC_DATA_TYPE(DATA_TYPE, K0)
1512 bF = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 15 * RHS_STEP_X * sizeof(DATA_TYPE)));
1513#endif // N0 > 8
1514
1515 // Accumulate
1516 ARM_DOT_K0XN0(K0, a0, b, c0);
1517#if M0 > 1
1518 ARM_DOT_K0XN0(K0, a1, b, c1);
1519#endif // M0 > 1
1520#if M0 > 2
1521 ARM_DOT_K0XN0(K0, a2, b, c2);
1522#endif // M0 > 2
1523#if M0 > 3
1524 ARM_DOT_K0XN0(K0, a3, b, c3);
1525#endif // M0 > 3
1526#if M0 > 4
1527 ARM_DOT_K0XN0(K0, a4, b, c4);
1528#endif // M0 > 4
1529#if M0 > 5
1530 ARM_DOT_K0XN0(K0, a5, b, c5);
1531#endif // M0 > 5
1532#if M0 > 6
1533 ARM_DOT_K0XN0(K0, a6, b, c6);
1534#endif // M0 > 6
1535#if M0 > 7
1536 ARM_DOT_K0XN0(K0, a7, b, c7);
1537#endif // M0 > 7
1538
1539 lhs_offset += K0 * sizeof(DATA_TYPE);
1540 rhs_offset += (N0 * RHS_STEP_X * RHS_STEP_LOOP) * sizeof(DATA_TYPE);
1541 }
1542
1543 // Left-over accumulations
1544 for(; i < K; ++i)
1545 {
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001546 // Load values from LHS matrix
1547 DATA_TYPE a0 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 0 * lhs_stride_y + zin0));
1548#if M0 > 1
1549 DATA_TYPE a1 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 1 * lhs_stride_y + zin1));
1550#endif // M0 > 1
1551#if M0 > 2
1552 DATA_TYPE a2 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 2 * lhs_stride_y + zin2));
1553#endif // M0 > 2
1554#if M0 > 3
1555 DATA_TYPE a3 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 3 * lhs_stride_y + zin3));
1556#endif // M0 > 3
1557#if M0 > 4
1558 DATA_TYPE a4 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 4 * lhs_stride_y + zin4));
1559#endif // M0 > 4
1560#if M0 > 5
1561 DATA_TYPE a5 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 5 * lhs_stride_y + zin5));
1562#endif // M0 > 5
1563#if M0 > 6
1564 DATA_TYPE a6 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 6 * lhs_stride_y + zin6));
1565#endif // M0 > 6
1566#if M0 > 7
1567 DATA_TYPE a7 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 7 * lhs_stride_y + zin7));
1568#endif // M0 > 7
1569
1570 // Load values from RHS matrix
1571 DATA_TYPE b0 = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0 * RHS_STEP_X * sizeof(DATA_TYPE)));
1572 DATA_TYPE b1 = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 1 * RHS_STEP_X * sizeof(DATA_TYPE)));
1573#if N0 > 2
1574 DATA_TYPE b2 = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 2 * RHS_STEP_X * sizeof(DATA_TYPE)));
1575#endif // N0 > 2
1576#if N0 > 3
1577 DATA_TYPE b3 = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 3 * RHS_STEP_X * sizeof(DATA_TYPE)));
1578#endif // N0 > 3
1579#if N0 > 4
1580 DATA_TYPE b4 = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 4 * RHS_STEP_X * sizeof(DATA_TYPE)));
1581 DATA_TYPE b5 = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 5 * RHS_STEP_X * sizeof(DATA_TYPE)));
1582 DATA_TYPE b6 = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 6 * RHS_STEP_X * sizeof(DATA_TYPE)));
1583 DATA_TYPE b7 = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 7 * RHS_STEP_X * sizeof(DATA_TYPE)));
1584#endif // N0 > 4
1585#if N0 > 8
1586 DATA_TYPE b8 = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 8 * RHS_STEP_X * sizeof(DATA_TYPE)));
1587 DATA_TYPE b9 = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 9 * RHS_STEP_X * sizeof(DATA_TYPE)));
1588 DATA_TYPE bA = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 10 * RHS_STEP_X * sizeof(DATA_TYPE)));
1589 DATA_TYPE bB = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 11 * RHS_STEP_X * sizeof(DATA_TYPE)));
1590 DATA_TYPE bC = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 12 * RHS_STEP_X * sizeof(DATA_TYPE)));
1591 DATA_TYPE bD = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 13 * RHS_STEP_X * sizeof(DATA_TYPE)));
1592 DATA_TYPE bE = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 14 * RHS_STEP_X * sizeof(DATA_TYPE)));
1593 DATA_TYPE bF = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 15 * RHS_STEP_X * sizeof(DATA_TYPE)));
1594#endif // N0 > 8
1595
1596 // Accumulate
1597 ARM_DOT_K0XN0(1, a0, b, c0);
1598#if M0 > 1
1599 ARM_DOT_K0XN0(1, a1, b, c1);
1600#endif // M0 > 1
1601#if M0 > 2
1602 ARM_DOT_K0XN0(1, a2, b, c2);
1603#endif // M0 > 2
1604#if M0 > 3
1605 ARM_DOT_K0XN0(1, a3, b, c3);
1606#endif // M0 > 3
1607#if M0 > 4
1608 ARM_DOT_K0XN0(1, a4, b, c4);
1609#endif // M0 > 4
1610#if M0 > 5
1611 ARM_DOT_K0XN0(1, a5, b, c5);
1612#endif // M0 > 5
1613#if M0 > 6
1614 ARM_DOT_K0XN0(1, a6, b, c6);
1615#endif // M0 > 6
1616#if M0 > 7
1617 ARM_DOT_K0XN0(1, a7, b, c7);
1618#endif // M0 > 7
1619
1620 lhs_offset += sizeof(DATA_TYPE);
1621 rhs_offset += sizeof(DATA_TYPE);
1622 }
1623
1624 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * dst_stride_y);
1625
1626 REPEAT_VAR_INIT_TO_CONST(8, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
1627
1628#if defined(REINTERPRET_OUTPUT_AS_3D)
1629 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
1630 // in order to take into account the presence of possible cross plane paddings
1631 //
1632 // | |
1633 // | plane0 |
1634 // | |
1635 // |__________________|
1636 // |******************|
1637 // | cross_plane_pad |
1638 // |******************|
1639 // | |
1640 // | plane1 |
1641 // | |
1642 // |__________________|
1643
1644 // The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
1645 zout0 = (0 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1646 zout0 = min((uint)(DEPTH_GEMM3D - 1), zout0);
1647 zout0 *= (dst_cross_plane_pad * dst_stride_y);
1648#if M0 > 1
1649 zout1 = (1 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1650 zout1 = min((uint)(DEPTH_GEMM3D - 1), zout1);
1651 zout1 *= (dst_cross_plane_pad * dst_stride_y);
1652#endif // M0 > 1
1653#if M0 > 2
1654 zout2 = (2 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1655 zout2 = min((uint)(DEPTH_GEMM3D - 1), zout2);
1656 zout2 *= (dst_cross_plane_pad * dst_stride_y);
1657#endif // M0 > 2
1658#if M0 > 3
1659 zout3 = (3 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1660 zout3 = min((uint)(DEPTH_GEMM3D - 1), zout3);
1661 zout3 *= (dst_cross_plane_pad * dst_stride_y);
1662#endif // M0 > 3
1663#if M0 > 4
1664 zout4 = (4 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1665 zout4 = min((uint)(DEPTH_GEMM3D - 1), zout4);
1666 zout4 *= (dst_cross_plane_pad * dst_stride_y);
1667#endif // M0 > 4
1668#if M0 > 5
1669 zout5 = (5 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1670 zout5 = min((uint)(DEPTH_GEMM3D - 1), zout5);
1671 zout5 *= (dst_cross_plane_pad * dst_stride_y);
1672#endif // M0 > 5
1673#if M0 > 6
1674 zout6 = (6 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1675 zout6 = min((uint)(DEPTH_GEMM3D - 1), zout6);
1676 zout6 *= (dst_cross_plane_pad * dst_stride_y);
1677#endif // M0 > 6
1678#if M0 > 7
1679 zout7 = (7 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1680 zout7 = min((uint)(DEPTH_GEMM3D - 1), zout7);
1681 zout7 *= (dst_cross_plane_pad * dst_stride_y);
1682#endif // M0 > 7
1683
1684 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1685 // multiply dst_stride_z by DEPTH_GEMM3D
1686 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1687
1688#else // defined(REINTERPRET_OUTPUT_AS_3D)
1689
1690 // Add offset for batched GEMM
1691 dst_addr += z * dst_stride_z;
1692
1693#endif // defined(REINTERPRET_OUTPUT_AS_3D)
1694
1695 // Multiply by the weight of matrix-matrix product and store the result
1696#if defined(ALPHA)
1697 c0 = c0 * (DATA_TYPE)ALPHA;
1698#if M0 > 1
1699 c1 = c1 * (DATA_TYPE)ALPHA;
1700#endif // M0 > 1
1701#if M0 > 2
1702 c2 = c2 * (DATA_TYPE)ALPHA;
1703#endif // M0 > 2
1704#if M0 > 3
1705 c3 = c3 * (DATA_TYPE)ALPHA;
1706#endif // M0 > 3
1707#if M0 > 4
1708 c4 = c4 * (DATA_TYPE)ALPHA;
1709#endif // M0 > 4
1710#if M0 > 5
1711 c5 = c5 * (DATA_TYPE)ALPHA;
1712#endif // M0 > 5
1713#if M0 > 6
1714 c6 = c6 * (DATA_TYPE)ALPHA;
1715#endif // M0 > 5
1716#if M0 > 7
1717 c7 = c7 * (DATA_TYPE)ALPHA;
1718#endif // M0 > 7
1719#endif // defined(ALPHA)
1720
1721 // Store output block
1722 VSTORE(N0)
1723 (c0, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y + zout0));
1724#if M0 > 1
1725 VSTORE(N0)
1726 (c1, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y + zout1));
1727#endif // M0 > 1
1728#if M0 > 2
1729 VSTORE(N0)
1730 (c2, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y + zout2));
1731#endif // M0 > 2
1732#if M0 > 3
1733 VSTORE(N0)
1734 (c3, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y + zout3));
1735#endif // M0 > 3
1736#if M0 > 4
1737 VSTORE(N0)
1738 (c4, 0, (__global DATA_TYPE *)(dst_addr + 4 * dst_stride_y + zout4));
1739#endif // M0 > 4
1740#if M0 > 5
1741 VSTORE(N0)
1742 (c5, 0, (__global DATA_TYPE *)(dst_addr + 5 * dst_stride_y + zout5));
1743#endif // M0 > 5
1744#if M0 > 6
1745 VSTORE(N0)
1746 (c6, 0, (__global DATA_TYPE *)(dst_addr + 6 * dst_stride_y + zout6));
1747#endif // M0 > 6
1748#if M0 > 7
1749 VSTORE(N0)
1750 (c7, 0, (__global DATA_TYPE *)(dst_addr + 7 * dst_stride_y + zout7));
1751#endif // M0 > 7
1752
1753#undef RHS_BLOCK_SIZE
1754#undef RHS_OFFSET_X
1755#undef RHS_STEP_X
1756}
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001757
1758#define VFMA(a, b, c) \
1759 ({ \
1760 c = fma(a, b, c); \
1761 })
1762
1763#if M0 == 1
1764#define LD_RHS_VFMA_M0xN0(i, a, c) \
1765 ({ \
1766 VEC_DATA_TYPE(DATA_TYPE, N0) \
1767 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1768 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1769 })
1770#elif M0 == 2 // M0 == 2
1771#define LD_RHS_VFMA_M0xN0(i, a, c) \
1772 ({ \
1773 VEC_DATA_TYPE(DATA_TYPE, N0) \
1774 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1775 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1776 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1777 })
1778#elif M0 == 3 // M0 == 3
1779#define LD_RHS_VFMA_M0xN0(i, a, c) \
1780 ({ \
1781 VEC_DATA_TYPE(DATA_TYPE, N0) \
1782 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1783 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1784 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1785 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1786 })
1787#elif M0 == 4 // M0 == 4
1788#define LD_RHS_VFMA_M0xN0(i, a, c) \
1789 ({ \
1790 VEC_DATA_TYPE(DATA_TYPE, N0) \
1791 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1792 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1793 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1794 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1795 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1796 })
1797#elif M0 == 5 // M0 == 5
1798#define LD_RHS_VFMA_M0xN0(i, a, c) \
1799 ({ \
1800 VEC_DATA_TYPE(DATA_TYPE, N0) \
1801 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1802 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1803 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1804 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1805 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1806 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
1807 })
1808#elif M0 == 6 // M0 == 6
1809#define LD_RHS_VFMA_M0xN0(i, a, c) \
1810 ({ \
1811 VEC_DATA_TYPE(DATA_TYPE, N0) \
1812 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1813 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1814 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1815 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1816 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1817 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
1818 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
1819 })
1820#elif M0 == 7 // M0 == 7
1821#define LD_RHS_VFMA_M0xN0(i, a, c) \
1822 ({ \
1823 VEC_DATA_TYPE(DATA_TYPE, N0) \
1824 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1825 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1826 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1827 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1828 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1829 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
1830 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
1831 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##6).s##i), b, (c##6)); \
1832 })
1833#elif M0 == 8 // M0 == 8
1834#define LD_RHS_VFMA_M0xN0(i, a, c) \
1835 ({ \
1836 VEC_DATA_TYPE(DATA_TYPE, N0) \
1837 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1838 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1839 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1840 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1841 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1842 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
1843 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
1844 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##6).s##i), b, (c##6)); \
1845 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##7).s##i), b, (c##7)); \
1846 })
1847#else // M0 not supported
1848#error "M0 not supported"
1849#endif // M0 not supported
1850
1851/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
1852 * The LHS matrix is NOT reshaped
1853 * The RHS is reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the block K0xN0 is NOT transposed
1854 *
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001855 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
1856 * @note The GEMM's dimensions (M,N and K) must be passed at compile time using -DM, -DN and and -DK (i.e. -DM=52, -DN=30 and -DK=90).
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001857 * @note The block's dimensions used for reshaping the RHS matrix (N0 and K0) must be passed at compile time using -DN0 and -DK0 (i.e. -DN0=8, -DK0=4).
1858 * @note The number of M0 rows to process must be passed at compile time using -DM0 (i.e. -DM0=2)
1859 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (i.e. -DH0=2)
1860 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
1861 * @note Only the following configurations of M0, N0 and K0 are currently supported:
1862 * - M0 = 1, 2, 3, 4, 5, 6, 7, 8
1863 * - N0 = 2, 3, 4, 8, 16
1864 * - K0 = 2, 3, 4, 8, 16
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001865 * - H0 >= 1
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001866 *
1867 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
1868 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
1869 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1870 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1871 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1872 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix
1873 *
1874 * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: F16/F32
1875 * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes)
1876 * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1877 * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes)
1878 * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1879 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
1880 * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
1881 * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes)
1882 * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1883 * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
1884 * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1885 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
1886 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
1887 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1888 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1889 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1890 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1891 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1892 * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes)
1893 * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
1894 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1895 * @param[in] lhs_cross_plane_pad (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
1896 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
1897 */
1898__kernel void gemm_mm_reshaped_only_rhs_nt(IMAGE_DECLARATION(lhs),
1899 IMAGE_DECLARATION(rhs),
1900 IMAGE_DECLARATION(dst),
1901 uint lhs_stride_z,
1902 uint rhs_stride_z,
1903 uint dst_stride_z
1904#if defined(REINTERPRET_INPUT_AS_3D)
1905 ,
1906 uint lhs_cross_plane_pad
1907#endif // REINTERPRET_INPUT_AS_3D
1908#if defined(REINTERPRET_OUTPUT_AS_3D)
1909 ,
1910 uint dst_cross_plane_pad
1911#endif // REINTERPRET_OUTPUT_AS_3D
1912 )
1913{
1914 // Block size
1915#define RHS_BLOCK_SIZE ((K0) * (N0))
1916
1917 // RHS offset and step X
1918#if defined(RHS_INTERLEAVE)
1919#define RHS_OFFSET_X (N0)
1920#define RHS_STEP_X ((N0) * (H0))
1921#define RHS_STEP_LOOP (1)
1922#else // defined(RHS_INTERLEAVE)
1923#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
1924#define RHS_STEP_X (N0)
1925#define RHS_STEP_LOOP (H0)
1926#endif // defined(RHS_INTERLEAVE)
1927
1928 uint x = get_global_id(0);
1929 uint y = get_global_id(1);
1930 uint z = get_global_id(2);
1931
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001932#if defined(DUMMY_WORK_ITEMS)
1933 if((x * N0 >= N) || (y * M0 >= M))
1934 {
1935 return;
1936 }
1937#endif // defined(DUMMY_WORK_ITEMS)
1938
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001939 // Compute LHS matrix address
1940 uint lhs_offset = lhs_offset_first_element_in_bytes + y * M0 * (uint)lhs_stride_y;
1941
1942 // Compute RHS matrix address
1943 uint rhs_offset = rhs_offset_first_element_in_bytes + (x % H0) * (uint)RHS_OFFSET_X * sizeof(DATA_TYPE) + (x / (uint)H0) * rhs_stride_y;
1944
1945#if defined(MATRIX_B_DEPTH)
1946 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1947 rhs_offset += (z % MATRIX_B_DEPTH) * rhs_stride_z;
1948#else // defined(MATRIX_B_DEPTH)
1949 rhs_offset += z * rhs_stride_z;
1950#endif // defined(MATRIX_B_DEPTH)
1951
1952 REPEAT_VAR_INIT_TO_CONST(8, uint, zin, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
1953
1954#if defined(REINTERPRET_INPUT_AS_3D)
1955 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
1956 // in order to take into account the presence of possible cross plane paddings
1957 //
1958 // | |
1959 // | plane0 |
1960 // | |
1961 // |__________________|
1962 // |******************|
1963 // | cross_plane_pad |
1964 // |******************|
1965 // | |
1966 // | plane1 |
1967 // | |
1968 // |__________________|
1969
1970 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
1971 zin0 = (0 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1972 zin0 = min((uint)(DEPTH_GEMM3D - 1), zin0);
1973 zin0 *= (lhs_cross_plane_pad * lhs_stride_y);
1974#if M0 > 1
1975 zin1 = (1 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1976 zin1 = min((uint)(DEPTH_GEMM3D - 1), zin1);
1977 zin1 *= (lhs_cross_plane_pad * lhs_stride_y);
1978#endif // M0 > 1
1979#if M0 > 2
1980 zin2 = (2 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1981 zin2 = min((uint)(DEPTH_GEMM3D - 1), zin2);
1982 zin2 *= (lhs_cross_plane_pad * lhs_stride_y);
1983#endif // M0 > 2
1984#if M0 > 3
1985 zin3 = (3 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1986 zin3 = min((uint)(DEPTH_GEMM3D - 1), zin3);
1987 zin3 *= (lhs_cross_plane_pad * lhs_stride_y);
1988#endif // M0 > 3
1989#if M0 > 4
1990 zin4 = (4 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1991 zin4 = min((uint)(DEPTH_GEMM3D - 1), zin4);
1992 zin4 *= (lhs_cross_plane_pad * lhs_stride_y);
1993#endif // M0 > 4
1994#if M0 > 5
1995 zin5 = (5 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1996 zin5 = min((uint)(DEPTH_GEMM3D - 1), zin5);
1997 zin5 *= (lhs_cross_plane_pad * lhs_stride_y);
1998#endif // M0 > 5
1999#if M0 > 6
2000 zin6 = (6 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2001 zin6 = min((uint)(DEPTH_GEMM3D - 1), zin6);
2002 zin6 *= (lhs_cross_plane_pad * lhs_stride_y);
2003#endif // M0 > 6
2004#if M0 > 7
2005 zin7 = (7 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2006 zin7 = min((uint)(DEPTH_GEMM3D - 1), zout7);
2007 zin7 *= (lhs_cross_plane_pad * lhs_stride_y);
2008#endif // M0 > 7
2009
2010 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2011 // multiply lhs_stride_z by DEPTH_GEMM3D
2012 lhs_offset += z * lhs_stride_z * DEPTH_GEMM3D;
2013
2014#else // defined(REINTERPRET_INPUT_AS_3D)
2015
2016 // Add offset for batched GEMM
2017 lhs_offset += z * lhs_stride_z;
2018
2019#endif // defined(REINTERPRET_INPUT_AS_3D)
2020
2021 // Initialize the accumulators
2022 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(N0-1)=0;
2023
2024 int i = 0;
2025 for(; i <= (K - K0); i += K0)
2026 {
2027 // Supported cases (M0, K0):
2028 // 1,2 - 1,3 - 1,4 - 1,8 - 1,16
2029 // 2,2 - 2,3 - 2,4 - 2,8 - 2,16
2030 // 3,2 - 3,3 - 3,4 - 3,8 - 3,16
2031 // 4,2 - 4,3 - 4,4 - 4,8 - 4,16
2032 // 5,2 - 5,3 - 5,4 - 5,8 - 5,16
2033 // 6,2 - 6,3 - 6,4 - 6,8 - 6,16
2034 // 7,2 - 7,3 - 7,4 - 7,8 - 7,16
2035 // 8,2 - 8,3 - 8,4 - 8,8 - 8,16
2036 // Load values from LHS matrix
2037 VEC_DATA_TYPE(DATA_TYPE, K0)
2038 a0 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 0 * lhs_stride_y + zin0));
2039#if M0 > 1
2040 VEC_DATA_TYPE(DATA_TYPE, K0)
2041 a1 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 1 * lhs_stride_y + zin1));
2042#endif // M0 > 1
2043#if M0 > 2
2044 VEC_DATA_TYPE(DATA_TYPE, K0)
2045 a2 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 2 * lhs_stride_y + zin2));
2046#endif // M0 > 2
2047#if M0 > 3
2048 VEC_DATA_TYPE(DATA_TYPE, K0)
2049 a3 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 3 * lhs_stride_y + zin3));
2050#endif // M0 > 3
2051#if M0 > 4
2052 VEC_DATA_TYPE(DATA_TYPE, K0)
2053 a4 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 4 * lhs_stride_y + zin4));
2054#endif // M0 > 4
2055#if M0 > 5
2056 VEC_DATA_TYPE(DATA_TYPE, K0)
2057 a5 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 5 * lhs_stride_y + zin5));
2058#endif // M0 > 5
2059#if M0 > 6
2060 VEC_DATA_TYPE(DATA_TYPE, K0)
2061 a6 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 6 * lhs_stride_y + zin6));
2062#endif // M0 > 6
2063#if M0 > 7
2064 VEC_DATA_TYPE(DATA_TYPE, K0)
2065 a7 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 7 * lhs_stride_y + zin7));
2066#endif // M0 > 7
2067
2068 LD_RHS_VFMA_M0xN0(0, a, c);
2069 LD_RHS_VFMA_M0xN0(1, a, c);
2070#if K0 > 2
2071 LD_RHS_VFMA_M0xN0(2, a, c);
2072#endif // K0 > 2
2073#if K0 > 3
2074 LD_RHS_VFMA_M0xN0(3, a, c);
2075#endif // K0 > 3
2076#if K0 > 4
2077 LD_RHS_VFMA_M0xN0(4, a, c);
2078 LD_RHS_VFMA_M0xN0(5, a, c);
2079 LD_RHS_VFMA_M0xN0(6, a, c);
2080 LD_RHS_VFMA_M0xN0(7, a, c);
2081#endif // K0 > 4
2082#if K0 > 8
2083 LD_RHS_VFMA_M0xN0(8, a, c);
2084 LD_RHS_VFMA_M0xN0(9, a, c);
2085 LD_RHS_VFMA_M0xN0(A, a, c);
2086 LD_RHS_VFMA_M0xN0(B, a, c);
2087 LD_RHS_VFMA_M0xN0(C, a, c);
2088 LD_RHS_VFMA_M0xN0(D, a, c);
2089 LD_RHS_VFMA_M0xN0(E, a, c);
2090 LD_RHS_VFMA_M0xN0(F, a, c);
2091#endif // K0 > 8
2092
2093 lhs_offset += K0 * sizeof(DATA_TYPE);
2094 rhs_offset += K0 * RHS_STEP_X * RHS_STEP_LOOP * sizeof(DATA_TYPE);
2095 }
2096
2097 // Left-over accumulations
2098 for(; i < K; ++i)
2099 {
2100 // Load values from LHS matrix
2101 VEC_DATA_TYPE(DATA_TYPE, 2)
2102 a0 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 0 * lhs_stride_y + zin0));
2103#if M0 > 1
2104 VEC_DATA_TYPE(DATA_TYPE, 2)
2105 a1 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 1 * lhs_stride_y + zin1));
2106#endif // M0 > 1
2107#if M0 > 2
2108 VEC_DATA_TYPE(DATA_TYPE, 2)
2109 a2 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 2 * lhs_stride_y + zin2));
2110#endif // M0 > 2
2111#if M0 > 3
2112 VEC_DATA_TYPE(DATA_TYPE, 2)
2113 a3 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 3 * lhs_stride_y + zin3));
2114#endif // M0 > 3
2115#if M0 > 4
2116 VEC_DATA_TYPE(DATA_TYPE, 2)
2117 a4 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 4 * lhs_stride_y + zin4));
2118#endif // M0 > 4
2119#if M0 > 5
2120 VEC_DATA_TYPE(DATA_TYPE, 2)
2121 a5 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 5 * lhs_stride_y + zin5));
2122#endif // M0 > 5
2123#if M0 > 6
2124 VEC_DATA_TYPE(DATA_TYPE, 2)
2125 a6 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 6 * lhs_stride_y + zin6));
2126#endif // M0 > 6
2127#if M0 > 7
2128 VEC_DATA_TYPE(DATA_TYPE, 2)
2129 a7 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 7 * lhs_stride_y + zin));
2130#endif // M0 > 7
2131
2132 LD_RHS_VFMA_M0xN0(0, a, c);
2133
2134 lhs_offset += sizeof(DATA_TYPE);
2135 rhs_offset += RHS_STEP_X * sizeof(DATA_TYPE);
2136 }
2137
2138 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * dst_stride_y);
2139
2140 REPEAT_VAR_INIT_TO_CONST(8, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
2141
2142#if defined(REINTERPRET_OUTPUT_AS_3D)
2143 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
2144 // in order to take into account the presence of possible cross plane paddings
2145 //
2146 // | |
2147 // | plane0 |
2148 // | |
2149 // |__________________|
2150 // |******************|
2151 // | cross_plane_pad |
2152 // |******************|
2153 // | |
2154 // | plane1 |
2155 // | |
2156 // |__________________|
2157
2158 // The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
2159 zout0 = (0 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2160 zout0 = min((uint)(DEPTH_GEMM3D - 1), zout0);
2161 zout0 *= (dst_cross_plane_pad * dst_stride_y);
2162#if M0 > 1
2163 zout1 = (1 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2164 zout1 = min((uint)(DEPTH_GEMM3D - 1), zout1);
2165 zout1 *= (dst_cross_plane_pad * dst_stride_y);
2166#endif // M0 > 1
2167#if M0 > 2
2168 zout2 = (2 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2169 zout2 = min((uint)(DEPTH_GEMM3D - 1), zout2);
2170 zout2 *= (dst_cross_plane_pad * dst_stride_y);
2171#endif // M0 > 2
2172#if M0 > 3
2173 zout3 = (3 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2174 zout3 = min((uint)(DEPTH_GEMM3D - 1), zout3);
2175 zout3 *= (dst_cross_plane_pad * dst_stride_y);
2176#endif // M0 > 3
2177#if M0 > 4
2178 zout4 = (4 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2179 zout4 = min((uint)(DEPTH_GEMM3D - 1), zout4);
2180 zout4 *= (dst_cross_plane_pad * dst_stride_y);
2181#endif // M0 > 4
2182#if M0 > 5
2183 zout5 = (5 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2184 zout5 = min((uint)(DEPTH_GEMM3D - 1), zout5);
2185 zout5 *= (dst_cross_plane_pad * dst_stride_y);
2186#endif // M0 > 5
2187#if M0 > 6
2188 zout6 = (6 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2189 zout6 = min((uint)(DEPTH_GEMM3D - 1), zout6);
2190 zout6 *= (dst_cross_plane_pad * dst_stride_y);
2191#endif // M0 > 6
2192#if M0 > 7
2193 zout7 = (7 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2194 zout7 = min((uint)(DEPTH_GEMM3D - 1), zout7);
2195 zout7 *= (dst_cross_plane_pad * dst_stride_y);
2196#endif // M0 > 7
2197
2198 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2199 // multiply dst_stride_z by DEPTH_GEMM3D
2200 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
2201
2202#else // defined(REINTERPRET_OUTPUT_AS_3D)
2203
2204 // Add offset for batched GEMM
2205 dst_addr += z * dst_stride_z;
2206
2207#endif // defined(REINTERPRET_OUTPUT_AS_3D)
2208
2209 // Multiply by the weight of matrix-matrix product and store the result
2210#if defined(ALPHA)
2211 c0 = c0 * (DATA_TYPE)ALPHA;
2212#if M0 > 1
2213 c1 = c1 * (DATA_TYPE)ALPHA;
2214#endif // M0 > 1
2215#if M0 > 2
2216 c2 = c2 * (DATA_TYPE)ALPHA;
2217#endif // M0 > 2
2218#if M0 > 3
2219 c3 = c3 * (DATA_TYPE)ALPHA;
2220#endif // M0 > 3
2221#if M0 > 4
2222 c4 = c4 * (DATA_TYPE)ALPHA;
2223#endif // M0 > 4
2224#if M0 > 5
2225 c5 = c5 * (DATA_TYPE)ALPHA;
2226#endif // M0 > 5
2227#if M0 > 6
2228 c6 = c6 * (DATA_TYPE)ALPHA;
2229#endif // M0 > 5
2230#if M0 > 7
2231 c7 = c7 * (DATA_TYPE)ALPHA;
2232#endif // M0 > 7
2233#endif // defined(ALPHA)
2234
2235 // Store output block
2236 VSTORE(N0)
2237 (c0, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y + zout0));
2238#if M0 > 1
2239 VSTORE(N0)
2240 (c1, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y + zout1));
2241#endif // M0 > 1
2242#if M0 > 2
2243 VSTORE(N0)
2244 (c2, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y + zout2));
2245#endif // M0 > 2
2246#if M0 > 3
2247 VSTORE(N0)
2248 (c3, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y + zout3));
2249#endif // M0 > 3
2250#if M0 > 4
2251 VSTORE(N0)
2252 (c4, 0, (__global DATA_TYPE *)(dst_addr + 4 * dst_stride_y + zout4));
2253#endif // M0 > 4
2254#if M0 > 5
2255 VSTORE(N0)
2256 (c5, 0, (__global DATA_TYPE *)(dst_addr + 5 * dst_stride_y + zout5));
2257#endif // M0 > 5
2258#if M0 > 6
2259 VSTORE(N0)
2260 (c6, 0, (__global DATA_TYPE *)(dst_addr + 6 * dst_stride_y + zout6));
2261#endif // M0 > 6
2262#if M0 > 7
2263 VSTORE(N0)
2264 (c7, 0, (__global DATA_TYPE *)(dst_addr + 7 * dst_stride_y + zout7));
2265#endif // M0 > 7
2266
2267#undef RHS_BLOCK_SIZE
2268#undef RHS_OFFSET_X
2269#undef RHS_STEP_X
2270}
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00002271#endif // defined(M0) && defined(N0) && defined(K0) && defined(H0) && defined(DATA_TYPE) && defined(M) && defined(N) && defined(K)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00002272
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00002273#if defined(M0) && defined(N0) && defined(K0) && defined(V0) && defined(H0) && defined(DATA_TYPE) && defined(M) && defined(N)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002274
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00002275#if K0 == 2
2276#define ARM_DOT_K0(a, b, c) \
2277 ({ \
2278 c = fma(a.s0, b.s0, c); \
2279 c = fma(a.s1, b.s1, c); \
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002280 })
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00002281#elif K0 == 3 // K0 == 3
2282#define ARM_DOT_K0(a, b, c) \
2283 ({ \
2284 c = fma(a.s0, b.s0, c); \
2285 c = fma(a.s1, b.s1, c); \
2286 c = fma(a.s2, b.s2, c); \
2287 })
2288#elif K0 == 4 // K0 == 4
2289#define ARM_DOT_K0(a, b, c) \
2290 ({ \
2291 c = fma(a.s0, b.s0, c); \
2292 c = fma(a.s1, b.s1, c); \
2293 c = fma(a.s2, b.s2, c); \
2294 c = fma(a.s3, b.s3, c); \
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002295 })
2296#elif K0 == 8 // K0 == 8
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00002297#define ARM_DOT_K0(a, b, c) \
2298 ({ \
2299 c = fma(a.s0, b.s0, c); \
2300 c = fma(a.s1, b.s1, c); \
2301 c = fma(a.s2, b.s2, c); \
2302 c = fma(a.s3, b.s3, c); \
2303 c = fma(a.s4, b.s4, c); \
2304 c = fma(a.s5, b.s5, c); \
2305 c = fma(a.s6, b.s6, c); \
2306 c = fma(a.s7, b.s7, c); \
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002307 })
2308#elif K0 == 16 // K0 == 16
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00002309#define ARM_DOT_K0(a, b, c) \
2310 ({ \
2311 c = fma(a.s0, b.s0, c); \
2312 c = fma(a.s1, b.s1, c); \
2313 c = fma(a.s2, b.s2, c); \
2314 c = fma(a.s3, b.s3, c); \
2315 c = fma(a.s4, b.s4, c); \
2316 c = fma(a.s5, b.s5, c); \
2317 c = fma(a.s6, b.s6, c); \
2318 c = fma(a.s7, b.s7, c); \
2319 c = fma(a.s8, b.s8, c); \
2320 c = fma(a.s9, b.s9, c); \
2321 c = fma(a.sA, b.sA, c); \
2322 c = fma(a.sB, b.sB, c); \
2323 c = fma(a.sC, b.sC, c); \
2324 c = fma(a.sD, b.sD, c); \
2325 c = fma(a.sE, b.sE, c); \
2326 c = fma(a.sF, b.sF, c); \
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002327 })
2328#else // K0 not supported
2329#error "K0 value not supported"
2330#endif // K0 conditions
2331
2332#if N0 == 2
2333#define ARM_DOT_K0XN0(a, b, c) \
2334 ({ \
2335 ARM_DOT_K0((a), (b##0), (c.s0)); \
2336 ARM_DOT_K0((a), (b##1), (c.s1)); \
2337 })
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00002338#elif N0 == 3 // N0 == 3
2339#define ARM_DOT_K0XN0(a, b, c) \
2340 ({ \
2341 ARM_DOT_K0((a), (b##0), (c.s0)); \
2342 ARM_DOT_K0((a), (b##1), (c.s1)); \
2343 ARM_DOT_K0((a), (b##2), (c.s2)); \
2344 })
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002345#elif N0 == 4 // N0 == 4
2346#define ARM_DOT_K0XN0(a, b, c) \
2347 ({ \
2348 ARM_DOT_K0((a), (b##0), (c.s0)); \
2349 ARM_DOT_K0((a), (b##1), (c.s1)); \
2350 ARM_DOT_K0((a), (b##2), (c.s2)); \
2351 ARM_DOT_K0((a), (b##3), (c.s3)); \
2352 })
2353#elif N0 == 8 // N0 == 8
2354#define ARM_DOT_K0XN0(a, b, c) \
2355 ({ \
2356 ARM_DOT_K0((a), (b##0), (c.s0)); \
2357 ARM_DOT_K0((a), (b##1), (c.s1)); \
2358 ARM_DOT_K0((a), (b##2), (c.s2)); \
2359 ARM_DOT_K0((a), (b##3), (c.s3)); \
2360 ARM_DOT_K0((a), (b##4), (c.s4)); \
2361 ARM_DOT_K0((a), (b##5), (c.s5)); \
2362 ARM_DOT_K0((a), (b##6), (c.s6)); \
2363 ARM_DOT_K0((a), (b##7), (c.s7)); \
2364 })
2365#elif N0 == 16 // N0 == 16
2366#define ARM_DOT_K0XN0(a, b, c) \
2367 ({ \
2368 ARM_DOT_K0((a), (b##0), (c.s0)); \
2369 ARM_DOT_K0((a), (b##1), (c.s1)); \
2370 ARM_DOT_K0((a), (b##2), (c.s2)); \
2371 ARM_DOT_K0((a), (b##3), (c.s3)); \
2372 ARM_DOT_K0((a), (b##4), (c.s4)); \
2373 ARM_DOT_K0((a), (b##5), (c.s5)); \
2374 ARM_DOT_K0((a), (b##6), (c.s6)); \
2375 ARM_DOT_K0((a), (b##7), (c.s7)); \
2376 ARM_DOT_K0((a), (b##8), (c.s8)); \
2377 ARM_DOT_K0((a), (b##9), (c.s9)); \
2378 ARM_DOT_K0((a), (b##A), (c.sA)); \
2379 ARM_DOT_K0((a), (b##B), (c.sB)); \
2380 ARM_DOT_K0((a), (b##C), (c.sC)); \
2381 ARM_DOT_K0((a), (b##D), (c.sD)); \
2382 ARM_DOT_K0((a), (b##E), (c.sE)); \
2383 ARM_DOT_K0((a), (b##F), (c.sF)); \
2384 })
2385#else // N0 not supported
2386#error "N0 value not supported"
2387#endif // N0 conditions
2388
2389/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
2390 * The LHS matrix must be reshaped with @ref CLGEMMReshapeLHSMatrixKernel and the M0xK0 must be NOT transposed
2391 * The RHS matrix must be reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the K0xN0 must be transposed
2392 *
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00002393 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
2394 * @note The GEMM's dimensions M and N must be passed at compile time using -DM and -DN (i.e. -DM=52 and -DN=90).
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002395 * @note The block's dimensions used for reshaping the LHS matrix and the RHS matrix (M0, N0 and K0) must be passed at compile time using -DM0, -DN0 and -DK0 (i.e. -DM0=4, -DN0=8, -DK0=4).
2396 * @note The number of M0xK0 vertical blocks stored on the same output row of the reshaped LHS matrix must be passed at compile time using -DV0 (i.e. -DV0=2)
2397 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (i.e. -DH0=2)
2398 * @note If the M0xK0 blocks in the reshaped LHS matrix have been interleaved, the option -DLHS_INTERLEAVE must passed at compile time.
2399 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
2400 * @note Only the following configurations of M0, N0 and K0 are currently supported:
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00002401 * - M0 = 1, 2, 3, 4, 5, 6, 7, 8
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00002402 * - N0 = 2, 3, 4, 8, 16
2403 * - K0 = 2, 3, 4, 8, 16
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002404 *
2405 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
2406 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2407 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2408 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2409 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix NOT reshaped
2410 *
2411 * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: F16/F32
2412 * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes)
2413 * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2414 * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes)
2415 * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2416 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
Gian Marco Iodice49b10152018-12-14 17:13:34 +00002417 * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002418 * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes)
2419 * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2420 * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
2421 * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2422 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
Gian Marco Iodice49b10152018-12-14 17:13:34 +00002423 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002424 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2425 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2426 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2427 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2428 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00002429 * @param[in] k Number of columns in LHS matrix and rows in RHS matrix not reshaped.
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002430 * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes)
2431 * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
2432 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2433 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
2434 */
2435__kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs),
2436 IMAGE_DECLARATION(rhs),
2437 IMAGE_DECLARATION(dst),
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00002438 uint k,
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002439 uint lhs_stride_z,
2440 uint rhs_stride_z,
2441 uint dst_stride_z
2442#if defined(REINTERPRET_OUTPUT_AS_3D)
2443 ,
2444 uint dst_cross_plane_pad
2445#endif // REINTERPRET_OUTPUT_AS_3D
2446 )
2447{
2448 // Block size
2449#define LHS_BLOCK_SIZE ((K0) * (M0))
2450
2451#if defined(LHS_INTERLEAVE)
2452#define LHS_OFFSET_X (K0)
2453#define LHS_STEP_X ((K0) * (V0))
2454#define LHS_STEP_LOOP (1)
2455#else // defined(INTERLEAVE)
2456#define LHS_OFFSET_X (LHS_BLOCK_SIZE)
2457#define LHS_STEP_X (K0)
2458#define LHS_STEP_LOOP (V0)
2459#endif // defined(INTERLEAVE)
2460
2461 // Block size
2462#define RHS_BLOCK_SIZE ((K0) * (N0))
2463
2464 // RHS offset and step X
2465#if defined(RHS_INTERLEAVE)
2466#define RHS_OFFSET_X (K0)
2467#define RHS_STEP_X ((K0) * (H0))
2468#define RHS_STEP_LOOP (1)
2469#else // defined(RHS_INTERLEAVE)
2470#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
2471#define RHS_STEP_X (K0)
2472#define RHS_STEP_LOOP (H0)
2473#endif // defined(RHS_INTERLEAVE)
2474
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00002475#if defined(DUMMY_WORK_ITEMS)
2476 if((get_global_id(0) * N0 >= N) || (get_global_id(1) * M0 >= M))
2477 {
2478 return;
2479 }
2480#endif // defined(DUMMY_WORK_ITEMS)
2481
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002482 // Compute LHS matrix address
2483 __global uchar *lhs_addr = lhs_ptr + lhs_offset_first_element_in_bytes + (get_global_id(1) % V0) * (uint)LHS_OFFSET_X * sizeof(DATA_TYPE) + (get_global_id(1) / V0) * (uint)lhs_stride_y +
2484 (get_global_id(2) * lhs_stride_z);
2485
2486 // Compute RHS matrix address
2487 __global uchar *rhs_addr = rhs_ptr + rhs_offset_first_element_in_bytes + (get_global_id(0) % H0) * (uint)RHS_OFFSET_X * sizeof(DATA_TYPE) + (get_global_id(0) / (uint)H0) * rhs_stride_y;
2488
2489#if defined(MATRIX_B_DEPTH)
2490 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2491 rhs_addr += (get_global_id(2) % MATRIX_B_DEPTH) * rhs_stride_z;
2492#else // defined(MATRIX_B_DEPTH)
2493 rhs_addr += get_global_id(2) * rhs_stride_z;
2494#endif // defined(MATRIX_B_DEPTH)
2495
2496 // Initialize the accumulators
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +00002497 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(M0-1)=0;
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002498
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00002499 for(int i = 0; i < k; i += K0)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002500 {
2501 // Supported cases (M0, K0):
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00002502 // 1,2 - 1,3 - 1,4 - 1,8 - 1,16
2503 // 2,2 - 2,3 - 2,4 - 2,8 - 2,16
2504 // 3,2 - 3,3 - 3,4 - 3,8 - 3,16
2505 // 4,2 - 4,3 - 4,4 - 4,8 - 4,16
2506 // 5,2 - 5,3 - 5,4 - 5,8 - 5,16
2507 // 6,2 - 6,3 - 6,4 - 6,8 - 6,16
2508 // 7,2 - 7,3 - 7,4 - 7,8 - 7,16
2509 // 8,2 - 8,3 - 8,4 - 8,8 - 8,16
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002510 // Load values from LHS matrix
2511 VEC_DATA_TYPE(DATA_TYPE, K0)
2512 a0 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_addr + 0 * LHS_STEP_X * sizeof(DATA_TYPE)));
2513#if M0 > 1
2514 VEC_DATA_TYPE(DATA_TYPE, K0)
2515 a1 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_addr + 1 * LHS_STEP_X * sizeof(DATA_TYPE)));
2516#endif // M0 > 1
2517#if M0 > 2
2518 VEC_DATA_TYPE(DATA_TYPE, K0)
2519 a2 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_addr + 2 * LHS_STEP_X * sizeof(DATA_TYPE)));
2520#endif // M0 > 2
2521#if M0 > 3
2522 VEC_DATA_TYPE(DATA_TYPE, K0)
2523 a3 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_addr + 3 * LHS_STEP_X * sizeof(DATA_TYPE)));
2524#endif // M0 > 3
2525#if M0 > 4
2526 VEC_DATA_TYPE(DATA_TYPE, K0)
2527 a4 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_addr + 4 * LHS_STEP_X * sizeof(DATA_TYPE)));
2528#endif // M0 > 4
2529#if M0 > 5
2530 VEC_DATA_TYPE(DATA_TYPE, K0)
2531 a5 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_addr + 5 * LHS_STEP_X * sizeof(DATA_TYPE)));
2532#endif // M0 > 5
2533#if M0 > 6
2534 VEC_DATA_TYPE(DATA_TYPE, K0)
2535 a6 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_addr + 6 * LHS_STEP_X * sizeof(DATA_TYPE)));
2536#endif // M0 > 6
2537#if M0 > 7
2538 VEC_DATA_TYPE(DATA_TYPE, K0)
2539 a7 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_addr + 7 * LHS_STEP_X * sizeof(DATA_TYPE)));
2540#endif // M0 > 7
2541
2542 // Load values from RHS matrix
2543 VEC_DATA_TYPE(DATA_TYPE, K0)
2544 b0 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 0 * RHS_STEP_X * sizeof(DATA_TYPE)));
2545 VEC_DATA_TYPE(DATA_TYPE, K0)
2546 b1 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 1 * RHS_STEP_X * sizeof(DATA_TYPE)));
2547#if N0 > 2
2548 VEC_DATA_TYPE(DATA_TYPE, K0)
2549 b2 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 2 * RHS_STEP_X * sizeof(DATA_TYPE)));
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00002550#endif // N0 > 2
2551#if N0 > 3
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002552 VEC_DATA_TYPE(DATA_TYPE, K0)
2553 b3 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 3 * RHS_STEP_X * sizeof(DATA_TYPE)));
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00002554#endif // N0 > 3
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002555#if N0 > 4
2556 VEC_DATA_TYPE(DATA_TYPE, K0)
2557 b4 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 4 * RHS_STEP_X * sizeof(DATA_TYPE)));
2558 VEC_DATA_TYPE(DATA_TYPE, K0)
2559 b5 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 5 * RHS_STEP_X * sizeof(DATA_TYPE)));
2560 VEC_DATA_TYPE(DATA_TYPE, K0)
2561 b6 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 6 * RHS_STEP_X * sizeof(DATA_TYPE)));
2562 VEC_DATA_TYPE(DATA_TYPE, K0)
2563 b7 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 7 * RHS_STEP_X * sizeof(DATA_TYPE)));
2564#endif // N0 > 4
2565#if N0 > 8
2566 VEC_DATA_TYPE(DATA_TYPE, K0)
2567 b8 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 8 * RHS_STEP_X * sizeof(DATA_TYPE)));
2568 VEC_DATA_TYPE(DATA_TYPE, K0)
2569 b9 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 9 * RHS_STEP_X * sizeof(DATA_TYPE)));
2570 VEC_DATA_TYPE(DATA_TYPE, K0)
2571 bA = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 10 * RHS_STEP_X * sizeof(DATA_TYPE)));
2572 VEC_DATA_TYPE(DATA_TYPE, K0)
2573 bB = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 11 * RHS_STEP_X * sizeof(DATA_TYPE)));
2574 VEC_DATA_TYPE(DATA_TYPE, K0)
2575 bC = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 12 * RHS_STEP_X * sizeof(DATA_TYPE)));
2576 VEC_DATA_TYPE(DATA_TYPE, K0)
2577 bD = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 13 * RHS_STEP_X * sizeof(DATA_TYPE)));
2578 VEC_DATA_TYPE(DATA_TYPE, K0)
2579 bE = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 14 * RHS_STEP_X * sizeof(DATA_TYPE)));
2580 VEC_DATA_TYPE(DATA_TYPE, K0)
2581 bF = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 15 * RHS_STEP_X * sizeof(DATA_TYPE)));
2582#endif // N0 > 8
2583
2584 // Accumulate
2585 ARM_DOT_K0XN0(a0, b, c0);
2586#if M0 > 1
2587 ARM_DOT_K0XN0(a1, b, c1);
2588#endif // M0 > 1
2589#if M0 > 2
2590 ARM_DOT_K0XN0(a2, b, c2);
2591#endif // M0 > 2
2592#if M0 > 3
2593 ARM_DOT_K0XN0(a3, b, c3);
2594#endif // M0 > 3
2595#if M0 > 4
2596 ARM_DOT_K0XN0(a4, b, c4);
2597#endif // M0 > 4
2598#if M0 > 5
2599 ARM_DOT_K0XN0(a5, b, c5);
2600#endif // M0 > 5
2601#if M0 > 6
2602 ARM_DOT_K0XN0(a6, b, c6);
2603#endif // M0 > 6
2604#if M0 > 7
2605 ARM_DOT_K0XN0(a7, b, c7);
2606#endif // M0 > 7
2607
2608 lhs_addr += (M0 * LHS_STEP_X * LHS_STEP_LOOP) * sizeof(DATA_TYPE);
2609 rhs_addr += (N0 * RHS_STEP_X * RHS_STEP_LOOP) * sizeof(DATA_TYPE);
2610 }
2611
2612 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * dst_stride_y);
2613
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +00002614 REPEAT_VAR_INIT_TO_CONST(8, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002615
2616#if defined(REINTERPRET_OUTPUT_AS_3D)
2617 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
2618 // in order to take into account the presence of possible cross plane paddings
2619 //
2620 // | |
2621 // | plane0 |
2622 // | |
2623 // |__________________|
2624 // |******************|
2625 // | cross_plane_pad |
2626 // |******************|
2627 // | |
2628 // | plane1 |
2629 // | |
2630 // |__________________|
2631
2632 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
2633 zout0 = (0 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2634 zout0 = min((uint)(DEPTH_GEMM3D - 1), zout0);
Gian Marco Iodice49b10152018-12-14 17:13:34 +00002635 zout0 *= (dst_cross_plane_pad * dst_stride_y);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002636#if M0 > 1
2637 zout1 = (1 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2638 zout1 = min((uint)(DEPTH_GEMM3D - 1), zout1);
Gian Marco Iodice49b10152018-12-14 17:13:34 +00002639 zout1 *= (dst_cross_plane_pad * dst_stride_y);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002640#endif // M0 > 1
2641#if M0 > 2
2642 zout2 = (2 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2643 zout2 = min((uint)(DEPTH_GEMM3D - 1), zout2);
Gian Marco Iodice49b10152018-12-14 17:13:34 +00002644 zout2 *= (dst_cross_plane_pad * dst_stride_y);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002645#endif // M0 > 2
2646#if M0 > 3
2647 zout3 = (3 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2648 zout3 = min((uint)(DEPTH_GEMM3D - 1), zout3);
Gian Marco Iodice49b10152018-12-14 17:13:34 +00002649 zout3 *= (dst_cross_plane_pad * dst_stride_y);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002650#endif // M0 > 3
2651#if M0 > 4
2652 zout4 = (4 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2653 zout4 = min((uint)(DEPTH_GEMM3D - 1), zout4);
Gian Marco Iodice49b10152018-12-14 17:13:34 +00002654 zout4 *= (dst_cross_plane_pad * dst_stride_y);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002655#endif // M0 > 4
2656#if M0 > 5
2657 zout5 = (5 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2658 zout5 = min((uint)(DEPTH_GEMM3D - 1), zout5);
Gian Marco Iodice49b10152018-12-14 17:13:34 +00002659 zout5 *= (dst_cross_plane_pad * dst_stride_y);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002660#endif // M0 > 5
2661#if M0 > 6
2662 zout6 = (6 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2663 zout6 = min((uint)(DEPTH_GEMM3D - 1), zout6);
Gian Marco Iodice49b10152018-12-14 17:13:34 +00002664 zout6 *= (dst_cross_plane_pad * dst_stride_y);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002665#endif // M0 > 6
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00002666#if M0 > 7
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002667 zout7 = (7 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2668 zout7 = min((uint)(DEPTH_GEMM3D - 1), zout7);
Gian Marco Iodice49b10152018-12-14 17:13:34 +00002669 zout7 *= (dst_cross_plane_pad * dst_stride_y);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002670#endif // M0 > 7
2671
2672 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2673 // multiply dst_stride_z by DEPTH_GEMM3D
2674 dst_addr += get_global_id(2) * dst_stride_z * DEPTH_GEMM3D;
2675
2676#else // defined(REINTERPRET_OUTPUT_AS_3D)
2677
2678 // Add offset for batched GEMM
2679 dst_addr += get_global_id(2) * dst_stride_z;
2680
2681#endif // defined(REINTERPRET_OUTPUT_AS_3D)
2682
2683 // Multiply by the weight of matrix-matrix product and store the result
2684#if defined(ALPHA)
2685 c0 = c0 * (DATA_TYPE)ALPHA;
2686#if M0 > 1
2687 c1 = c1 * (DATA_TYPE)ALPHA;
2688#endif // M0 > 1
2689#if M0 > 2
2690 c2 = c2 * (DATA_TYPE)ALPHA;
2691#endif // M0 > 2
2692#if M0 > 3
2693 c3 = c3 * (DATA_TYPE)ALPHA;
2694#endif // M0 > 3
2695#if M0 > 4
2696 c4 = c4 * (DATA_TYPE)ALPHA;
2697#endif // M0 > 4
2698#if M0 > 5
2699 c5 = c5 * (DATA_TYPE)ALPHA;
2700#endif // M0 > 5
2701#if M0 > 6
2702 c6 = c6 * (DATA_TYPE)ALPHA;
2703#endif // M0 > 5
2704#if M0 > 7
2705 c7 = c7 * (DATA_TYPE)ALPHA;
2706#endif // M0 > 7
2707#endif // defined(ALPHA)
2708
2709 // Store output block
2710 VSTORE(N0)
2711 (c0, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y + zout0));
2712#if M0 > 1
2713 VSTORE(N0)
2714 (c1, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y + zout1));
2715#endif // M0 > 1
2716#if M0 > 2
2717 VSTORE(N0)
2718 (c2, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y + zout2));
2719#endif // M0 > 2
2720#if M0 > 3
2721 VSTORE(N0)
2722 (c3, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y + zout3));
2723#endif // M0 > 3
2724#if M0 > 4
2725 VSTORE(N0)
2726 (c4, 0, (__global DATA_TYPE *)(dst_addr + 4 * dst_stride_y + zout4));
2727#endif // M0 > 4
2728#if M0 > 5
2729 VSTORE(N0)
2730 (c5, 0, (__global DATA_TYPE *)(dst_addr + 5 * dst_stride_y + zout5));
2731#endif // M0 > 5
2732#if M0 > 6
2733 VSTORE(N0)
2734 (c6, 0, (__global DATA_TYPE *)(dst_addr + 6 * dst_stride_y + zout6));
2735#endif // M0 > 6
2736#if M0 > 7
2737 VSTORE(N0)
2738 (c7, 0, (__global DATA_TYPE *)(dst_addr + 7 * dst_stride_y + zout7));
2739#endif // M0 > 7
2740
2741#undef LHS_BLOCK_SIZE
2742#undef LHS_OFFSET_X
2743#undef LHS_STEP_X
2744#undef RHS_BLOCK_SIZE
2745#undef RHS_OFFSET_X
2746#undef RHS_STEP_X
2747}
2748#endif // defined(M0) && defined(N0) && defined(K0) && defined(V0) && defined(H0) && defined(K) && defined(DATA_TYPE)
2749
Gian Marco36a0a462018-01-12 10:21:40 +00002750#if defined(TRANSPOSE_W) && defined(MULT_TRANSPOSE1XW_WIDTH)
2751
Gian Marco19835e52018-01-30 13:35:54 +00002752#if ELEMENT_SIZE == 1
Gian Marco36a0a462018-01-12 10:21:40 +00002753#define DATA_TYPE uchar
Gian Marco19835e52018-01-30 13:35:54 +00002754#elif ELEMENT_SIZE == 2
2755#define DATA_TYPE ushort
2756#elif ELEMENT_SIZE == 4
2757#define DATA_TYPE uint
2758#else // ELEMENT_SIZE == 1
2759#error "Element size not supported"
2760#endif // ELEMENT_SIZE
Gian Marco36a0a462018-01-12 10:21:40 +00002761
2762/** This OpenCL kernel computes the "vector" 1xW transposition of input matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002763 *
Gian Marco19835e52018-01-30 13:35:54 +00002764 * @note The transposition width must be passed at compile time using -DTRANSPOSE_W (i.e. -DTRANSPOSE_W)
2765 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
Gian Marco36a0a462018-01-12 10:21:40 +00002766 *
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01002767 * @param[in] src_ptr Pointer to the source matrix. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002768 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
2769 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2770 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
2771 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marcoae2af742018-02-15 12:35:44 +00002772 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
2773 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002774 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002775 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002776 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00002777 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002778 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00002779 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marcoae2af742018-02-15 12:35:44 +00002780 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2781 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002782 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2783 */
Gian Marcoae2af742018-02-15 12:35:44 +00002784__kernel void gemm_transpose1xW(TENSOR3D_DECLARATION(src),
2785 TENSOR3D_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002786{
2787 uint x = get_global_id(0);
2788 uint y = get_global_id(1);
Gian Marcoae2af742018-02-15 12:35:44 +00002789 uint z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002790
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002791 // Compute address for Matrix B - source
Gian Marcoae2af742018-02-15 12:35:44 +00002792 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002793
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002794 // Compute address for Matrix B transposed - destination. X and Y are swapped
Gian Marco36a0a462018-01-12 10:21:40 +00002795 uint dst_addr_in_bytes = dst_offset_first_element_in_bytes + y * TRANSPOSE_W * sizeof(DATA_TYPE) * MULT_TRANSPOSE1XW_WIDTH + (x / MULT_TRANSPOSE1XW_WIDTH) * dst_stride_y +
2796 (x % MULT_TRANSPOSE1XW_WIDTH) * TRANSPOSE_W * sizeof(DATA_TYPE);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002797
Gian Marcoae2af742018-02-15 12:35:44 +00002798 // Add offset for batched GEMM
2799 dst_addr_in_bytes += z * dst_stride_z;
2800
Gian Marco36a0a462018-01-12 10:21:40 +00002801 VEC_DATA_TYPE(DATA_TYPE, TRANSPOSE_W)
2802 b0 = VLOAD(TRANSPOSE_W)(0, (__global DATA_TYPE *)src.ptr);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002803
Gian Marco36a0a462018-01-12 10:21:40 +00002804 VSTORE(TRANSPOSE_W)
2805 (b0, 0, (__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002806}
Gian Marco36a0a462018-01-12 10:21:40 +00002807#endif // defined(TRANSPOSE_W) && defined(MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002808
Gian Marco36a0a462018-01-12 10:21:40 +00002809#if defined(MULT_INTERLEAVE4X4_HEIGHT) && defined(DATA_TYPE)
2810
Gian Marco Iodice4b908652018-10-18 10:21:02 +01002811/** This OpenCL kernel reshapes the input matrix transposing each 4x4 block. If -DUNROLL_BLOCK is passed at compile time, the 4x4 block
2812 * will be simply unrolled.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002813 *
Gian Marco19835e52018-01-30 13:35:54 +00002814 * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float)
2815 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002816 * @note In case the input has to be reinterpreted as a 3D tensor (i.e. input of convolution layer 1x1), the following information must be passed at compile time:
2817 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
2818 * -# HEIGHT_GEMM3D: The height of the input in case it has to be reinterpreted as a 3D tensor.
2819 * -# DEPTH_GEMM3D: The depth of the input in case it has to be reinterpreted as a 3D tensor
2820 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
Gian Marco19835e52018-01-30 13:35:54 +00002821 *
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01002822 * @param[in] src_ptr Pointer to the source matrix. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002823 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
2824 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2825 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
2826 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marcoae2af742018-02-15 12:35:44 +00002827 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
2828 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002829 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002830 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002831 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2832 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2833 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2834 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marcoae2af742018-02-15 12:35:44 +00002835 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2836 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002837 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002838 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002839 */
Gian Marcoae2af742018-02-15 12:35:44 +00002840__kernel void gemm_interleave4x4(TENSOR3D_DECLARATION(src),
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002841 TENSOR3D_DECLARATION(dst)
2842#if defined(REINTERPRET_INPUT_AS_3D)
2843 ,
2844 uint cross_plane_pad
2845#endif // REINTERPRET_INPUT_AS_3D
2846 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002847{
Gian Marco36a0a462018-01-12 10:21:40 +00002848 // Compute source and destination addresses
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002849 uint x = get_global_id(0);
2850 uint y = get_global_id(1);
Gian Marcoae2af742018-02-15 12:35:44 +00002851 uint z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002852
Gian Marcoae2af742018-02-15 12:35:44 +00002853 // Compute address for source tensor
2854 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002855
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002856 // Compute address for Matrix B transposed - destination. X and Y are swapped
Gian Marco36a0a462018-01-12 10:21:40 +00002857 uint dst_addr_in_bytes = dst_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) * 16 * MULT_INTERLEAVE4X4_HEIGHT + (y / MULT_INTERLEAVE4X4_HEIGHT) * dst_stride_y +
2858 (y % MULT_INTERLEAVE4X4_HEIGHT) * 4 * sizeof(DATA_TYPE);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002859
Gian Marcoae2af742018-02-15 12:35:44 +00002860 // Add offset for batched GEMM
2861 dst_addr_in_bytes += z * dst_stride_z;
2862
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002863#if defined(REINTERPRET_INPUT_AS_3D)
2864 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * 4 * sizeof(DATA_TYPE) + y * 4 * src_stride_y;
2865
2866 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
2867 // in order to take into account the presence of possible cross plane paddings
2868 //
2869 // | |
2870 // | plane0 |
2871 // | |
2872 // |__________________|
2873 // |******************|
2874 // | cross_plane_pad |
2875 // |******************|
2876 // | |
2877 // | plane1 |
2878 // | |
2879 // |__________________|
2880
2881 // The plane (zin) is calculated dividing M (y * 4) by HEIGHT_GEMM3D
2882 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(y * 4)) / (uint4)HEIGHT_GEMM3D;
2883 zin = min(DEPTH_GEMM3D - 1, zin);
2884
2885 // Add offset due to the cross plane paddings
2886 zin *= (cross_plane_pad * src_stride_y);
2887
2888 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2889 // multiply src_stride_z by DEPTH_GEMM3D
2890 input_ptr += z * src_stride_z * DEPTH_GEMM3D;
2891
2892 // Load values from Matrix A
2893 VEC_DATA_TYPE(DATA_TYPE, 4)
2894 a0 = vload4(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y + zin.s0));
2895 VEC_DATA_TYPE(DATA_TYPE, 4)
2896 a1 = vload4(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y + zin.s1));
2897 VEC_DATA_TYPE(DATA_TYPE, 4)
2898 a2 = vload4(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y + zin.s2));
2899 VEC_DATA_TYPE(DATA_TYPE, 4)
2900 a3 = vload4(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y + zin.s3));
2901#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +00002902 __global uchar *input_ptr = src.ptr;
2903
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002904 // Load values from Matrix A
Gian Marco36a0a462018-01-12 10:21:40 +00002905 VEC_DATA_TYPE(DATA_TYPE, 4)
Gian Marcoae2af742018-02-15 12:35:44 +00002906 a0 = vload4(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y));
Gian Marco36a0a462018-01-12 10:21:40 +00002907 VEC_DATA_TYPE(DATA_TYPE, 4)
Gian Marcoae2af742018-02-15 12:35:44 +00002908 a1 = vload4(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y));
Gian Marco36a0a462018-01-12 10:21:40 +00002909 VEC_DATA_TYPE(DATA_TYPE, 4)
Gian Marcoae2af742018-02-15 12:35:44 +00002910 a2 = vload4(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y));
Gian Marco36a0a462018-01-12 10:21:40 +00002911 VEC_DATA_TYPE(DATA_TYPE, 4)
Gian Marcoae2af742018-02-15 12:35:44 +00002912 a3 = vload4(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y));
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002913#endif // defined(REINTERPRET_INPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002914
Gian Marco Iodice4b908652018-10-18 10:21:02 +01002915#if defined(UNROLL_BLOCK)
2916 vstore4(a0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 0 * MULT_INTERLEAVE4X4_HEIGHT));
2917 vstore4(a1, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 4 * MULT_INTERLEAVE4X4_HEIGHT));
2918 vstore4(a2, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 8 * MULT_INTERLEAVE4X4_HEIGHT));
2919 vstore4(a3, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 12 * MULT_INTERLEAVE4X4_HEIGHT));
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +00002920#else // defined(UNROLL_BLOCK)
Gian Marco36a0a462018-01-12 10:21:40 +00002921 VEC_DATA_TYPE(DATA_TYPE, 4)
2922 val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s0, a1.s0, a2.s0, a3.s0);
2923 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 0 * MULT_INTERLEAVE4X4_HEIGHT));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002924
Gian Marco36a0a462018-01-12 10:21:40 +00002925 val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s1, a1.s1, a2.s1, a3.s1);
2926 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 4 * MULT_INTERLEAVE4X4_HEIGHT));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002927
Gian Marco36a0a462018-01-12 10:21:40 +00002928 val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s2, a1.s2, a2.s2, a3.s2);
2929 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 8 * MULT_INTERLEAVE4X4_HEIGHT));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002930
Gian Marco36a0a462018-01-12 10:21:40 +00002931 val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s3, a1.s3, a2.s3, a3.s3);
2932 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 12 * MULT_INTERLEAVE4X4_HEIGHT));
Gian Marco Iodice4b908652018-10-18 10:21:02 +01002933#endif // defined(UNROLL_BLOCK)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002934}
Gian Marco36a0a462018-01-12 10:21:40 +00002935#endif // defined(MULT_INTERLEAVE4X4_HEIGHT) && defined(DATA_TYPE)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002936
Gian Marco36a0a462018-01-12 10:21:40 +00002937#if defined(COLS_B) && defined(MULT_TRANSPOSE1XW_WIDTH) && defined(MULT_INTERLEAVE4X4_HEIGHT)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002938/** This OpenCL kernel is optimised for Midgard. It computes the matrix multiplication between matrix A (src0) and matrix B (src1)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002939 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_32bit and @ref gemm_transpose1x4 before running the matrix multiplication
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002940 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002941 * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
2942 *
Gian Marco19835e52018-01-30 13:35:54 +00002943 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
2944 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
2945 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002946 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
2947 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002948 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002949 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
2950 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2951 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2952 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2953 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
2954 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002955 * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
2956 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002957 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
2958 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
2959 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2960 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
2961 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2962 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002963 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002964 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
2965 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2966 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
2967 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2968 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002969 * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
2970 * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes)
2971 * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
2972 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002973 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002974 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00002975 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002976 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00002977 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002978 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002979 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
2980 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
2981 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002982 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002983 */
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01002984__kernel void gemm_mm_interleaved_transposed_f32(IMAGE_DECLARATION(src0),
2985 IMAGE_DECLARATION(src1),
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002986#if defined(ADD_VEC_C)
2987 VECTOR_DECLARATION(src2),
2988#endif /* defined(ADD_VEC_C) */
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01002989 IMAGE_DECLARATION(dst),
2990 uint src0_stride_z,
2991 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002992 uint dst_stride_z
2993#if defined(REINTERPRET_OUTPUT_AS_3D)
2994 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002995 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002996#endif // REINTERPRET_OUTPUT_AS_3D
2997 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002998{
Gian Marco36a0a462018-01-12 10:21:40 +00002999 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
3000 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +00003001 int z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003002
Gian Marco36a0a462018-01-12 10:21:40 +00003003 // Offset
3004 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
3005 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 4;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003006
Gian Marco36a0a462018-01-12 10:21:40 +00003007 // src_addr_a = address of matrix A
3008 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003009 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
3010 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
3011
3012#if defined(MATRIX_B_DEPTH)
3013 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
3014 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
3015#else // defined(MATRIX_B_DEPTH)
3016 src1_addr_in_bytes += z * src1_stride_z;
3017#endif // defined(MATRIX_B_DEPTH)
3018
3019 __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes);
3020 __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003021
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003022 // Compute end row address for matrix B
Gian Marco36a0a462018-01-12 10:21:40 +00003023 __global float *src_end_addr_b = src_addr_b + COLS_B;
3024
3025 src_addr_a += offset_row_a;
3026 src_addr_b += offset_row_b;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003027
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003028 // Reset accumulators
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003029 float4 c00 = 0.0f;
3030 float4 c10 = 0.0f;
3031 float4 c20 = 0.0f;
3032 float4 c30 = 0.0f;
3033
Gian Marco36a0a462018-01-12 10:21:40 +00003034 for(; src_addr_b <= (src_end_addr_b - (int)(8 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003035 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003036 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00003037 float4 a0 = vload4(0, src_addr_a);
3038 float4 b0 = vload4(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003039
3040 c00 += (float4)a0.s0 * b0;
3041 c10 += (float4)a0.s1 * b0;
3042 c20 += (float4)a0.s2 * b0;
3043 c30 += (float4)a0.s3 * b0;
3044
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003045 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00003046 a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
3047 b0 = vload4(0, src_addr_b + 4 * MULT_TRANSPOSE1XW_WIDTH);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003048
3049 c00 += (float4)a0.s0 * b0;
3050 c10 += (float4)a0.s1 * b0;
3051 c20 += (float4)a0.s2 * b0;
3052 c30 += (float4)a0.s3 * b0;
3053 }
3054
Gian Marco36a0a462018-01-12 10:21:40 +00003055 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003056 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003057 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00003058 float4 a0 = vload4(0, src_addr_a);
3059 float4 b0 = vload4(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003060
3061 c00 += (float4)a0.s0 * b0;
3062 c10 += (float4)a0.s1 * b0;
3063 c20 += (float4)a0.s2 * b0;
3064 c30 += (float4)a0.s3 * b0;
3065 }
3066
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003067 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003068 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
3069
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003070#if defined(ALPHA)
3071 // Multiply by the weight of matrix product
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003072 c00 = c00 * (float4)ALPHA;
3073 c10 = c10 * (float4)ALPHA;
3074 c20 = c20 * (float4)ALPHA;
3075 c30 = c30 * (float4)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003076#endif // defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003077
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003078#if defined(ADD_VEC_C)
3079 __global float *src2_addr = (__global float *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
3080 float4 c0 = vload4(0, src2_addr);
3081
3082 c00 += c0;
3083 c10 += c0;
3084 c20 += c0;
3085 c30 += c0;
3086#endif /* defined(ADD_VEC_C) */
3087
Gian Marcoae2af742018-02-15 12:35:44 +00003088 // Compute dst address
3089 __global uchar *dst_addr = offset(&dst, 0, 0);
3090
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003091#if defined(REINTERPRET_OUTPUT_AS_3D)
3092 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003093 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003094 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003095 // | |
3096 // | plane0 |
3097 // | |
3098 // |__________________|
3099 // |******************|
3100 // | cross_plane_pad |
3101 // |******************|
3102 // | |
3103 // | plane1 |
3104 // | |
3105 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003106
3107 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
3108 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
3109 zout = min(DEPTH_GEMM3D - 1, zout);
3110
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003111 // Add offset due to the cross plane paddings
3112 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003113
3114 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3115 // multiply dst_stride_z by DEPTH_GEMM3D
3116 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
3117
3118 // Store 4x4 block
3119 vstore4(c00, 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
3120 vstore4(c10, 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
3121 vstore4(c20, 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
3122 vstore4(c30, 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
3123
3124#else // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +00003125 // Add offset for batched GEMM
3126 dst_addr += z * dst_stride_z;
3127
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003128 // Store 4x4 block
Gian Marcoae2af742018-02-15 12:35:44 +00003129 vstore4(c00, 0, (__global float *)(dst_addr + 0 * dst_stride_y));
3130 vstore4(c10, 0, (__global float *)(dst_addr + 1 * dst_stride_y));
3131 vstore4(c20, 0, (__global float *)(dst_addr + 2 * dst_stride_y));
3132 vstore4(c30, 0, (__global float *)(dst_addr + 3 * dst_stride_y));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003133#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003134}
3135
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003136/** This OpenCL kernel is optimized for Bifrost. It computes the matrix multiplication between matrix A (src0) and matrix B (src1)
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003137 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_32bit and @ref gemm_transpose1x4 before running the matrix multiplication.
3138 *
3139 * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003140 *
Gian Marco19835e52018-01-30 13:35:54 +00003141 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
3142 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
3143 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003144 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
3145 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
3146 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003147 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003148 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
3149 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
3150 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
3151 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
3152 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
3153 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003154 * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
3155 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003156 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
3157 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
3158 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3159 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
3160 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3161 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003162 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003163 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
3164 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3165 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
3166 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3167 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003168 * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
3169 * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes)
3170 * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
3171 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003172 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003173 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00003174 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003175 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00003176 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003177 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003178 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
3179 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
3180 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003181 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003182 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003183__kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0),
3184 IMAGE_DECLARATION(src1),
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003185#if defined(ADD_VEC_C)
3186 VECTOR_DECLARATION(src2),
3187#endif /* defined(ADD_VEC_C) */
Gian Marcoae2af742018-02-15 12:35:44 +00003188 IMAGE_DECLARATION(dst),
3189 uint src0_stride_z,
3190 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003191 uint dst_stride_z
3192#if defined(REINTERPRET_OUTPUT_AS_3D)
3193 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003194 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003195#endif // REINTERPRET_OUTPUT_AS_3D
3196 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003197{
Gian Marco36a0a462018-01-12 10:21:40 +00003198 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
3199 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +00003200 int z = get_global_id(2);
Gian Marco36a0a462018-01-12 10:21:40 +00003201
3202 // Offset
3203 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
3204 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 4;
3205
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003206 // src_addr_a = address of matrix A
3207 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003208 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
3209 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
3210
3211#if defined(MATRIX_B_DEPTH)
3212 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
3213 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
3214#else // defined(MATRIX_B_DEPTH)
3215 src1_addr_in_bytes += z * src1_stride_z;
3216#endif // defined(MATRIX_B_DEPTH)
3217
3218 __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes);
3219 __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003220
Gian Marco36a0a462018-01-12 10:21:40 +00003221 src_addr_a += offset_row_a;
3222 src_addr_b += offset_row_b;
3223
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003224 // Reset accumulators
3225 float c00 = 0.0f;
3226 float c01 = 0.0f;
3227 float c02 = 0.0f;
3228 float c03 = 0.0f;
3229 float c10 = 0.0f;
3230 float c11 = 0.0f;
3231 float c12 = 0.0f;
3232 float c13 = 0.0f;
3233 float c20 = 0.0f;
3234 float c21 = 0.0f;
3235 float c22 = 0.0f;
3236 float c23 = 0.0f;
3237 float c30 = 0.0f;
3238 float c31 = 0.0f;
3239 float c32 = 0.0f;
3240 float c33 = 0.0f;
3241
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003242#define COLS_MTX_B (COLS_B / (4 * MULT_TRANSPOSE1XW_WIDTH))
3243
3244 int i = 0;
3245 for(; i <= (int)(COLS_MTX_B - 4); i += 4)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003246 {
3247 // Load values from matrix A (interleaved) and matrix B (transposed)
3248 float4 a0 = vload4(0, src_addr_a);
3249 float4 b0 = vload4(0, src_addr_b);
3250
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003251 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3252 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003253
3254 c00 = fma(a0.s0, b0.s0, c00);
3255 c01 = fma(a0.s0, b0.s1, c01);
3256 c02 = fma(a0.s0, b0.s2, c02);
3257 c03 = fma(a0.s0, b0.s3, c03);
3258
3259 c10 = fma(a0.s1, b0.s0, c10);
3260 c11 = fma(a0.s1, b0.s1, c11);
3261 c12 = fma(a0.s1, b0.s2, c12);
3262 c13 = fma(a0.s1, b0.s3, c13);
3263
3264 c20 = fma(a0.s2, b0.s0, c20);
3265 c21 = fma(a0.s2, b0.s1, c21);
3266 c22 = fma(a0.s2, b0.s2, c22);
3267 c23 = fma(a0.s2, b0.s3, c23);
3268
3269 c30 = fma(a0.s3, b0.s0, c30);
3270 c31 = fma(a0.s3, b0.s1, c31);
3271 c32 = fma(a0.s3, b0.s2, c32);
3272 c33 = fma(a0.s3, b0.s3, c33);
3273
3274 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003275 a0 = vload4(0, src_addr_a);
3276 b0 = vload4(0, src_addr_b);
3277
3278 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3279 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003280
3281 c00 = fma(a0.s0, b0.s0, c00);
3282 c01 = fma(a0.s0, b0.s1, c01);
3283 c02 = fma(a0.s0, b0.s2, c02);
3284 c03 = fma(a0.s0, b0.s3, c03);
3285
3286 c10 = fma(a0.s1, b0.s0, c10);
3287 c11 = fma(a0.s1, b0.s1, c11);
3288 c12 = fma(a0.s1, b0.s2, c12);
3289 c13 = fma(a0.s1, b0.s3, c13);
3290
3291 c20 = fma(a0.s2, b0.s0, c20);
3292 c21 = fma(a0.s2, b0.s1, c21);
3293 c22 = fma(a0.s2, b0.s2, c22);
3294 c23 = fma(a0.s2, b0.s3, c23);
3295
3296 c30 = fma(a0.s3, b0.s0, c30);
3297 c31 = fma(a0.s3, b0.s1, c31);
3298 c32 = fma(a0.s3, b0.s2, c32);
3299 c33 = fma(a0.s3, b0.s3, c33);
3300
3301 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003302 a0 = vload4(0, src_addr_a);
3303 b0 = vload4(0, src_addr_b);
3304
3305 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3306 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
3307
3308 c00 = fma(a0.s0, b0.s0, c00);
3309 c01 = fma(a0.s0, b0.s1, c01);
3310 c02 = fma(a0.s0, b0.s2, c02);
3311 c03 = fma(a0.s0, b0.s3, c03);
3312
3313 c10 = fma(a0.s1, b0.s0, c10);
3314 c11 = fma(a0.s1, b0.s1, c11);
3315 c12 = fma(a0.s1, b0.s2, c12);
3316 c13 = fma(a0.s1, b0.s3, c13);
3317
3318 c20 = fma(a0.s2, b0.s0, c20);
3319 c21 = fma(a0.s2, b0.s1, c21);
3320 c22 = fma(a0.s2, b0.s2, c22);
3321 c23 = fma(a0.s2, b0.s3, c23);
3322
3323 c30 = fma(a0.s3, b0.s0, c30);
3324 c31 = fma(a0.s3, b0.s1, c31);
3325 c32 = fma(a0.s3, b0.s2, c32);
3326 c33 = fma(a0.s3, b0.s3, c33);
3327
3328 // Load values from matrix A (interleaved) and matrix B (transposed)
3329 a0 = vload4(0, src_addr_a);
3330 b0 = vload4(0, src_addr_b);
3331
3332 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3333 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003334
3335 c00 = fma(a0.s0, b0.s0, c00);
3336 c01 = fma(a0.s0, b0.s1, c01);
3337 c02 = fma(a0.s0, b0.s2, c02);
3338 c03 = fma(a0.s0, b0.s3, c03);
3339
3340 c10 = fma(a0.s1, b0.s0, c10);
3341 c11 = fma(a0.s1, b0.s1, c11);
3342 c12 = fma(a0.s1, b0.s2, c12);
3343 c13 = fma(a0.s1, b0.s3, c13);
3344
3345 c20 = fma(a0.s2, b0.s0, c20);
3346 c21 = fma(a0.s2, b0.s1, c21);
3347 c22 = fma(a0.s2, b0.s2, c22);
3348 c23 = fma(a0.s2, b0.s3, c23);
3349
3350 c30 = fma(a0.s3, b0.s0, c30);
3351 c31 = fma(a0.s3, b0.s1, c31);
3352 c32 = fma(a0.s3, b0.s2, c32);
3353 c33 = fma(a0.s3, b0.s3, c33);
3354 }
3355
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003356 for(; i < (int)(COLS_MTX_B); ++i)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003357 {
3358 // Load values from matrix A (interleaved) and matrix B (transposed)
3359 float4 a0 = vload4(0, src_addr_a);
3360 float4 b0 = vload4(0, src_addr_b);
3361
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003362 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3363 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
3364
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003365 c00 = fma(a0.s0, b0.s0, c00);
3366 c01 = fma(a0.s0, b0.s1, c01);
3367 c02 = fma(a0.s0, b0.s2, c02);
3368 c03 = fma(a0.s0, b0.s3, c03);
3369
3370 c10 = fma(a0.s1, b0.s0, c10);
3371 c11 = fma(a0.s1, b0.s1, c11);
3372 c12 = fma(a0.s1, b0.s2, c12);
3373 c13 = fma(a0.s1, b0.s3, c13);
3374
3375 c20 = fma(a0.s2, b0.s0, c20);
3376 c21 = fma(a0.s2, b0.s1, c21);
3377 c22 = fma(a0.s2, b0.s2, c22);
3378 c23 = fma(a0.s2, b0.s3, c23);
3379
3380 c30 = fma(a0.s3, b0.s0, c30);
3381 c31 = fma(a0.s3, b0.s1, c31);
3382 c32 = fma(a0.s3, b0.s2, c32);
3383 c33 = fma(a0.s3, b0.s3, c33);
3384 }
3385
3386 // Compute destination address
3387 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
3388
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003389#if defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003390 // Multiply by the weight of matrix product
3391 c00 = c00 * ALPHA;
3392 c01 = c01 * ALPHA;
3393 c02 = c02 * ALPHA;
3394 c03 = c03 * ALPHA;
3395 c10 = c10 * ALPHA;
3396 c11 = c11 * ALPHA;
3397 c12 = c12 * ALPHA;
3398 c13 = c13 * ALPHA;
3399 c20 = c20 * ALPHA;
3400 c21 = c21 * ALPHA;
3401 c22 = c22 * ALPHA;
3402 c23 = c23 * ALPHA;
3403 c30 = c30 * ALPHA;
3404 c31 = c31 * ALPHA;
3405 c32 = c32 * ALPHA;
3406 c33 = c33 * ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003407#endif // defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003408
Gian Marcoae2af742018-02-15 12:35:44 +00003409 // Compute dst address
3410 __global uchar *dst_addr = offset(&dst, 0, 0);
3411
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003412#if defined(ADD_VEC_C)
3413 __global float *src2_addr = (__global float *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
3414 float4 c0 = vload4(0, src2_addr);
3415
3416 c00 += c0.s0;
3417 c01 += c0.s1;
3418 c02 += c0.s2;
3419 c03 += c0.s3;
3420 c10 += c0.s0;
3421 c11 += c0.s1;
3422 c12 += c0.s2;
3423 c13 += c0.s3;
3424 c20 += c0.s0;
3425 c21 += c0.s1;
3426 c22 += c0.s2;
3427 c23 += c0.s3;
3428 c30 += c0.s0;
3429 c31 += c0.s1;
3430 c32 += c0.s2;
3431 c33 += c0.s3;
3432#endif /* defined(ADD_VEC_C) */
3433
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003434#if defined(REINTERPRET_OUTPUT_AS_3D)
3435 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003436 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003437 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003438 // | |
3439 // | plane0 |
3440 // | |
3441 // |__________________|
3442 // |******************|
3443 // | cross_plane_pad |
3444 // |******************|
3445 // | |
3446 // | plane1 |
3447 // | |
3448 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003449
3450 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
3451 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
3452 zout = min(DEPTH_GEMM3D - 1, zout);
3453
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003454 // Add offset due to the cross plane paddings
3455 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003456
3457 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3458 // multiply dst_stride_z by DEPTH_GEMM3D
3459 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
3460
3461 // Store 4x4 block
3462 vstore4((float4)(c00, c01, c02, c03), 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
3463 vstore4((float4)(c10, c11, c12, c13), 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
3464 vstore4((float4)(c20, c21, c22, c23), 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
3465 vstore4((float4)(c30, c31, c32, c33), 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
3466
3467#else // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +00003468 // Add offset for batched GEMM
3469 dst_addr += z * dst_stride_z;
3470
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003471 // Store 4x4 block
Gian Marcoae2af742018-02-15 12:35:44 +00003472 vstore4((float4)(c00, c01, c02, c03), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
3473 vstore4((float4)(c10, c11, c12, c13), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
3474 vstore4((float4)(c20, c21, c22, c23), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
3475 vstore4((float4)(c30, c31, c32, c33), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003476#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003477}
3478
Georgios Pinitas84225582018-05-14 12:00:05 +01003479// Undefine local defines
3480#undef COLS_MTX_B
3481
Matthew Bentham6f31f8c2017-10-27 11:50:06 +01003482#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003483/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003484 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003485 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003486 * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
3487 *
Gian Marco19835e52018-01-30 13:35:54 +00003488 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
3489 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
3490 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003491 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
3492 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003493 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003494 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
3495 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
3496 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
3497 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
3498 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
3499 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003500 * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
3501 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003502 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
3503 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
3504 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3505 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
3506 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3507 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003508 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003509 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
3510 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3511 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
3512 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3513 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003514 * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
3515 * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes)
3516 * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
3517 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003518 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003519 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00003520 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003521 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00003522 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003523 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003524 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
3525 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
3526 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003527 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003528 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003529__kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0),
3530 IMAGE_DECLARATION(src1),
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003531#if defined(ADD_VEC_C)
3532 VECTOR_DECLARATION(src2),
3533#endif /* defined(ADD_VEC_C) */
Gian Marcoae2af742018-02-15 12:35:44 +00003534 IMAGE_DECLARATION(dst),
3535 uint src0_stride_z,
3536 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003537 uint dst_stride_z
3538#if defined(REINTERPRET_OUTPUT_AS_3D)
3539 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003540 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003541#endif // REINTERPRET_OUTPUT_AS_3D
3542 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003543{
Gian Marco36a0a462018-01-12 10:21:40 +00003544 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
3545 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +00003546 int z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003547
Gian Marco36a0a462018-01-12 10:21:40 +00003548 // Offset
3549 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
3550 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003551
Gian Marco36a0a462018-01-12 10:21:40 +00003552 // src_addr_a = address of matrix A
3553 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003554 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
3555 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
3556
3557#if defined(MATRIX_B_DEPTH)
3558 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
3559 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
3560#else // defined(MATRIX_B_DEPTH)
3561 src1_addr_in_bytes += z * src1_stride_z;
3562#endif // defined(MATRIX_B_DEPTH)
3563
3564 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
3565 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003566
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003567 // Compute end row address for matrix B
Gian Marco36a0a462018-01-12 10:21:40 +00003568 __global half *src_end_addr_b = src_addr_b + COLS_B;
3569
3570 src_addr_a += offset_row_a;
3571 src_addr_b += offset_row_b;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003572
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003573 // Reset accumulators
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003574 half8 c00 = 0.0f;
3575 half8 c10 = 0.0f;
3576 half8 c20 = 0.0f;
3577 half8 c30 = 0.0f;
3578
Gian Marco36a0a462018-01-12 10:21:40 +00003579 for(; src_addr_b <= (src_end_addr_b - (int)(16 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 16 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003580 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003581 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00003582 half4 a0 = vload4(0, src_addr_a);
3583 half8 b0 = vload8(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003584
3585 c00 += (half8)a0.s0 * b0;
3586 c10 += (half8)a0.s1 * b0;
3587 c20 += (half8)a0.s2 * b0;
3588 c30 += (half8)a0.s3 * b0;
3589
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003590 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00003591 a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
3592 b0 = vload8(0, src_addr_b + 8 * MULT_TRANSPOSE1XW_WIDTH);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003593
3594 c00 += (half8)a0.s0 * b0;
3595 c10 += (half8)a0.s1 * b0;
3596 c20 += (half8)a0.s2 * b0;
3597 c30 += (half8)a0.s3 * b0;
3598 }
3599
Gian Marco36a0a462018-01-12 10:21:40 +00003600 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003601 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003602 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00003603 half4 a0 = vload4(0, src_addr_a);
3604 half8 b0 = vload8(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003605
3606 c00 += (half8)a0.s0 * b0;
3607 c10 += (half8)a0.s1 * b0;
3608 c20 += (half8)a0.s2 * b0;
3609 c30 += (half8)a0.s3 * b0;
3610 }
3611
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003612 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003613 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
3614
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003615#if defined(ALPHA)
3616 // Multiply by the weight of matrix product
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003617 c00 = c00 * (half8)ALPHA;
3618 c10 = c10 * (half8)ALPHA;
3619 c20 = c20 * (half8)ALPHA;
3620 c30 = c30 * (half8)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003621#endif // defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003622
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003623#if defined(ADD_VEC_C)
3624 // *INDENT-OFF*
3625 // clang-format off
3626 __global half *src2_addr = (__global half *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
3627 half8 c0 = vload8(0, src2_addr);
3628 // clang-format on
3629 // *INDENT-ON*
3630
3631 c00 += c0;
3632 c10 += c0;
3633 c20 += c0;
3634 c30 += c0;
3635#endif /* defined(ADD_VEC_C) */
3636
Gian Marcoae2af742018-02-15 12:35:44 +00003637 // Compute dst address
3638 __global uchar *dst_addr = offset(&dst, 0, 0);
3639
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003640#if defined(REINTERPRET_OUTPUT_AS_3D)
3641 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003642 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003643 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003644 // | |
3645 // | plane0 |
3646 // | |
3647 // |__________________|
3648 // |******************|
3649 // | cross_plane_pad |
3650 // |******************|
3651 // | |
3652 // | plane1 |
3653 // | |
3654 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003655
3656 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
3657 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
3658 zout = min(DEPTH_GEMM3D - 1, zout);
3659
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003660 // Add offset due to the cross plane paddings
3661 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003662
3663 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3664 // multiply dst_stride_z by DEPTH_GEMM3D
3665 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
3666
3667 // Store 4x8 block
3668 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
3669 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
3670 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
3671 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
3672
3673#else // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +00003674 // Add offset for batched GEMM
3675 dst_addr += z * dst_stride_z;
3676
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003677 // Store 4x8 block
Gian Marcoae2af742018-02-15 12:35:44 +00003678 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
3679 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
3680 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
3681 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003682#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003683}
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003684
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003685/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1) while accumulating the result in a 32 floating point variable.
3686 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
3687 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003688 * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
3689 *
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003690 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
3691 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
3692 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
3693 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
3694 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
3695 *
3696 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
3697 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
3698 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
3699 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
3700 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
3701 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003702 * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
3703 *
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003704 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
3705 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
3706 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3707 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
3708 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3709 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
3710 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
3711 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
3712 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3713 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
3714 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3715 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003716 * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
3717 * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes)
3718 * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
3719 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003720 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
3721 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
3722 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
3723 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
3724 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
3725 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
3726 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
3727 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
3728 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
3729 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
3730 */
3731__kernel void gemm_mm_interleaved_transposed_f16_acc32(IMAGE_DECLARATION(src0),
3732 IMAGE_DECLARATION(src1),
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003733#if defined(ADD_VEC_C)
3734 VECTOR_DECLARATION(src2),
3735#endif /* defined(ADD_VEC_C) */
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003736 IMAGE_DECLARATION(dst),
3737 uint src0_stride_z,
3738 uint src1_stride_z,
3739 uint dst_stride_z
3740#if defined(REINTERPRET_OUTPUT_AS_3D)
3741 ,
3742 uint cross_plane_pad
3743#endif // REINTERPRET_OUTPUT_AS_3D
3744 )
3745{
3746 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
3747 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
3748 int z = get_global_id(2);
3749
3750 // Offset
3751 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
3752 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
3753
3754 // src_addr_a = address of matrix A
3755 // src_addr_b = address of matrix B
3756 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
3757 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
3758
3759#if defined(MATRIX_B_DEPTH)
3760 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
3761 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
3762#else // defined(MATRIX_B_DEPTH)
3763 src1_addr_in_bytes += z * src1_stride_z;
3764#endif // defined(MATRIX_B_DEPTH)
3765
3766 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
3767 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
3768
3769 // Compute end row address for matrix B
3770 __global half *src_end_addr_b = src_addr_b + COLS_B;
3771
3772 src_addr_a += offset_row_a;
3773 src_addr_b += offset_row_b;
3774
3775 // Reset accumulators
3776 float8 c00 = 0.0f;
3777 float8 c10 = 0.0f;
3778 float8 c20 = 0.0f;
3779 float8 c30 = 0.0f;
3780
3781 for(; src_addr_b <= (src_end_addr_b - (int)(16 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 16 * MULT_TRANSPOSE1XW_WIDTH)
3782 {
3783 // Load values from matrix A (interleaved) and matrix B (transposed)
3784 float4 a0 = convert_float4(vload4(0, src_addr_a));
3785 float8 b0 = convert_float8(vload8(0, src_addr_b));
3786
3787 c00 += (float8)a0.s0 * b0;
3788 c10 += (float8)a0.s1 * b0;
3789 c20 += (float8)a0.s2 * b0;
3790 c30 += (float8)a0.s3 * b0;
3791
3792 // Load values from matrix A (interleaved) and matrix B (transposed)
3793 a0 = convert_float4(vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT));
3794 b0 = convert_float8(vload8(0, src_addr_b + 8 * MULT_TRANSPOSE1XW_WIDTH));
3795
3796 c00 += (float8)a0.s0 * b0;
3797 c10 += (float8)a0.s1 * b0;
3798 c20 += (float8)a0.s2 * b0;
3799 c30 += (float8)a0.s3 * b0;
3800 }
3801
3802 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
3803 {
3804 // Load values from matrix A (interleaved) and matrix B (transposed)
3805 float4 a0 = convert_float4(vload4(0, src_addr_a));
3806 float8 b0 = convert_float8(vload8(0, src_addr_b));
3807
3808 c00 += (float8)a0.s0 * b0;
3809 c10 += (float8)a0.s1 * b0;
3810 c20 += (float8)a0.s2 * b0;
3811 c30 += (float8)a0.s3 * b0;
3812 }
3813
3814 // Compute destination address
3815 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
3816
3817#if defined(ALPHA)
3818 // Multiply by the weight of matrix product
3819 c00 = c00 * (float8)ALPHA;
3820 c10 = c10 * (float8)ALPHA;
3821 c20 = c20 * (float8)ALPHA;
3822 c30 = c30 * (float8)ALPHA;
3823#endif // defined(ALPHA)
3824
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003825#if defined(ADD_VEC_C)
3826 // *INDENT-OFF*
3827 // clang-format off
3828 __global half *src2_addr = (__global half *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
3829 float8 c0 = convert_float8(vload8(0, src2_addr));
3830 // clang-format on
3831 // *INDENT-ON*
3832
3833 c00 += c0;
3834 c10 += c0;
3835 c20 += c0;
3836 c30 += c0;
3837#endif /* defined(ADD_VEC_C) */
3838
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003839 // Compute dst address
3840 __global uchar *dst_addr = offset(&dst, 0, 0);
3841
3842#if defined(REINTERPRET_OUTPUT_AS_3D)
3843 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
3844 // in order to take into account the presence of possible cross plane paddings
3845 //
3846 // | |
3847 // | plane0 |
3848 // | |
3849 // |__________________|
3850 // |******************|
3851 // | cross_plane_pad |
3852 // |******************|
3853 // | |
3854 // | plane1 |
3855 // | |
3856 // |__________________|
3857
3858 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
3859 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
3860 zout = min(DEPTH_GEMM3D - 1, zout);
3861
3862 // Add offset due to the cross plane paddings
3863 zout *= (cross_plane_pad * dst_stride_y);
3864
3865 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3866 // multiply dst_stride_z by DEPTH_GEMM3D
3867 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
3868
3869 // Store 4x8 block
3870 vstore8(convert_half8(c00), 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
3871 vstore8(convert_half8(c10), 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
3872 vstore8(convert_half8(c20), 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
3873 vstore8(convert_half8(c30), 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
3874
3875#else // defined(REINTERPRET_OUTPUT_AS_3D)
3876 // Add offset for batched GEMM
3877 dst_addr += z * dst_stride_z;
3878
3879 // Store 4x8 block
3880 vstore8(convert_half8(c00), 0, (__global half *)(dst_addr + 0 * dst_stride_y));
3881 vstore8(convert_half8(c10), 0, (__global half *)(dst_addr + 1 * dst_stride_y));
3882 vstore8(convert_half8(c20), 0, (__global half *)(dst_addr + 2 * dst_stride_y));
3883 vstore8(convert_half8(c30), 0, (__global half *)(dst_addr + 3 * dst_stride_y));
3884#endif // defined(REINTERPRET_OUTPUT_AS_3D)
3885}
3886
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003887/** This OpenCL kernel optimized for Bifrost architectures computes the matrix multiplication between matrix A (src0) and matrix B (src1)
3888 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
3889 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003890 * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
3891 *
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003892 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
3893 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
3894 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
3895 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
3896 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
3897 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003898 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
3899 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
3900 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
3901 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
3902 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
3903 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003904 * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
3905 *
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003906 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
3907 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
3908 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3909 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
3910 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3911 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
3912 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
3913 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
3914 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3915 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
3916 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3917 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003918 * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
3919 * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes)
3920 * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
3921 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003922 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
3923 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
3924 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
3925 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
3926 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
3927 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003928 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003929 */
3930__kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0),
3931 IMAGE_DECLARATION(src1),
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003932#if defined(ADD_VEC_C)
3933 VECTOR_DECLARATION(src2),
3934#endif /* defined(ADD_VEC_C) */
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003935 IMAGE_DECLARATION(dst),
3936 uint src0_stride_z,
3937 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003938 uint dst_stride_z
3939#if defined(REINTERPRET_OUTPUT_AS_3D)
3940 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003941 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003942#endif // REINTERPRET_OUTPUT_AS_3D
3943 )
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003944{
3945 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
3946 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
3947 int z = get_global_id(2);
3948
3949 // Offset
3950 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
3951 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
3952
3953 // src_addr_a = address of matrix A
3954 // src_addr_b = address of matrix B
3955 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
3956 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
3957
3958#if defined(MATRIX_B_DEPTH)
3959 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
3960 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
3961#else // defined(MATRIX_B_DEPTH)
3962 src1_addr_in_bytes += z * src1_stride_z;
3963#endif // defined(MATRIX_B_DEPTH)
3964
3965 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
3966 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
3967
3968 // Compute end row address for matrix B
3969 __global half *src_end_addr_b = src_addr_b + COLS_B;
3970
3971 src_addr_a += offset_row_a;
3972 src_addr_b += offset_row_b;
3973
3974 // Reset accumulators
3975 half8 c00 = 0.0f;
3976 half8 c10 = 0.0f;
3977 half8 c20 = 0.0f;
3978 half8 c30 = 0.0f;
3979
3980#define COLS_MTX_B (COLS_B / (8 * MULT_TRANSPOSE1XW_WIDTH))
3981
3982 int i = 0;
3983 for(; i <= (int)(COLS_MTX_B - 4); i += 4)
3984 {
3985#if MULT_INTERLEAVE4X4_HEIGHT == 1
3986 // Load values from matrix A (interleaved) and matrix B (transposed)
3987 half8 a0 = vload8(0, src_addr_a);
3988 half8 b0 = vload8(0, src_addr_b);
3989
3990 src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT;
3991 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3992
3993 c00 = fma((half8)a0.s0, b0, c00);
3994 c10 = fma((half8)a0.s1, b0, c10);
3995 c20 = fma((half8)a0.s2, b0, c20);
3996 c30 = fma((half8)a0.s3, b0, c30);
3997
3998 // Load values from matrix B (transposed)
3999 b0 = vload8(0, src_addr_b);
4000
4001 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
4002
4003 c00 = fma((half8)a0.s4, b0, c00);
4004 c10 = fma((half8)a0.s5, b0, c10);
4005 c20 = fma((half8)a0.s6, b0, c20);
4006 c30 = fma((half8)a0.s7, b0, c30);
4007
4008 // Load values from matrix A (interleaved) and matrix B (transposed)
4009 a0 = vload8(0, src_addr_a);
4010 b0 = vload8(0, src_addr_b);
4011
4012 src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT;
4013 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
4014
4015 c00 = fma((half8)a0.s0, b0, c00);
4016 c10 = fma((half8)a0.s1, b0, c10);
4017 c20 = fma((half8)a0.s2, b0, c20);
4018 c30 = fma((half8)a0.s3, b0, c30);
4019
4020 // Load values from matrix B (transposed)
4021 b0 = vload8(0, src_addr_b);
4022
4023 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
4024
4025 c00 = fma((half8)a0.s4, b0, c00);
4026 c10 = fma((half8)a0.s5, b0, c10);
4027 c20 = fma((half8)a0.s6, b0, c20);
4028 c30 = fma((half8)a0.s7, b0, c30);
4029#else // MULT_INTERLEAVE4X4_HEIGHT == 1
4030 // Load values from matrix A (interleaved) and matrix B (transposed)
4031 half4 a0 = vload4(0, src_addr_a);
4032 half8 b0 = vload8(0, src_addr_b);
4033
4034 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
4035 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
4036
4037 c00 = fma((half8)a0.s0, b0, c00);
4038 c10 = fma((half8)a0.s1, b0, c10);
4039 c20 = fma((half8)a0.s2, b0, c20);
4040 c30 = fma((half8)a0.s3, b0, c30);
4041
4042 // Load values from matrix A (interleaved) and matrix B (transposed)
4043 a0 = vload4(0, src_addr_a);
4044 b0 = vload8(0, src_addr_b);
4045
4046 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
4047 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
4048
4049 c00 = fma((half8)a0.s0, b0, c00);
4050 c10 = fma((half8)a0.s1, b0, c10);
4051 c20 = fma((half8)a0.s2, b0, c20);
4052 c30 = fma((half8)a0.s3, b0, c30);
4053
4054 // Load values from matrix A (interleaved) and matrix B (transposed)
4055 a0 = vload4(0, src_addr_a);
4056 b0 = vload8(0, src_addr_b);
4057
4058 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
4059 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
4060
4061 c00 = fma((half8)a0.s0, b0, c00);
4062 c10 = fma((half8)a0.s1, b0, c10);
4063 c20 = fma((half8)a0.s2, b0, c20);
4064 c30 = fma((half8)a0.s3, b0, c30);
4065
4066 // Load values from matrix A (interleaved) and matrix B (transposed)
4067 a0 = vload4(0, src_addr_a);
4068 b0 = vload8(0, src_addr_b);
4069
4070 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
4071 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
4072
4073 c00 = fma((half8)a0.s0, b0, c00);
4074 c10 = fma((half8)a0.s1, b0, c10);
4075 c20 = fma((half8)a0.s2, b0, c20);
4076 c30 = fma((half8)a0.s3, b0, c30);
4077#endif // MULT_INTERLEAVE4X4_HEIGHT == 1
4078 }
4079
4080 for(; i < (int)(COLS_MTX_B); ++i)
4081 {
4082 // Load values from matrix A (interleaved) and matrix B (transposed)
4083 half4 a0 = vload4(0, src_addr_a);
4084 half8 b0 = vload8(0, src_addr_b);
4085
4086 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
4087 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
4088
4089 c00 = fma((half8)a0.s0, b0, c00);
4090 c10 = fma((half8)a0.s1, b0, c10);
4091 c20 = fma((half8)a0.s2, b0, c20);
4092 c30 = fma((half8)a0.s3, b0, c30);
4093 }
4094
4095 // Compute destination address
4096 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
4097
4098#if defined(ALPHA)
4099 // Multiply by the weight of matrix product
4100 c00 = c00 * (half8)ALPHA;
4101 c10 = c10 * (half8)ALPHA;
4102 c20 = c20 * (half8)ALPHA;
4103 c30 = c30 * (half8)ALPHA;
4104#endif // defined(ALPHA)
4105
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004106#if defined(ADD_VEC_C)
4107 // *INDENT-OFF*
4108 // clang-format off
4109 __global half *src2_addr = (__global half *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
4110 half8 c0 = vload8(0, src2_addr);
4111 // clang-format on
4112 // *INDENT-ON*
4113
4114 c00 += c0;
4115 c10 += c0;
4116 c20 += c0;
4117 c30 += c0;
4118#endif /* defined(ADD_VEC_C) */
4119
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004120 // Compute dst address
4121 __global uchar *dst_addr = offset(&dst, 0, 0);
4122
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004123#if defined(REINTERPRET_OUTPUT_AS_3D)
4124 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004125 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004126 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004127 // | |
4128 // | plane0 |
4129 // | |
4130 // |__________________|
4131 // |******************|
4132 // | cross_plane_pad |
4133 // |******************|
4134 // | |
4135 // | plane1 |
4136 // | |
4137 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004138
4139 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
4140 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
4141 zout = min(DEPTH_GEMM3D - 1, zout);
4142
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004143 // Add offset due to the cross plane paddings
4144 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004145
4146 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4147 // multiply dst_stride_z by DEPTH_GEMM3D
4148 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
4149
4150 // Store 4x8 block
4151 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
4152 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
4153 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
4154 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
4155
4156#else // defined(REINTERPRET_OUTPUT_AS_3D)
4157 // Add offset for batched GEMM
4158 dst_addr += z * dst_stride_z;
4159
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004160 // Store 4x8 block
4161 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
4162 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
4163 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
4164 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004165#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004166}
Georgios Pinitas84225582018-05-14 12:00:05 +01004167
4168// Undefine local defines
4169#undef COLS_MTX_B
4170
Matthew Bentham6f31f8c2017-10-27 11:50:06 +01004171#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004172
Gian Marco36a0a462018-01-12 10:21:40 +00004173#endif // defined(COLS_B) && defined(MULT_TRANSPOSE1XW_WIDTH) && defined(MULT_INTERLEAVE4X4_HEIGHT)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01004174
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004175#if defined(COLS_A) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && (NUM_ELEMS_PROCESSED_PER_THREAD_Y)
4176#if defined(DATA_TYPE)
4177#define VECTOR_TYPE VEC_DATA_TYPE(DATA_TYPE, NUM_ELEMS_PROCESSED_PER_THREAD_X)
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004178/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped.
4179 *
4180 * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004181 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004182 * @note This OpenCL kernel works with floating point data types (F16/F32)
4183 * @note The floating point data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
4184 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004185 * @note The number of matrix A columns and the optional alpha's value need to be passed at compile time using -DCOLS_A and -DALPHA
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00004186 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
4187 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004188 *
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004189 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
4190 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004191 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
4192 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
4193 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
4194 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
4195 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004196 * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
4197 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004198 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004199 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
4200 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4201 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
4202 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4203 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01004204 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004205 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
4206 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4207 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
4208 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4209 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004210 * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
4211 * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes)
4212 * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
4213 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01004214 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004215 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
4216 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
4217 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
4218 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
4219 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004220 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
4221 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
4222 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004223 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
4224 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements for the output tensor (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004225 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004226__kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
4227 IMAGE_DECLARATION(src1),
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004228#if defined(ADD_VEC_C)
4229 VECTOR_DECLARATION(src2),
4230#endif /* defined(ADD_VEC_C) */
Gian Marcoae2af742018-02-15 12:35:44 +00004231 IMAGE_DECLARATION(dst),
4232 uint src0_stride_z,
4233 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004234 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004235#if defined(REINTERPRET_INPUT_AS_3D)
4236 ,
4237 uint src_cross_plane_pad
4238#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004239#if defined(REINTERPRET_OUTPUT_AS_3D)
4240 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004241 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004242#endif // REINTERPRET_OUTPUT_AS_3D
4243 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004244{
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004245 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004246
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004247 // Compute starting address for matrix A and Matrix B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004248 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004249
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004250 // Update address for the matrix A
4251 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004252
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004253 // Update address for the matrix B
4254 src_addr.s1 += idx * sizeof(DATA_TYPE);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004255
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004256#if defined(REINTERPRET_INPUT_AS_3D)
4257 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
4258 // in order to take into account the presence of possible cross plane paddings
4259 //
4260 // | |
4261 // | plane0 |
4262 // | |
4263 // |__________________|
4264 // |******************|
4265 // | cross_plane_pad |
4266 // |******************|
4267 // | |
4268 // | plane1 |
4269 // | |
4270 // |__________________|
4271
4272 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
4273 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
4274 zin = min(DEPTH_GEMM3D - 1, zin);
4275
4276 // Add offset due to the cross plane paddings
4277 zin *= (src_cross_plane_pad * src0_stride_y);
4278
4279 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4280 // multiply src0_stride_z by DEPTH_GEMM3D
4281 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
4282
4283#else // defined(REINTERPRET_INPUT_AS_3D)
4284
Gian Marcoae2af742018-02-15 12:35:44 +00004285 // Add offset for batched GEMM
4286 src_addr.s0 += get_global_id(2) * src0_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00004287
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004288#endif // defined(REINTERPRET_INPUT_AS_3D)
4289
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00004290#if defined(MATRIX_B_DEPTH)
4291 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
4292 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
4293#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00004294 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00004295#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00004296
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004297 int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(DATA_TYPE));
4298
4299 VECTOR_TYPE acc0 = 0.0f;
4300#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4301 VECTOR_TYPE acc1 = 0.0f;
4302#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4303#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4304 VECTOR_TYPE acc2 = 0.0f;
4305#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4306#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4307 VECTOR_TYPE acc3 = 0.0f;
4308#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4309
Georgios Pinitas96880cf2017-10-20 18:52:20 +01004310 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(DATA_TYPE)); src_addr += (int2)(2 * sizeof(DATA_TYPE), 2 * src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004311 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004312#if defined(REINTERPRET_INPUT_AS_3D)
4313 // Load values from matrix A
4314 VEC_DATA_TYPE(DATA_TYPE, 2)
4315 a0 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
4316#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4317 VEC_DATA_TYPE(DATA_TYPE, 2)
4318 a1 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
4319#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4320#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4321 VEC_DATA_TYPE(DATA_TYPE, 2)
4322 a2 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
4323#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4324#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4325 VEC_DATA_TYPE(DATA_TYPE, 2)
4326 a3 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
4327#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4328#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004329 // Load values from matrix A
4330 VEC_DATA_TYPE(DATA_TYPE, 2)
4331 a0 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
4332#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4333 VEC_DATA_TYPE(DATA_TYPE, 2)
4334 a1 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
4335#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4336#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4337 VEC_DATA_TYPE(DATA_TYPE, 2)
4338 a2 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
4339#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4340#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4341 VEC_DATA_TYPE(DATA_TYPE, 2)
4342 a3 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
4343#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004344#endif // defined(REINTERPRET_INPUT_AS_3D)
4345
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004346 // Load values from matrix B
4347 VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
4348 VECTOR_TYPE b1 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1 + src1_stride_y));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004349
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004350 // Accumulate
4351 acc0 += b0 * (VECTOR_TYPE)a0.s0;
4352 acc0 += b1 * (VECTOR_TYPE)a0.s1;
4353#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4354 acc1 += b0 * (VECTOR_TYPE)a1.s0;
4355 acc1 += b1 * (VECTOR_TYPE)a1.s1;
4356#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4357#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4358 acc2 += b0 * (VECTOR_TYPE)a2.s0;
4359 acc2 += b1 * (VECTOR_TYPE)a2.s1;
4360#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4361#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4362 acc3 += b0 * (VECTOR_TYPE)a3.s0;
4363 acc3 += b1 * (VECTOR_TYPE)a3.s1;
4364#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004365 }
4366
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004367 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(DATA_TYPE), src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004368 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004369#if defined(REINTERPRET_INPUT_AS_3D)
4370 // Load values from matrix A
4371 DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
4372#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4373 DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
4374#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4375#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4376 DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
4377#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4378#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4379 DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
4380#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4381#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004382 // Load values from matrix A
4383 DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
4384#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4385 DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
4386#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4387#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4388 DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
4389#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4390#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4391 DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
4392#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004393#endif // defined(REINTERPRET_INPUT_AS_3D)
4394
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004395 // Load values from matrix B
4396 VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004397
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004398 // Accumulate
4399 acc0 += b0 * (VECTOR_TYPE)a0;
4400#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4401 acc1 += b0 * (VECTOR_TYPE)a1;
4402#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4403#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4404 acc2 += b0 * (VECTOR_TYPE)a2;
4405#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4406#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4407 acc3 += b0 * (VECTOR_TYPE)a3;
4408#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004409 }
4410
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004411 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004412 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
4413
Gian Marcoae2af742018-02-15 12:35:44 +00004414 // Compute dst address
4415 __global uchar *dst_addr = offset(&dst, 0, 0);
4416
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004417 // Multiply by the weight of matrix-matrix product and store the result
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004418#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004419 acc0 = acc0 * (VECTOR_TYPE)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004420#endif // defined(ALPHA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004421#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
4422 acc1 = acc1 * (VECTOR_TYPE)ALPHA;
4423#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
4424#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
4425 acc2 = acc2 * (VECTOR_TYPE)ALPHA;
4426#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
4427#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
4428 acc3 = acc3 * (VECTOR_TYPE)ALPHA;
4429#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
4430
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004431#if defined(ADD_VEC_C)
4432 // *INDENT-OFF*
4433 // clang-format off
4434 __global DATA_TYPE *src2_addr = (__global DATA_TYPE *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
4435 VECTOR_TYPE c0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, src2_addr);
4436 // clang-format on
4437 // *INDENT-ON*
4438
4439 acc0 += c0;
4440#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4441 acc1 += c0;
4442#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4443#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4444 acc2 += c0;
4445#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4446#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4447 acc3 += c0;
4448#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4449#endif /* defined(ADD_VEC_C) */
4450
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004451 int z = get_global_id(2);
4452
4453#if defined(REINTERPRET_OUTPUT_AS_3D)
4454 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004455 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004456 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004457 // | |
4458 // | plane0 |
4459 // | |
4460 // |__________________|
4461 // |******************|
4462 // | cross_plane_pad |
4463 // |******************|
4464 // | |
4465 // | plane1 |
4466 // | |
4467 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004468
4469 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
4470 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
4471 zout = min(DEPTH_GEMM3D - 1, zout);
4472
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004473 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004474 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004475
4476 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4477 // multiply dst_stride_z by DEPTH_GEMM3D
4478 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
4479
4480 // Store output block
4481 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
4482 (acc0, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y + zout.s0));
4483#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4484 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
4485 (acc1, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y + zout.s1));
4486#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4487#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4488 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
4489 (acc2, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y + zout.s2));
4490#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4491#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4492 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
4493 (acc3, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y + zout.s3));
4494#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4495
4496#else // defined(REINTERPRET_OUTPUT_AS_3D)
4497 // Add offset for batched GEMM
4498 dst_addr += z * dst_stride_z;
4499
4500 // Store output block
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004501 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00004502 (acc0, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004503#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004504 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00004505 (acc1, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004506#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4507#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004508 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00004509 (acc2, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004510#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4511#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004512 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00004513 (acc3, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004514#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004515#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004516}
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004517#endif // defined(DATA_TYPE)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01004518
Michele Di Giorgiof6f08da2018-04-26 10:24:30 +01004519/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004520 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004521 * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
4522 *
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004523 * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
4524 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
4525 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
4526 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
4527 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00004528 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
4529 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004530 *
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004531 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
4532 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004533 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
4534 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
4535 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
4536 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
4537 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004538 * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
4539 *
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004540 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
4541 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
4542 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4543 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
4544 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4545 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
4546 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
4547 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
4548 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4549 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
4550 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4551 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004552 * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
4553 * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes)
4554 * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
4555 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004556 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
4557 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
4558 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
4559 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
4560 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
4561 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004562 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
4563 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
4564 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004565 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
4566 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004567 */
4568__kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
4569 IMAGE_DECLARATION(src1),
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004570#if defined(ADD_VEC_C)
4571 VECTOR_DECLARATION(src2),
4572#endif /* defined(ADD_VEC_C) */
Gian Marcoae2af742018-02-15 12:35:44 +00004573 IMAGE_DECLARATION(dst),
4574 uint src0_stride_z,
4575 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004576 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004577#if defined(REINTERPRET_INPUT_AS_3D)
4578 ,
4579 uint src_cross_plane_pad
4580#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004581#if defined(REINTERPRET_OUTPUT_AS_3D)
4582 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004583 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004584#endif // REINTERPRET_OUTPUT_AS_3D
4585 )
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004586{
4587 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
4588
4589 // Compute starting address for matrix A and matrix B
4590 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
4591
4592 // Update address for matrix A
4593 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
4594
4595 // Update address for matrix B
4596 src_addr.s1 += idx * sizeof(float);
4597
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004598#if defined(REINTERPRET_INPUT_AS_3D)
4599 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
4600 // in order to take into account the presence of possible cross plane paddings
4601 //
4602 // | |
4603 // | plane0 |
4604 // | |
4605 // |__________________|
4606 // |******************|
4607 // | cross_plane_pad |
4608 // |******************|
4609 // | |
4610 // | plane1 |
4611 // | |
4612 // |__________________|
4613
4614 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
4615 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
4616 zin = min(DEPTH_GEMM3D - 1, zin);
4617
4618 // Add offset due to the cross plane paddings
4619 zin *= (src_cross_plane_pad * src0_stride_y);
4620
4621 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4622 // multiply src0_stride_z by DEPTH_GEMM3D
4623 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
4624
4625#else // defined(REINTERPRET_INPUT_AS_3D)
4626
Gian Marcoae2af742018-02-15 12:35:44 +00004627 // Add offset for batched GEMM
4628 src_addr.s0 += get_global_id(2) * src0_stride_z;
4629
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004630#endif // defined(REINTERPRET_INPUT_AS_3D)
4631
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00004632#if defined(MATRIX_B_DEPTH)
4633 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
4634 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
4635#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00004636 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00004637#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00004638
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004639 // Initialize accumulators
4640 float acc00 = 0.0f;
4641 float acc01 = 0.0f;
4642 float acc02 = 0.0f;
4643 float acc03 = 0.0f;
4644
4645#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4646 float acc10 = 0.0f;
4647 float acc11 = 0.0f;
4648 float acc12 = 0.0f;
4649 float acc13 = 0.0f;
4650#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4651
4652#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4653 float acc20 = 0.0f;
4654 float acc21 = 0.0f;
4655 float acc22 = 0.0f;
4656 float acc23 = 0.0f;
4657#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4658
4659#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4660 float acc30 = 0.0f;
4661 float acc31 = 0.0f;
4662 float acc32 = 0.0f;
4663 float acc33 = 0.0f;
4664#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4665
4666 // A and B src indices get incremented at the same time.
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004667 int i = 0;
4668 for(; i <= ((int)COLS_A - 4); i += 4)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004669 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004670#if defined(REINTERPRET_INPUT_AS_3D)
4671 // Load values from matrix A and matrix B
4672 float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
4673#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4674 float4 a1 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
4675#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4676#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4677 float4 a2 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
4678#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4679#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4680 float4 a3 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
4681#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4682#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004683 // Load values from matrix A and matrix B
4684 float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004685#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004686 float4 a1 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004687#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4688#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004689 float4 a2 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004690#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4691#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004692 float4 a3 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004693#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004694#endif // defined(REINTERPRET_INPUT_AS_3D)
4695
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004696 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
4697 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004698
4699 // Multiply and accumulate
4700 acc00 = fma(a0.s0, b0.s0, acc00);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004701 acc01 = fma(a0.s0, b0.s1, acc01);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004702 acc02 = fma(a0.s0, b0.s2, acc02);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004703 acc03 = fma(a0.s0, b0.s3, acc03);
4704
4705#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004706
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004707 acc10 = fma(a1.s0, b0.s0, acc10);
4708 acc11 = fma(a1.s0, b0.s1, acc11);
4709 acc12 = fma(a1.s0, b0.s2, acc12);
4710 acc13 = fma(a1.s0, b0.s3, acc13);
4711
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004712#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4713#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004714
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004715 acc20 = fma(a2.s0, b0.s0, acc20);
4716 acc21 = fma(a2.s0, b0.s1, acc21);
4717 acc22 = fma(a2.s0, b0.s2, acc22);
4718 acc23 = fma(a2.s0, b0.s3, acc23);
4719
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004720#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4721#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004722
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004723 acc30 = fma(a3.s0, b0.s0, acc30);
4724 acc31 = fma(a3.s0, b0.s1, acc31);
4725 acc32 = fma(a3.s0, b0.s2, acc32);
4726 acc33 = fma(a3.s0, b0.s3, acc33);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004727#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004728
4729 // Load values from matrix A and matrix B
4730 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
4731 src_addr.s1 += src1_stride_y;
4732
4733 // Multiply and accumulate
4734 acc00 = fma(a0.s1, b0.s0, acc00);
4735 acc01 = fma(a0.s1, b0.s1, acc01);
4736 acc02 = fma(a0.s1, b0.s2, acc02);
4737 acc03 = fma(a0.s1, b0.s3, acc03);
4738
4739#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4740
4741 acc10 = fma(a1.s1, b0.s0, acc10);
4742 acc11 = fma(a1.s1, b0.s1, acc11);
4743 acc12 = fma(a1.s1, b0.s2, acc12);
4744 acc13 = fma(a1.s1, b0.s3, acc13);
4745
4746#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4747#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4748
4749 acc20 = fma(a2.s1, b0.s0, acc20);
4750 acc21 = fma(a2.s1, b0.s1, acc21);
4751 acc22 = fma(a2.s1, b0.s2, acc22);
4752 acc23 = fma(a2.s1, b0.s3, acc23);
4753
4754#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4755#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4756
4757 acc30 = fma(a3.s1, b0.s0, acc30);
4758 acc31 = fma(a3.s1, b0.s1, acc31);
4759 acc32 = fma(a3.s1, b0.s2, acc32);
4760 acc33 = fma(a3.s1, b0.s3, acc33);
4761#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4762
4763 // Load values from matrix A and matrix B
4764 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
4765 src_addr.s1 += src1_stride_y;
4766
4767 // Multiply and accumulate
4768 acc00 = fma(a0.s2, b0.s0, acc00);
4769 acc01 = fma(a0.s2, b0.s1, acc01);
4770 acc02 = fma(a0.s2, b0.s2, acc02);
4771 acc03 = fma(a0.s2, b0.s3, acc03);
4772
4773#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4774
4775 acc10 = fma(a1.s2, b0.s0, acc10);
4776 acc11 = fma(a1.s2, b0.s1, acc11);
4777 acc12 = fma(a1.s2, b0.s2, acc12);
4778 acc13 = fma(a1.s2, b0.s3, acc13);
4779
4780#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4781#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4782
4783 acc20 = fma(a2.s2, b0.s0, acc20);
4784 acc21 = fma(a2.s2, b0.s1, acc21);
4785 acc22 = fma(a2.s2, b0.s2, acc22);
4786 acc23 = fma(a2.s2, b0.s3, acc23);
4787
4788#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4789#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4790
4791 acc30 = fma(a3.s2, b0.s0, acc30);
4792 acc31 = fma(a3.s2, b0.s1, acc31);
4793 acc32 = fma(a3.s2, b0.s2, acc32);
4794 acc33 = fma(a3.s2, b0.s3, acc33);
4795#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4796
4797 // Load values from matrix A and matrix B
4798 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
4799 src_addr.s1 += src1_stride_y;
4800
4801 // Multiply and accumulate
4802 acc00 = fma(a0.s3, b0.s0, acc00);
4803 acc01 = fma(a0.s3, b0.s1, acc01);
4804 acc02 = fma(a0.s3, b0.s2, acc02);
4805 acc03 = fma(a0.s3, b0.s3, acc03);
4806
4807#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4808
4809 acc10 = fma(a1.s3, b0.s0, acc10);
4810 acc11 = fma(a1.s3, b0.s1, acc11);
4811 acc12 = fma(a1.s3, b0.s2, acc12);
4812 acc13 = fma(a1.s3, b0.s3, acc13);
4813
4814#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4815#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4816
4817 acc20 = fma(a2.s3, b0.s0, acc20);
4818 acc21 = fma(a2.s3, b0.s1, acc21);
4819 acc22 = fma(a2.s3, b0.s2, acc22);
4820 acc23 = fma(a2.s3, b0.s3, acc23);
4821
4822#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4823#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4824
4825 acc30 = fma(a3.s3, b0.s0, acc30);
4826 acc31 = fma(a3.s3, b0.s1, acc31);
4827 acc32 = fma(a3.s3, b0.s2, acc32);
4828 acc33 = fma(a3.s3, b0.s3, acc33);
4829#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4830
4831 src_addr.s0 += 4 * sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004832 }
4833
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004834 for(; i < (int)COLS_A; ++i)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004835 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004836#if defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004837 // Load values from matrix A
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004838 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
4839#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4840 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
4841#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4842#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4843 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
4844#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4845#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4846 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
4847#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4848#else // defined(REINTERPRET_INPUT_AS_3D)
4849 // Load values from matrix A
4850 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004851#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4852 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
4853#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4854#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4855 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
4856#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4857#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4858 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
4859#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004860#endif // defined(REINTERPRET_INPUT_AS_3D)
4861
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004862 // Load values from matrix B
4863 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004864 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004865
4866 // Multiply and accumulate
4867 acc00 = fma(a0, b0.s0, acc00);
4868 acc01 = fma(a0, b0.s1, acc01);
4869 acc02 = fma(a0, b0.s2, acc02);
4870 acc03 = fma(a0, b0.s3, acc03);
4871#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4872 acc10 = fma(a1, b0.s0, acc10);
4873 acc11 = fma(a1, b0.s1, acc11);
4874 acc12 = fma(a1, b0.s2, acc12);
4875 acc13 = fma(a1, b0.s3, acc13);
4876#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4877#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4878 acc20 = fma(a2, b0.s0, acc20);
4879 acc21 = fma(a2, b0.s1, acc21);
4880 acc22 = fma(a2, b0.s2, acc22);
4881 acc23 = fma(a2, b0.s3, acc23);
4882#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4883#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4884 acc30 = fma(a3, b0.s0, acc30);
4885 acc31 = fma(a3, b0.s1, acc31);
4886 acc32 = fma(a3, b0.s2, acc32);
4887 acc33 = fma(a3, b0.s3, acc33);
4888#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004889
4890 src_addr.s0 += sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004891 }
4892
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004893 int z = get_global_id(2);
4894
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004895 // Compute destination address
4896 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
4897
4898 // Multiply by the weight of matrix-matrix product and store the result
4899#if defined(ALPHA)
4900 acc00 = acc00 * ALPHA;
4901 acc01 = acc01 * ALPHA;
4902 acc02 = acc02 * ALPHA;
4903 acc03 = acc03 * ALPHA;
4904#endif // defined(ALPHA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004905#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004906 acc10 = acc10 * ALPHA;
4907 acc11 = acc11 * ALPHA;
4908 acc12 = acc12 * ALPHA;
4909 acc13 = acc13 * ALPHA;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004910#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
4911#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004912 acc20 = acc20 * ALPHA;
4913 acc21 = acc21 * ALPHA;
4914 acc22 = acc22 * ALPHA;
4915 acc23 = acc23 * ALPHA;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004916#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
4917#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004918 acc30 = acc30 * ALPHA;
4919 acc31 = acc31 * ALPHA;
4920 acc32 = acc32 * ALPHA;
4921 acc33 = acc33 * ALPHA;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004922#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
4923
4924 // Compute dst address
4925 __global uchar *dst_addr = offset(&dst, 0, 0);
4926
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004927#if defined(ADD_VEC_C)
4928 __global float *src2_addr = (__global float *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
4929 float4 c0 = vload4(0, src2_addr);
4930
4931 acc00 += c0.s0;
4932 acc01 += c0.s1;
4933 acc02 += c0.s2;
4934 acc03 += c0.s3;
4935#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4936 acc10 += c0.s0;
4937 acc11 += c0.s1;
4938 acc12 += c0.s2;
4939 acc13 += c0.s3;
4940#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4941#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4942 acc20 += c0.s0;
4943 acc21 += c0.s1;
4944 acc22 += c0.s2;
4945 acc23 += c0.s3;
4946#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4947#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4948 acc30 += c0.s0;
4949 acc31 += c0.s1;
4950 acc32 += c0.s2;
4951 acc33 += c0.s3;
4952#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4953#endif /* defined(ADD_VEC_C) */
4954
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004955#if defined(REINTERPRET_OUTPUT_AS_3D)
4956 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004957 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004958 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004959 // | |
4960 // | plane0 |
4961 // | |
4962 // |__________________|
4963 // |******************|
4964 // | cross_plane_pad |
4965 // |******************|
4966 // | |
4967 // | plane1 |
4968 // | |
4969 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004970
4971 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
4972 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
4973 zout = min(DEPTH_GEMM3D - 1, zout);
4974
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004975 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004976 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004977
4978 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4979 // multiply dst_stride_z by DEPTH_GEMM3D
4980 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
4981
4982 // Store the output block
4983 vstore4((float4)(acc00, acc01, acc02, acc03), 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
4984#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4985 vstore4((float4)(acc10, acc11, acc12, acc13), 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
4986#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4987#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4988 vstore4((float4)(acc20, acc21, acc22, acc23), 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
4989#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4990#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4991 vstore4((float4)(acc30, acc31, acc32, acc33), 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004992#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004993
4994#else // defined(REINTERPRET_OUTPUT_AS_3D)
4995 // Add offset for batched GEMM
4996 dst_addr += z * dst_stride_z;
4997
4998 // Store the output block
4999 vstore4((float4)(acc00, acc01, acc02, acc03), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
5000#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5001 vstore4((float4)(acc10, acc11, acc12, acc13), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
5002#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5003#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5004 vstore4((float4)(acc20, acc21, acc22, acc23), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
5005#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5006#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5007 vstore4((float4)(acc30, acc31, acc32, acc33), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
5008#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5009#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005010}
5011
5012/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
5013 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00005014 * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
5015 *
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005016 * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
5017 * This OpenCL kernel is optimized for Bifrost when the number of matrix B columns is less or equal to 1000.
5018 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
5019 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=2.
5020 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
5021 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha if alpha!=1.0f.
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00005022 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
5023 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005024 *
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005025 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
5026 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005027 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
5028 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
5029 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
5030 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
5031 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00005032 * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
5033 *
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005034 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
5035 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
5036 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5037 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
5038 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
5039 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
5040 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
5041 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
5042 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5043 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
5044 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
5045 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00005046 * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
5047 * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes)
5048 * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
5049 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005050 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
5051 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
5052 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
5053 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
5054 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
5055 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005056 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
5057 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
5058 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005059 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
5060 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005061 */
5062__kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
5063 IMAGE_DECLARATION(src1),
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00005064#if defined(ADD_VEC_C)
5065 VECTOR_DECLARATION(src2),
5066#endif /* defined(ADD_VEC_C) */
Gian Marcoae2af742018-02-15 12:35:44 +00005067 IMAGE_DECLARATION(dst),
5068 uint src0_stride_z,
5069 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005070 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005071#if defined(REINTERPRET_INPUT_AS_3D)
5072 ,
5073 uint src_cross_plane_pad
5074#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005075#if defined(REINTERPRET_OUTPUT_AS_3D)
5076 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005077 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005078#endif // REINTERPRET_OUTPUT_AS_3D
5079 )
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005080{
5081 // Requires 2 NUM_ELEMS_PROCESSED_PER_THREAD_X, C vect2, A vect4, B (2 vload2) // to fix for NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5082 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
5083
5084 // Compute starting address for matrix A and Matrix B
5085 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
5086
5087 // Update address for the matrix A
5088 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
5089
5090 // Update address for the matrix B
5091 src_addr.s1 += idx * sizeof(float);
5092
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005093#if defined(REINTERPRET_INPUT_AS_3D)
5094 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
5095 // in order to take into account the presence of possible cross plane paddings
5096 //
5097 // | |
5098 // | plane0 |
5099 // | |
5100 // |__________________|
5101 // |******************|
5102 // | cross_plane_pad |
5103 // |******************|
5104 // | |
5105 // | plane1 |
5106 // | |
5107 // |__________________|
5108
5109 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
5110 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
5111 zin = min(DEPTH_GEMM3D - 1, zin);
5112
5113 // Add offset due to the cross plane paddings
5114 zin *= (src_cross_plane_pad * src0_stride_y);
5115
5116 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
5117 // multiply src0_stride_z by DEPTH_GEMM3D
5118 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
5119
5120#else // defined(REINTERPRET_INPUT_AS_3D)
5121
Gian Marcoae2af742018-02-15 12:35:44 +00005122 // Add offset for batched GEMM
5123 src_addr.s0 += get_global_id(2) * src0_stride_z;
5124
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005125#endif // defined(REINTERPRET_INPUT_AS_3D)
5126
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00005127#if defined(MATRIX_B_DEPTH)
5128 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
5129 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
5130#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00005131 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00005132#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00005133
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005134 // Initialize accumulators
5135 float acc00 = 0.0f;
5136 float acc01 = 0.0f;
5137
5138#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5139 float acc10 = 0.0f;
5140 float acc11 = 0.0f;
5141#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5142#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5143 float acc20 = 0.0f;
5144 float acc21 = 0.0f;
5145#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5146#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5147 float acc30 = 0.0f;
5148 float acc31 = 0.0f;
5149#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5150
5151 // A and B src indices get incremented at the same time.
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005152 int i = 0;
5153 for(; i <= ((int)COLS_A - 8); i += 8)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005154 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005155#if defined(REINTERPRET_INPUT_AS_3D)
5156 // Load values from matrix A
5157 float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + zin.s0));
5158#else // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005159 // Load values from matrix A
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005160 float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0));
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005161#endif // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005162
5163 // Load values from matrix B
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005164 float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
5165 src_addr.s1 += src1_stride_y;
5166 float2 b1 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
5167 src_addr.s1 += src1_stride_y;
5168 float2 b2 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
5169 src_addr.s1 += src1_stride_y;
5170 float2 b3 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
5171 src_addr.s1 += src1_stride_y;
5172 float2 b4 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
5173 src_addr.s1 += src1_stride_y;
5174 float2 b5 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
5175 src_addr.s1 += src1_stride_y;
5176 float2 b6 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
5177 src_addr.s1 += src1_stride_y;
5178 float2 b7 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
5179 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005180
5181 // Multiply and accumulate
5182 acc00 = fma(a0.s0, b0.s0, acc00);
5183 acc00 = fma(a0.s1, b1.s0, acc00);
5184 acc00 = fma(a0.s2, b2.s0, acc00);
5185 acc00 = fma(a0.s3, b3.s0, acc00);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005186 acc00 = fma(a0.s4, b4.s0, acc00);
5187 acc00 = fma(a0.s5, b5.s0, acc00);
5188 acc00 = fma(a0.s6, b6.s0, acc00);
5189 acc00 = fma(a0.s7, b7.s0, acc00);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005190
5191 acc01 = fma(a0.s0, b0.s1, acc01);
5192 acc01 = fma(a0.s1, b1.s1, acc01);
5193 acc01 = fma(a0.s2, b2.s1, acc01);
5194 acc01 = fma(a0.s3, b3.s1, acc01);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005195 acc01 = fma(a0.s4, b4.s1, acc01);
5196 acc01 = fma(a0.s5, b5.s1, acc01);
5197 acc01 = fma(a0.s6, b6.s1, acc01);
5198 acc01 = fma(a0.s7, b7.s1, acc01);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005199
5200#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005201#if defined(REINTERPRET_INPUT_AS_3D)
5202 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
5203#else // defined(REINTERPRET_INPUT_AS_3D)
5204 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
5205#endif // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005206 acc10 = fma(a0.s0, b0.s0, acc10);
5207 acc10 = fma(a0.s1, b1.s0, acc10);
5208 acc10 = fma(a0.s2, b2.s0, acc10);
5209 acc10 = fma(a0.s3, b3.s0, acc10);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005210 acc10 = fma(a0.s4, b4.s0, acc10);
5211 acc10 = fma(a0.s5, b5.s0, acc10);
5212 acc10 = fma(a0.s6, b6.s0, acc10);
5213 acc10 = fma(a0.s7, b7.s0, acc10);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005214
5215 acc11 = fma(a0.s0, b0.s1, acc11);
5216 acc11 = fma(a0.s1, b1.s1, acc11);
5217 acc11 = fma(a0.s2, b2.s1, acc11);
5218 acc11 = fma(a0.s3, b3.s1, acc11);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005219 acc11 = fma(a0.s4, b4.s1, acc11);
5220 acc11 = fma(a0.s5, b5.s1, acc11);
5221 acc11 = fma(a0.s6, b6.s1, acc11);
5222 acc11 = fma(a0.s7, b7.s1, acc11);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005223#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5224#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005225#if defined(REINTERPRET_INPUT_AS_3D)
5226 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
5227#else // defined(REINTERPRET_INPUT_AS_3D)
5228 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
5229#endif // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005230 acc20 = fma(a0.s0, b0.s0, acc20);
5231 acc20 = fma(a0.s1, b1.s0, acc20);
5232 acc20 = fma(a0.s2, b2.s0, acc20);
5233 acc20 = fma(a0.s3, b3.s0, acc20);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005234 acc20 = fma(a0.s4, b4.s0, acc20);
5235 acc20 = fma(a0.s5, b5.s0, acc20);
5236 acc20 = fma(a0.s6, b6.s0, acc20);
5237 acc20 = fma(a0.s7, b7.s0, acc20);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005238
5239 acc21 = fma(a0.s0, b0.s1, acc21);
5240 acc21 = fma(a0.s1, b1.s1, acc21);
5241 acc21 = fma(a0.s2, b2.s1, acc21);
5242 acc21 = fma(a0.s3, b3.s1, acc21);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005243 acc21 = fma(a0.s4, b4.s1, acc21);
5244 acc21 = fma(a0.s5, b5.s1, acc21);
5245 acc21 = fma(a0.s6, b6.s1, acc21);
5246 acc21 = fma(a0.s7, b7.s1, acc21);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005247#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5248#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005249#if defined(REINTERPRET_INPUT_AS_3D)
5250 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
5251#else // defined(REINTERPRET_INPUT_AS_3D)
5252 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
5253#endif // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005254 acc30 = fma(a0.s0, b0.s0, acc30);
5255 acc30 = fma(a0.s1, b1.s0, acc30);
5256 acc30 = fma(a0.s2, b2.s0, acc30);
5257 acc30 = fma(a0.s3, b3.s0, acc30);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005258 acc30 = fma(a0.s4, b4.s0, acc30);
5259 acc30 = fma(a0.s5, b5.s0, acc30);
5260 acc30 = fma(a0.s6, b6.s0, acc30);
5261 acc30 = fma(a0.s7, b7.s0, acc30);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005262
5263 acc31 = fma(a0.s0, b0.s1, acc31);
5264 acc31 = fma(a0.s1, b1.s1, acc31);
5265 acc31 = fma(a0.s2, b2.s1, acc31);
5266 acc31 = fma(a0.s3, b3.s1, acc31);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005267 acc31 = fma(a0.s4, b4.s1, acc31);
5268 acc31 = fma(a0.s5, b5.s1, acc31);
5269 acc31 = fma(a0.s6, b6.s1, acc31);
5270 acc31 = fma(a0.s7, b7.s1, acc31);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005271#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005272
5273 src_addr.s0 += sizeof(float) * 8;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005274 }
5275 // float size increment
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005276 for(; i < (int)COLS_A; ++i)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005277 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005278#if defined(REINTERPRET_INPUT_AS_3D)
5279 // Load values from matrix A
5280 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
5281#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5282 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
5283#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5284#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5285 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
5286#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5287#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5288 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
5289#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5290#else // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005291 // Load values from matrix A
5292 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
5293#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5294 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
5295#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5296#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5297 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
5298#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5299#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5300 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
5301#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005302#endif // defined(REINTERPRET_INPUT_AS_3D)
5303
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005304 // Load values from matrix B
5305 float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005306 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005307
5308 // Multiply and accumulate
5309 acc00 = fma(a0, b0.s0, acc00);
5310 acc01 = fma(a0, b0.s1, acc01);
5311#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5312 acc10 = fma(a1, b0.s0, acc10);
5313 acc11 = fma(a1, b0.s1, acc11);
5314#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5315#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5316 acc20 = fma(a2, b0.s0, acc20);
5317 acc21 = fma(a2, b0.s1, acc21);
5318#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5319#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5320 acc30 = fma(a3, b0.s0, acc30);
5321 acc31 = fma(a3, b0.s1, acc31);
5322#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005323
5324 src_addr.s0 += sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005325 }
5326
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005327 // Multiply by the weight of matrix-matrix product and store the result
5328#if defined(ALPHA)
5329 acc00 = acc00 * ALPHA;
5330 acc01 = acc01 * ALPHA;
5331#endif // defined(ALPHA)
5332#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
5333 acc10 = acc10 * ALPHA;
5334 acc11 = acc11 * ALPHA;
5335#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
5336#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
5337 acc20 = acc20 * ALPHA;
5338 acc21 = acc21 * ALPHA;
5339#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
5340#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
5341 acc30 = acc30 * ALPHA;
5342 acc31 = acc31 * ALPHA;
5343#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
5344
5345 int z = get_global_id(2);
5346
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005347 // Compute destination address
5348 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
5349
Gian Marcoae2af742018-02-15 12:35:44 +00005350 // Compute dst address
5351 __global uchar *dst_addr = offset(&dst, 0, 0);
5352
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00005353#if defined(ADD_VEC_C)
5354 __global float *src2_addr = (__global float *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
5355 float2 c0 = vload2(0, src2_addr);
5356
5357 acc00 += c0.s0;
5358 acc01 += c0.s1;
5359#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5360 acc10 += c0.s0;
5361 acc11 += c0.s1;
5362#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5363#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5364 acc20 += c0.s0;
5365 acc21 += c0.s1;
5366#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5367#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5368 acc30 += c0.s0;
5369 acc31 += c0.s1;
5370#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5371#endif /* defined(ADD_VEC_C) */
5372
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005373#if defined(REINTERPRET_OUTPUT_AS_3D)
5374 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01005375 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005376 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01005377 // | |
5378 // | plane0 |
5379 // | |
5380 // |__________________|
5381 // |******************|
5382 // | cross_plane_pad |
5383 // |******************|
5384 // | |
5385 // | plane1 |
5386 // | |
5387 // |__________________|
Gian Marcoae2af742018-02-15 12:35:44 +00005388
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005389 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
5390 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
5391 zout = min(DEPTH_GEMM3D - 1, zout);
5392
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01005393 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005394 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005395
5396 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
5397 // multiply dst_stride_z by DEPTH_GEMM3D
5398 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
5399
5400 // Store the output block
5401 vstore2((float2)(acc00, acc01), 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005402#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005403 vstore2((float2)(acc10, acc11), 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005404#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5405#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005406 vstore2((float2)(acc20, acc21), 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005407#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5408#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005409 vstore2((float2)(acc30, acc31), 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005410#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005411
5412#else // defined(REINTERPRET_OUTPUT_AS_3D)
5413 // Add offset for batched GEMM
5414 dst_addr += z * dst_stride_z;
5415
5416 // Store the output block
5417 vstore2((float2)(acc00, acc01), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
5418#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5419 vstore2((float2)(acc10, acc11), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
5420#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5421#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5422 vstore2((float2)(acc20, acc21), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
5423#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5424#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5425 vstore2((float2)(acc30, acc31), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
5426#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5427#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005428}
5429
Vidhya Sudhan Loganathanbdff4912018-05-22 15:03:09 +01005430#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005431/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
5432 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00005433 * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
5434 *
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005435 * @note This OpenCL kernel works with the 16-bit floating point data type (half) and accumulating the result in a 32 floating point variable.
5436 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
5437 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
5438 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
5439 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
5440 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
5441 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
5442 *
5443 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
5444 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
5445 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
5446 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
5447 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
5448 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
5449 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00005450 * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
5451 *
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005452 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
5453 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
5454 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5455 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
5456 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
5457 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
5458 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
5459 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
5460 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5461 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
5462 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
5463 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00005464 * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
5465 * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes)
5466 * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
5467 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005468 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
5469 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
5470 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
5471 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
5472 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
5473 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
5474 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
5475 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
5476 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
5477 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
5478 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
5479 */
5480__kernel void gemm_mm_floating_point_f16_bifrost_acc32(IMAGE_DECLARATION(src0),
5481 IMAGE_DECLARATION(src1),
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00005482#if defined(ADD_VEC_C)
5483 VECTOR_DECLARATION(src2),
5484#endif /* defined(ADD_VEC_C) */
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005485 IMAGE_DECLARATION(dst),
5486 uint src0_stride_z,
5487 uint src1_stride_z,
5488 uint dst_stride_z
5489#if defined(REINTERPRET_INPUT_AS_3D)
5490 ,
5491 uint src_cross_plane_pad
5492#endif // REINTERPRET_INPUT_AS_3D
5493#if defined(REINTERPRET_OUTPUT_AS_3D)
5494 ,
5495 uint dst_cross_plane_pad
5496#endif // REINTERPRET_OUTPUT_AS_3D
5497 )
5498{
5499 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
5500
5501 // Compute starting address for matrix A and Matrix B
5502 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
5503
5504 // Update address for the matrix A
5505 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
5506
5507 // Update address for the matrix B
5508 src_addr.s1 += idx * sizeof(half);
5509
5510#if defined(REINTERPRET_INPUT_AS_3D)
5511 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
5512 // in order to take into account the presence of possible cross plane paddings
5513 //
5514 // | |
5515 // | plane0 |
5516 // | |
5517 // |__________________|
5518 // |******************|
5519 // | cross_plane_pad |
5520 // |******************|
5521 // | |
5522 // | plane1 |
5523 // | |
5524 // |__________________|
5525
5526 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
5527 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
5528 zin = min(DEPTH_GEMM3D - 1, zin);
5529
5530 // Add offset due to the cross plane paddings
5531 zin *= (src_cross_plane_pad * src0_stride_y);
5532
5533 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
5534 // multiply src0_stride_z by DEPTH_GEMM3D
5535 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
5536
5537#else // defined(REINTERPRET_INPUT_AS_3D)
5538
5539 // Add offset for batched GEMM
5540 src_addr.s0 += get_global_id(2) * src0_stride_z;
5541
5542#endif // defined(REINTERPRET_INPUT_AS_3D)
5543
5544#if defined(MATRIX_B_DEPTH)
5545 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
5546 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
5547#else // defined(MATRIX_B_DEPTH)
5548 src_addr.s1 += get_global_id(2) * src1_stride_z;
5549#endif // defined(MATRIX_B_DEPTH)
5550
5551 float8 acc0 = 0.0h;
5552#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5553 float8 acc1 = 0.0h;
5554#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5555#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5556 float8 acc2 = 0.0h;
5557#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5558#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5559 float8 acc3 = 0.0h;
5560#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5561
5562 int i = 0;
5563 for(; i <= ((int)COLS_A - 4); i += 4)
5564 {
5565#if defined(REINTERPRET_INPUT_AS_3D)
5566 // Load values from matrix A
5567 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
5568#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5569 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
5570#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5571#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5572 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
5573#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5574#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5575 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
5576#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5577#else // defined(REINTERPRET_INPUT_AS_3D)
5578 // Load values from matrix A
5579 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
5580#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5581 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
5582#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5583#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5584 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
5585#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5586#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5587 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
5588#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5589#endif // defined(REINTERPRET_INPUT_AS_3D)
5590
5591 // Load values from matrix B
5592 float8 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
5593 src_addr.s1 += src1_stride_y;
5594
5595 // Accumulate
5596 acc0 = fma(b0, (float8)a0.s0, acc0);
5597#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5598 acc1 = fma(b0, (float8)a1.s0, acc1);
5599#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5600#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5601 acc2 = fma(b0, (float8)a2.s0, acc2);
5602#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5603#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5604 acc3 = fma(b0, (float8)a3.s0, acc3);
5605#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5606
5607 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
5608 src_addr.s1 += src1_stride_y;
5609 acc0 = fma(b0, (float8)a0.s1, acc0);
5610#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5611 acc1 = fma(b0, (float8)a1.s1, acc1);
5612#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5613#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5614 acc2 = fma(b0, (float8)a2.s1, acc2);
5615#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5616#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5617 acc3 = fma(b0, (float8)a3.s1, acc3);
5618#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5619
5620 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
5621 src_addr.s1 += src1_stride_y;
5622 acc0 = fma(b0, (float8)a0.s2, acc0);
5623#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5624 acc1 = fma(b0, (float8)a1.s2, acc1);
5625#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5626#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5627 acc2 = fma(b0, (float8)a2.s2, acc2);
5628#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5629#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5630 acc3 = fma(b0, (float8)a3.s2, acc3);
5631#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5632
5633 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
5634 src_addr.s1 += src1_stride_y;
5635 acc0 = fma(b0, (float8)a0.s3, acc0);
5636#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5637 acc1 = fma(b0, (float8)a1.s3, acc1);
5638#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5639#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5640 acc2 = fma(b0, (float8)a2.s3, acc2);
5641#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5642#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5643 acc3 = fma(b0, (float8)a3.s3, acc3);
5644#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5645
5646 src_addr.s0 += 4 * sizeof(half);
5647 }
5648
5649 for(; i < (int)COLS_A; ++i)
5650 {
5651#if defined(REINTERPRET_INPUT_AS_3D)
5652 // Load values from matrix A
5653 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
5654#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5655 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
5656#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5657#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5658 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
5659#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5660#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5661 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
5662#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5663#else // defined(REINTERPRET_INPUT_AS_3D)
5664 // Load values from matrix A
5665 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
5666#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5667 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
5668#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5669#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5670 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
5671#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5672#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5673 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
5674#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5675#endif // defined(REINTERPRET_INPUT_AS_3D)
5676
5677 // Load values from matrix B
5678 float8 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
5679
5680 src_addr += (int2)(sizeof(half), src1_stride_y);
5681
5682 // Accumulate
5683 acc0 = fma(b0, (float8)a0, acc0); // b0 * (half8)a0;
5684#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5685 acc1 = fma(b0, (float8)a1, acc1); // b0 * (half8)a1;
5686#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5687#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5688 acc2 = fma(b0, (float8)a2, acc2); // b0 * (half8)a2;
5689#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5690#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5691 acc3 = fma(b0, (float8)a3, acc3); // b0 * (half8)a3;
5692#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5693 }
5694
5695 // Multiply by the weight of matrix-matrix product and store the result
5696#if defined(ALPHA)
5697 half8 hacc0 = convert_half8(acc0) * (half8)ALPHA;
5698#else //defined(ALPHA)
5699 half8 hacc0 = convert_half8(acc0);
5700#endif // defined(ALPHA)
5701#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5702#if defined(ALPHA)
5703 half8 hacc1 = convert_half8(acc1) * (half8)ALPHA;
5704#else //defined(ALPHA)
5705 half8 hacc1 = convert_half8(acc1);
5706#endif //defined(ALPHA)
5707#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y
5708
5709#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5710#if defined(ALPHA)
5711 half8 hacc2 = convert_half8(acc2) * (half8)ALPHA;
5712#else //defined(ALPHA)
5713 half8 hacc2 = convert_half8(acc2);
5714#endif //defined(ALPHA)
5715#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5716
5717#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5718#if defined(ALPHA)
5719 half8 hacc3 = convert_half8(acc3) * (half8)ALPHA;
5720#else //defined(ALPHA)
5721 half8 hacc3 = convert_half8(acc3);
5722#endif // defined(ALPHA)
5723#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5724
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00005725#if defined(ADD_VEC_C)
5726 // *INDENT-OFF*
5727 // clang-format off
5728 __global half *src2_addr = (__global half *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
5729 half8 c0 = vload8(0, src2_addr);
5730 // clang-format on
5731 // *INDENT-ON*
5732
5733 hacc0 += c0;
5734#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5735 hacc1 += c0;
5736#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5737#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5738 hacc2 += c0;
5739#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5740#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5741 hacc3 += c0;
5742#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5743#endif /* defined(ADD_VEC_C) */
5744
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005745 int z = get_global_id(2);
5746
5747 // Compute destination address
5748 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
5749
5750 // Compute dst address
5751 __global uchar *dst_addr = offset(&dst, 0, 0);
5752
5753#if defined(REINTERPRET_OUTPUT_AS_3D)
5754 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
5755 // in order to take into account the presence of possible cross plane paddings
5756 //
5757 // | |
5758 // | plane0 |
5759 // | |
5760 // |__________________|
5761 // |******************|
5762 // | cross_plane_pad |
5763 // |******************|
5764 // | |
5765 // | plane1 |
5766 // | |
5767 // |__________________|
5768
5769 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
5770 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
5771 zout = min(DEPTH_GEMM3D - 1, zout);
5772
5773 // Add offset due to the cross plane paddings
5774 zout *= (dst_cross_plane_pad * dst_stride_y);
5775
5776 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
5777 // multiply dst_stride_z by DEPTH_GEMM3D
5778 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
5779
5780 // Store the output block
5781 vstore8(hacc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
5782#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5783 vstore8(hacc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
5784#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5785#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5786 vstore8(hacc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
5787#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5788#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5789 vstore8(hacc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
5790#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5791
5792#else // defined(REINTERPRET_OUTPUT_AS_3D)
5793 // Add offset for batched GEMM
5794 dst_addr += z * dst_stride_z;
5795
5796 // Store the output block
5797 vstore8(hacc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
5798#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5799 vstore8(hacc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
5800#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5801#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5802 vstore8(hacc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
5803#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5804#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5805 vstore8(hacc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
5806#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5807#endif // REINTERPRET_OUTPUT_AS_3D
5808}
5809
5810/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
5811 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00005812 * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
5813 *
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005814 * @note This OpenCL kernel works with the 16-bit floating point data type (half) and uses the fma units.
5815 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
5816 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
5817 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
5818 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
5819 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
5820 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
5821 *
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005822 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
5823 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005824 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
5825 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
5826 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
5827 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
5828 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00005829 * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
5830 *
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005831 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
5832 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
5833 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5834 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
5835 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
5836 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
5837 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
5838 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
5839 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5840 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
5841 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
5842 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00005843 * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
5844 * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes)
5845 * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
5846 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005847 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
5848 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
5849 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
5850 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
5851 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
5852 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005853 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
5854 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
5855 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005856 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
5857 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005858 */
5859__kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0),
5860 IMAGE_DECLARATION(src1),
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00005861#if defined(ADD_VEC_C)
5862 VECTOR_DECLARATION(src2),
5863#endif /* defined(ADD_VEC_C) */
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005864 IMAGE_DECLARATION(dst),
5865 uint src0_stride_z,
5866 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005867 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005868#if defined(REINTERPRET_INPUT_AS_3D)
5869 ,
5870 uint src_cross_plane_pad
5871#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005872#if defined(REINTERPRET_OUTPUT_AS_3D)
5873 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005874 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005875#endif // REINTERPRET_OUTPUT_AS_3D
5876 )
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005877{
5878 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
5879
5880 // Compute starting address for matrix A and Matrix B
5881 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
5882
5883 // Update address for the matrix A
5884 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
5885
5886 // Update address for the matrix B
5887 src_addr.s1 += idx * sizeof(half);
5888
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005889#if defined(REINTERPRET_INPUT_AS_3D)
5890 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
5891 // in order to take into account the presence of possible cross plane paddings
5892 //
5893 // | |
5894 // | plane0 |
5895 // | |
5896 // |__________________|
5897 // |******************|
5898 // | cross_plane_pad |
5899 // |******************|
5900 // | |
5901 // | plane1 |
5902 // | |
5903 // |__________________|
5904
5905 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
5906 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
5907 zin = min(DEPTH_GEMM3D - 1, zin);
5908
5909 // Add offset due to the cross plane paddings
5910 zin *= (src_cross_plane_pad * src0_stride_y);
5911
5912 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
5913 // multiply src0_stride_z by DEPTH_GEMM3D
5914 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
5915
5916#else // defined(REINTERPRET_INPUT_AS_3D)
5917
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005918 // Add offset for batched GEMM
5919 src_addr.s0 += get_global_id(2) * src0_stride_z;
5920
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005921#endif // defined(REINTERPRET_INPUT_AS_3D)
5922
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005923#if defined(MATRIX_B_DEPTH)
5924 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
5925 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
5926#else // defined(MATRIX_B_DEPTH)
5927 src_addr.s1 += get_global_id(2) * src1_stride_z;
5928#endif // defined(MATRIX_B_DEPTH)
5929
5930 half8 acc0 = 0.0h;
5931#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5932 half8 acc1 = 0.0h;
5933#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5934#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5935 half8 acc2 = 0.0h;
5936#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5937#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5938 half8 acc3 = 0.0h;
5939#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5940
5941 int i = 0;
5942 for(; i <= ((int)COLS_A - 4); i += 4)
5943 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005944#if defined(REINTERPRET_INPUT_AS_3D)
5945 // Load values from matrix A
5946 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
5947#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5948 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
5949#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5950#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5951 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
5952#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5953#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5954 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
5955#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5956#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005957 // Load values from matrix A
5958 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
5959#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5960 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
5961#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5962#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5963 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
5964#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5965#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5966 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
5967#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005968#endif // defined(REINTERPRET_INPUT_AS_3D)
5969
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005970 // Load values from matrix B
5971 half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
5972 src_addr.s1 += src1_stride_y;
5973
5974 // Accumulate
5975 acc0 = fma(b0, (half8)a0.s0, acc0);
5976#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5977 acc1 = fma(b0, (half8)a1.s0, acc1);
5978#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5979#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5980 acc2 = fma(b0, (half8)a2.s0, acc2);
5981#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5982#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5983 acc3 = fma(b0, (half8)a3.s0, acc3);
5984#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5985
5986 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
5987 src_addr.s1 += src1_stride_y;
5988 acc0 = fma(b0, (half8)a0.s1, acc0);
5989#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5990 acc1 = fma(b0, (half8)a1.s1, acc1);
5991#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5992#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5993 acc2 = fma(b0, (half8)a2.s1, acc2);
5994#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5995#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5996 acc3 = fma(b0, (half8)a3.s1, acc3);
5997#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5998
5999 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
6000 src_addr.s1 += src1_stride_y;
6001 acc0 = fma(b0, (half8)a0.s2, acc0);
6002#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6003 acc1 = fma(b0, (half8)a1.s2, acc1);
6004#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6005#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6006 acc2 = fma(b0, (half8)a2.s2, acc2);
6007#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6008#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6009 acc3 = fma(b0, (half8)a3.s2, acc3);
6010#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6011
6012 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
6013 src_addr.s1 += src1_stride_y;
6014 acc0 = fma(b0, (half8)a0.s3, acc0);
6015#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6016 acc1 = fma(b0, (half8)a1.s3, acc1);
6017#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6018#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6019 acc2 = fma(b0, (half8)a2.s3, acc2);
6020#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6021#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6022 acc3 = fma(b0, (half8)a3.s3, acc3);
6023#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6024
6025 src_addr.s0 += 4 * sizeof(half);
6026 }
6027
6028 for(; i < (int)COLS_A; ++i)
6029 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01006030#if defined(REINTERPRET_INPUT_AS_3D)
6031 // Load values from matrix A
6032 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
6033#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6034 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
6035#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6036#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6037 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
6038#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6039#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6040 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
6041#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6042#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006043 // Load values from matrix A
6044 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
6045#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6046 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
6047#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6048#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6049 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
6050#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6051#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6052 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
6053#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01006054#endif // defined(REINTERPRET_INPUT_AS_3D)
6055
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006056 // Load values from matrix B
6057 half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
6058
6059 src_addr += (int2)(sizeof(half), src1_stride_y);
6060
6061 // Accumulate
6062 acc0 = fma(b0, (half8)a0, acc0); // b0 * (half8)a0;
6063#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6064 acc1 = fma(b0, (half8)a1, acc1); // b0 * (half8)a1;
6065#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6066#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6067 acc2 = fma(b0, (half8)a2, acc2); // b0 * (half8)a2;
6068#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6069#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6070 acc3 = fma(b0, (half8)a3, acc3); // b0 * (half8)a3;
6071#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6072 }
6073
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006074 // Multiply by the weight of matrix-matrix product and store the result
6075#if defined(ALPHA)
6076 acc0 = acc0 * (half8)ALPHA;
6077#endif // defined(ALPHA)
6078#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
6079 acc1 = acc1 * (half8)ALPHA;
6080#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
6081#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
6082 acc2 = acc2 * (half8)ALPHA;
6083#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
6084#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
6085 acc3 = acc3 * (half8)ALPHA;
6086#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
6087
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00006088#if defined(ADD_VEC_C)
6089 // *INDENT-OFF*
6090 // clang-format off
6091 __global half *src2_addr = (__global half *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
6092 half8 c0 = vload8(0, src2_addr);
6093 // clang-format on
6094 // *INDENT-ON*
6095
6096 acc0 += c0;
6097#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6098 acc1 += c0;
6099#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6100#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6101 acc2 += c0;
6102#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6103#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6104 acc3 += c0;
6105#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6106#endif /* defined(ADD_VEC_C) */
6107
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006108 int z = get_global_id(2);
6109
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006110 // Compute destination address
6111 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
6112
6113 // Compute dst address
6114 __global uchar *dst_addr = offset(&dst, 0, 0);
6115
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006116#if defined(REINTERPRET_OUTPUT_AS_3D)
6117 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01006118 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006119 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01006120 // | |
6121 // | plane0 |
6122 // | |
6123 // |__________________|
6124 // |******************|
6125 // | cross_plane_pad |
6126 // |******************|
6127 // | |
6128 // | plane1 |
6129 // | |
6130 // |__________________|
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006131
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006132 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
6133 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
6134 zout = min(DEPTH_GEMM3D - 1, zout);
6135
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01006136 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01006137 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006138
6139 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
6140 // multiply dst_stride_z by DEPTH_GEMM3D
6141 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
6142
6143 // Store the output block
6144 vstore8(acc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
6145#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6146 vstore8(acc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
6147#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6148#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6149 vstore8(acc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
6150#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6151#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6152 vstore8(acc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
6153#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6154
6155#else // defined(REINTERPRET_OUTPUT_AS_3D)
6156 // Add offset for batched GEMM
6157 dst_addr += z * dst_stride_z;
6158
6159 // Store the output block
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006160 vstore8(acc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
6161#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006162 vstore8(acc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
6163#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6164#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006165 vstore8(acc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
6166#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6167#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006168 vstore8(acc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
6169#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006170#endif // REINTERPRET_OUTPUT_AS_3D
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006171}
Vidhya Sudhan Loganathanbdff4912018-05-22 15:03:09 +01006172#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006173
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01006174#endif // defined(COLS_A) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && (NUM_ELEMS_PROCESSED_PER_THREAD_Y)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006175
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006176#if defined(BETA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006177/** This OpenCL kernel performs the in-place matrix addition between 2 matrices taking into account that the second matrix might be weighted by a scalar value beta:
6178 *
Gian Marco19835e52018-01-30 13:35:54 +00006179 * @note The beta's value need to be passed at compile time using -DBETA
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006180 *
6181 * @param[in] src_ptr Pointer to the source matrix. Supported data types: F32
6182 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
6183 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
6184 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
6185 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006186 * @param[in] src_stride_z Stride of the destination tensor in Z dimension (in bytes)
6187 * @param[in] src_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006188 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01006189 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006190 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
6191 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
6192 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
6193 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006194 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
6195 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006196 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
6197 */
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006198__kernel void gemm_ma_f32(TENSOR3D_DECLARATION(src),
6199 TENSOR3D_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006200{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006201 // Compute source and destination addresses
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006202 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
6203 Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006204
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006205 // Load values from A x B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006206 float4 alpha_ab = vload4(0, (__global float *)dst.ptr);
6207
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006208 // Load values from Matrix C
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006209 float4 c = vload4(0, (__global float *)src.ptr);
6210
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006211 // Computes alpha * axb + beta * c
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006212 float4 out = alpha_ab + (float4)BETA * c;
6213
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006214 // Store final result in axb matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006215 vstore4(out, 0, (__global float *)dst.ptr);
6216}
6217
Vidhya Sudhan Loganathan76c85642018-05-25 13:53:02 +01006218#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006219/** This OpenCL kernel performs the in-place matrix addition between 2 matrices taking into account that the second matrix might be weighted by a scalar value beta:
6220 *
Gian Marco19835e52018-01-30 13:35:54 +00006221 * @note The beta's value need to be passed at compile time using -DBETA
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01006222 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006223 * @param[in] src_ptr Pointer to the source matrix. Supported data types: F16
6224 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
6225 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
6226 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
6227 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006228 * @param[in] src_stride_z Stride of the destination tensor in Z dimension (in bytes)
6229 * @param[in] src_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006230 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01006231 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006232 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
6233 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
6234 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
6235 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006236 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
6237 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006238 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
6239 */
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006240__kernel void gemm_ma_f16(TENSOR3D_DECLARATION(src),
6241 TENSOR3D_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006242{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006243 // Compute source and destination addresses
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006244 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
6245 Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006246
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006247 // Load values from A x B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006248 half8 alpha_ab = vload8(0, (__global half *)dst.ptr);
6249
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006250 // Load values from Matrix C
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006251 half8 c = vload8(0, (__global half *)src.ptr);
6252
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006253 // Computes alpha * axb + beta * c
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006254 half8 out = alpha_ab + (half8)BETA * c;
6255
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006256 // Store final result in axb matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006257 vstore8(out, 0, (__global half *)dst.ptr);
6258}
Vidhya Sudhan Loganathan76c85642018-05-25 13:53:02 +01006259#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006260#endif // defined(BETA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006261
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006262#if defined(WIDTH_VECTOR_A)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006263/** This OpenCL kernel computes the vector by matrix multiplication between each row of A (src0) and matrix B (src1) used for locally connected layer
6264 *
Gian Marco19835e52018-01-30 13:35:54 +00006265 * @note The width of A need to be passed at compile time using -DWIDTH_VECTOR_A
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006266 *
Gian Marco19835e52018-01-30 13:35:54 +00006267 * @note The input A and matrix B must not be reshaped
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006268 *
6269 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
6270 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
6271 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
6272 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
6273 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
6274 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01006275 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006276 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
6277 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
6278 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
6279 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
6280 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
6281 * @param[in] src1_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
6282 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01006283 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006284 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
6285 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
6286 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
6287 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
6288 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
6289 */
6290__kernel void gemm_lc_vm_f32(IMAGE_DECLARATION(src0),
6291 TENSOR3D_DECLARATION(src1),
6292 IMAGE_DECLARATION(dst))
6293{
6294 int idx = get_global_id(0) * 4;
6295 int idy = get_global_id(1);
6296
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006297 // Compute the address for the vector A and matrix B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006298 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes + src0_stride_y * idy, src1_offset_first_element_in_bytes + src1_stride_z * idy));
6299 src_addr.s1 += idx * sizeof(float);
6300
6301 int end_row_vec_a = src_addr.s0 + (WIDTH_VECTOR_A * sizeof(float));
6302
6303 float4 acc = 0.0f;
6304
Georgios Pinitas96880cf2017-10-20 18:52:20 +01006305 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(float)); src_addr += (int2)(2 * sizeof(float), 2 * src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006306 {
6307 float2 a0 = vload2(0, (__global float *)(src0_ptr + src_addr.s0));
6308 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
6309 float4 b1 = vload4(0, (__global float *)(src1_ptr + src_addr.s1 + src1_stride_y));
6310
6311 acc += b0 * (float4)a0.s0;
6312 acc += b1 * (float4)a0.s1;
6313 }
6314
6315 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(float), src1_stride_y))
6316 {
6317 float a0 = *((__global float *)(src0_ptr + src_addr.s0));
6318 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
6319
6320 acc += b0 * (float4)a0;
6321 }
6322
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006323 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006324 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
6325
6326 vstore4(acc, 0, (__global float *)(offset(&dst, 0, 0)));
6327}
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006328#endif // defined(WIDTH_VECTOR_A)
6329
6330/** This kernel accumulates each row with the biases vector.
6331 *
6332 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=short.
6333 * @note The vector size must be passed at compile time using -DVECTOR_SIZE e.g. -DVECTOR_SIZE=16.
6334 *
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01006335 * @param[in, out] accum_ptr Pointer to the accumulate tensor. Supported data type: U8/S8/U16/S16/F16/U32/S32/F32
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006336 * @param[in] accum_stride_x Stride of the accmulate tensor in X dimension (in bytes)
6337 * @param[in] accum_step_x accum_stride_x * number of elements along X processed per workitem(in bytes)
6338 * @param[in] accum_stride_y Stride of the accumlulate tensor in Y dimension (in bytes)
6339 * @param[in] accum_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
6340 * @param[in] accum_offset_first_element_in_bytes The offset of the first element in the accumulate tensor
6341 * @param[in] biases_ptr Pointer to the biases vector. Same as @p accum_ptr
6342 * @param[in] biases_stride_x Stride of the destination tensor in X dimension (in bytes)
6343 * @param[in] biases_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
6344 * @param[in] biases_offset_first_element_in_bytes The offset of the first element in the destination tensor
6345 */
6346#if defined(DATA_TYPE) && defined(VECTOR_SIZE)
6347__kernel void gemm_accumulate_biases(
6348 IMAGE_DECLARATION(accum),
6349 VECTOR_DECLARATION(biases))
6350{
6351 Image accum = CONVERT_TO_IMAGE_STRUCT(accum);
6352 Vector biases = CONVERT_TO_VECTOR_STRUCT(biases);
6353
6354 // Vector size, i.e. number of vector elements.
6355 VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
6356 accum_value = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)accum.ptr);
6357 VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
6358 biases_value = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)biases.ptr);
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01006359 accum_value = biases_value + accum_value;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006360 // Store result in the accumulate buffer
6361 VSTORE(VECTOR_SIZE)
6362 (accum_value, 0, (__global DATA_TYPE *)accum.ptr);
6363}
6364#endif // defined(DATA_TYPE) && defined(VECTOR_SIZE)