blob: 211484440bb2189d6ed53ddafa40cdf1976defe8 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +00002 * Copyright (c) 2017-2019 ARM Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "helpers.h"
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +000025#include "repeat.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010026
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +000027#if defined(M0) && defined(K0) && defined(V0) && defined(DATA_TYPE) && defined(SRC_WIDTH)
28#define INC2 (VEC_DATA_TYPE(uint, 2))(0, 1)
29#define INC3 (VEC_DATA_TYPE(uint, 3))(0, 1, 2)
30#define INC4 (VEC_DATA_TYPE(uint, 4))(0, 1, 2, 3)
31#define INC8 (VEC_DATA_TYPE(uint, 8))(0, 1, 2, 3, 4, 5, 6, 7)
32#define INC16 (VEC_DATA_TYPE(uint, 16))(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)
33#define CONCAT_INC(K0) INC##K0
34#define INC(K0) CONCAT_INC(K0)
35
36#if(SRC_WIDTH % K0)
37#define BOUNDARY_CONDITION_X(x, a) \
38 ({ \
39 a = select(0, a, CONVERT(((x * (VEC_DATA_TYPE(uint, K0))K0 + INC(K0)) < (VEC_DATA_TYPE(uint, K0))SRC_WIDTH), VEC_DATA_TYPE(DATA_TYPE, K0))); \
40 })
41#else // (SRC_WIDTH % K0)
42#define BOUNDARY_CONDITION_X(x, a) \
43 ({})
44#endif // (SRC_WIDTH % K0)
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000045
46/** This OpenCL kernel reshapes the lhs input matrix. The kernel splits the input matrix in blocks of size M0xK0 and stores each one (not transposed) in
47 * the output matrix unrolling the values.
48 *
49 * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float)
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +000050 * @note The width of the input tensor must be passed at compile time using -DSRC_WIDTH (i.e. -DSRC_WIDTH=16)
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000051 * @note The block's dimensions (M0 and K0) must be passed at compile time using -DM0 and -DK0 (i.e. -DM0=2, -DK0=2).
52 * @note The number of M0xK0 vertical blocks to store on the same output row must be passed at compile time using -DV0 (i.e. -DV0=2)
53 * @note Only the following values for M0, K0 and V0 are supported:
54 * M0: 2,3,4,5,6,7,8
Gian Marco Iodicebacfec52019-01-11 11:30:55 +000055 * K0: 2,3,4,8,16
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000056 * V0: greater than 0
57 * @note In case the input has to be reinterpreted as a 3D tensor (i.e. input of convolution layer 1x1), the following information must be passed at compile time:
58 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
59 * -# HEIGHT_GEMM3D: The height of the input in case it has to be reinterpreted as a 3D tensor.
60 * -# DEPTH_GEMM3D: The depth of the input in case it has to be reinterpreted as a 3D tensor
61 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
62 * @note If the M0xK0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
63 *
64 * @param[in] src_ptr Pointer to the source LHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
65 * @param[in] src_stride_x Stride of the source LHS tensor in X dimension (in bytes)
66 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
67 * @param[in] src_stride_y Stride of the source LHS tensor in Y dimension (in bytes)
68 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
69 * @param[in] src_stride_z Stride of the source LHS tensor in Z dimension (in bytes)
70 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
71 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source LHS tensor
72 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
73 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
74 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
75 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
76 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
77 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
78 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
79 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
80 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
81 */
82__kernel void gemm_reshape_lhs_matrix_nt(TENSOR3D_DECLARATION(src),
83 TENSOR3D_DECLARATION(dst)
84#if defined(REINTERPRET_INPUT_AS_3D)
85 ,
86 uint cross_plane_pad
87#endif // REINTERPRET_INPUT_AS_3D
88 )
89{
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +000090 // Block size
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000091#define BLOCK_SIZE ((M0) * (K0))
92
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +000093 // Output offset X
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000094#if defined(INTERLEAVE)
95#define OUTPUT_OFFSET_X (K0)
96#else // defined(INTERLEAVE)
97#define OUTPUT_OFFSET_X (BLOCK_SIZE)
98#endif // defined(INTERLEAVE)
99
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000100 // Output step X
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000101#if defined(INTERLEAVE)
102#define OUTPUT_STEP_X (K0) * (V0)
103#else // Do not interleave
104#define OUTPUT_STEP_X (K0)
105#endif // defined(INTERLEAVE)
106
107 // Compute source and destination addresses
108 uint x = get_global_id(0);
109 uint y = get_global_id(1);
110 uint z = get_global_id(2);
111
112 // ------------------ Compute input/output addresses ---------------------------
113
114 // Compute the input address
115 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)K0 * sizeof(DATA_TYPE) + y * (uint)M0 * src_stride_y;
116
117 // Compute the output address
118 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)BLOCK_SIZE * (uint)V0 * sizeof(DATA_TYPE)) + ((y / (uint)V0) * (uint)dst_stride_y) + ((y % V0) *
119 (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE));
120
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000121 // Create variables: uint zin0=0, zin1=0, zin2=0...zin(M0-1)=0;
122 REPEAT_VAR_INIT_TO_CONST(M0, uint, zin, 0);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000123
124#if defined(REINTERPRET_INPUT_AS_3D)
125 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
126 // multiply src_stride_z by DEPTH_GEMM3D
127
128 // Note for the REINTERPRET_INPUT_AS_3D case
129 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
130 // in order to take into account the presence of possible cross plane paddings
131 //
132 // | |
133 // | plane0 |
134 // | |
135 // |__________________|
136 // |******************|
137 // | cross_plane_pad |
138 // |******************|
139 // | |
140 // | plane1 |
141 // | |
142 // |__________________|
143
144 input_ptr += z * (uint)src_stride_z * DEPTH_GEMM3D;
145
146 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
147 zin0 = (0 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
148 zin0 = min((uint)(DEPTH_GEMM3D - 1), zin0);
149 zin0 *= (cross_plane_pad * src_stride_y);
150#if M0 > 1
151 zin1 = (1 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
152 zin1 = min((uint)(DEPTH_GEMM3D - 1), zin1);
153 zin1 *= (cross_plane_pad * src_stride_y);
154#endif // M0 > 1
155#if M0 > 2
156 zin2 = (2 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
157 zin2 = min((uint)(DEPTH_GEMM3D - 1), zin2);
158 zin2 *= (cross_plane_pad * src_stride_y);
159#endif // M0 > 2
160#if M0 > 3
161 zin3 = (3 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
162 zin3 = min((uint)(DEPTH_GEMM3D - 1), zin3);
163 zin3 *= (cross_plane_pad * src_stride_y);
164#endif // M0 > 3
165#if M0 > 4
166 zin4 = (4 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
167 zin4 = min((uint)(DEPTH_GEMM3D - 1), zin4);
168 zin4 *= (cross_plane_pad * src_stride_y);
169#endif // M0 > 4
170#if M0 > 5
171 zin5 = (5 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
172 zin5 = min((uint)(DEPTH_GEMM3D - 1), zin5);
173 zin5 *= (cross_plane_pad * src_stride_y);
174#endif // M0 > 5
175#if M0 > 6
176 zin6 = (6 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
177 zin6 = min((uint)(DEPTH_GEMM3D - 1), zin6);
178 zin6 *= (cross_plane_pad * src_stride_y);
179#endif // M0 > 6
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +0000180#if M0 > 7
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000181 zin7 = (7 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
182 zin7 = min((uint)(DEPTH_GEMM3D - 1), zin7);
183 zin7 *= (cross_plane_pad * src_stride_y);
184#endif // M0 > 7
185
186#else // defined(REINTERPRET_INPUT_AS_3D)
187
188 input_ptr += z * (uint)src_stride_z;
189
190#endif // defined(REINTERPRET_INPUT_AS_3D)
191
192 // Add offset for batched GEMM
193 output_ptr += z * (uint)dst_stride_z;
194
195 // ---------------------------Load input values --------------------------------
196
197 // Load values from the LHS matrix
198 VEC_DATA_TYPE(DATA_TYPE, K0)
199 a0 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y + zin0));
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000200 BOUNDARY_CONDITION_X(x, a0);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000201#if M0 > 1
202 VEC_DATA_TYPE(DATA_TYPE, K0)
203 a1 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y + zin1));
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000204 BOUNDARY_CONDITION_X(x, a1);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000205#endif // M0 > 1
206#if M0 > 2
207 VEC_DATA_TYPE(DATA_TYPE, K0)
208 a2 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y + zin2));
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000209 BOUNDARY_CONDITION_X(x, a2);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000210#endif // M0 > 2
211#if M0 > 3
212 VEC_DATA_TYPE(DATA_TYPE, K0)
213 a3 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y + zin3));
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000214 BOUNDARY_CONDITION_X(x, a3);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000215#endif // M0 > 3
216#if M0 > 4
217 VEC_DATA_TYPE(DATA_TYPE, K0)
218 a4 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 4 * src_stride_y + zin4));
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000219 BOUNDARY_CONDITION_X(x, a4);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000220#endif // M0 > 4
221#if M0 > 5
222 VEC_DATA_TYPE(DATA_TYPE, K0)
223 a5 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 5 * src_stride_y + zin5));
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000224 BOUNDARY_CONDITION_X(x, a5);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000225#endif // M0 > 5
226#if M0 > 6
227 VEC_DATA_TYPE(DATA_TYPE, K0)
228 a6 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 6 * src_stride_y + zin6));
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000229 BOUNDARY_CONDITION_X(x, a6);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000230#endif // M0 > 6
231#if M0 > 7
232 VEC_DATA_TYPE(DATA_TYPE, K0)
233 a7 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 7 * src_stride_y + zin7));
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000234 BOUNDARY_CONDITION_X(x, a7);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000235#endif // M0 > 7
236
237 // ---------------------------Store output values ------------------------------
238
239 VSTORE(K0)
240 (a0, 0, (__global DATA_TYPE *)(output_ptr + 0 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
241#if M0 > 1
242 VSTORE(K0)
243 (a1, 0, (__global DATA_TYPE *)(output_ptr + 1 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
244#endif // M0 > 1
245#if M0 > 2
246 VSTORE(K0)
247 (a2, 0, (__global DATA_TYPE *)(output_ptr + 2 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
248#endif // M0 > 2
249#if M0 > 3
250 VSTORE(K0)
251 (a3, 0, (__global DATA_TYPE *)(output_ptr + 3 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
252#endif // M0 > 3
253#if M0 > 4
254 VSTORE(K0)
255 (a4, 0, (__global DATA_TYPE *)(output_ptr + 4 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
256#endif // M0 > 4
257#if M0 > 5
258 VSTORE(K0)
259 (a5, 0, (__global DATA_TYPE *)(output_ptr + 5 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
260#endif // M0 > 5
261#if M0 > 6
262 VSTORE(K0)
263 (a6, 0, (__global DATA_TYPE *)(output_ptr + 6 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
264#endif // M0 > 6
265#if M0 > 7
266 VSTORE(K0)
267 (a7, 0, (__global DATA_TYPE *)(output_ptr + 7 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
268#endif // M0 > 7
269
270#undef BLOCK_SIZE
271#undef OUTPUT_OFFSET_X
272#undef OUTPUT_STEP_X
273}
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000274
275#if M0 == 2
276#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
277 ({ \
278 VEC_DATA_TYPE(DATA_TYPE, M0) \
279 res = (VEC_DATA_TYPE(DATA_TYPE, M0))(a0.s##i, a1.s##i); \
280 VSTORE(M0) \
281 (res, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
282 })
283#elif M0 == 3 // M0 == 3
284#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
285 ({ \
286 VEC_DATA_TYPE(DATA_TYPE, M0) \
287 res = (VEC_DATA_TYPE(DATA_TYPE, M0))(a0.s##i, a1.s##i, a2.s##i); \
288 VSTORE(M0) \
289 (res, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
290 })
291#elif M0 == 4 // M0 == 4
292#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
293 ({ \
294 VEC_DATA_TYPE(DATA_TYPE, M0) \
295 res = (VEC_DATA_TYPE(DATA_TYPE, M0))(a0.s##i, a1.s##i, a2.s##i, a3.s##i); \
296 VSTORE(M0) \
297 (res, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
298 })
299#elif M0 == 5 // M0 == 5
300#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
301 ({ \
302 VEC_DATA_TYPE(DATA_TYPE, 4) \
303 res0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s##i, a1.s##i, a2.s##i, a3.s##i); \
304 DATA_TYPE res1 = a4.s##i; \
305 VSTORE(4) \
306 (res0, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
307 *((__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE)) + 4) = res1; \
308 })
309#elif M0 == 6 // M0 == 6
310#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
311 ({ \
312 VEC_DATA_TYPE(DATA_TYPE, 4) \
313 res0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s##i, a1.s##i, a2.s##i, a3.s##i); \
314 VEC_DATA_TYPE(DATA_TYPE, 2) \
315 res1 = (VEC_DATA_TYPE(DATA_TYPE, 2))(a4.s##i, a5.s##i); \
316 VSTORE(4) \
317 (res0, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
318 VSTORE(2) \
319 (res1, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE)) + 4); \
320 })
321#elif M0 == 7 // M0 == 7
322#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
323 ({ \
324 VEC_DATA_TYPE(DATA_TYPE, 4) \
325 res0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s##i, a1.s##i, a2.s##i, a3.s##i); \
326 VEC_DATA_TYPE(DATA_TYPE, 3) \
327 res1 = (VEC_DATA_TYPE(DATA_TYPE, 3))(a4.s##i, a5.s##i, a6.s##i); \
328 VSTORE(4) \
329 (res0, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
330 VSTORE(3) \
331 (res1, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE)) + 4); \
332 })
333#elif M0 == 8 // M0 == 8
334#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
335 ({ \
336 VEC_DATA_TYPE(DATA_TYPE, M0) \
337 res = (VEC_DATA_TYPE(DATA_TYPE, M0))(a0.s##i, a1.s##i, a2.s##i, a3.s##i, a4.s##i, a5.s##i, a6.s##i, a7.s##i); \
338 VSTORE(M0) \
339 (res, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
340 })
341#else // M0 not supported
342#error "M0 value not supported"
343#endif // N0 conditions
344
345/** This OpenCL kernel reshapes the lhs input matrix. The kernel splits the input matrix in blocks of size M0xK0 and stores each one (transposed) in
346 * the output matrix unrolling the values.
347 *
348 * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float)
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000349 * @note The width of the input tensor must be passed at compile time using -DSRC_WIDTH (i.e. -DSRC_WIDTH=16)
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000350 * @note The block's dimensions (M0 and K0) must be passed at compile time using -DM0 and -DK0 (i.e. -DM0=2, -DK0=2).
351 * @note The number of M0xK0 vertical blocks to store on the same output row must be passed at compile time using -DV0 (i.e. -DV0=2)
352 * @note Only the following values for M0, K0 and V0 are supported:
353 * M0: 2,3,4,5,6,7,8
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000354 * K0: 2,3,4,8,16
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000355 * V0: greater than 0
356 * @note In case the input has to be reinterpreted as a 3D tensor (i.e. input of convolution layer 1x1), the following information must be passed at compile time:
357 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
358 * -# HEIGHT_GEMM3D: The height of the input in case it has to be reinterpreted as a 3D tensor.
359 * -# DEPTH_GEMM3D: The depth of the input in case it has to be reinterpreted as a 3D tensor
360 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
361 * @note If the M0xK0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
362 *
363 * @param[in] src_ptr Pointer to the source LHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
364 * @param[in] src_stride_x Stride of the source LHS tensor in X dimension (in bytes)
365 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
366 * @param[in] src_stride_y Stride of the source LHS tensor in Y dimension (in bytes)
367 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
368 * @param[in] src_stride_z Stride of the source LHS tensor in Z dimension (in bytes)
369 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
370 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source LHS tensor
371 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
372 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
373 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
374 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
375 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
376 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
377 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
378 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
379 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
380 */
381__kernel void gemm_reshape_lhs_matrix_t(TENSOR3D_DECLARATION(src),
382 TENSOR3D_DECLARATION(dst)
383#if defined(REINTERPRET_INPUT_AS_3D)
384 ,
385 uint cross_plane_pad
386#endif // REINTERPRET_INPUT_AS_3D
387 )
388{
389 // Block size
390#define BLOCK_SIZE ((M0) * (K0))
391
392 // Output offset X
393#if defined(INTERLEAVE)
394#define OUTPUT_OFFSET_X (M0)
395#else // defined(INTERLEAVE)
396#define OUTPUT_OFFSET_X (BLOCK_SIZE)
397#endif // defined(INTERLEAVE)
398
399 // Output step X
400#if defined(INTERLEAVE)
401#define OUTPUT_STEP_X (M0) * (V0)
402#else // Do not interleave
403#define OUTPUT_STEP_X (M0)
404#endif // defined(INTERLEAVE)
405
406 // Compute source and destination addresses
407 uint x = get_global_id(0);
408 uint y = get_global_id(1);
409 uint z = get_global_id(2);
410
411 // ------------------ Compute input/output addresses ---------------------------
412
413 // Compute the input address
414 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)K0 * sizeof(DATA_TYPE) + y * (uint)M0 * src_stride_y;
415
416 // Compute the output address
417 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)BLOCK_SIZE * (uint)V0 * sizeof(DATA_TYPE)) + ((y / (uint)V0) * (uint)dst_stride_y) + ((y % V0) *
418 (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE));
419
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000420 // Create variables: uint zin0=0, zin1=0, zin2=0...zin(M0-1)=0;
421 REPEAT_VAR_INIT_TO_CONST(M0, uint, zin, 0);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000422
423#if defined(REINTERPRET_INPUT_AS_3D)
424 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
425 // multiply src_stride_z by DEPTH_GEMM3D
426
427 // Note for the REINTERPRET_INPUT_AS_3D case
428 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
429 // in order to take into account the presence of possible cross plane paddings
430 //
431 // | |
432 // | plane0 |
433 // | |
434 // |__________________|
435 // |******************|
436 // | cross_plane_pad |
437 // |******************|
438 // | |
439 // | plane1 |
440 // | |
441 // |__________________|
442
443 input_ptr += z * (uint)src_stride_z * DEPTH_GEMM3D;
444
445 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
446 zin0 = (0 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
447 zin0 = min((uint)(DEPTH_GEMM3D - 1), zin0);
448 zin0 *= (cross_plane_pad * src_stride_y);
449#if M0 > 1
450 zin1 = (1 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
451 zin1 = min((uint)(DEPTH_GEMM3D - 1), zin1);
452 zin1 *= (cross_plane_pad * src_stride_y);
453#endif // M0 > 1
454#if M0 > 2
455 zin2 = (2 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
456 zin2 = min((uint)(DEPTH_GEMM3D - 1), zin2);
457 zin2 *= (cross_plane_pad * src_stride_y);
458#endif // M0 > 2
459#if M0 > 3
460 zin3 = (3 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
461 zin3 = min((uint)(DEPTH_GEMM3D - 1), zin3);
462 zin3 *= (cross_plane_pad * src_stride_y);
463#endif // M0 > 3
464#if M0 > 4
465 zin4 = (4 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
466 zin4 = min((uint)(DEPTH_GEMM3D - 1), zin4);
467 zin4 *= (cross_plane_pad * src_stride_y);
468#endif // M0 > 4
469#if M0 > 5
470 zin5 = (5 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
471 zin5 = min((uint)(DEPTH_GEMM3D - 1), zin5);
472 zin5 *= (cross_plane_pad * src_stride_y);
473#endif // M0 > 5
474#if M0 > 6
475 zin6 = (6 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
476 zin6 = min((uint)(DEPTH_GEMM3D - 1), zin6);
477 zin6 *= (cross_plane_pad * src_stride_y);
478#endif // M0 > 6
Gian Marco Iodice20b527a2019-01-23 14:05:42 +0000479#if M0 > 7
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000480 zin7 = (7 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
481 zin7 = min((uint)(DEPTH_GEMM3D - 1), zin7);
482 zin7 *= (cross_plane_pad * src_stride_y);
483#endif // M0 > 7
484
485#else // defined(REINTERPRET_INPUT_AS_3D)
486
487 input_ptr += z * (uint)src_stride_z;
488
489#endif // defined(REINTERPRET_INPUT_AS_3D)
490
491 // Add offset for batched GEMM
492 output_ptr += z * (uint)dst_stride_z;
493
494 // ---------------------------Load input values --------------------------------
495
496 // Load values from the LHS matrix
497 VEC_DATA_TYPE(DATA_TYPE, K0)
498 a0 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y + zin0));
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000499 BOUNDARY_CONDITION_X(x, a0);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000500#if M0 > 1
501 VEC_DATA_TYPE(DATA_TYPE, K0)
502 a1 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y + zin1));
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000503 BOUNDARY_CONDITION_X(x, a1);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000504#endif // M0 > 1
505#if M0 > 2
506 VEC_DATA_TYPE(DATA_TYPE, K0)
507 a2 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y + zin2));
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000508 BOUNDARY_CONDITION_X(x, a2);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000509#endif // M0 > 2
510#if M0 > 3
511 VEC_DATA_TYPE(DATA_TYPE, K0)
512 a3 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y + zin3));
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000513 BOUNDARY_CONDITION_X(x, a3);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000514#endif // M0 > 3
515#if M0 > 4
516 VEC_DATA_TYPE(DATA_TYPE, K0)
517 a4 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 4 * src_stride_y + zin4));
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000518 BOUNDARY_CONDITION_X(x, a4);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000519#endif // M0 > 4
520#if M0 > 5
521 VEC_DATA_TYPE(DATA_TYPE, K0)
522 a5 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 5 * src_stride_y + zin5));
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000523 BOUNDARY_CONDITION_X(x, a5);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000524#endif // M0 > 5
525#if M0 > 6
526 VEC_DATA_TYPE(DATA_TYPE, K0)
527 a6 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 6 * src_stride_y + zin6));
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000528 BOUNDARY_CONDITION_X(x, a6);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000529#endif // M0 > 6
530#if M0 > 7
531 VEC_DATA_TYPE(DATA_TYPE, K0)
532 a7 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 7 * src_stride_y + zin7));
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000533 BOUNDARY_CONDITION_X(x, a7);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000534#endif // M0 > 7
535
536 // ---------------------------Transpose and store block -----------------------
537
538 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 0);
539 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 1);
540#if K0 > 2
541 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 2);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000542#endif // K0 > 2
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000543#if K0 > 3
544 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 3);
545#endif // K0 > 3
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000546#if K0 > 4
547 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 4);
548 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 5);
549 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 6);
550 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 7);
551#endif // K0 > 4
552#if K0 > 8
553 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 8);
554 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 9);
555 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, A);
556 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, B);
557 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, C);
558 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, D);
559 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, E);
560 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, F);
561#endif // K0 > 8
562
563#undef BLOCK_SIZE
564#undef OUTPUT_OFFSET_X
565#undef OUTPUT_STEP_X
566}
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000567#endif // defined(M0) && defined(K0) && defined(V0) && defined(DATA_TYPE) && defined(SRC_WIDTH)
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000568
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000569#if defined(K0) && defined(N0) && defined(H0) && defined(DATA_TYPE) && defined(SRC_HEIGHT)
570/** This OpenCL kernel reshapes the rhs input matrix. The kernel splits the input matrix in blocks of size K0xN0 and stores each one (not transposed) in
571 * the output matrix unrolling the values.
572 *
573 * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float)
574 * @note The height of the input tensor must be passed at compile time using -DSRC_HEIGHT (i.e. -DSRC_HEIGHT=16)
575 * @note The block's dimensions (K0 and N0) must be passed at compile time using -DK0 and -DN0 (i.e. -DK0=2, -DN0=2).
576 * @note The number of K0xN0 vertical blocks to store on the same output row must be passed at compile time using -DH0 (i.e. -DH0=2)
577 * @note If the K0xN0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
578 * @note Only the following values for K0, N0 and H0 are supported:
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000579 * N0: 2,3,4,8,16
580 * K0: 1,2,3,4,8,16
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000581 * H0: greater than 0
582 *
583 * @param[in] src_ptr Pointer to the source RHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
584 * @param[in] src_stride_x Stride of the source RHS tensor in X dimension (in bytes)
585 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
586 * @param[in] src_stride_y Stride of the source RHS tensor in Y dimension (in bytes)
587 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
588 * @param[in] src_stride_z Stride of the source RHS tensor in Z dimension (in bytes)
589 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
590 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source RHS tensor
591 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
592 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
593 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
594 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
595 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
596 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
597 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
598 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
599 */
600__kernel void gemm_reshape_rhs_matrix_nt(TENSOR3D_DECLARATION(src),
601 TENSOR3D_DECLARATION(dst))
602{
603 // Block size
604#define BLOCK_SIZE ((K0) * (N0))
605
606 // Output offset X
607#if defined(INTERLEAVE)
608#define OUTPUT_OFFSET_X (N0)
609#else // defined(INTERLEAVE)
610#define OUTPUT_OFFSET_X (BLOCK_SIZE)
611#endif // defined(INTERLEAVE)
612
613 // Output step X
614#if defined(INTERLEAVE)
615#define OUTPUT_STEP_X (N0) * (H0)
616#else // Do not interleave
617#define OUTPUT_STEP_X (N0)
618#endif // defined(INTERLEAVE)
619
620 // Compute source and destination addresses
621 uint x = get_global_id(0);
622 uint y = get_global_id(1);
623 uint z = get_global_id(2);
624
625 // ------------------ Compute input/output addresses ---------------------------
626
627 // Compute the input address
628 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)N0 * sizeof(DATA_TYPE) + y * (uint)K0 * src_stride_y + z * (uint)src_stride_z;
629
630 // Compute the output address
631 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (y * (uint)BLOCK_SIZE * (uint)H0 * sizeof(DATA_TYPE)) + ((x % (uint)H0) * (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE)) + ((
632 x / (uint)H0)
633 * (uint)dst_stride_y)
634 + z * (uint)dst_stride_z;
635
636 // ---------------------------Load input values --------------------------------
637
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +0000638 REPEAT_VAR_INIT_TO_CONST(K0, VEC_DATA_TYPE(DATA_TYPE, N0), a, 0); ////uint a0=0, a1=0, a2=0...a(M0-1)=0;
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000639
640 // Load values from the RHS matrix
641 a0 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y));
642#if K0 > 1
643 if(y * (uint)K0 + 1 < SRC_HEIGHT)
644 {
645 a1 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y));
646 }
647#endif // K0 > 1
648#if K0 > 2
649 if(y * (uint)K0 + 2 < SRC_HEIGHT)
650 {
651 a2 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y));
652 }
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000653#endif // K0 > 2
654#if K0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000655 if(y * (uint)K0 + 3 < SRC_HEIGHT)
656 {
657 a3 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y));
658 }
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000659#endif // K0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000660#if K0 > 4
661 if(y * (uint)K0 + 4 < SRC_HEIGHT)
662 {
663 a4 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 4 * src_stride_y));
664 }
665 if(y * (uint)K0 + 5 < SRC_HEIGHT)
666 {
667 a5 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 5 * src_stride_y));
668 }
669 if(y * (uint)K0 + 6 < SRC_HEIGHT)
670 {
671 a6 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 6 * src_stride_y));
672 }
673 if(y * (uint)K0 + 7 < SRC_HEIGHT)
674 {
675 a7 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 7 * src_stride_y));
676 }
677#endif // K0 > 4
678#if K0 > 8
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000679 if(y * (uint)K0 + 8 < SRC_HEIGHT)
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000680 {
681 a8 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 8 * src_stride_y));
682 }
683 if(y * (uint)K0 + 9 < SRC_HEIGHT)
684 {
685 a9 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 9 * src_stride_y));
686 }
687 if(y * (uint)K0 + 10 < SRC_HEIGHT)
688 {
689 aA = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 10 * src_stride_y));
690 }
691 if(y * (uint)K0 + 11 < SRC_HEIGHT)
692 {
693 aB = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 11 * src_stride_y));
694 }
695 if(y * (uint)K0 + 12 < SRC_HEIGHT)
696 {
697 aC = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 12 * src_stride_y));
698 }
699 if(y * (uint)K0 + 13 < SRC_HEIGHT)
700 {
701 aD = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 13 * src_stride_y));
702 }
703 if(y * (uint)K0 + 14 < SRC_HEIGHT)
704 {
705 aE = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 14 * src_stride_y));
706 }
707 if(y * (uint)K0 + 15 < SRC_HEIGHT)
708 {
709 aF = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 15 * src_stride_y));
710 }
711#endif // K0 > 8
712
713 // ---------------------------Store output values ------------------------------
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000714 VSTORE(N0)
715 (a0, 0, (__global DATA_TYPE *)(output_ptr + 0 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
716#if K0 > 1
717 VSTORE(N0)
718 (a1, 0, (__global DATA_TYPE *)(output_ptr + 1 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
719#endif // K0 > 1
720#if K0 > 2
721 VSTORE(N0)
722 (a2, 0, (__global DATA_TYPE *)(output_ptr + 2 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000723#endif // K0 > 2
724#if K0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000725 VSTORE(N0)
726 (a3, 0, (__global DATA_TYPE *)(output_ptr + 3 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000727#endif // K0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000728#if K0 > 4
729 VSTORE(N0)
730 (a4, 0, (__global DATA_TYPE *)(output_ptr + 4 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
731 VSTORE(N0)
732 (a5, 0, (__global DATA_TYPE *)(output_ptr + 5 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
733 VSTORE(N0)
734 (a6, 0, (__global DATA_TYPE *)(output_ptr + 6 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
735 VSTORE(N0)
736 (a7, 0, (__global DATA_TYPE *)(output_ptr + 7 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
737#endif // N0 > 4
738#if K0 > 8
739 VSTORE(N0)
740 (a8, 0, (__global DATA_TYPE *)(output_ptr + 8 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
741 VSTORE(N0)
742 (a9, 0, (__global DATA_TYPE *)(output_ptr + 9 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
743 VSTORE(N0)
744 (aA, 0, (__global DATA_TYPE *)(output_ptr + 10 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
745 VSTORE(N0)
746 (aB, 0, (__global DATA_TYPE *)(output_ptr + 11 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
747 VSTORE(N0)
748 (aC, 0, (__global DATA_TYPE *)(output_ptr + 12 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
749 VSTORE(N0)
750 (aD, 0, (__global DATA_TYPE *)(output_ptr + 13 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
751 VSTORE(N0)
752 (aE, 0, (__global DATA_TYPE *)(output_ptr + 14 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
753 VSTORE(N0)
754 (aF, 0, (__global DATA_TYPE *)(output_ptr + 15 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
755#endif // N0 > 8
756
757#undef BLOCK_SIZE
758#undef OUTPUT_OFFSET_X
759#undef OUTPUT_STEP_X
760}
761
762#if defined(TRANSPOSE)
763/** This OpenCL kernel reshapes the rhs input matrix. The kernel splits the input matrix in blocks of size K0xN0 and stores each one (transposed) in
764 * the output matrix unrolling the values.
765 *
766 * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float)
767 * @note The height of the input tensor must be passed at compile time using -DSRC_HEIGHT (i.e. -DSRC_HEIGHT=16)
768 * @note The block's dimensions (K0 and N0) must be passed at compile time using -DK0 and -DN0 (i.e. -DK0=2, -DN0=2).
769 * @note The number of K0xN0 vertical blocks to store on the same output row must be passed at compile time using -DH0 (i.e. -DH0=2)
770 * @note If the K0xN0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
771 * @note The option -DTRANSPOSE must passed at compile time.
772 * @note Only the following values for K0, N0 and H0 are supported:
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000773 * N0: 2,3,4,8,16
774 * K0: 2,3,4,8,16
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000775 * H0: greater than 0
776 *
777 * @param[in] src_ptr Pointer to the source RHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
778 * @param[in] src_stride_x Stride of the source RHS tensor in X dimension (in bytes)
779 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
780 * @param[in] src_stride_y Stride of the source RHS tensor in Y dimension (in bytes)
781 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
782 * @param[in] src_stride_z Stride of the source RHS tensor in Z dimension (in bytes)
783 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
784 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source RHS tensor
785 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
786 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
787 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
788 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
789 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
790 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
791 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
792 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
793 */
794__kernel void gemm_reshape_rhs_matrix_t(TENSOR3D_DECLARATION(src),
795 TENSOR3D_DECLARATION(dst))
796{
797 // Block size
798#define BLOCK_SIZE ((K0) * (N0))
799
800 // Output offset X
801#if defined(INTERLEAVE)
802#define OUTPUT_OFFSET_X (K0)
803#else // defined(INTERLEAVE)
804#define OUTPUT_OFFSET_X (BLOCK_SIZE)
805#endif // defined(INTERLEAVE)
806
807 // Output step X
808#if defined(INTERLEAVE)
809#define OUTPUT_STEP_X (K0) * (H0)
810#else // Do not interleave
811#define OUTPUT_STEP_X (K0)
812#endif // defined(INTERLEAVE)
813
814 // Compute source and destination addresses
815 uint x = get_global_id(0);
816 uint y = get_global_id(1);
817 uint z = get_global_id(2);
818
819 // ------------------ Compute input/output addresses ---------------------------
820
821 // Compute the input address
822 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)N0 * sizeof(DATA_TYPE) + y * (uint)K0 * src_stride_y + z * (uint)src_stride_z;
823
824 // Compute the output address
825 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (y * (uint)BLOCK_SIZE * (uint)H0 * sizeof(DATA_TYPE)) + ((x % H0) * (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE)) + ((x /
826 (uint)H0) * (uint)dst_stride_y) + z * (uint)dst_stride_z;
827
828 // ---------------------------Load input values --------------------------------
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +0000829 REPEAT_VAR_INIT_TO_CONST(K0, VEC_DATA_TYPE(DATA_TYPE, N0), a, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) a0=0, a1=0, ... a(K0-1)=0;
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000830
831 // Load values from the RHS matrix
832 a0 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y));
833 if(y * (uint)K0 + 1 < SRC_HEIGHT)
834 {
835 a1 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y));
836 }
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000837#if K0 > 2
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000838 if(y * (uint)K0 + 2 < SRC_HEIGHT)
839 {
840 a2 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y));
841 }
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000842#endif // K0 > 2
843#if K0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000844 if(y * (uint)K0 + 3 < SRC_HEIGHT)
845 {
846 a3 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y));
847 }
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000848#endif // K0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000849#if K0 > 4
850 if(y * (uint)K0 + 4 < SRC_HEIGHT)
851 {
852 a4 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 4 * src_stride_y));
853 }
854 if(y * (uint)K0 + 5 < SRC_HEIGHT)
855 {
856 a5 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 5 * src_stride_y));
857 }
858 if(y * (uint)K0 + 6 < SRC_HEIGHT)
859 {
860 a6 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 6 * src_stride_y));
861 }
862 if(y * (uint)K0 + 7 < SRC_HEIGHT)
863 {
864 a7 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 7 * src_stride_y));
865 }
866#endif // K0 > 4
867#if K0 > 8
Gian Marco Iodice89124342018-12-19 14:17:22 +0000868 if(y * (uint)K0 + 8 < SRC_HEIGHT)
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000869 {
870 a8 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 8 * src_stride_y));
871 }
872 if(y * (uint)K0 + 9 < SRC_HEIGHT)
873 {
874 a9 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 9 * src_stride_y));
875 }
876 if(y * (uint)K0 + 10 < SRC_HEIGHT)
877 {
878 aA = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 10 * src_stride_y));
879 }
880 if(y * (uint)K0 + 11 < SRC_HEIGHT)
881 {
882 aB = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 11 * src_stride_y));
883 }
884 if(y * (uint)K0 + 12 < SRC_HEIGHT)
885 {
886 aC = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 12 * src_stride_y));
887 }
888 if(y * (uint)K0 + 13 < SRC_HEIGHT)
889 {
890 aD = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 13 * src_stride_y));
891 }
892 if(y * (uint)K0 + 14 < SRC_HEIGHT)
893 {
894 aE = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 14 * src_stride_y));
895 }
896 if(y * (uint)K0 + 15 < SRC_HEIGHT)
897 {
898 aF = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 15 * src_stride_y));
899 }
900#endif // K0 > 8
901
902 // ---------------------------Transpose the block ------------------------------
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +0000903 REPEAT_VAR_INIT_TO_CONST(N0, VEC_DATA_TYPE(DATA_TYPE, K0), res, 0); //VEC_DATA_TYPE(DATA_TYPE, K0) res0=0, res1=0, res2=0,... res(N0-1)=0;
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000904
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000905#if K0 == 2
906 // This part computes the following transpositions:
907 // 2x2 -> 2x2
908 // 2x4 -> 4x2
909 // 2x8 -> 8x2
910 // 2x16 -> 16x2
911 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0);
912 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1);
913#if N0 > 2
914 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2);
915#endif // N0 > 2
916#if N0 > 3
917 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3);
918#endif // N0 > 3
919#if N0 > 4
920 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4);
921 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5);
922 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6);
923 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7);
924#endif // N0 > 4
925#if N0 > 8
926 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8);
927 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9);
928 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA);
929 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB);
930 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC);
931 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD);
932 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE);
933 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF);
934#endif // N0 > 8
935
936#elif K0 == 3 // K0 == 2
937 // This part computes the following transpositions:
938 // 3x2 -> 2x3
939 // 3x4 -> 4x3
940 // 3x8 -> 8x3
941 // 3x16 -> 16x3
942 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0);
943 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1);
944#if N0 > 2
945 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2);
946#endif // N0 > 2
947#if N0 > 3
948 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3);
949#endif // N0 > 3
950#if N0 > 4
951 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4);
952 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5);
953 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6);
954 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7);
955#endif // N0 > 4
956#if N0 > 8
957 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8);
958 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9);
959 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA);
960 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB);
961 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC);
962 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD);
963 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE);
964 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF);
965#endif // N0 > 8
966
967#elif K0 == 4 // K0 == 4
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000968 // This part computes the following transpositions:
969 // 4x2 -> 2x4
970 // 4x4 -> 4x4
971 // 4x8 -> 8x4
972 // 4x16 -> 16x4
973 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0, a3.s0);
974 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1, a3.s1);
975#if N0 > 2
976 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2, a3.s2);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000977#endif // N0 > 2
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000978#if N0 > 3
979 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3, a3.s3);
980#endif // N0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000981#if N0 > 4
982 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4, a3.s4);
983 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5, a3.s5);
984 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6, a3.s6);
985 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7, a3.s7);
986#endif // N0 > 4
987#if N0 > 8
988 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8, a3.s8);
989 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9, a3.s9);
990 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA, a3.sA);
991 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB, a3.sB);
992 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC, a3.sC);
993 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD, a3.sD);
994 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE, a3.sE);
995 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF, a3.sF);
996#endif // N0 > 8
997
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000998#elif K0 == 8 // K0 == 8
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000999 // This part computes the following transpositions:
1000 // 8x2 -> 2x8
1001 // 8x4 -> 4x8
1002 // 8x8 -> 8x8
1003 // 8x16 -> 16x8
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001004 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0, a3.s0, a4.s0, a5.s0, a6.s0, a7.s0);
1005 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1, a3.s1, a4.s1, a5.s1, a6.s1, a7.s1);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +00001006#if N0 > 2
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001007 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2, a3.s2, a4.s2, a5.s2, a6.s2, a7.s2);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +00001008#endif // N0 > 2
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001009#if N0 > 3
1010 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3, a3.s3, a4.s3, a5.s3, a6.s3, a7.s3);
1011#endif // N0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +00001012#if N0 > 4
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001013 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4, a3.s4, a4.s4, a5.s4, a6.s4, a7.s4);
1014 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5, a3.s5, a4.s5, a5.s5, a6.s5, a7.s5);
1015 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6, a3.s6, a4.s6, a5.s6, a6.s6, a7.s6);
1016 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7, a3.s7, a4.s7, a5.s7, a6.s7, a7.s7);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +00001017#endif // N0 > 4
1018#if N0 > 8
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001019 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8, a3.s8, a4.s8, a5.s8, a6.s8, a7.s8);
1020 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9, a3.s9, a4.s9, a5.s9, a6.s9, a7.s9);
1021 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA, a3.sA, a4.sA, a5.sA, a6.sA, a7.sA);
1022 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB, a3.sB, a4.sB, a5.sB, a6.sB, a7.sB);
1023 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC, a3.sC, a4.sC, a5.sC, a6.sC, a7.sC);
1024 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD, a3.sD, a4.sD, a5.sD, a6.sD, a7.sD);
1025 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE, a3.sE, a4.sE, a5.sE, a6.sE, a7.sE);
1026 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF, a3.sF, a4.sF, a5.sF, a6.sF, a7.sF);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +00001027#endif // N0 > 8
1028
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001029#elif K0 == 16 // K0 == 16
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +00001030
1031 // This part computes the following transpositions:
1032 // 16x2 -> 2x16
1033 // 16x4 -> 4x16
1034 // 16x8 -> 8x16
1035 // 16x16 -> 16x16
1036 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0, a3.s0, a4.s0, a5.s0, a6.s0, a7.s0,
1037 a8.s0, a9.s0, aA.s0, aB.s0, aC.s0, aD.s0, aE.s0, aF.s0);
1038 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1, a3.s1, a4.s1, a5.s1, a6.s1, a7.s1,
1039 a8.s1, a9.s1, aA.s1, aB.s1, aC.s1, aD.s1, aE.s1, aF.s1);
1040#if N0 > 2
1041 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2, a3.s2, a4.s2, a5.s2, a6.s2, a7.s2,
1042 a8.s2, a9.s2, aA.s2, aB.s2, aC.s2, aD.s2, aE.s2, aF.s2);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001043#endif // N0 > 2
1044#if N0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +00001045 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3, a3.s3, a4.s3, a5.s3, a6.s3, a7.s3,
1046 a8.s3, a9.s3, aA.s3, aB.s3, aC.s3, aD.s3, aE.s3, aF.s3);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001047#endif // N0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +00001048#if N0 > 4
1049 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4, a3.s4, a4.s4, a5.s4, a6.s4, a7.s4,
1050 a8.s4, a9.s4, aA.s4, aB.s4, aC.s4, aD.s4, aE.s4, aF.s4);
1051 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5, a3.s5, a4.s5, a5.s5, a6.s5, a7.s5,
1052 a8.s5, a9.s5, aA.s5, aB.s5, aC.s5, aD.s5, aE.s5, aF.s5);
1053 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6, a3.s6, a4.s6, a5.s6, a6.s6, a7.s6,
1054 a8.s6, a9.s6, aA.s6, aB.s6, aC.s6, aD.s6, aE.s6, aF.s6);
1055 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7, a3.s7, a4.s7, a5.s7, a6.s7, a7.s7,
1056 a8.s7, a9.s7, aA.s7, aB.s7, aC.s7, aD.s7, aE.s7, aF.s7);
1057#endif // N0 > 4
1058#if N0 > 8
1059 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8, a3.s8, a4.s8, a5.s8, a6.s8, a7.s8,
1060 a8.s8, a9.s8, aA.s8, aB.s8, aC.s8, aD.s8, aE.s8, aF.s8);
1061 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9, a3.s9, a4.s9, a5.s9, a6.s9, a7.s9,
1062 a8.s9, a9.s9, aA.s9, aB.s9, aC.s9, aD.s9, aE.s9, aF.s9);
1063 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA, a3.sA, a4.sA, a5.sA, a6.sA, a7.sA,
1064 a8.sA, a9.sA, aA.sA, aB.sA, aC.sA, aD.sA, aE.sA, aF.sA);
1065 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB, a3.sB, a4.sB, a5.sB, a6.sB, a7.sB,
1066 a8.sB, a9.sB, aA.sB, aB.sB, aC.sB, aD.sB, aE.sB, aF.sB);
1067 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC, a3.sC, a4.sC, a5.sC, a6.sC, a7.sC,
1068 a8.sC, a9.sC, aA.sC, aB.sC, aC.sC, aD.sC, aE.sC, aF.sC);
1069 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD, a3.sD, a4.sD, a5.sD, a6.sD, a7.sD,
1070 a8.sD, a9.sD, aA.sD, aB.sD, aC.sD, aD.sD, aE.sD, aF.sD);
1071 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE, a3.sE, a4.sE, a5.sE, a6.sE, a7.sE,
1072 a8.sE, a9.sE, aA.sE, aB.sE, aC.sE, aD.sE, aE.sE, aF.sE);
1073 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF, a3.sF, a4.sF, a5.sF, a6.sF, a7.sF,
1074 a8.sF, a9.sF, aA.sF, aB.sF, aC.sF, aD.sF, aE.sF, aF.sF);
1075#endif // N0 > 8
1076
1077#else // N0 == 16
1078#error "Not supported N0 value"
1079#endif // N0 > 2
1080
1081 // ---------------------------Store the output values ------------------------------
1082
1083 VSTORE(K0)
1084 (res0, 0, (__global DATA_TYPE *)(output_ptr + 0 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
1085 VSTORE(K0)
1086 (res1, 0, (__global DATA_TYPE *)(output_ptr + 1 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
1087#if N0 > 2
1088 VSTORE(K0)
1089 (res2, 0, (__global DATA_TYPE *)(output_ptr + 2 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001090#endif // N0 > 2
1091#if N0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +00001092 VSTORE(K0)
1093 (res3, 0, (__global DATA_TYPE *)(output_ptr + 3 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001094#endif // N0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +00001095#if N0 > 4
1096 VSTORE(K0)
1097 (res4, 0, (__global DATA_TYPE *)(output_ptr + 4 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
1098 VSTORE(K0)
1099 (res5, 0, (__global DATA_TYPE *)(output_ptr + 5 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
1100 VSTORE(K0)
1101 (res6, 0, (__global DATA_TYPE *)(output_ptr + 6 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
1102 VSTORE(K0)
1103 (res7, 0, (__global DATA_TYPE *)(output_ptr + 7 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
1104#endif // N0 > 4
1105#if N0 > 8
1106 VSTORE(K0)
1107 (res8, 0, (__global DATA_TYPE *)(output_ptr + 8 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
1108 VSTORE(K0)
1109 (res9, 0, (__global DATA_TYPE *)(output_ptr + 9 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
1110 VSTORE(K0)
1111 (resA, 0, (__global DATA_TYPE *)(output_ptr + 10 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
1112 VSTORE(K0)
1113 (resB, 0, (__global DATA_TYPE *)(output_ptr + 11 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
1114 VSTORE(K0)
1115 (resC, 0, (__global DATA_TYPE *)(output_ptr + 12 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
1116 VSTORE(K0)
1117 (resD, 0, (__global DATA_TYPE *)(output_ptr + 13 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
1118 VSTORE(K0)
1119 (resE, 0, (__global DATA_TYPE *)(output_ptr + 14 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
1120 VSTORE(K0)
1121 (resF, 0, (__global DATA_TYPE *)(output_ptr + 15 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
1122#endif // N0 > 8
1123
1124#undef BLOCK_SIZE
1125#undef OUTPUT_OFFSET_X
1126#undef OUTPUT_STEP_X
1127}
1128#endif // defined(TRANSPOSE)
1129#endif // defined(K0) && defined(N0) && defined(H0) && defined(DATA_TYPE) && defined(SRC_HEIGHT)
1130
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001131#if defined(M0) && defined(N0) && defined(K0) && defined(H0) && defined(DATA_TYPE) && defined(K)
1132
1133#define CONCAT(a, b) a##b
1134
1135#define ARM_DOT1(a, b, c) \
1136 ({ \
1137 c = fma(a, b, c); \
1138 })
1139#define ARM_DOT2(a, b, c) \
1140 ({ \
1141 c = fma(a.s0, b.s0, c); \
1142 c = fma(a.s1, b.s1, c); \
1143 })
1144#define ARM_DOT3(a, b, c) \
1145 ({ \
1146 ARM_DOT2(a, b, c); \
1147 c = fma((a.s2), (b.s2), c); \
1148 })
1149#define ARM_DOT4(a, b, c) \
1150 ({ \
1151 ARM_DOT3(a, b, c); \
1152 c = fma((a.s3), (b.s3), c); \
1153 })
1154#define ARM_DOT8(a, b, c) \
1155 ({ \
1156 ARM_DOT4((a.lo), (b.lo), c); \
1157 ARM_DOT4((a.hi), (b.hi), c); \
1158 })
1159#define ARM_DOT16(a, b, c) \
1160 ({ \
1161 ARM_DOT8((a.lo), (b.lo), c); \
1162 ARM_DOT8((a.hi), (b.hi), c); \
1163 })
1164
1165#if N0 == 2
1166#define ARM_DOT_K0XN0(k0, a, b, c) \
1167 ({ \
1168 CONCAT(ARM_DOT, k0) \
1169 ((a), (b##0), (c.s0)); \
1170 CONCAT(ARM_DOT, k0) \
1171 ((a), (b##1), (c.s1)); \
1172 })
1173#elif N0 == 3 // N0 == 3
1174#define ARM_DOT_K0XN0(k0, a, b, c) \
1175 ({ \
1176 CONCAT(ARM_DOT, k0) \
1177 ((a), (b##0), (c.s0)); \
1178 CONCAT(ARM_DOT, k0) \
1179 ((a), (b##1), (c.s1)); \
1180 CONCAT(ARM_DOT, k0) \
1181 ((a), (b##2), (c.s2)); \
1182 })
1183#elif N0 == 4 // N0 == 4
1184#define ARM_DOT_K0XN0(k0, a, b, c) \
1185 ({ \
1186 CONCAT(ARM_DOT, k0) \
1187 ((a), (b##0), (c.s0)); \
1188 CONCAT(ARM_DOT, k0) \
1189 ((a), (b##1), (c.s1)); \
1190 CONCAT(ARM_DOT, k0) \
1191 ((a), (b##2), (c.s2)); \
1192 CONCAT(ARM_DOT, k0) \
1193 ((a), (b##3), (c.s3)); \
1194 })
1195#elif N0 == 8 // N0 == 8
1196#define ARM_DOT_K0XN0(k0, a, b, c) \
1197 ({ \
1198 CONCAT(ARM_DOT, k0) \
1199 ((a), (b##0), (c.s0)); \
1200 CONCAT(ARM_DOT, k0) \
1201 ((a), (b##1), (c.s1)); \
1202 CONCAT(ARM_DOT, k0) \
1203 ((a), (b##2), (c.s2)); \
1204 CONCAT(ARM_DOT, k0) \
1205 ((a), (b##3), (c.s3)); \
1206 CONCAT(ARM_DOT, k0) \
1207 ((a), (b##4), (c.s4)); \
1208 CONCAT(ARM_DOT, k0) \
1209 ((a), (b##5), (c.s5)); \
1210 CONCAT(ARM_DOT, k0) \
1211 ((a), (b##6), (c.s6)); \
1212 CONCAT(ARM_DOT, k0) \
1213 ((a), (b##7), (c.s7)); \
1214 })
1215#elif N0 == 16 // N0 == 16
1216#define ARM_DOT_K0XN0(k0, a, b, c) \
1217 ({ \
1218 CONCAT(ARM_DOT, k0) \
1219 ((a), (b##0), (c.s0)); \
1220 CONCAT(ARM_DOT, k0) \
1221 ((a), (b##1), (c.s1)); \
1222 CONCAT(ARM_DOT, k0) \
1223 ((a), (b##2), (c.s2)); \
1224 CONCAT(ARM_DOT, k0) \
1225 ((a), (b##3), (c.s3)); \
1226 CONCAT(ARM_DOT, k0) \
1227 ((a), (b##4), (c.s4)); \
1228 CONCAT(ARM_DOT, k0) \
1229 ((a), (b##5), (c.s5)); \
1230 CONCAT(ARM_DOT, k0) \
1231 ((a), (b##6), (c.s6)); \
1232 CONCAT(ARM_DOT, k0) \
1233 ((a), (b##7), (c.s7)); \
1234 CONCAT(ARM_DOT, k0) \
1235 ((a), (b##8), (c.s8)); \
1236 CONCAT(ARM_DOT, k0) \
1237 ((a), (b##9), (c.s9)); \
1238 CONCAT(ARM_DOT, k0) \
1239 ((a), (b##A), (c.sA)); \
1240 CONCAT(ARM_DOT, k0) \
1241 ((a), (b##B), (c.sB)); \
1242 CONCAT(ARM_DOT, k0) \
1243 ((a), (b##C), (c.sC)); \
1244 CONCAT(ARM_DOT, k0) \
1245 ((a), (b##D), (c.sD)); \
1246 CONCAT(ARM_DOT, k0) \
1247 ((a), (b##E), (c.sE)); \
1248 CONCAT(ARM_DOT, k0) \
1249 ((a), (b##F), (c.sF)); \
1250 })
1251#else // N0 not supported
1252#error "N0 value not supported"
1253#endif // N0 conditions
1254
1255/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
1256 * The LHS matrix is NOT reshaped
1257 * The RHS is reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the block K0xN0 is transposed
1258 *
1259 * @note The number of columns of LHS matrix must be passed at compile time using -DK (i.e. -DK=64)
1260 * @note The block's dimensions used for reshaping the RHS matrix (N0 and K0) must be passed at compile time using -DN0 and -DK0 (i.e. -DN0=8, -DK0=4).
1261 * @note The number of M0 rows to process must be passed at compile time using -DM0 (i.e. -DM0=2)
1262 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (i.e. -DH0=2)
1263 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
1264 * @note Only the following configurations of M0, N0 and K0 are currently supported:
1265 * - M0 = 1, 2, 3, 4, 5, 6, 7, 8
1266 * - N0 = 2, 3, 4, 8, 16
1267 * - K0 = 2, 3, 4, 8, 16
1268 * - H0 > 1
1269 *
1270 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
1271 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
1272 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1273 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1274 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1275 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix
1276 *
1277 * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: F16/F32
1278 * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes)
1279 * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1280 * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes)
1281 * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1282 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
1283 * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
1284 * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes)
1285 * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1286 * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
1287 * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1288 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
1289 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
1290 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1291 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1292 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1293 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1294 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1295 * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes)
1296 * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
1297 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1298 * @param[in] lhs_cross_plane_pad (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
1299 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
1300 */
1301__kernel void gemm_mm_reshaped_only_rhs_t(IMAGE_DECLARATION(lhs),
1302 IMAGE_DECLARATION(rhs),
1303 IMAGE_DECLARATION(dst),
1304 uint lhs_stride_z,
1305 uint rhs_stride_z,
1306 uint dst_stride_z
1307#if defined(REINTERPRET_INPUT_AS_3D)
1308 ,
1309 uint lhs_cross_plane_pad
1310#endif // REINTERPRET_INPUT_AS_3D
1311#if defined(REINTERPRET_OUTPUT_AS_3D)
1312 ,
1313 uint dst_cross_plane_pad
1314#endif // REINTERPRET_OUTPUT_AS_3D
1315 )
1316{
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001317 // Block size
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001318#define RHS_BLOCK_SIZE ((K0) * (N0))
1319
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001320 // RHS offset and step X
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001321#if defined(RHS_INTERLEAVE)
1322#define RHS_OFFSET_X (K0)
1323#define RHS_STEP_X ((K0) * (H0))
1324#define RHS_STEP_LOOP (1)
1325#else // defined(RHS_INTERLEAVE)
1326#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
1327#define RHS_STEP_X (K0)
1328#define RHS_STEP_LOOP (H0)
1329#endif // defined(RHS_INTERLEAVE)
1330
1331 uint x = get_global_id(0);
1332 uint y = get_global_id(1);
1333 uint z = get_global_id(2);
1334
1335 // Compute LHS matrix address
1336 uint lhs_offset = lhs_offset_first_element_in_bytes + y * M0 * (uint)lhs_stride_y;
1337
1338 // Compute RHS matrix address
1339 uint rhs_offset = rhs_offset_first_element_in_bytes + (x % H0) * (uint)RHS_OFFSET_X * sizeof(DATA_TYPE) + (x / (uint)H0) * rhs_stride_y;
1340
1341#if defined(MATRIX_B_DEPTH)
1342 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1343 rhs_offset += (z % MATRIX_B_DEPTH) * rhs_stride_z;
1344#else // defined(MATRIX_B_DEPTH)
1345 rhs_offset += z * rhs_stride_z;
1346#endif // defined(MATRIX_B_DEPTH)
1347
1348 REPEAT_VAR_INIT_TO_CONST(8, uint, zin, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
1349
1350#if defined(REINTERPRET_INPUT_AS_3D)
1351 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
1352 // in order to take into account the presence of possible cross plane paddings
1353 //
1354 // | |
1355 // | plane0 |
1356 // | |
1357 // |__________________|
1358 // |******************|
1359 // | cross_plane_pad |
1360 // |******************|
1361 // | |
1362 // | plane1 |
1363 // | |
1364 // |__________________|
1365
1366 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
1367 zin0 = (0 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1368 zin0 = min((uint)(DEPTH_GEMM3D - 1), zin0);
1369 zin0 *= (lhs_cross_plane_pad * lhs_stride_y);
1370#if M0 > 1
1371 zin1 = (1 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1372 zin1 = min((uint)(DEPTH_GEMM3D - 1), zin1);
1373 zin1 *= (lhs_cross_plane_pad * lhs_stride_y);
1374#endif // M0 > 1
1375#if M0 > 2
1376 zin2 = (2 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1377 zin2 = min((uint)(DEPTH_GEMM3D - 1), zin2);
1378 zin2 *= (lhs_cross_plane_pad * lhs_stride_y);
1379#endif // M0 > 2
1380#if M0 > 3
1381 zin3 = (3 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1382 zin3 = min((uint)(DEPTH_GEMM3D - 1), zin3);
1383 zin3 *= (lhs_cross_plane_pad * lhs_stride_y);
1384#endif // M0 > 3
1385#if M0 > 4
1386 zin4 = (4 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1387 zin4 = min((uint)(DEPTH_GEMM3D - 1), zin4);
1388 zin4 *= (lhs_cross_plane_pad * lhs_stride_y);
1389#endif // M0 > 4
1390#if M0 > 5
1391 zin5 = (5 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1392 zin5 = min((uint)(DEPTH_GEMM3D - 1), zin5);
1393 zin5 *= (lhs_cross_plane_pad * lhs_stride_y);
1394#endif // M0 > 5
1395#if M0 > 6
1396 zin6 = (6 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1397 zin6 = min((uint)(DEPTH_GEMM3D - 1), zin6);
1398 zin6 *= (lhs_cross_plane_pad * lhs_stride_y);
1399#endif // M0 > 6
1400#if M0 > 7
1401 zin7 = (7 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1402 zin7 = min((uint)(DEPTH_GEMM3D - 1), zout7);
1403 zin7 *= (lhs_cross_plane_pad * lhs_stride_y);
1404#endif // M0 > 7
1405
1406 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1407 // multiply lhs_stride_z by DEPTH_GEMM3D
1408 lhs_offset += z * lhs_stride_z * DEPTH_GEMM3D;
1409
1410#else // defined(REINTERPRET_INPUT_AS_3D)
1411
1412 // Add offset for batched GEMM
1413 lhs_offset += z * lhs_stride_z;
1414
1415#endif // defined(REINTERPRET_INPUT_AS_3D)
1416
1417 // Initialize the accumulators
1418 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(M0-1)=0;
1419
1420 int i = 0;
1421 for(; i <= (K - K0); i += K0)
1422 {
1423 // Supported cases (M0, K0):
1424 // 1,2 - 1,3 - 1,4 - 1,8 - 1,16
1425 // 2,2 - 2,3 - 2,4 - 2,8 - 2,16
1426 // 3,2 - 3,3 - 3,4 - 3,8 - 3,16
1427 // 4,2 - 4,3 - 4,4 - 4,8 - 4,16
1428 // 5,2 - 5,3 - 5,4 - 5,8 - 5,16
1429 // 6,2 - 6,3 - 6,4 - 6,8 - 6,16
1430 // 7,2 - 7,3 - 7,4 - 7,8 - 7,16
1431 // 8,2 - 8,3 - 8,4 - 8,8 - 8,16
1432 // Load values from LHS matrix
1433 VEC_DATA_TYPE(DATA_TYPE, K0)
1434 a0 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 0 * lhs_stride_y + zin0));
1435#if M0 > 1
1436 VEC_DATA_TYPE(DATA_TYPE, K0)
1437 a1 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 1 * lhs_stride_y + zin1));
1438#endif // M0 > 1
1439#if M0 > 2
1440 VEC_DATA_TYPE(DATA_TYPE, K0)
1441 a2 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 2 * lhs_stride_y + zin2));
1442#endif // M0 > 2
1443#if M0 > 3
1444 VEC_DATA_TYPE(DATA_TYPE, K0)
1445 a3 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 3 * lhs_stride_y + zin3));
1446#endif // M0 > 3
1447#if M0 > 4
1448 VEC_DATA_TYPE(DATA_TYPE, K0)
1449 a4 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 4 * lhs_stride_y + zin4));
1450#endif // M0 > 4
1451#if M0 > 5
1452 VEC_DATA_TYPE(DATA_TYPE, K0)
1453 a5 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 5 * lhs_stride_y + zin5));
1454#endif // M0 > 5
1455#if M0 > 6
1456 VEC_DATA_TYPE(DATA_TYPE, K0)
1457 a6 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 6 * lhs_stride_y + zin6));
1458#endif // M0 > 6
1459#if M0 > 7
1460 VEC_DATA_TYPE(DATA_TYPE, K0)
1461 a7 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 7 * lhs_stride_y + zin7));
1462#endif // M0 > 7
1463
1464 // Load values from RHS matrix
1465 VEC_DATA_TYPE(DATA_TYPE, K0)
1466 b0 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0 * RHS_STEP_X * sizeof(DATA_TYPE)));
1467 VEC_DATA_TYPE(DATA_TYPE, K0)
1468 b1 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 1 * RHS_STEP_X * sizeof(DATA_TYPE)));
1469#if N0 > 2
1470 VEC_DATA_TYPE(DATA_TYPE, K0)
1471 b2 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 2 * RHS_STEP_X * sizeof(DATA_TYPE)));
1472#endif // N0 > 2
1473#if N0 > 3
1474 VEC_DATA_TYPE(DATA_TYPE, K0)
1475 b3 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 3 * RHS_STEP_X * sizeof(DATA_TYPE)));
1476#endif // N0 > 3
1477#if N0 > 4
1478 VEC_DATA_TYPE(DATA_TYPE, K0)
1479 b4 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 4 * RHS_STEP_X * sizeof(DATA_TYPE)));
1480 VEC_DATA_TYPE(DATA_TYPE, K0)
1481 b5 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 5 * RHS_STEP_X * sizeof(DATA_TYPE)));
1482 VEC_DATA_TYPE(DATA_TYPE, K0)
1483 b6 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 6 * RHS_STEP_X * sizeof(DATA_TYPE)));
1484 VEC_DATA_TYPE(DATA_TYPE, K0)
1485 b7 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 7 * RHS_STEP_X * sizeof(DATA_TYPE)));
1486#endif // N0 > 4
1487#if N0 > 8
1488 VEC_DATA_TYPE(DATA_TYPE, K0)
1489 b8 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 8 * RHS_STEP_X * sizeof(DATA_TYPE)));
1490 VEC_DATA_TYPE(DATA_TYPE, K0)
1491 b9 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 9 * RHS_STEP_X * sizeof(DATA_TYPE)));
1492 VEC_DATA_TYPE(DATA_TYPE, K0)
1493 bA = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 10 * RHS_STEP_X * sizeof(DATA_TYPE)));
1494 VEC_DATA_TYPE(DATA_TYPE, K0)
1495 bB = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 11 * RHS_STEP_X * sizeof(DATA_TYPE)));
1496 VEC_DATA_TYPE(DATA_TYPE, K0)
1497 bC = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 12 * RHS_STEP_X * sizeof(DATA_TYPE)));
1498 VEC_DATA_TYPE(DATA_TYPE, K0)
1499 bD = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 13 * RHS_STEP_X * sizeof(DATA_TYPE)));
1500 VEC_DATA_TYPE(DATA_TYPE, K0)
1501 bE = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 14 * RHS_STEP_X * sizeof(DATA_TYPE)));
1502 VEC_DATA_TYPE(DATA_TYPE, K0)
1503 bF = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 15 * RHS_STEP_X * sizeof(DATA_TYPE)));
1504#endif // N0 > 8
1505
1506 // Accumulate
1507 ARM_DOT_K0XN0(K0, a0, b, c0);
1508#if M0 > 1
1509 ARM_DOT_K0XN0(K0, a1, b, c1);
1510#endif // M0 > 1
1511#if M0 > 2
1512 ARM_DOT_K0XN0(K0, a2, b, c2);
1513#endif // M0 > 2
1514#if M0 > 3
1515 ARM_DOT_K0XN0(K0, a3, b, c3);
1516#endif // M0 > 3
1517#if M0 > 4
1518 ARM_DOT_K0XN0(K0, a4, b, c4);
1519#endif // M0 > 4
1520#if M0 > 5
1521 ARM_DOT_K0XN0(K0, a5, b, c5);
1522#endif // M0 > 5
1523#if M0 > 6
1524 ARM_DOT_K0XN0(K0, a6, b, c6);
1525#endif // M0 > 6
1526#if M0 > 7
1527 ARM_DOT_K0XN0(K0, a7, b, c7);
1528#endif // M0 > 7
1529
1530 lhs_offset += K0 * sizeof(DATA_TYPE);
1531 rhs_offset += (N0 * RHS_STEP_X * RHS_STEP_LOOP) * sizeof(DATA_TYPE);
1532 }
1533
1534 // Left-over accumulations
1535 for(; i < K; ++i)
1536 {
1537 // Supported cases (M0, K0):
1538 // 1,2 - 1,3 - 1,4 - 1,8 - 1,16
1539 // 2,2 - 2,3 - 2,4 - 2,8 - 2,16
1540 // 3,2 - 3,3 - 3,4 - 3,8 - 3,16
1541 // 4,2 - 4,3 - 4,4 - 4,8 - 4,16
1542 // 5,2 - 5,3 - 5,4 - 5,8 - 5,16
1543 // 6,2 - 6,3 - 6,4 - 6,8 - 6,16
1544 // 7,2 - 7,3 - 7,4 - 7,8 - 7,16
1545 // 8,2 - 8,3 - 8,4 - 8,8 - 8,16
1546 // Load values from LHS matrix
1547 DATA_TYPE a0 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 0 * lhs_stride_y + zin0));
1548#if M0 > 1
1549 DATA_TYPE a1 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 1 * lhs_stride_y + zin1));
1550#endif // M0 > 1
1551#if M0 > 2
1552 DATA_TYPE a2 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 2 * lhs_stride_y + zin2));
1553#endif // M0 > 2
1554#if M0 > 3
1555 DATA_TYPE a3 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 3 * lhs_stride_y + zin3));
1556#endif // M0 > 3
1557#if M0 > 4
1558 DATA_TYPE a4 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 4 * lhs_stride_y + zin4));
1559#endif // M0 > 4
1560#if M0 > 5
1561 DATA_TYPE a5 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 5 * lhs_stride_y + zin5));
1562#endif // M0 > 5
1563#if M0 > 6
1564 DATA_TYPE a6 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 6 * lhs_stride_y + zin6));
1565#endif // M0 > 6
1566#if M0 > 7
1567 DATA_TYPE a7 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 7 * lhs_stride_y + zin7));
1568#endif // M0 > 7
1569
1570 // Load values from RHS matrix
1571 DATA_TYPE b0 = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0 * RHS_STEP_X * sizeof(DATA_TYPE)));
1572 DATA_TYPE b1 = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 1 * RHS_STEP_X * sizeof(DATA_TYPE)));
1573#if N0 > 2
1574 DATA_TYPE b2 = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 2 * RHS_STEP_X * sizeof(DATA_TYPE)));
1575#endif // N0 > 2
1576#if N0 > 3
1577 DATA_TYPE b3 = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 3 * RHS_STEP_X * sizeof(DATA_TYPE)));
1578#endif // N0 > 3
1579#if N0 > 4
1580 DATA_TYPE b4 = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 4 * RHS_STEP_X * sizeof(DATA_TYPE)));
1581 DATA_TYPE b5 = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 5 * RHS_STEP_X * sizeof(DATA_TYPE)));
1582 DATA_TYPE b6 = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 6 * RHS_STEP_X * sizeof(DATA_TYPE)));
1583 DATA_TYPE b7 = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 7 * RHS_STEP_X * sizeof(DATA_TYPE)));
1584#endif // N0 > 4
1585#if N0 > 8
1586 DATA_TYPE b8 = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 8 * RHS_STEP_X * sizeof(DATA_TYPE)));
1587 DATA_TYPE b9 = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 9 * RHS_STEP_X * sizeof(DATA_TYPE)));
1588 DATA_TYPE bA = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 10 * RHS_STEP_X * sizeof(DATA_TYPE)));
1589 DATA_TYPE bB = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 11 * RHS_STEP_X * sizeof(DATA_TYPE)));
1590 DATA_TYPE bC = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 12 * RHS_STEP_X * sizeof(DATA_TYPE)));
1591 DATA_TYPE bD = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 13 * RHS_STEP_X * sizeof(DATA_TYPE)));
1592 DATA_TYPE bE = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 14 * RHS_STEP_X * sizeof(DATA_TYPE)));
1593 DATA_TYPE bF = *((__global DATA_TYPE *)(rhs_ptr + rhs_offset + 15 * RHS_STEP_X * sizeof(DATA_TYPE)));
1594#endif // N0 > 8
1595
1596 // Accumulate
1597 ARM_DOT_K0XN0(1, a0, b, c0);
1598#if M0 > 1
1599 ARM_DOT_K0XN0(1, a1, b, c1);
1600#endif // M0 > 1
1601#if M0 > 2
1602 ARM_DOT_K0XN0(1, a2, b, c2);
1603#endif // M0 > 2
1604#if M0 > 3
1605 ARM_DOT_K0XN0(1, a3, b, c3);
1606#endif // M0 > 3
1607#if M0 > 4
1608 ARM_DOT_K0XN0(1, a4, b, c4);
1609#endif // M0 > 4
1610#if M0 > 5
1611 ARM_DOT_K0XN0(1, a5, b, c5);
1612#endif // M0 > 5
1613#if M0 > 6
1614 ARM_DOT_K0XN0(1, a6, b, c6);
1615#endif // M0 > 6
1616#if M0 > 7
1617 ARM_DOT_K0XN0(1, a7, b, c7);
1618#endif // M0 > 7
1619
1620 lhs_offset += sizeof(DATA_TYPE);
1621 rhs_offset += sizeof(DATA_TYPE);
1622 }
1623
1624 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * dst_stride_y);
1625
1626 REPEAT_VAR_INIT_TO_CONST(8, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
1627
1628#if defined(REINTERPRET_OUTPUT_AS_3D)
1629 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
1630 // in order to take into account the presence of possible cross plane paddings
1631 //
1632 // | |
1633 // | plane0 |
1634 // | |
1635 // |__________________|
1636 // |******************|
1637 // | cross_plane_pad |
1638 // |******************|
1639 // | |
1640 // | plane1 |
1641 // | |
1642 // |__________________|
1643
1644 // The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
1645 zout0 = (0 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1646 zout0 = min((uint)(DEPTH_GEMM3D - 1), zout0);
1647 zout0 *= (dst_cross_plane_pad * dst_stride_y);
1648#if M0 > 1
1649 zout1 = (1 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1650 zout1 = min((uint)(DEPTH_GEMM3D - 1), zout1);
1651 zout1 *= (dst_cross_plane_pad * dst_stride_y);
1652#endif // M0 > 1
1653#if M0 > 2
1654 zout2 = (2 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1655 zout2 = min((uint)(DEPTH_GEMM3D - 1), zout2);
1656 zout2 *= (dst_cross_plane_pad * dst_stride_y);
1657#endif // M0 > 2
1658#if M0 > 3
1659 zout3 = (3 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1660 zout3 = min((uint)(DEPTH_GEMM3D - 1), zout3);
1661 zout3 *= (dst_cross_plane_pad * dst_stride_y);
1662#endif // M0 > 3
1663#if M0 > 4
1664 zout4 = (4 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1665 zout4 = min((uint)(DEPTH_GEMM3D - 1), zout4);
1666 zout4 *= (dst_cross_plane_pad * dst_stride_y);
1667#endif // M0 > 4
1668#if M0 > 5
1669 zout5 = (5 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1670 zout5 = min((uint)(DEPTH_GEMM3D - 1), zout5);
1671 zout5 *= (dst_cross_plane_pad * dst_stride_y);
1672#endif // M0 > 5
1673#if M0 > 6
1674 zout6 = (6 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1675 zout6 = min((uint)(DEPTH_GEMM3D - 1), zout6);
1676 zout6 *= (dst_cross_plane_pad * dst_stride_y);
1677#endif // M0 > 6
1678#if M0 > 7
1679 zout7 = (7 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1680 zout7 = min((uint)(DEPTH_GEMM3D - 1), zout7);
1681 zout7 *= (dst_cross_plane_pad * dst_stride_y);
1682#endif // M0 > 7
1683
1684 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1685 // multiply dst_stride_z by DEPTH_GEMM3D
1686 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1687
1688#else // defined(REINTERPRET_OUTPUT_AS_3D)
1689
1690 // Add offset for batched GEMM
1691 dst_addr += z * dst_stride_z;
1692
1693#endif // defined(REINTERPRET_OUTPUT_AS_3D)
1694
1695 // Multiply by the weight of matrix-matrix product and store the result
1696#if defined(ALPHA)
1697 c0 = c0 * (DATA_TYPE)ALPHA;
1698#if M0 > 1
1699 c1 = c1 * (DATA_TYPE)ALPHA;
1700#endif // M0 > 1
1701#if M0 > 2
1702 c2 = c2 * (DATA_TYPE)ALPHA;
1703#endif // M0 > 2
1704#if M0 > 3
1705 c3 = c3 * (DATA_TYPE)ALPHA;
1706#endif // M0 > 3
1707#if M0 > 4
1708 c4 = c4 * (DATA_TYPE)ALPHA;
1709#endif // M0 > 4
1710#if M0 > 5
1711 c5 = c5 * (DATA_TYPE)ALPHA;
1712#endif // M0 > 5
1713#if M0 > 6
1714 c6 = c6 * (DATA_TYPE)ALPHA;
1715#endif // M0 > 5
1716#if M0 > 7
1717 c7 = c7 * (DATA_TYPE)ALPHA;
1718#endif // M0 > 7
1719#endif // defined(ALPHA)
1720
1721 // Store output block
1722 VSTORE(N0)
1723 (c0, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y + zout0));
1724#if M0 > 1
1725 VSTORE(N0)
1726 (c1, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y + zout1));
1727#endif // M0 > 1
1728#if M0 > 2
1729 VSTORE(N0)
1730 (c2, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y + zout2));
1731#endif // M0 > 2
1732#if M0 > 3
1733 VSTORE(N0)
1734 (c3, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y + zout3));
1735#endif // M0 > 3
1736#if M0 > 4
1737 VSTORE(N0)
1738 (c4, 0, (__global DATA_TYPE *)(dst_addr + 4 * dst_stride_y + zout4));
1739#endif // M0 > 4
1740#if M0 > 5
1741 VSTORE(N0)
1742 (c5, 0, (__global DATA_TYPE *)(dst_addr + 5 * dst_stride_y + zout5));
1743#endif // M0 > 5
1744#if M0 > 6
1745 VSTORE(N0)
1746 (c6, 0, (__global DATA_TYPE *)(dst_addr + 6 * dst_stride_y + zout6));
1747#endif // M0 > 6
1748#if M0 > 7
1749 VSTORE(N0)
1750 (c7, 0, (__global DATA_TYPE *)(dst_addr + 7 * dst_stride_y + zout7));
1751#endif // M0 > 7
1752
1753#undef RHS_BLOCK_SIZE
1754#undef RHS_OFFSET_X
1755#undef RHS_STEP_X
1756}
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001757
1758#define VFMA(a, b, c) \
1759 ({ \
1760 c = fma(a, b, c); \
1761 })
1762
1763#if M0 == 1
1764#define LD_RHS_VFMA_M0xN0(i, a, c) \
1765 ({ \
1766 VEC_DATA_TYPE(DATA_TYPE, N0) \
1767 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1768 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1769 })
1770#elif M0 == 2 // M0 == 2
1771#define LD_RHS_VFMA_M0xN0(i, a, c) \
1772 ({ \
1773 VEC_DATA_TYPE(DATA_TYPE, N0) \
1774 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1775 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1776 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1777 })
1778#elif M0 == 3 // M0 == 3
1779#define LD_RHS_VFMA_M0xN0(i, a, c) \
1780 ({ \
1781 VEC_DATA_TYPE(DATA_TYPE, N0) \
1782 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1783 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1784 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1785 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1786 })
1787#elif M0 == 4 // M0 == 4
1788#define LD_RHS_VFMA_M0xN0(i, a, c) \
1789 ({ \
1790 VEC_DATA_TYPE(DATA_TYPE, N0) \
1791 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1792 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1793 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1794 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1795 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1796 })
1797#elif M0 == 5 // M0 == 5
1798#define LD_RHS_VFMA_M0xN0(i, a, c) \
1799 ({ \
1800 VEC_DATA_TYPE(DATA_TYPE, N0) \
1801 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1802 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1803 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1804 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1805 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1806 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
1807 })
1808#elif M0 == 6 // M0 == 6
1809#define LD_RHS_VFMA_M0xN0(i, a, c) \
1810 ({ \
1811 VEC_DATA_TYPE(DATA_TYPE, N0) \
1812 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1813 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1814 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1815 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1816 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1817 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
1818 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
1819 })
1820#elif M0 == 7 // M0 == 7
1821#define LD_RHS_VFMA_M0xN0(i, a, c) \
1822 ({ \
1823 VEC_DATA_TYPE(DATA_TYPE, N0) \
1824 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1825 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1826 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1827 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1828 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1829 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
1830 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
1831 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##6).s##i), b, (c##6)); \
1832 })
1833#elif M0 == 8 // M0 == 8
1834#define LD_RHS_VFMA_M0xN0(i, a, c) \
1835 ({ \
1836 VEC_DATA_TYPE(DATA_TYPE, N0) \
1837 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1838 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1839 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1840 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1841 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1842 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
1843 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
1844 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##6).s##i), b, (c##6)); \
1845 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##7).s##i), b, (c##7)); \
1846 })
1847#else // M0 not supported
1848#error "M0 not supported"
1849#endif // M0 not supported
1850
1851/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
1852 * The LHS matrix is NOT reshaped
1853 * The RHS is reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the block K0xN0 is NOT transposed
1854 *
1855 * @note The block's dimensions used for reshaping the RHS matrix (N0 and K0) must be passed at compile time using -DN0 and -DK0 (i.e. -DN0=8, -DK0=4).
1856 * @note The number of M0 rows to process must be passed at compile time using -DM0 (i.e. -DM0=2)
1857 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (i.e. -DH0=2)
1858 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
1859 * @note Only the following configurations of M0, N0 and K0 are currently supported:
1860 * - M0 = 1, 2, 3, 4, 5, 6, 7, 8
1861 * - N0 = 2, 3, 4, 8, 16
1862 * - K0 = 2, 3, 4, 8, 16
1863 * - H0 > 1
1864 *
1865 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
1866 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
1867 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1868 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1869 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1870 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix
1871 *
1872 * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: F16/F32
1873 * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes)
1874 * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1875 * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes)
1876 * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1877 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
1878 * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
1879 * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes)
1880 * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1881 * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
1882 * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1883 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
1884 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
1885 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1886 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1887 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1888 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1889 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1890 * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes)
1891 * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
1892 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1893 * @param[in] lhs_cross_plane_pad (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
1894 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
1895 */
1896__kernel void gemm_mm_reshaped_only_rhs_nt(IMAGE_DECLARATION(lhs),
1897 IMAGE_DECLARATION(rhs),
1898 IMAGE_DECLARATION(dst),
1899 uint lhs_stride_z,
1900 uint rhs_stride_z,
1901 uint dst_stride_z
1902#if defined(REINTERPRET_INPUT_AS_3D)
1903 ,
1904 uint lhs_cross_plane_pad
1905#endif // REINTERPRET_INPUT_AS_3D
1906#if defined(REINTERPRET_OUTPUT_AS_3D)
1907 ,
1908 uint dst_cross_plane_pad
1909#endif // REINTERPRET_OUTPUT_AS_3D
1910 )
1911{
1912 // Block size
1913#define RHS_BLOCK_SIZE ((K0) * (N0))
1914
1915 // RHS offset and step X
1916#if defined(RHS_INTERLEAVE)
1917#define RHS_OFFSET_X (N0)
1918#define RHS_STEP_X ((N0) * (H0))
1919#define RHS_STEP_LOOP (1)
1920#else // defined(RHS_INTERLEAVE)
1921#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
1922#define RHS_STEP_X (N0)
1923#define RHS_STEP_LOOP (H0)
1924#endif // defined(RHS_INTERLEAVE)
1925
1926 uint x = get_global_id(0);
1927 uint y = get_global_id(1);
1928 uint z = get_global_id(2);
1929
1930 // Compute LHS matrix address
1931 uint lhs_offset = lhs_offset_first_element_in_bytes + y * M0 * (uint)lhs_stride_y;
1932
1933 // Compute RHS matrix address
1934 uint rhs_offset = rhs_offset_first_element_in_bytes + (x % H0) * (uint)RHS_OFFSET_X * sizeof(DATA_TYPE) + (x / (uint)H0) * rhs_stride_y;
1935
1936#if defined(MATRIX_B_DEPTH)
1937 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1938 rhs_offset += (z % MATRIX_B_DEPTH) * rhs_stride_z;
1939#else // defined(MATRIX_B_DEPTH)
1940 rhs_offset += z * rhs_stride_z;
1941#endif // defined(MATRIX_B_DEPTH)
1942
1943 REPEAT_VAR_INIT_TO_CONST(8, uint, zin, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
1944
1945#if defined(REINTERPRET_INPUT_AS_3D)
1946 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
1947 // in order to take into account the presence of possible cross plane paddings
1948 //
1949 // | |
1950 // | plane0 |
1951 // | |
1952 // |__________________|
1953 // |******************|
1954 // | cross_plane_pad |
1955 // |******************|
1956 // | |
1957 // | plane1 |
1958 // | |
1959 // |__________________|
1960
1961 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
1962 zin0 = (0 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1963 zin0 = min((uint)(DEPTH_GEMM3D - 1), zin0);
1964 zin0 *= (lhs_cross_plane_pad * lhs_stride_y);
1965#if M0 > 1
1966 zin1 = (1 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1967 zin1 = min((uint)(DEPTH_GEMM3D - 1), zin1);
1968 zin1 *= (lhs_cross_plane_pad * lhs_stride_y);
1969#endif // M0 > 1
1970#if M0 > 2
1971 zin2 = (2 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1972 zin2 = min((uint)(DEPTH_GEMM3D - 1), zin2);
1973 zin2 *= (lhs_cross_plane_pad * lhs_stride_y);
1974#endif // M0 > 2
1975#if M0 > 3
1976 zin3 = (3 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1977 zin3 = min((uint)(DEPTH_GEMM3D - 1), zin3);
1978 zin3 *= (lhs_cross_plane_pad * lhs_stride_y);
1979#endif // M0 > 3
1980#if M0 > 4
1981 zin4 = (4 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1982 zin4 = min((uint)(DEPTH_GEMM3D - 1), zin4);
1983 zin4 *= (lhs_cross_plane_pad * lhs_stride_y);
1984#endif // M0 > 4
1985#if M0 > 5
1986 zin5 = (5 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1987 zin5 = min((uint)(DEPTH_GEMM3D - 1), zin5);
1988 zin5 *= (lhs_cross_plane_pad * lhs_stride_y);
1989#endif // M0 > 5
1990#if M0 > 6
1991 zin6 = (6 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1992 zin6 = min((uint)(DEPTH_GEMM3D - 1), zin6);
1993 zin6 *= (lhs_cross_plane_pad * lhs_stride_y);
1994#endif // M0 > 6
1995#if M0 > 7
1996 zin7 = (7 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
1997 zin7 = min((uint)(DEPTH_GEMM3D - 1), zout7);
1998 zin7 *= (lhs_cross_plane_pad * lhs_stride_y);
1999#endif // M0 > 7
2000
2001 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2002 // multiply lhs_stride_z by DEPTH_GEMM3D
2003 lhs_offset += z * lhs_stride_z * DEPTH_GEMM3D;
2004
2005#else // defined(REINTERPRET_INPUT_AS_3D)
2006
2007 // Add offset for batched GEMM
2008 lhs_offset += z * lhs_stride_z;
2009
2010#endif // defined(REINTERPRET_INPUT_AS_3D)
2011
2012 // Initialize the accumulators
2013 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(N0-1)=0;
2014
2015 int i = 0;
2016 for(; i <= (K - K0); i += K0)
2017 {
2018 // Supported cases (M0, K0):
2019 // 1,2 - 1,3 - 1,4 - 1,8 - 1,16
2020 // 2,2 - 2,3 - 2,4 - 2,8 - 2,16
2021 // 3,2 - 3,3 - 3,4 - 3,8 - 3,16
2022 // 4,2 - 4,3 - 4,4 - 4,8 - 4,16
2023 // 5,2 - 5,3 - 5,4 - 5,8 - 5,16
2024 // 6,2 - 6,3 - 6,4 - 6,8 - 6,16
2025 // 7,2 - 7,3 - 7,4 - 7,8 - 7,16
2026 // 8,2 - 8,3 - 8,4 - 8,8 - 8,16
2027 // Load values from LHS matrix
2028 VEC_DATA_TYPE(DATA_TYPE, K0)
2029 a0 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 0 * lhs_stride_y + zin0));
2030#if M0 > 1
2031 VEC_DATA_TYPE(DATA_TYPE, K0)
2032 a1 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 1 * lhs_stride_y + zin1));
2033#endif // M0 > 1
2034#if M0 > 2
2035 VEC_DATA_TYPE(DATA_TYPE, K0)
2036 a2 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 2 * lhs_stride_y + zin2));
2037#endif // M0 > 2
2038#if M0 > 3
2039 VEC_DATA_TYPE(DATA_TYPE, K0)
2040 a3 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 3 * lhs_stride_y + zin3));
2041#endif // M0 > 3
2042#if M0 > 4
2043 VEC_DATA_TYPE(DATA_TYPE, K0)
2044 a4 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 4 * lhs_stride_y + zin4));
2045#endif // M0 > 4
2046#if M0 > 5
2047 VEC_DATA_TYPE(DATA_TYPE, K0)
2048 a5 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 5 * lhs_stride_y + zin5));
2049#endif // M0 > 5
2050#if M0 > 6
2051 VEC_DATA_TYPE(DATA_TYPE, K0)
2052 a6 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 6 * lhs_stride_y + zin6));
2053#endif // M0 > 6
2054#if M0 > 7
2055 VEC_DATA_TYPE(DATA_TYPE, K0)
2056 a7 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_ptr + lhs_offset + 7 * lhs_stride_y + zin7));
2057#endif // M0 > 7
2058
2059 LD_RHS_VFMA_M0xN0(0, a, c);
2060 LD_RHS_VFMA_M0xN0(1, a, c);
2061#if K0 > 2
2062 LD_RHS_VFMA_M0xN0(2, a, c);
2063#endif // K0 > 2
2064#if K0 > 3
2065 LD_RHS_VFMA_M0xN0(3, a, c);
2066#endif // K0 > 3
2067#if K0 > 4
2068 LD_RHS_VFMA_M0xN0(4, a, c);
2069 LD_RHS_VFMA_M0xN0(5, a, c);
2070 LD_RHS_VFMA_M0xN0(6, a, c);
2071 LD_RHS_VFMA_M0xN0(7, a, c);
2072#endif // K0 > 4
2073#if K0 > 8
2074 LD_RHS_VFMA_M0xN0(8, a, c);
2075 LD_RHS_VFMA_M0xN0(9, a, c);
2076 LD_RHS_VFMA_M0xN0(A, a, c);
2077 LD_RHS_VFMA_M0xN0(B, a, c);
2078 LD_RHS_VFMA_M0xN0(C, a, c);
2079 LD_RHS_VFMA_M0xN0(D, a, c);
2080 LD_RHS_VFMA_M0xN0(E, a, c);
2081 LD_RHS_VFMA_M0xN0(F, a, c);
2082#endif // K0 > 8
2083
2084 lhs_offset += K0 * sizeof(DATA_TYPE);
2085 rhs_offset += K0 * RHS_STEP_X * RHS_STEP_LOOP * sizeof(DATA_TYPE);
2086 }
2087
2088 // Left-over accumulations
2089 for(; i < K; ++i)
2090 {
2091 // Load values from LHS matrix
2092 VEC_DATA_TYPE(DATA_TYPE, 2)
2093 a0 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 0 * lhs_stride_y + zin0));
2094#if M0 > 1
2095 VEC_DATA_TYPE(DATA_TYPE, 2)
2096 a1 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 1 * lhs_stride_y + zin1));
2097#endif // M0 > 1
2098#if M0 > 2
2099 VEC_DATA_TYPE(DATA_TYPE, 2)
2100 a2 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 2 * lhs_stride_y + zin2));
2101#endif // M0 > 2
2102#if M0 > 3
2103 VEC_DATA_TYPE(DATA_TYPE, 2)
2104 a3 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 3 * lhs_stride_y + zin3));
2105#endif // M0 > 3
2106#if M0 > 4
2107 VEC_DATA_TYPE(DATA_TYPE, 2)
2108 a4 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 4 * lhs_stride_y + zin4));
2109#endif // M0 > 4
2110#if M0 > 5
2111 VEC_DATA_TYPE(DATA_TYPE, 2)
2112 a5 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 5 * lhs_stride_y + zin5));
2113#endif // M0 > 5
2114#if M0 > 6
2115 VEC_DATA_TYPE(DATA_TYPE, 2)
2116 a6 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 6 * lhs_stride_y + zin6));
2117#endif // M0 > 6
2118#if M0 > 7
2119 VEC_DATA_TYPE(DATA_TYPE, 2)
2120 a7 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 7 * lhs_stride_y + zin));
2121#endif // M0 > 7
2122
2123 LD_RHS_VFMA_M0xN0(0, a, c);
2124
2125 lhs_offset += sizeof(DATA_TYPE);
2126 rhs_offset += RHS_STEP_X * sizeof(DATA_TYPE);
2127 }
2128
2129 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * dst_stride_y);
2130
2131 REPEAT_VAR_INIT_TO_CONST(8, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
2132
2133#if defined(REINTERPRET_OUTPUT_AS_3D)
2134 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
2135 // in order to take into account the presence of possible cross plane paddings
2136 //
2137 // | |
2138 // | plane0 |
2139 // | |
2140 // |__________________|
2141 // |******************|
2142 // | cross_plane_pad |
2143 // |******************|
2144 // | |
2145 // | plane1 |
2146 // | |
2147 // |__________________|
2148
2149 // The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
2150 zout0 = (0 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2151 zout0 = min((uint)(DEPTH_GEMM3D - 1), zout0);
2152 zout0 *= (dst_cross_plane_pad * dst_stride_y);
2153#if M0 > 1
2154 zout1 = (1 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2155 zout1 = min((uint)(DEPTH_GEMM3D - 1), zout1);
2156 zout1 *= (dst_cross_plane_pad * dst_stride_y);
2157#endif // M0 > 1
2158#if M0 > 2
2159 zout2 = (2 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2160 zout2 = min((uint)(DEPTH_GEMM3D - 1), zout2);
2161 zout2 *= (dst_cross_plane_pad * dst_stride_y);
2162#endif // M0 > 2
2163#if M0 > 3
2164 zout3 = (3 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2165 zout3 = min((uint)(DEPTH_GEMM3D - 1), zout3);
2166 zout3 *= (dst_cross_plane_pad * dst_stride_y);
2167#endif // M0 > 3
2168#if M0 > 4
2169 zout4 = (4 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2170 zout4 = min((uint)(DEPTH_GEMM3D - 1), zout4);
2171 zout4 *= (dst_cross_plane_pad * dst_stride_y);
2172#endif // M0 > 4
2173#if M0 > 5
2174 zout5 = (5 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2175 zout5 = min((uint)(DEPTH_GEMM3D - 1), zout5);
2176 zout5 *= (dst_cross_plane_pad * dst_stride_y);
2177#endif // M0 > 5
2178#if M0 > 6
2179 zout6 = (6 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2180 zout6 = min((uint)(DEPTH_GEMM3D - 1), zout6);
2181 zout6 *= (dst_cross_plane_pad * dst_stride_y);
2182#endif // M0 > 6
2183#if M0 > 7
2184 zout7 = (7 + (uint)(y * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2185 zout7 = min((uint)(DEPTH_GEMM3D - 1), zout7);
2186 zout7 *= (dst_cross_plane_pad * dst_stride_y);
2187#endif // M0 > 7
2188
2189 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2190 // multiply dst_stride_z by DEPTH_GEMM3D
2191 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
2192
2193#else // defined(REINTERPRET_OUTPUT_AS_3D)
2194
2195 // Add offset for batched GEMM
2196 dst_addr += z * dst_stride_z;
2197
2198#endif // defined(REINTERPRET_OUTPUT_AS_3D)
2199
2200 // Multiply by the weight of matrix-matrix product and store the result
2201#if defined(ALPHA)
2202 c0 = c0 * (DATA_TYPE)ALPHA;
2203#if M0 > 1
2204 c1 = c1 * (DATA_TYPE)ALPHA;
2205#endif // M0 > 1
2206#if M0 > 2
2207 c2 = c2 * (DATA_TYPE)ALPHA;
2208#endif // M0 > 2
2209#if M0 > 3
2210 c3 = c3 * (DATA_TYPE)ALPHA;
2211#endif // M0 > 3
2212#if M0 > 4
2213 c4 = c4 * (DATA_TYPE)ALPHA;
2214#endif // M0 > 4
2215#if M0 > 5
2216 c5 = c5 * (DATA_TYPE)ALPHA;
2217#endif // M0 > 5
2218#if M0 > 6
2219 c6 = c6 * (DATA_TYPE)ALPHA;
2220#endif // M0 > 5
2221#if M0 > 7
2222 c7 = c7 * (DATA_TYPE)ALPHA;
2223#endif // M0 > 7
2224#endif // defined(ALPHA)
2225
2226 // Store output block
2227 VSTORE(N0)
2228 (c0, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y + zout0));
2229#if M0 > 1
2230 VSTORE(N0)
2231 (c1, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y + zout1));
2232#endif // M0 > 1
2233#if M0 > 2
2234 VSTORE(N0)
2235 (c2, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y + zout2));
2236#endif // M0 > 2
2237#if M0 > 3
2238 VSTORE(N0)
2239 (c3, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y + zout3));
2240#endif // M0 > 3
2241#if M0 > 4
2242 VSTORE(N0)
2243 (c4, 0, (__global DATA_TYPE *)(dst_addr + 4 * dst_stride_y + zout4));
2244#endif // M0 > 4
2245#if M0 > 5
2246 VSTORE(N0)
2247 (c5, 0, (__global DATA_TYPE *)(dst_addr + 5 * dst_stride_y + zout5));
2248#endif // M0 > 5
2249#if M0 > 6
2250 VSTORE(N0)
2251 (c6, 0, (__global DATA_TYPE *)(dst_addr + 6 * dst_stride_y + zout6));
2252#endif // M0 > 6
2253#if M0 > 7
2254 VSTORE(N0)
2255 (c7, 0, (__global DATA_TYPE *)(dst_addr + 7 * dst_stride_y + zout7));
2256#endif // M0 > 7
2257
2258#undef RHS_BLOCK_SIZE
2259#undef RHS_OFFSET_X
2260#undef RHS_STEP_X
2261}
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00002262#endif // defined(M0) && defined(N0) && defined(K0) && defined(H0) && defined(DATA_TYPE) && defined(K)
2263
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00002264#if defined(M0) && defined(N0) && defined(K0) && defined(V0) && defined(H0) && defined(DATA_TYPE)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002265
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00002266#if K0 == 2
2267#define ARM_DOT_K0(a, b, c) \
2268 ({ \
2269 c = fma(a.s0, b.s0, c); \
2270 c = fma(a.s1, b.s1, c); \
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002271 })
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00002272#elif K0 == 3 // K0 == 3
2273#define ARM_DOT_K0(a, b, c) \
2274 ({ \
2275 c = fma(a.s0, b.s0, c); \
2276 c = fma(a.s1, b.s1, c); \
2277 c = fma(a.s2, b.s2, c); \
2278 })
2279#elif K0 == 4 // K0 == 4
2280#define ARM_DOT_K0(a, b, c) \
2281 ({ \
2282 c = fma(a.s0, b.s0, c); \
2283 c = fma(a.s1, b.s1, c); \
2284 c = fma(a.s2, b.s2, c); \
2285 c = fma(a.s3, b.s3, c); \
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002286 })
2287#elif K0 == 8 // K0 == 8
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00002288#define ARM_DOT_K0(a, b, c) \
2289 ({ \
2290 c = fma(a.s0, b.s0, c); \
2291 c = fma(a.s1, b.s1, c); \
2292 c = fma(a.s2, b.s2, c); \
2293 c = fma(a.s3, b.s3, c); \
2294 c = fma(a.s4, b.s4, c); \
2295 c = fma(a.s5, b.s5, c); \
2296 c = fma(a.s6, b.s6, c); \
2297 c = fma(a.s7, b.s7, c); \
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002298 })
2299#elif K0 == 16 // K0 == 16
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00002300#define ARM_DOT_K0(a, b, c) \
2301 ({ \
2302 c = fma(a.s0, b.s0, c); \
2303 c = fma(a.s1, b.s1, c); \
2304 c = fma(a.s2, b.s2, c); \
2305 c = fma(a.s3, b.s3, c); \
2306 c = fma(a.s4, b.s4, c); \
2307 c = fma(a.s5, b.s5, c); \
2308 c = fma(a.s6, b.s6, c); \
2309 c = fma(a.s7, b.s7, c); \
2310 c = fma(a.s8, b.s8, c); \
2311 c = fma(a.s9, b.s9, c); \
2312 c = fma(a.sA, b.sA, c); \
2313 c = fma(a.sB, b.sB, c); \
2314 c = fma(a.sC, b.sC, c); \
2315 c = fma(a.sD, b.sD, c); \
2316 c = fma(a.sE, b.sE, c); \
2317 c = fma(a.sF, b.sF, c); \
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002318 })
2319#else // K0 not supported
2320#error "K0 value not supported"
2321#endif // K0 conditions
2322
2323#if N0 == 2
2324#define ARM_DOT_K0XN0(a, b, c) \
2325 ({ \
2326 ARM_DOT_K0((a), (b##0), (c.s0)); \
2327 ARM_DOT_K0((a), (b##1), (c.s1)); \
2328 })
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00002329#elif N0 == 3 // N0 == 3
2330#define ARM_DOT_K0XN0(a, b, c) \
2331 ({ \
2332 ARM_DOT_K0((a), (b##0), (c.s0)); \
2333 ARM_DOT_K0((a), (b##1), (c.s1)); \
2334 ARM_DOT_K0((a), (b##2), (c.s2)); \
2335 })
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002336#elif N0 == 4 // N0 == 4
2337#define ARM_DOT_K0XN0(a, b, c) \
2338 ({ \
2339 ARM_DOT_K0((a), (b##0), (c.s0)); \
2340 ARM_DOT_K0((a), (b##1), (c.s1)); \
2341 ARM_DOT_K0((a), (b##2), (c.s2)); \
2342 ARM_DOT_K0((a), (b##3), (c.s3)); \
2343 })
2344#elif N0 == 8 // N0 == 8
2345#define ARM_DOT_K0XN0(a, b, c) \
2346 ({ \
2347 ARM_DOT_K0((a), (b##0), (c.s0)); \
2348 ARM_DOT_K0((a), (b##1), (c.s1)); \
2349 ARM_DOT_K0((a), (b##2), (c.s2)); \
2350 ARM_DOT_K0((a), (b##3), (c.s3)); \
2351 ARM_DOT_K0((a), (b##4), (c.s4)); \
2352 ARM_DOT_K0((a), (b##5), (c.s5)); \
2353 ARM_DOT_K0((a), (b##6), (c.s6)); \
2354 ARM_DOT_K0((a), (b##7), (c.s7)); \
2355 })
2356#elif N0 == 16 // N0 == 16
2357#define ARM_DOT_K0XN0(a, b, c) \
2358 ({ \
2359 ARM_DOT_K0((a), (b##0), (c.s0)); \
2360 ARM_DOT_K0((a), (b##1), (c.s1)); \
2361 ARM_DOT_K0((a), (b##2), (c.s2)); \
2362 ARM_DOT_K0((a), (b##3), (c.s3)); \
2363 ARM_DOT_K0((a), (b##4), (c.s4)); \
2364 ARM_DOT_K0((a), (b##5), (c.s5)); \
2365 ARM_DOT_K0((a), (b##6), (c.s6)); \
2366 ARM_DOT_K0((a), (b##7), (c.s7)); \
2367 ARM_DOT_K0((a), (b##8), (c.s8)); \
2368 ARM_DOT_K0((a), (b##9), (c.s9)); \
2369 ARM_DOT_K0((a), (b##A), (c.sA)); \
2370 ARM_DOT_K0((a), (b##B), (c.sB)); \
2371 ARM_DOT_K0((a), (b##C), (c.sC)); \
2372 ARM_DOT_K0((a), (b##D), (c.sD)); \
2373 ARM_DOT_K0((a), (b##E), (c.sE)); \
2374 ARM_DOT_K0((a), (b##F), (c.sF)); \
2375 })
2376#else // N0 not supported
2377#error "N0 value not supported"
2378#endif // N0 conditions
2379
2380/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
2381 * The LHS matrix must be reshaped with @ref CLGEMMReshapeLHSMatrixKernel and the M0xK0 must be NOT transposed
2382 * The RHS matrix must be reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the K0xN0 must be transposed
2383 *
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002384 * @note The block's dimensions used for reshaping the LHS matrix and the RHS matrix (M0, N0 and K0) must be passed at compile time using -DM0, -DN0 and -DK0 (i.e. -DM0=4, -DN0=8, -DK0=4).
2385 * @note The number of M0xK0 vertical blocks stored on the same output row of the reshaped LHS matrix must be passed at compile time using -DV0 (i.e. -DV0=2)
2386 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (i.e. -DH0=2)
2387 * @note If the M0xK0 blocks in the reshaped LHS matrix have been interleaved, the option -DLHS_INTERLEAVE must passed at compile time.
2388 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
2389 * @note Only the following configurations of M0, N0 and K0 are currently supported:
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00002390 * - M0 = 1, 2, 3, 4, 5, 6, 7, 8
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00002391 * - N0 = 2, 3, 4, 8, 16
2392 * - K0 = 2, 3, 4, 8, 16
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002393 *
2394 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
2395 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2396 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2397 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2398 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix NOT reshaped
2399 *
2400 * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: F16/F32
2401 * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes)
2402 * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2403 * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes)
2404 * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2405 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
Gian Marco Iodice49b10152018-12-14 17:13:34 +00002406 * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002407 * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes)
2408 * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2409 * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
2410 * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2411 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
Gian Marco Iodice49b10152018-12-14 17:13:34 +00002412 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002413 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2414 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2415 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2416 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2417 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00002418 * @param[in] k Number of columns in LHS matrix and rows in RHS matrix not reshaped.
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002419 * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes)
2420 * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
2421 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2422 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
2423 */
2424__kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs),
2425 IMAGE_DECLARATION(rhs),
2426 IMAGE_DECLARATION(dst),
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00002427 uint k,
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002428 uint lhs_stride_z,
2429 uint rhs_stride_z,
2430 uint dst_stride_z
2431#if defined(REINTERPRET_OUTPUT_AS_3D)
2432 ,
2433 uint dst_cross_plane_pad
2434#endif // REINTERPRET_OUTPUT_AS_3D
2435 )
2436{
2437 // Block size
2438#define LHS_BLOCK_SIZE ((K0) * (M0))
2439
2440#if defined(LHS_INTERLEAVE)
2441#define LHS_OFFSET_X (K0)
2442#define LHS_STEP_X ((K0) * (V0))
2443#define LHS_STEP_LOOP (1)
2444#else // defined(INTERLEAVE)
2445#define LHS_OFFSET_X (LHS_BLOCK_SIZE)
2446#define LHS_STEP_X (K0)
2447#define LHS_STEP_LOOP (V0)
2448#endif // defined(INTERLEAVE)
2449
2450 // Block size
2451#define RHS_BLOCK_SIZE ((K0) * (N0))
2452
2453 // RHS offset and step X
2454#if defined(RHS_INTERLEAVE)
2455#define RHS_OFFSET_X (K0)
2456#define RHS_STEP_X ((K0) * (H0))
2457#define RHS_STEP_LOOP (1)
2458#else // defined(RHS_INTERLEAVE)
2459#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
2460#define RHS_STEP_X (K0)
2461#define RHS_STEP_LOOP (H0)
2462#endif // defined(RHS_INTERLEAVE)
2463
2464 // Compute LHS matrix address
2465 __global uchar *lhs_addr = lhs_ptr + lhs_offset_first_element_in_bytes + (get_global_id(1) % V0) * (uint)LHS_OFFSET_X * sizeof(DATA_TYPE) + (get_global_id(1) / V0) * (uint)lhs_stride_y +
2466 (get_global_id(2) * lhs_stride_z);
2467
2468 // Compute RHS matrix address
2469 __global uchar *rhs_addr = rhs_ptr + rhs_offset_first_element_in_bytes + (get_global_id(0) % H0) * (uint)RHS_OFFSET_X * sizeof(DATA_TYPE) + (get_global_id(0) / (uint)H0) * rhs_stride_y;
2470
2471#if defined(MATRIX_B_DEPTH)
2472 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2473 rhs_addr += (get_global_id(2) % MATRIX_B_DEPTH) * rhs_stride_z;
2474#else // defined(MATRIX_B_DEPTH)
2475 rhs_addr += get_global_id(2) * rhs_stride_z;
2476#endif // defined(MATRIX_B_DEPTH)
2477
2478 // Initialize the accumulators
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +00002479 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(M0-1)=0;
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002480
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00002481 for(int i = 0; i < k; i += K0)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002482 {
2483 // Supported cases (M0, K0):
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00002484 // 1,2 - 1,3 - 1,4 - 1,8 - 1,16
2485 // 2,2 - 2,3 - 2,4 - 2,8 - 2,16
2486 // 3,2 - 3,3 - 3,4 - 3,8 - 3,16
2487 // 4,2 - 4,3 - 4,4 - 4,8 - 4,16
2488 // 5,2 - 5,3 - 5,4 - 5,8 - 5,16
2489 // 6,2 - 6,3 - 6,4 - 6,8 - 6,16
2490 // 7,2 - 7,3 - 7,4 - 7,8 - 7,16
2491 // 8,2 - 8,3 - 8,4 - 8,8 - 8,16
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002492 // Load values from LHS matrix
2493 VEC_DATA_TYPE(DATA_TYPE, K0)
2494 a0 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_addr + 0 * LHS_STEP_X * sizeof(DATA_TYPE)));
2495#if M0 > 1
2496 VEC_DATA_TYPE(DATA_TYPE, K0)
2497 a1 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_addr + 1 * LHS_STEP_X * sizeof(DATA_TYPE)));
2498#endif // M0 > 1
2499#if M0 > 2
2500 VEC_DATA_TYPE(DATA_TYPE, K0)
2501 a2 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_addr + 2 * LHS_STEP_X * sizeof(DATA_TYPE)));
2502#endif // M0 > 2
2503#if M0 > 3
2504 VEC_DATA_TYPE(DATA_TYPE, K0)
2505 a3 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_addr + 3 * LHS_STEP_X * sizeof(DATA_TYPE)));
2506#endif // M0 > 3
2507#if M0 > 4
2508 VEC_DATA_TYPE(DATA_TYPE, K0)
2509 a4 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_addr + 4 * LHS_STEP_X * sizeof(DATA_TYPE)));
2510#endif // M0 > 4
2511#if M0 > 5
2512 VEC_DATA_TYPE(DATA_TYPE, K0)
2513 a5 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_addr + 5 * LHS_STEP_X * sizeof(DATA_TYPE)));
2514#endif // M0 > 5
2515#if M0 > 6
2516 VEC_DATA_TYPE(DATA_TYPE, K0)
2517 a6 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_addr + 6 * LHS_STEP_X * sizeof(DATA_TYPE)));
2518#endif // M0 > 6
2519#if M0 > 7
2520 VEC_DATA_TYPE(DATA_TYPE, K0)
2521 a7 = VLOAD(K0)(0, (__global DATA_TYPE *)(lhs_addr + 7 * LHS_STEP_X * sizeof(DATA_TYPE)));
2522#endif // M0 > 7
2523
2524 // Load values from RHS matrix
2525 VEC_DATA_TYPE(DATA_TYPE, K0)
2526 b0 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 0 * RHS_STEP_X * sizeof(DATA_TYPE)));
2527 VEC_DATA_TYPE(DATA_TYPE, K0)
2528 b1 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 1 * RHS_STEP_X * sizeof(DATA_TYPE)));
2529#if N0 > 2
2530 VEC_DATA_TYPE(DATA_TYPE, K0)
2531 b2 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 2 * RHS_STEP_X * sizeof(DATA_TYPE)));
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00002532#endif // N0 > 2
2533#if N0 > 3
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002534 VEC_DATA_TYPE(DATA_TYPE, K0)
2535 b3 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 3 * RHS_STEP_X * sizeof(DATA_TYPE)));
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00002536#endif // N0 > 3
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002537#if N0 > 4
2538 VEC_DATA_TYPE(DATA_TYPE, K0)
2539 b4 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 4 * RHS_STEP_X * sizeof(DATA_TYPE)));
2540 VEC_DATA_TYPE(DATA_TYPE, K0)
2541 b5 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 5 * RHS_STEP_X * sizeof(DATA_TYPE)));
2542 VEC_DATA_TYPE(DATA_TYPE, K0)
2543 b6 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 6 * RHS_STEP_X * sizeof(DATA_TYPE)));
2544 VEC_DATA_TYPE(DATA_TYPE, K0)
2545 b7 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 7 * RHS_STEP_X * sizeof(DATA_TYPE)));
2546#endif // N0 > 4
2547#if N0 > 8
2548 VEC_DATA_TYPE(DATA_TYPE, K0)
2549 b8 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 8 * RHS_STEP_X * sizeof(DATA_TYPE)));
2550 VEC_DATA_TYPE(DATA_TYPE, K0)
2551 b9 = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 9 * RHS_STEP_X * sizeof(DATA_TYPE)));
2552 VEC_DATA_TYPE(DATA_TYPE, K0)
2553 bA = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 10 * RHS_STEP_X * sizeof(DATA_TYPE)));
2554 VEC_DATA_TYPE(DATA_TYPE, K0)
2555 bB = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 11 * RHS_STEP_X * sizeof(DATA_TYPE)));
2556 VEC_DATA_TYPE(DATA_TYPE, K0)
2557 bC = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 12 * RHS_STEP_X * sizeof(DATA_TYPE)));
2558 VEC_DATA_TYPE(DATA_TYPE, K0)
2559 bD = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 13 * RHS_STEP_X * sizeof(DATA_TYPE)));
2560 VEC_DATA_TYPE(DATA_TYPE, K0)
2561 bE = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 14 * RHS_STEP_X * sizeof(DATA_TYPE)));
2562 VEC_DATA_TYPE(DATA_TYPE, K0)
2563 bF = VLOAD(K0)(0, (__global DATA_TYPE *)(rhs_addr + 15 * RHS_STEP_X * sizeof(DATA_TYPE)));
2564#endif // N0 > 8
2565
2566 // Accumulate
2567 ARM_DOT_K0XN0(a0, b, c0);
2568#if M0 > 1
2569 ARM_DOT_K0XN0(a1, b, c1);
2570#endif // M0 > 1
2571#if M0 > 2
2572 ARM_DOT_K0XN0(a2, b, c2);
2573#endif // M0 > 2
2574#if M0 > 3
2575 ARM_DOT_K0XN0(a3, b, c3);
2576#endif // M0 > 3
2577#if M0 > 4
2578 ARM_DOT_K0XN0(a4, b, c4);
2579#endif // M0 > 4
2580#if M0 > 5
2581 ARM_DOT_K0XN0(a5, b, c5);
2582#endif // M0 > 5
2583#if M0 > 6
2584 ARM_DOT_K0XN0(a6, b, c6);
2585#endif // M0 > 6
2586#if M0 > 7
2587 ARM_DOT_K0XN0(a7, b, c7);
2588#endif // M0 > 7
2589
2590 lhs_addr += (M0 * LHS_STEP_X * LHS_STEP_LOOP) * sizeof(DATA_TYPE);
2591 rhs_addr += (N0 * RHS_STEP_X * RHS_STEP_LOOP) * sizeof(DATA_TYPE);
2592 }
2593
2594 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * dst_stride_y);
2595
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +00002596 REPEAT_VAR_INIT_TO_CONST(8, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002597
2598#if defined(REINTERPRET_OUTPUT_AS_3D)
2599 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
2600 // in order to take into account the presence of possible cross plane paddings
2601 //
2602 // | |
2603 // | plane0 |
2604 // | |
2605 // |__________________|
2606 // |******************|
2607 // | cross_plane_pad |
2608 // |******************|
2609 // | |
2610 // | plane1 |
2611 // | |
2612 // |__________________|
2613
2614 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
2615 zout0 = (0 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2616 zout0 = min((uint)(DEPTH_GEMM3D - 1), zout0);
Gian Marco Iodice49b10152018-12-14 17:13:34 +00002617 zout0 *= (dst_cross_plane_pad * dst_stride_y);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002618#if M0 > 1
2619 zout1 = (1 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2620 zout1 = min((uint)(DEPTH_GEMM3D - 1), zout1);
Gian Marco Iodice49b10152018-12-14 17:13:34 +00002621 zout1 *= (dst_cross_plane_pad * dst_stride_y);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002622#endif // M0 > 1
2623#if M0 > 2
2624 zout2 = (2 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2625 zout2 = min((uint)(DEPTH_GEMM3D - 1), zout2);
Gian Marco Iodice49b10152018-12-14 17:13:34 +00002626 zout2 *= (dst_cross_plane_pad * dst_stride_y);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002627#endif // M0 > 2
2628#if M0 > 3
2629 zout3 = (3 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2630 zout3 = min((uint)(DEPTH_GEMM3D - 1), zout3);
Gian Marco Iodice49b10152018-12-14 17:13:34 +00002631 zout3 *= (dst_cross_plane_pad * dst_stride_y);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002632#endif // M0 > 3
2633#if M0 > 4
2634 zout4 = (4 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2635 zout4 = min((uint)(DEPTH_GEMM3D - 1), zout4);
Gian Marco Iodice49b10152018-12-14 17:13:34 +00002636 zout4 *= (dst_cross_plane_pad * dst_stride_y);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002637#endif // M0 > 4
2638#if M0 > 5
2639 zout5 = (5 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2640 zout5 = min((uint)(DEPTH_GEMM3D - 1), zout5);
Gian Marco Iodice49b10152018-12-14 17:13:34 +00002641 zout5 *= (dst_cross_plane_pad * dst_stride_y);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002642#endif // M0 > 5
2643#if M0 > 6
2644 zout6 = (6 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2645 zout6 = min((uint)(DEPTH_GEMM3D - 1), zout6);
Gian Marco Iodice49b10152018-12-14 17:13:34 +00002646 zout6 *= (dst_cross_plane_pad * dst_stride_y);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002647#endif // M0 > 6
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00002648#if M0 > 7
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002649 zout7 = (7 + (uint)(get_global_id(1) * (uint)M0)) / (uint)HEIGHT_GEMM3D;
2650 zout7 = min((uint)(DEPTH_GEMM3D - 1), zout7);
Gian Marco Iodice49b10152018-12-14 17:13:34 +00002651 zout7 *= (dst_cross_plane_pad * dst_stride_y);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002652#endif // M0 > 7
2653
2654 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2655 // multiply dst_stride_z by DEPTH_GEMM3D
2656 dst_addr += get_global_id(2) * dst_stride_z * DEPTH_GEMM3D;
2657
2658#else // defined(REINTERPRET_OUTPUT_AS_3D)
2659
2660 // Add offset for batched GEMM
2661 dst_addr += get_global_id(2) * dst_stride_z;
2662
2663#endif // defined(REINTERPRET_OUTPUT_AS_3D)
2664
2665 // Multiply by the weight of matrix-matrix product and store the result
2666#if defined(ALPHA)
2667 c0 = c0 * (DATA_TYPE)ALPHA;
2668#if M0 > 1
2669 c1 = c1 * (DATA_TYPE)ALPHA;
2670#endif // M0 > 1
2671#if M0 > 2
2672 c2 = c2 * (DATA_TYPE)ALPHA;
2673#endif // M0 > 2
2674#if M0 > 3
2675 c3 = c3 * (DATA_TYPE)ALPHA;
2676#endif // M0 > 3
2677#if M0 > 4
2678 c4 = c4 * (DATA_TYPE)ALPHA;
2679#endif // M0 > 4
2680#if M0 > 5
2681 c5 = c5 * (DATA_TYPE)ALPHA;
2682#endif // M0 > 5
2683#if M0 > 6
2684 c6 = c6 * (DATA_TYPE)ALPHA;
2685#endif // M0 > 5
2686#if M0 > 7
2687 c7 = c7 * (DATA_TYPE)ALPHA;
2688#endif // M0 > 7
2689#endif // defined(ALPHA)
2690
2691 // Store output block
2692 VSTORE(N0)
2693 (c0, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y + zout0));
2694#if M0 > 1
2695 VSTORE(N0)
2696 (c1, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y + zout1));
2697#endif // M0 > 1
2698#if M0 > 2
2699 VSTORE(N0)
2700 (c2, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y + zout2));
2701#endif // M0 > 2
2702#if M0 > 3
2703 VSTORE(N0)
2704 (c3, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y + zout3));
2705#endif // M0 > 3
2706#if M0 > 4
2707 VSTORE(N0)
2708 (c4, 0, (__global DATA_TYPE *)(dst_addr + 4 * dst_stride_y + zout4));
2709#endif // M0 > 4
2710#if M0 > 5
2711 VSTORE(N0)
2712 (c5, 0, (__global DATA_TYPE *)(dst_addr + 5 * dst_stride_y + zout5));
2713#endif // M0 > 5
2714#if M0 > 6
2715 VSTORE(N0)
2716 (c6, 0, (__global DATA_TYPE *)(dst_addr + 6 * dst_stride_y + zout6));
2717#endif // M0 > 6
2718#if M0 > 7
2719 VSTORE(N0)
2720 (c7, 0, (__global DATA_TYPE *)(dst_addr + 7 * dst_stride_y + zout7));
2721#endif // M0 > 7
2722
2723#undef LHS_BLOCK_SIZE
2724#undef LHS_OFFSET_X
2725#undef LHS_STEP_X
2726#undef RHS_BLOCK_SIZE
2727#undef RHS_OFFSET_X
2728#undef RHS_STEP_X
2729}
2730#endif // defined(M0) && defined(N0) && defined(K0) && defined(V0) && defined(H0) && defined(K) && defined(DATA_TYPE)
2731
Gian Marco36a0a462018-01-12 10:21:40 +00002732#if defined(TRANSPOSE_W) && defined(MULT_TRANSPOSE1XW_WIDTH)
2733
Gian Marco19835e52018-01-30 13:35:54 +00002734#if ELEMENT_SIZE == 1
Gian Marco36a0a462018-01-12 10:21:40 +00002735#define DATA_TYPE uchar
Gian Marco19835e52018-01-30 13:35:54 +00002736#elif ELEMENT_SIZE == 2
2737#define DATA_TYPE ushort
2738#elif ELEMENT_SIZE == 4
2739#define DATA_TYPE uint
2740#else // ELEMENT_SIZE == 1
2741#error "Element size not supported"
2742#endif // ELEMENT_SIZE
Gian Marco36a0a462018-01-12 10:21:40 +00002743
2744/** This OpenCL kernel computes the "vector" 1xW transposition of input matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002745 *
Gian Marco19835e52018-01-30 13:35:54 +00002746 * @note The transposition width must be passed at compile time using -DTRANSPOSE_W (i.e. -DTRANSPOSE_W)
2747 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
Gian Marco36a0a462018-01-12 10:21:40 +00002748 *
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01002749 * @param[in] src_ptr Pointer to the source matrix. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002750 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
2751 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2752 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
2753 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marcoae2af742018-02-15 12:35:44 +00002754 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
2755 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002756 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002757 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002758 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00002759 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002760 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00002761 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marcoae2af742018-02-15 12:35:44 +00002762 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2763 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002764 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2765 */
Gian Marcoae2af742018-02-15 12:35:44 +00002766__kernel void gemm_transpose1xW(TENSOR3D_DECLARATION(src),
2767 TENSOR3D_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002768{
2769 uint x = get_global_id(0);
2770 uint y = get_global_id(1);
Gian Marcoae2af742018-02-15 12:35:44 +00002771 uint z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002772
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002773 // Compute address for Matrix B - source
Gian Marcoae2af742018-02-15 12:35:44 +00002774 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002775
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002776 // Compute address for Matrix B transposed - destination. X and Y are swapped
Gian Marco36a0a462018-01-12 10:21:40 +00002777 uint dst_addr_in_bytes = dst_offset_first_element_in_bytes + y * TRANSPOSE_W * sizeof(DATA_TYPE) * MULT_TRANSPOSE1XW_WIDTH + (x / MULT_TRANSPOSE1XW_WIDTH) * dst_stride_y +
2778 (x % MULT_TRANSPOSE1XW_WIDTH) * TRANSPOSE_W * sizeof(DATA_TYPE);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002779
Gian Marcoae2af742018-02-15 12:35:44 +00002780 // Add offset for batched GEMM
2781 dst_addr_in_bytes += z * dst_stride_z;
2782
Gian Marco36a0a462018-01-12 10:21:40 +00002783 VEC_DATA_TYPE(DATA_TYPE, TRANSPOSE_W)
2784 b0 = VLOAD(TRANSPOSE_W)(0, (__global DATA_TYPE *)src.ptr);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002785
Gian Marco36a0a462018-01-12 10:21:40 +00002786 VSTORE(TRANSPOSE_W)
2787 (b0, 0, (__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002788}
Gian Marco36a0a462018-01-12 10:21:40 +00002789#endif // defined(TRANSPOSE_W) && defined(MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002790
Gian Marco36a0a462018-01-12 10:21:40 +00002791#if defined(MULT_INTERLEAVE4X4_HEIGHT) && defined(DATA_TYPE)
2792
Gian Marco Iodice4b908652018-10-18 10:21:02 +01002793/** This OpenCL kernel reshapes the input matrix transposing each 4x4 block. If -DUNROLL_BLOCK is passed at compile time, the 4x4 block
2794 * will be simply unrolled.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002795 *
Gian Marco19835e52018-01-30 13:35:54 +00002796 * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float)
2797 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002798 * @note In case the input has to be reinterpreted as a 3D tensor (i.e. input of convolution layer 1x1), the following information must be passed at compile time:
2799 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
2800 * -# HEIGHT_GEMM3D: The height of the input in case it has to be reinterpreted as a 3D tensor.
2801 * -# DEPTH_GEMM3D: The depth of the input in case it has to be reinterpreted as a 3D tensor
2802 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
Gian Marco19835e52018-01-30 13:35:54 +00002803 *
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01002804 * @param[in] src_ptr Pointer to the source matrix. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002805 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
2806 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2807 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
2808 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marcoae2af742018-02-15 12:35:44 +00002809 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
2810 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002811 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002812 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002813 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2814 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2815 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2816 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marcoae2af742018-02-15 12:35:44 +00002817 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2818 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002819 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002820 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002821 */
Gian Marcoae2af742018-02-15 12:35:44 +00002822__kernel void gemm_interleave4x4(TENSOR3D_DECLARATION(src),
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002823 TENSOR3D_DECLARATION(dst)
2824#if defined(REINTERPRET_INPUT_AS_3D)
2825 ,
2826 uint cross_plane_pad
2827#endif // REINTERPRET_INPUT_AS_3D
2828 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002829{
Gian Marco36a0a462018-01-12 10:21:40 +00002830 // Compute source and destination addresses
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002831 uint x = get_global_id(0);
2832 uint y = get_global_id(1);
Gian Marcoae2af742018-02-15 12:35:44 +00002833 uint z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002834
Gian Marcoae2af742018-02-15 12:35:44 +00002835 // Compute address for source tensor
2836 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002837
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002838 // Compute address for Matrix B transposed - destination. X and Y are swapped
Gian Marco36a0a462018-01-12 10:21:40 +00002839 uint dst_addr_in_bytes = dst_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) * 16 * MULT_INTERLEAVE4X4_HEIGHT + (y / MULT_INTERLEAVE4X4_HEIGHT) * dst_stride_y +
2840 (y % MULT_INTERLEAVE4X4_HEIGHT) * 4 * sizeof(DATA_TYPE);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002841
Gian Marcoae2af742018-02-15 12:35:44 +00002842 // Add offset for batched GEMM
2843 dst_addr_in_bytes += z * dst_stride_z;
2844
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002845#if defined(REINTERPRET_INPUT_AS_3D)
2846 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * 4 * sizeof(DATA_TYPE) + y * 4 * src_stride_y;
2847
2848 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
2849 // in order to take into account the presence of possible cross plane paddings
2850 //
2851 // | |
2852 // | plane0 |
2853 // | |
2854 // |__________________|
2855 // |******************|
2856 // | cross_plane_pad |
2857 // |******************|
2858 // | |
2859 // | plane1 |
2860 // | |
2861 // |__________________|
2862
2863 // The plane (zin) is calculated dividing M (y * 4) by HEIGHT_GEMM3D
2864 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(y * 4)) / (uint4)HEIGHT_GEMM3D;
2865 zin = min(DEPTH_GEMM3D - 1, zin);
2866
2867 // Add offset due to the cross plane paddings
2868 zin *= (cross_plane_pad * src_stride_y);
2869
2870 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2871 // multiply src_stride_z by DEPTH_GEMM3D
2872 input_ptr += z * src_stride_z * DEPTH_GEMM3D;
2873
2874 // Load values from Matrix A
2875 VEC_DATA_TYPE(DATA_TYPE, 4)
2876 a0 = vload4(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y + zin.s0));
2877 VEC_DATA_TYPE(DATA_TYPE, 4)
2878 a1 = vload4(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y + zin.s1));
2879 VEC_DATA_TYPE(DATA_TYPE, 4)
2880 a2 = vload4(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y + zin.s2));
2881 VEC_DATA_TYPE(DATA_TYPE, 4)
2882 a3 = vload4(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y + zin.s3));
2883#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +00002884 __global uchar *input_ptr = src.ptr;
2885
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002886 // Load values from Matrix A
Gian Marco36a0a462018-01-12 10:21:40 +00002887 VEC_DATA_TYPE(DATA_TYPE, 4)
Gian Marcoae2af742018-02-15 12:35:44 +00002888 a0 = vload4(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y));
Gian Marco36a0a462018-01-12 10:21:40 +00002889 VEC_DATA_TYPE(DATA_TYPE, 4)
Gian Marcoae2af742018-02-15 12:35:44 +00002890 a1 = vload4(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y));
Gian Marco36a0a462018-01-12 10:21:40 +00002891 VEC_DATA_TYPE(DATA_TYPE, 4)
Gian Marcoae2af742018-02-15 12:35:44 +00002892 a2 = vload4(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y));
Gian Marco36a0a462018-01-12 10:21:40 +00002893 VEC_DATA_TYPE(DATA_TYPE, 4)
Gian Marcoae2af742018-02-15 12:35:44 +00002894 a3 = vload4(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y));
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002895#endif // defined(REINTERPRET_INPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002896
Gian Marco Iodice4b908652018-10-18 10:21:02 +01002897#if defined(UNROLL_BLOCK)
2898 vstore4(a0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 0 * MULT_INTERLEAVE4X4_HEIGHT));
2899 vstore4(a1, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 4 * MULT_INTERLEAVE4X4_HEIGHT));
2900 vstore4(a2, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 8 * MULT_INTERLEAVE4X4_HEIGHT));
2901 vstore4(a3, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 12 * MULT_INTERLEAVE4X4_HEIGHT));
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +00002902#else // defined(UNROLL_BLOCK)
Gian Marco36a0a462018-01-12 10:21:40 +00002903 VEC_DATA_TYPE(DATA_TYPE, 4)
2904 val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s0, a1.s0, a2.s0, a3.s0);
2905 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 0 * MULT_INTERLEAVE4X4_HEIGHT));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002906
Gian Marco36a0a462018-01-12 10:21:40 +00002907 val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s1, a1.s1, a2.s1, a3.s1);
2908 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 4 * MULT_INTERLEAVE4X4_HEIGHT));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002909
Gian Marco36a0a462018-01-12 10:21:40 +00002910 val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s2, a1.s2, a2.s2, a3.s2);
2911 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 8 * MULT_INTERLEAVE4X4_HEIGHT));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002912
Gian Marco36a0a462018-01-12 10:21:40 +00002913 val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s3, a1.s3, a2.s3, a3.s3);
2914 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 12 * MULT_INTERLEAVE4X4_HEIGHT));
Gian Marco Iodice4b908652018-10-18 10:21:02 +01002915#endif // defined(UNROLL_BLOCK)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002916}
Gian Marco36a0a462018-01-12 10:21:40 +00002917#endif // defined(MULT_INTERLEAVE4X4_HEIGHT) && defined(DATA_TYPE)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002918
Gian Marco36a0a462018-01-12 10:21:40 +00002919#if defined(COLS_B) && defined(MULT_TRANSPOSE1XW_WIDTH) && defined(MULT_INTERLEAVE4X4_HEIGHT)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002920/** This OpenCL kernel is optimised for Midgard. It computes the matrix multiplication between matrix A (src0) and matrix B (src1)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002921 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_32bit and @ref gemm_transpose1x4 before running the matrix multiplication
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002922 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002923 * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
2924 *
Gian Marco19835e52018-01-30 13:35:54 +00002925 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
2926 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
2927 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002928 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
2929 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002930 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002931 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
2932 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2933 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2934 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2935 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
2936 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002937 * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
2938 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002939 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
2940 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
2941 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2942 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
2943 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2944 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002945 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002946 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
2947 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2948 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
2949 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2950 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002951 * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
2952 * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes)
2953 * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
2954 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002955 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002956 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00002957 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002958 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00002959 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002960 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002961 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
2962 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
2963 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002964 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002965 */
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01002966__kernel void gemm_mm_interleaved_transposed_f32(IMAGE_DECLARATION(src0),
2967 IMAGE_DECLARATION(src1),
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002968#if defined(ADD_VEC_C)
2969 VECTOR_DECLARATION(src2),
2970#endif /* defined(ADD_VEC_C) */
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01002971 IMAGE_DECLARATION(dst),
2972 uint src0_stride_z,
2973 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002974 uint dst_stride_z
2975#if defined(REINTERPRET_OUTPUT_AS_3D)
2976 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002977 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002978#endif // REINTERPRET_OUTPUT_AS_3D
2979 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002980{
Gian Marco36a0a462018-01-12 10:21:40 +00002981 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
2982 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +00002983 int z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002984
Gian Marco36a0a462018-01-12 10:21:40 +00002985 // Offset
2986 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
2987 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 4;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002988
Gian Marco36a0a462018-01-12 10:21:40 +00002989 // src_addr_a = address of matrix A
2990 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002991 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
2992 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
2993
2994#if defined(MATRIX_B_DEPTH)
2995 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2996 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
2997#else // defined(MATRIX_B_DEPTH)
2998 src1_addr_in_bytes += z * src1_stride_z;
2999#endif // defined(MATRIX_B_DEPTH)
3000
3001 __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes);
3002 __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003003
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003004 // Compute end row address for matrix B
Gian Marco36a0a462018-01-12 10:21:40 +00003005 __global float *src_end_addr_b = src_addr_b + COLS_B;
3006
3007 src_addr_a += offset_row_a;
3008 src_addr_b += offset_row_b;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003009
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003010 // Reset accumulators
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003011 float4 c00 = 0.0f;
3012 float4 c10 = 0.0f;
3013 float4 c20 = 0.0f;
3014 float4 c30 = 0.0f;
3015
Gian Marco36a0a462018-01-12 10:21:40 +00003016 for(; src_addr_b <= (src_end_addr_b - (int)(8 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003017 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003018 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00003019 float4 a0 = vload4(0, src_addr_a);
3020 float4 b0 = vload4(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003021
3022 c00 += (float4)a0.s0 * b0;
3023 c10 += (float4)a0.s1 * b0;
3024 c20 += (float4)a0.s2 * b0;
3025 c30 += (float4)a0.s3 * b0;
3026
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003027 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00003028 a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
3029 b0 = vload4(0, src_addr_b + 4 * MULT_TRANSPOSE1XW_WIDTH);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003030
3031 c00 += (float4)a0.s0 * b0;
3032 c10 += (float4)a0.s1 * b0;
3033 c20 += (float4)a0.s2 * b0;
3034 c30 += (float4)a0.s3 * b0;
3035 }
3036
Gian Marco36a0a462018-01-12 10:21:40 +00003037 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003038 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003039 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00003040 float4 a0 = vload4(0, src_addr_a);
3041 float4 b0 = vload4(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003042
3043 c00 += (float4)a0.s0 * b0;
3044 c10 += (float4)a0.s1 * b0;
3045 c20 += (float4)a0.s2 * b0;
3046 c30 += (float4)a0.s3 * b0;
3047 }
3048
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003049 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003050 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
3051
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003052#if defined(ALPHA)
3053 // Multiply by the weight of matrix product
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003054 c00 = c00 * (float4)ALPHA;
3055 c10 = c10 * (float4)ALPHA;
3056 c20 = c20 * (float4)ALPHA;
3057 c30 = c30 * (float4)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003058#endif // defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003059
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003060#if defined(ADD_VEC_C)
3061 __global float *src2_addr = (__global float *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
3062 float4 c0 = vload4(0, src2_addr);
3063
3064 c00 += c0;
3065 c10 += c0;
3066 c20 += c0;
3067 c30 += c0;
3068#endif /* defined(ADD_VEC_C) */
3069
Gian Marcoae2af742018-02-15 12:35:44 +00003070 // Compute dst address
3071 __global uchar *dst_addr = offset(&dst, 0, 0);
3072
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003073#if defined(REINTERPRET_OUTPUT_AS_3D)
3074 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003075 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003076 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003077 // | |
3078 // | plane0 |
3079 // | |
3080 // |__________________|
3081 // |******************|
3082 // | cross_plane_pad |
3083 // |******************|
3084 // | |
3085 // | plane1 |
3086 // | |
3087 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003088
3089 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
3090 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
3091 zout = min(DEPTH_GEMM3D - 1, zout);
3092
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003093 // Add offset due to the cross plane paddings
3094 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003095
3096 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3097 // multiply dst_stride_z by DEPTH_GEMM3D
3098 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
3099
3100 // Store 4x4 block
3101 vstore4(c00, 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
3102 vstore4(c10, 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
3103 vstore4(c20, 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
3104 vstore4(c30, 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
3105
3106#else // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +00003107 // Add offset for batched GEMM
3108 dst_addr += z * dst_stride_z;
3109
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003110 // Store 4x4 block
Gian Marcoae2af742018-02-15 12:35:44 +00003111 vstore4(c00, 0, (__global float *)(dst_addr + 0 * dst_stride_y));
3112 vstore4(c10, 0, (__global float *)(dst_addr + 1 * dst_stride_y));
3113 vstore4(c20, 0, (__global float *)(dst_addr + 2 * dst_stride_y));
3114 vstore4(c30, 0, (__global float *)(dst_addr + 3 * dst_stride_y));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003115#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003116}
3117
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003118/** This OpenCL kernel is optimized for Bifrost. It computes the matrix multiplication between matrix A (src0) and matrix B (src1)
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003119 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_32bit and @ref gemm_transpose1x4 before running the matrix multiplication.
3120 *
3121 * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003122 *
Gian Marco19835e52018-01-30 13:35:54 +00003123 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
3124 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
3125 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003126 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
3127 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
3128 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003129 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003130 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
3131 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
3132 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
3133 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
3134 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
3135 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003136 * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
3137 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003138 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
3139 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
3140 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3141 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
3142 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3143 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003144 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003145 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
3146 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3147 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
3148 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3149 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003150 * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
3151 * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes)
3152 * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
3153 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003154 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003155 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00003156 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003157 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00003158 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003159 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003160 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
3161 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
3162 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003163 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003164 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003165__kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0),
3166 IMAGE_DECLARATION(src1),
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003167#if defined(ADD_VEC_C)
3168 VECTOR_DECLARATION(src2),
3169#endif /* defined(ADD_VEC_C) */
Gian Marcoae2af742018-02-15 12:35:44 +00003170 IMAGE_DECLARATION(dst),
3171 uint src0_stride_z,
3172 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003173 uint dst_stride_z
3174#if defined(REINTERPRET_OUTPUT_AS_3D)
3175 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003176 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003177#endif // REINTERPRET_OUTPUT_AS_3D
3178 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003179{
Gian Marco36a0a462018-01-12 10:21:40 +00003180 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
3181 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +00003182 int z = get_global_id(2);
Gian Marco36a0a462018-01-12 10:21:40 +00003183
3184 // Offset
3185 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
3186 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 4;
3187
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003188 // src_addr_a = address of matrix A
3189 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003190 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
3191 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
3192
3193#if defined(MATRIX_B_DEPTH)
3194 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
3195 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
3196#else // defined(MATRIX_B_DEPTH)
3197 src1_addr_in_bytes += z * src1_stride_z;
3198#endif // defined(MATRIX_B_DEPTH)
3199
3200 __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes);
3201 __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003202
Gian Marco36a0a462018-01-12 10:21:40 +00003203 src_addr_a += offset_row_a;
3204 src_addr_b += offset_row_b;
3205
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003206 // Reset accumulators
3207 float c00 = 0.0f;
3208 float c01 = 0.0f;
3209 float c02 = 0.0f;
3210 float c03 = 0.0f;
3211 float c10 = 0.0f;
3212 float c11 = 0.0f;
3213 float c12 = 0.0f;
3214 float c13 = 0.0f;
3215 float c20 = 0.0f;
3216 float c21 = 0.0f;
3217 float c22 = 0.0f;
3218 float c23 = 0.0f;
3219 float c30 = 0.0f;
3220 float c31 = 0.0f;
3221 float c32 = 0.0f;
3222 float c33 = 0.0f;
3223
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003224#define COLS_MTX_B (COLS_B / (4 * MULT_TRANSPOSE1XW_WIDTH))
3225
3226 int i = 0;
3227 for(; i <= (int)(COLS_MTX_B - 4); i += 4)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003228 {
3229 // Load values from matrix A (interleaved) and matrix B (transposed)
3230 float4 a0 = vload4(0, src_addr_a);
3231 float4 b0 = vload4(0, src_addr_b);
3232
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003233 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3234 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003235
3236 c00 = fma(a0.s0, b0.s0, c00);
3237 c01 = fma(a0.s0, b0.s1, c01);
3238 c02 = fma(a0.s0, b0.s2, c02);
3239 c03 = fma(a0.s0, b0.s3, c03);
3240
3241 c10 = fma(a0.s1, b0.s0, c10);
3242 c11 = fma(a0.s1, b0.s1, c11);
3243 c12 = fma(a0.s1, b0.s2, c12);
3244 c13 = fma(a0.s1, b0.s3, c13);
3245
3246 c20 = fma(a0.s2, b0.s0, c20);
3247 c21 = fma(a0.s2, b0.s1, c21);
3248 c22 = fma(a0.s2, b0.s2, c22);
3249 c23 = fma(a0.s2, b0.s3, c23);
3250
3251 c30 = fma(a0.s3, b0.s0, c30);
3252 c31 = fma(a0.s3, b0.s1, c31);
3253 c32 = fma(a0.s3, b0.s2, c32);
3254 c33 = fma(a0.s3, b0.s3, c33);
3255
3256 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003257 a0 = vload4(0, src_addr_a);
3258 b0 = vload4(0, src_addr_b);
3259
3260 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3261 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003262
3263 c00 = fma(a0.s0, b0.s0, c00);
3264 c01 = fma(a0.s0, b0.s1, c01);
3265 c02 = fma(a0.s0, b0.s2, c02);
3266 c03 = fma(a0.s0, b0.s3, c03);
3267
3268 c10 = fma(a0.s1, b0.s0, c10);
3269 c11 = fma(a0.s1, b0.s1, c11);
3270 c12 = fma(a0.s1, b0.s2, c12);
3271 c13 = fma(a0.s1, b0.s3, c13);
3272
3273 c20 = fma(a0.s2, b0.s0, c20);
3274 c21 = fma(a0.s2, b0.s1, c21);
3275 c22 = fma(a0.s2, b0.s2, c22);
3276 c23 = fma(a0.s2, b0.s3, c23);
3277
3278 c30 = fma(a0.s3, b0.s0, c30);
3279 c31 = fma(a0.s3, b0.s1, c31);
3280 c32 = fma(a0.s3, b0.s2, c32);
3281 c33 = fma(a0.s3, b0.s3, c33);
3282
3283 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003284 a0 = vload4(0, src_addr_a);
3285 b0 = vload4(0, src_addr_b);
3286
3287 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3288 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
3289
3290 c00 = fma(a0.s0, b0.s0, c00);
3291 c01 = fma(a0.s0, b0.s1, c01);
3292 c02 = fma(a0.s0, b0.s2, c02);
3293 c03 = fma(a0.s0, b0.s3, c03);
3294
3295 c10 = fma(a0.s1, b0.s0, c10);
3296 c11 = fma(a0.s1, b0.s1, c11);
3297 c12 = fma(a0.s1, b0.s2, c12);
3298 c13 = fma(a0.s1, b0.s3, c13);
3299
3300 c20 = fma(a0.s2, b0.s0, c20);
3301 c21 = fma(a0.s2, b0.s1, c21);
3302 c22 = fma(a0.s2, b0.s2, c22);
3303 c23 = fma(a0.s2, b0.s3, c23);
3304
3305 c30 = fma(a0.s3, b0.s0, c30);
3306 c31 = fma(a0.s3, b0.s1, c31);
3307 c32 = fma(a0.s3, b0.s2, c32);
3308 c33 = fma(a0.s3, b0.s3, c33);
3309
3310 // Load values from matrix A (interleaved) and matrix B (transposed)
3311 a0 = vload4(0, src_addr_a);
3312 b0 = vload4(0, src_addr_b);
3313
3314 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3315 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003316
3317 c00 = fma(a0.s0, b0.s0, c00);
3318 c01 = fma(a0.s0, b0.s1, c01);
3319 c02 = fma(a0.s0, b0.s2, c02);
3320 c03 = fma(a0.s0, b0.s3, c03);
3321
3322 c10 = fma(a0.s1, b0.s0, c10);
3323 c11 = fma(a0.s1, b0.s1, c11);
3324 c12 = fma(a0.s1, b0.s2, c12);
3325 c13 = fma(a0.s1, b0.s3, c13);
3326
3327 c20 = fma(a0.s2, b0.s0, c20);
3328 c21 = fma(a0.s2, b0.s1, c21);
3329 c22 = fma(a0.s2, b0.s2, c22);
3330 c23 = fma(a0.s2, b0.s3, c23);
3331
3332 c30 = fma(a0.s3, b0.s0, c30);
3333 c31 = fma(a0.s3, b0.s1, c31);
3334 c32 = fma(a0.s3, b0.s2, c32);
3335 c33 = fma(a0.s3, b0.s3, c33);
3336 }
3337
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003338 for(; i < (int)(COLS_MTX_B); ++i)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003339 {
3340 // Load values from matrix A (interleaved) and matrix B (transposed)
3341 float4 a0 = vload4(0, src_addr_a);
3342 float4 b0 = vload4(0, src_addr_b);
3343
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003344 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3345 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
3346
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003347 c00 = fma(a0.s0, b0.s0, c00);
3348 c01 = fma(a0.s0, b0.s1, c01);
3349 c02 = fma(a0.s0, b0.s2, c02);
3350 c03 = fma(a0.s0, b0.s3, c03);
3351
3352 c10 = fma(a0.s1, b0.s0, c10);
3353 c11 = fma(a0.s1, b0.s1, c11);
3354 c12 = fma(a0.s1, b0.s2, c12);
3355 c13 = fma(a0.s1, b0.s3, c13);
3356
3357 c20 = fma(a0.s2, b0.s0, c20);
3358 c21 = fma(a0.s2, b0.s1, c21);
3359 c22 = fma(a0.s2, b0.s2, c22);
3360 c23 = fma(a0.s2, b0.s3, c23);
3361
3362 c30 = fma(a0.s3, b0.s0, c30);
3363 c31 = fma(a0.s3, b0.s1, c31);
3364 c32 = fma(a0.s3, b0.s2, c32);
3365 c33 = fma(a0.s3, b0.s3, c33);
3366 }
3367
3368 // Compute destination address
3369 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
3370
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003371#if defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003372 // Multiply by the weight of matrix product
3373 c00 = c00 * ALPHA;
3374 c01 = c01 * ALPHA;
3375 c02 = c02 * ALPHA;
3376 c03 = c03 * ALPHA;
3377 c10 = c10 * ALPHA;
3378 c11 = c11 * ALPHA;
3379 c12 = c12 * ALPHA;
3380 c13 = c13 * ALPHA;
3381 c20 = c20 * ALPHA;
3382 c21 = c21 * ALPHA;
3383 c22 = c22 * ALPHA;
3384 c23 = c23 * ALPHA;
3385 c30 = c30 * ALPHA;
3386 c31 = c31 * ALPHA;
3387 c32 = c32 * ALPHA;
3388 c33 = c33 * ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003389#endif // defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003390
Gian Marcoae2af742018-02-15 12:35:44 +00003391 // Compute dst address
3392 __global uchar *dst_addr = offset(&dst, 0, 0);
3393
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003394#if defined(ADD_VEC_C)
3395 __global float *src2_addr = (__global float *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
3396 float4 c0 = vload4(0, src2_addr);
3397
3398 c00 += c0.s0;
3399 c01 += c0.s1;
3400 c02 += c0.s2;
3401 c03 += c0.s3;
3402 c10 += c0.s0;
3403 c11 += c0.s1;
3404 c12 += c0.s2;
3405 c13 += c0.s3;
3406 c20 += c0.s0;
3407 c21 += c0.s1;
3408 c22 += c0.s2;
3409 c23 += c0.s3;
3410 c30 += c0.s0;
3411 c31 += c0.s1;
3412 c32 += c0.s2;
3413 c33 += c0.s3;
3414#endif /* defined(ADD_VEC_C) */
3415
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003416#if defined(REINTERPRET_OUTPUT_AS_3D)
3417 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003418 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003419 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003420 // | |
3421 // | plane0 |
3422 // | |
3423 // |__________________|
3424 // |******************|
3425 // | cross_plane_pad |
3426 // |******************|
3427 // | |
3428 // | plane1 |
3429 // | |
3430 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003431
3432 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
3433 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
3434 zout = min(DEPTH_GEMM3D - 1, zout);
3435
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003436 // Add offset due to the cross plane paddings
3437 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003438
3439 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3440 // multiply dst_stride_z by DEPTH_GEMM3D
3441 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
3442
3443 // Store 4x4 block
3444 vstore4((float4)(c00, c01, c02, c03), 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
3445 vstore4((float4)(c10, c11, c12, c13), 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
3446 vstore4((float4)(c20, c21, c22, c23), 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
3447 vstore4((float4)(c30, c31, c32, c33), 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
3448
3449#else // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +00003450 // Add offset for batched GEMM
3451 dst_addr += z * dst_stride_z;
3452
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003453 // Store 4x4 block
Gian Marcoae2af742018-02-15 12:35:44 +00003454 vstore4((float4)(c00, c01, c02, c03), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
3455 vstore4((float4)(c10, c11, c12, c13), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
3456 vstore4((float4)(c20, c21, c22, c23), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
3457 vstore4((float4)(c30, c31, c32, c33), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003458#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003459}
3460
Georgios Pinitas84225582018-05-14 12:00:05 +01003461// Undefine local defines
3462#undef COLS_MTX_B
3463
Matthew Bentham6f31f8c2017-10-27 11:50:06 +01003464#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003465/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003466 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003467 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003468 * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
3469 *
Gian Marco19835e52018-01-30 13:35:54 +00003470 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
3471 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
3472 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003473 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
3474 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003475 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003476 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
3477 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
3478 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
3479 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
3480 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
3481 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003482 * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
3483 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003484 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
3485 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
3486 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3487 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
3488 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3489 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003490 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003491 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
3492 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3493 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
3494 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3495 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003496 * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
3497 * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes)
3498 * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
3499 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003500 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003501 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00003502 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003503 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00003504 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003505 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003506 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
3507 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
3508 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003509 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003510 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003511__kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0),
3512 IMAGE_DECLARATION(src1),
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003513#if defined(ADD_VEC_C)
3514 VECTOR_DECLARATION(src2),
3515#endif /* defined(ADD_VEC_C) */
Gian Marcoae2af742018-02-15 12:35:44 +00003516 IMAGE_DECLARATION(dst),
3517 uint src0_stride_z,
3518 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003519 uint dst_stride_z
3520#if defined(REINTERPRET_OUTPUT_AS_3D)
3521 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003522 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003523#endif // REINTERPRET_OUTPUT_AS_3D
3524 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003525{
Gian Marco36a0a462018-01-12 10:21:40 +00003526 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
3527 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +00003528 int z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003529
Gian Marco36a0a462018-01-12 10:21:40 +00003530 // Offset
3531 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
3532 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003533
Gian Marco36a0a462018-01-12 10:21:40 +00003534 // src_addr_a = address of matrix A
3535 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003536 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
3537 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
3538
3539#if defined(MATRIX_B_DEPTH)
3540 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
3541 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
3542#else // defined(MATRIX_B_DEPTH)
3543 src1_addr_in_bytes += z * src1_stride_z;
3544#endif // defined(MATRIX_B_DEPTH)
3545
3546 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
3547 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003548
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003549 // Compute end row address for matrix B
Gian Marco36a0a462018-01-12 10:21:40 +00003550 __global half *src_end_addr_b = src_addr_b + COLS_B;
3551
3552 src_addr_a += offset_row_a;
3553 src_addr_b += offset_row_b;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003554
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003555 // Reset accumulators
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003556 half8 c00 = 0.0f;
3557 half8 c10 = 0.0f;
3558 half8 c20 = 0.0f;
3559 half8 c30 = 0.0f;
3560
Gian Marco36a0a462018-01-12 10:21:40 +00003561 for(; src_addr_b <= (src_end_addr_b - (int)(16 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 16 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003562 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003563 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00003564 half4 a0 = vload4(0, src_addr_a);
3565 half8 b0 = vload8(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003566
3567 c00 += (half8)a0.s0 * b0;
3568 c10 += (half8)a0.s1 * b0;
3569 c20 += (half8)a0.s2 * b0;
3570 c30 += (half8)a0.s3 * b0;
3571
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003572 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00003573 a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
3574 b0 = vload8(0, src_addr_b + 8 * MULT_TRANSPOSE1XW_WIDTH);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003575
3576 c00 += (half8)a0.s0 * b0;
3577 c10 += (half8)a0.s1 * b0;
3578 c20 += (half8)a0.s2 * b0;
3579 c30 += (half8)a0.s3 * b0;
3580 }
3581
Gian Marco36a0a462018-01-12 10:21:40 +00003582 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003583 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003584 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00003585 half4 a0 = vload4(0, src_addr_a);
3586 half8 b0 = vload8(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003587
3588 c00 += (half8)a0.s0 * b0;
3589 c10 += (half8)a0.s1 * b0;
3590 c20 += (half8)a0.s2 * b0;
3591 c30 += (half8)a0.s3 * b0;
3592 }
3593
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003594 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003595 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
3596
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003597#if defined(ALPHA)
3598 // Multiply by the weight of matrix product
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003599 c00 = c00 * (half8)ALPHA;
3600 c10 = c10 * (half8)ALPHA;
3601 c20 = c20 * (half8)ALPHA;
3602 c30 = c30 * (half8)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003603#endif // defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003604
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003605#if defined(ADD_VEC_C)
3606 // *INDENT-OFF*
3607 // clang-format off
3608 __global half *src2_addr = (__global half *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
3609 half8 c0 = vload8(0, src2_addr);
3610 // clang-format on
3611 // *INDENT-ON*
3612
3613 c00 += c0;
3614 c10 += c0;
3615 c20 += c0;
3616 c30 += c0;
3617#endif /* defined(ADD_VEC_C) */
3618
Gian Marcoae2af742018-02-15 12:35:44 +00003619 // Compute dst address
3620 __global uchar *dst_addr = offset(&dst, 0, 0);
3621
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003622#if defined(REINTERPRET_OUTPUT_AS_3D)
3623 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003624 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003625 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003626 // | |
3627 // | plane0 |
3628 // | |
3629 // |__________________|
3630 // |******************|
3631 // | cross_plane_pad |
3632 // |******************|
3633 // | |
3634 // | plane1 |
3635 // | |
3636 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003637
3638 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
3639 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
3640 zout = min(DEPTH_GEMM3D - 1, zout);
3641
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003642 // Add offset due to the cross plane paddings
3643 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003644
3645 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3646 // multiply dst_stride_z by DEPTH_GEMM3D
3647 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
3648
3649 // Store 4x8 block
3650 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
3651 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
3652 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
3653 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
3654
3655#else // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +00003656 // Add offset for batched GEMM
3657 dst_addr += z * dst_stride_z;
3658
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003659 // Store 4x8 block
Gian Marcoae2af742018-02-15 12:35:44 +00003660 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
3661 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
3662 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
3663 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003664#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003665}
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003666
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003667/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1) while accumulating the result in a 32 floating point variable.
3668 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
3669 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003670 * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
3671 *
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003672 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
3673 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
3674 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
3675 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
3676 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
3677 *
3678 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
3679 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
3680 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
3681 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
3682 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
3683 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003684 * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
3685 *
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003686 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
3687 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
3688 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3689 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
3690 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3691 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
3692 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
3693 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
3694 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3695 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
3696 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3697 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003698 * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
3699 * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes)
3700 * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
3701 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003702 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
3703 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
3704 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
3705 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
3706 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
3707 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
3708 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
3709 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
3710 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
3711 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
3712 */
3713__kernel void gemm_mm_interleaved_transposed_f16_acc32(IMAGE_DECLARATION(src0),
3714 IMAGE_DECLARATION(src1),
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003715#if defined(ADD_VEC_C)
3716 VECTOR_DECLARATION(src2),
3717#endif /* defined(ADD_VEC_C) */
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003718 IMAGE_DECLARATION(dst),
3719 uint src0_stride_z,
3720 uint src1_stride_z,
3721 uint dst_stride_z
3722#if defined(REINTERPRET_OUTPUT_AS_3D)
3723 ,
3724 uint cross_plane_pad
3725#endif // REINTERPRET_OUTPUT_AS_3D
3726 )
3727{
3728 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
3729 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
3730 int z = get_global_id(2);
3731
3732 // Offset
3733 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
3734 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
3735
3736 // src_addr_a = address of matrix A
3737 // src_addr_b = address of matrix B
3738 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
3739 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
3740
3741#if defined(MATRIX_B_DEPTH)
3742 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
3743 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
3744#else // defined(MATRIX_B_DEPTH)
3745 src1_addr_in_bytes += z * src1_stride_z;
3746#endif // defined(MATRIX_B_DEPTH)
3747
3748 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
3749 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
3750
3751 // Compute end row address for matrix B
3752 __global half *src_end_addr_b = src_addr_b + COLS_B;
3753
3754 src_addr_a += offset_row_a;
3755 src_addr_b += offset_row_b;
3756
3757 // Reset accumulators
3758 float8 c00 = 0.0f;
3759 float8 c10 = 0.0f;
3760 float8 c20 = 0.0f;
3761 float8 c30 = 0.0f;
3762
3763 for(; src_addr_b <= (src_end_addr_b - (int)(16 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 16 * MULT_TRANSPOSE1XW_WIDTH)
3764 {
3765 // Load values from matrix A (interleaved) and matrix B (transposed)
3766 float4 a0 = convert_float4(vload4(0, src_addr_a));
3767 float8 b0 = convert_float8(vload8(0, src_addr_b));
3768
3769 c00 += (float8)a0.s0 * b0;
3770 c10 += (float8)a0.s1 * b0;
3771 c20 += (float8)a0.s2 * b0;
3772 c30 += (float8)a0.s3 * b0;
3773
3774 // Load values from matrix A (interleaved) and matrix B (transposed)
3775 a0 = convert_float4(vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT));
3776 b0 = convert_float8(vload8(0, src_addr_b + 8 * MULT_TRANSPOSE1XW_WIDTH));
3777
3778 c00 += (float8)a0.s0 * b0;
3779 c10 += (float8)a0.s1 * b0;
3780 c20 += (float8)a0.s2 * b0;
3781 c30 += (float8)a0.s3 * b0;
3782 }
3783
3784 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
3785 {
3786 // Load values from matrix A (interleaved) and matrix B (transposed)
3787 float4 a0 = convert_float4(vload4(0, src_addr_a));
3788 float8 b0 = convert_float8(vload8(0, src_addr_b));
3789
3790 c00 += (float8)a0.s0 * b0;
3791 c10 += (float8)a0.s1 * b0;
3792 c20 += (float8)a0.s2 * b0;
3793 c30 += (float8)a0.s3 * b0;
3794 }
3795
3796 // Compute destination address
3797 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
3798
3799#if defined(ALPHA)
3800 // Multiply by the weight of matrix product
3801 c00 = c00 * (float8)ALPHA;
3802 c10 = c10 * (float8)ALPHA;
3803 c20 = c20 * (float8)ALPHA;
3804 c30 = c30 * (float8)ALPHA;
3805#endif // defined(ALPHA)
3806
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003807#if defined(ADD_VEC_C)
3808 // *INDENT-OFF*
3809 // clang-format off
3810 __global half *src2_addr = (__global half *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
3811 float8 c0 = convert_float8(vload8(0, src2_addr));
3812 // clang-format on
3813 // *INDENT-ON*
3814
3815 c00 += c0;
3816 c10 += c0;
3817 c20 += c0;
3818 c30 += c0;
3819#endif /* defined(ADD_VEC_C) */
3820
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003821 // Compute dst address
3822 __global uchar *dst_addr = offset(&dst, 0, 0);
3823
3824#if defined(REINTERPRET_OUTPUT_AS_3D)
3825 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
3826 // in order to take into account the presence of possible cross plane paddings
3827 //
3828 // | |
3829 // | plane0 |
3830 // | |
3831 // |__________________|
3832 // |******************|
3833 // | cross_plane_pad |
3834 // |******************|
3835 // | |
3836 // | plane1 |
3837 // | |
3838 // |__________________|
3839
3840 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
3841 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
3842 zout = min(DEPTH_GEMM3D - 1, zout);
3843
3844 // Add offset due to the cross plane paddings
3845 zout *= (cross_plane_pad * dst_stride_y);
3846
3847 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3848 // multiply dst_stride_z by DEPTH_GEMM3D
3849 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
3850
3851 // Store 4x8 block
3852 vstore8(convert_half8(c00), 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
3853 vstore8(convert_half8(c10), 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
3854 vstore8(convert_half8(c20), 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
3855 vstore8(convert_half8(c30), 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
3856
3857#else // defined(REINTERPRET_OUTPUT_AS_3D)
3858 // Add offset for batched GEMM
3859 dst_addr += z * dst_stride_z;
3860
3861 // Store 4x8 block
3862 vstore8(convert_half8(c00), 0, (__global half *)(dst_addr + 0 * dst_stride_y));
3863 vstore8(convert_half8(c10), 0, (__global half *)(dst_addr + 1 * dst_stride_y));
3864 vstore8(convert_half8(c20), 0, (__global half *)(dst_addr + 2 * dst_stride_y));
3865 vstore8(convert_half8(c30), 0, (__global half *)(dst_addr + 3 * dst_stride_y));
3866#endif // defined(REINTERPRET_OUTPUT_AS_3D)
3867}
3868
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003869/** This OpenCL kernel optimized for Bifrost architectures computes the matrix multiplication between matrix A (src0) and matrix B (src1)
3870 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
3871 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003872 * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
3873 *
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003874 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
3875 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
3876 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
3877 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
3878 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
3879 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003880 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
3881 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
3882 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
3883 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
3884 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
3885 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003886 * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
3887 *
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003888 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
3889 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
3890 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3891 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
3892 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3893 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
3894 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
3895 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
3896 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3897 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
3898 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3899 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003900 * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
3901 * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes)
3902 * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
3903 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003904 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
3905 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
3906 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
3907 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
3908 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
3909 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003910 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003911 */
3912__kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0),
3913 IMAGE_DECLARATION(src1),
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003914#if defined(ADD_VEC_C)
3915 VECTOR_DECLARATION(src2),
3916#endif /* defined(ADD_VEC_C) */
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003917 IMAGE_DECLARATION(dst),
3918 uint src0_stride_z,
3919 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003920 uint dst_stride_z
3921#if defined(REINTERPRET_OUTPUT_AS_3D)
3922 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003923 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003924#endif // REINTERPRET_OUTPUT_AS_3D
3925 )
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003926{
3927 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
3928 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
3929 int z = get_global_id(2);
3930
3931 // Offset
3932 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
3933 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
3934
3935 // src_addr_a = address of matrix A
3936 // src_addr_b = address of matrix B
3937 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
3938 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
3939
3940#if defined(MATRIX_B_DEPTH)
3941 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
3942 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
3943#else // defined(MATRIX_B_DEPTH)
3944 src1_addr_in_bytes += z * src1_stride_z;
3945#endif // defined(MATRIX_B_DEPTH)
3946
3947 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
3948 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
3949
3950 // Compute end row address for matrix B
3951 __global half *src_end_addr_b = src_addr_b + COLS_B;
3952
3953 src_addr_a += offset_row_a;
3954 src_addr_b += offset_row_b;
3955
3956 // Reset accumulators
3957 half8 c00 = 0.0f;
3958 half8 c10 = 0.0f;
3959 half8 c20 = 0.0f;
3960 half8 c30 = 0.0f;
3961
3962#define COLS_MTX_B (COLS_B / (8 * MULT_TRANSPOSE1XW_WIDTH))
3963
3964 int i = 0;
3965 for(; i <= (int)(COLS_MTX_B - 4); i += 4)
3966 {
3967#if MULT_INTERLEAVE4X4_HEIGHT == 1
3968 // Load values from matrix A (interleaved) and matrix B (transposed)
3969 half8 a0 = vload8(0, src_addr_a);
3970 half8 b0 = vload8(0, src_addr_b);
3971
3972 src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT;
3973 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3974
3975 c00 = fma((half8)a0.s0, b0, c00);
3976 c10 = fma((half8)a0.s1, b0, c10);
3977 c20 = fma((half8)a0.s2, b0, c20);
3978 c30 = fma((half8)a0.s3, b0, c30);
3979
3980 // Load values from matrix B (transposed)
3981 b0 = vload8(0, src_addr_b);
3982
3983 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3984
3985 c00 = fma((half8)a0.s4, b0, c00);
3986 c10 = fma((half8)a0.s5, b0, c10);
3987 c20 = fma((half8)a0.s6, b0, c20);
3988 c30 = fma((half8)a0.s7, b0, c30);
3989
3990 // Load values from matrix A (interleaved) and matrix B (transposed)
3991 a0 = vload8(0, src_addr_a);
3992 b0 = vload8(0, src_addr_b);
3993
3994 src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT;
3995 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3996
3997 c00 = fma((half8)a0.s0, b0, c00);
3998 c10 = fma((half8)a0.s1, b0, c10);
3999 c20 = fma((half8)a0.s2, b0, c20);
4000 c30 = fma((half8)a0.s3, b0, c30);
4001
4002 // Load values from matrix B (transposed)
4003 b0 = vload8(0, src_addr_b);
4004
4005 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
4006
4007 c00 = fma((half8)a0.s4, b0, c00);
4008 c10 = fma((half8)a0.s5, b0, c10);
4009 c20 = fma((half8)a0.s6, b0, c20);
4010 c30 = fma((half8)a0.s7, b0, c30);
4011#else // MULT_INTERLEAVE4X4_HEIGHT == 1
4012 // Load values from matrix A (interleaved) and matrix B (transposed)
4013 half4 a0 = vload4(0, src_addr_a);
4014 half8 b0 = vload8(0, src_addr_b);
4015
4016 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
4017 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
4018
4019 c00 = fma((half8)a0.s0, b0, c00);
4020 c10 = fma((half8)a0.s1, b0, c10);
4021 c20 = fma((half8)a0.s2, b0, c20);
4022 c30 = fma((half8)a0.s3, b0, c30);
4023
4024 // Load values from matrix A (interleaved) and matrix B (transposed)
4025 a0 = vload4(0, src_addr_a);
4026 b0 = vload8(0, src_addr_b);
4027
4028 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
4029 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
4030
4031 c00 = fma((half8)a0.s0, b0, c00);
4032 c10 = fma((half8)a0.s1, b0, c10);
4033 c20 = fma((half8)a0.s2, b0, c20);
4034 c30 = fma((half8)a0.s3, b0, c30);
4035
4036 // Load values from matrix A (interleaved) and matrix B (transposed)
4037 a0 = vload4(0, src_addr_a);
4038 b0 = vload8(0, src_addr_b);
4039
4040 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
4041 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
4042
4043 c00 = fma((half8)a0.s0, b0, c00);
4044 c10 = fma((half8)a0.s1, b0, c10);
4045 c20 = fma((half8)a0.s2, b0, c20);
4046 c30 = fma((half8)a0.s3, b0, c30);
4047
4048 // Load values from matrix A (interleaved) and matrix B (transposed)
4049 a0 = vload4(0, src_addr_a);
4050 b0 = vload8(0, src_addr_b);
4051
4052 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
4053 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
4054
4055 c00 = fma((half8)a0.s0, b0, c00);
4056 c10 = fma((half8)a0.s1, b0, c10);
4057 c20 = fma((half8)a0.s2, b0, c20);
4058 c30 = fma((half8)a0.s3, b0, c30);
4059#endif // MULT_INTERLEAVE4X4_HEIGHT == 1
4060 }
4061
4062 for(; i < (int)(COLS_MTX_B); ++i)
4063 {
4064 // Load values from matrix A (interleaved) and matrix B (transposed)
4065 half4 a0 = vload4(0, src_addr_a);
4066 half8 b0 = vload8(0, src_addr_b);
4067
4068 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
4069 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
4070
4071 c00 = fma((half8)a0.s0, b0, c00);
4072 c10 = fma((half8)a0.s1, b0, c10);
4073 c20 = fma((half8)a0.s2, b0, c20);
4074 c30 = fma((half8)a0.s3, b0, c30);
4075 }
4076
4077 // Compute destination address
4078 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
4079
4080#if defined(ALPHA)
4081 // Multiply by the weight of matrix product
4082 c00 = c00 * (half8)ALPHA;
4083 c10 = c10 * (half8)ALPHA;
4084 c20 = c20 * (half8)ALPHA;
4085 c30 = c30 * (half8)ALPHA;
4086#endif // defined(ALPHA)
4087
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004088#if defined(ADD_VEC_C)
4089 // *INDENT-OFF*
4090 // clang-format off
4091 __global half *src2_addr = (__global half *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
4092 half8 c0 = vload8(0, src2_addr);
4093 // clang-format on
4094 // *INDENT-ON*
4095
4096 c00 += c0;
4097 c10 += c0;
4098 c20 += c0;
4099 c30 += c0;
4100#endif /* defined(ADD_VEC_C) */
4101
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004102 // Compute dst address
4103 __global uchar *dst_addr = offset(&dst, 0, 0);
4104
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004105#if defined(REINTERPRET_OUTPUT_AS_3D)
4106 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004107 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004108 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004109 // | |
4110 // | plane0 |
4111 // | |
4112 // |__________________|
4113 // |******************|
4114 // | cross_plane_pad |
4115 // |******************|
4116 // | |
4117 // | plane1 |
4118 // | |
4119 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004120
4121 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
4122 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
4123 zout = min(DEPTH_GEMM3D - 1, zout);
4124
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004125 // Add offset due to the cross plane paddings
4126 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004127
4128 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4129 // multiply dst_stride_z by DEPTH_GEMM3D
4130 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
4131
4132 // Store 4x8 block
4133 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
4134 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
4135 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
4136 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
4137
4138#else // defined(REINTERPRET_OUTPUT_AS_3D)
4139 // Add offset for batched GEMM
4140 dst_addr += z * dst_stride_z;
4141
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004142 // Store 4x8 block
4143 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
4144 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
4145 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
4146 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004147#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004148}
Georgios Pinitas84225582018-05-14 12:00:05 +01004149
4150// Undefine local defines
4151#undef COLS_MTX_B
4152
Matthew Bentham6f31f8c2017-10-27 11:50:06 +01004153#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004154
Gian Marco36a0a462018-01-12 10:21:40 +00004155#endif // defined(COLS_B) && defined(MULT_TRANSPOSE1XW_WIDTH) && defined(MULT_INTERLEAVE4X4_HEIGHT)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01004156
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004157#if defined(COLS_A) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && (NUM_ELEMS_PROCESSED_PER_THREAD_Y)
4158#if defined(DATA_TYPE)
4159#define VECTOR_TYPE VEC_DATA_TYPE(DATA_TYPE, NUM_ELEMS_PROCESSED_PER_THREAD_X)
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004160/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped.
4161 *
4162 * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004163 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004164 * @note This OpenCL kernel works with floating point data types (F16/F32)
4165 * @note The floating point data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
4166 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004167 * @note The number of matrix A columns and the optional alpha's value need to be passed at compile time using -DCOLS_A and -DALPHA
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00004168 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
4169 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004170 *
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004171 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
4172 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004173 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
4174 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
4175 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
4176 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
4177 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004178 * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
4179 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004180 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004181 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
4182 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4183 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
4184 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4185 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01004186 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004187 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
4188 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4189 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
4190 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4191 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004192 * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
4193 * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes)
4194 * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
4195 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01004196 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004197 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
4198 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
4199 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
4200 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
4201 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004202 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
4203 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
4204 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004205 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
4206 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements for the output tensor (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004207 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004208__kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
4209 IMAGE_DECLARATION(src1),
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004210#if defined(ADD_VEC_C)
4211 VECTOR_DECLARATION(src2),
4212#endif /* defined(ADD_VEC_C) */
Gian Marcoae2af742018-02-15 12:35:44 +00004213 IMAGE_DECLARATION(dst),
4214 uint src0_stride_z,
4215 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004216 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004217#if defined(REINTERPRET_INPUT_AS_3D)
4218 ,
4219 uint src_cross_plane_pad
4220#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004221#if defined(REINTERPRET_OUTPUT_AS_3D)
4222 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004223 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004224#endif // REINTERPRET_OUTPUT_AS_3D
4225 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004226{
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004227 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004228
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004229 // Compute starting address for matrix A and Matrix B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004230 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004231
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004232 // Update address for the matrix A
4233 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004234
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004235 // Update address for the matrix B
4236 src_addr.s1 += idx * sizeof(DATA_TYPE);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004237
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004238#if defined(REINTERPRET_INPUT_AS_3D)
4239 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
4240 // in order to take into account the presence of possible cross plane paddings
4241 //
4242 // | |
4243 // | plane0 |
4244 // | |
4245 // |__________________|
4246 // |******************|
4247 // | cross_plane_pad |
4248 // |******************|
4249 // | |
4250 // | plane1 |
4251 // | |
4252 // |__________________|
4253
4254 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
4255 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
4256 zin = min(DEPTH_GEMM3D - 1, zin);
4257
4258 // Add offset due to the cross plane paddings
4259 zin *= (src_cross_plane_pad * src0_stride_y);
4260
4261 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4262 // multiply src0_stride_z by DEPTH_GEMM3D
4263 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
4264
4265#else // defined(REINTERPRET_INPUT_AS_3D)
4266
Gian Marcoae2af742018-02-15 12:35:44 +00004267 // Add offset for batched GEMM
4268 src_addr.s0 += get_global_id(2) * src0_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00004269
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004270#endif // defined(REINTERPRET_INPUT_AS_3D)
4271
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00004272#if defined(MATRIX_B_DEPTH)
4273 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
4274 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
4275#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00004276 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00004277#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00004278
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004279 int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(DATA_TYPE));
4280
4281 VECTOR_TYPE acc0 = 0.0f;
4282#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4283 VECTOR_TYPE acc1 = 0.0f;
4284#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4285#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4286 VECTOR_TYPE acc2 = 0.0f;
4287#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4288#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4289 VECTOR_TYPE acc3 = 0.0f;
4290#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4291
Georgios Pinitas96880cf2017-10-20 18:52:20 +01004292 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(DATA_TYPE)); src_addr += (int2)(2 * sizeof(DATA_TYPE), 2 * src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004293 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004294#if defined(REINTERPRET_INPUT_AS_3D)
4295 // Load values from matrix A
4296 VEC_DATA_TYPE(DATA_TYPE, 2)
4297 a0 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
4298#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4299 VEC_DATA_TYPE(DATA_TYPE, 2)
4300 a1 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
4301#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4302#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4303 VEC_DATA_TYPE(DATA_TYPE, 2)
4304 a2 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
4305#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4306#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4307 VEC_DATA_TYPE(DATA_TYPE, 2)
4308 a3 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
4309#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4310#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004311 // Load values from matrix A
4312 VEC_DATA_TYPE(DATA_TYPE, 2)
4313 a0 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
4314#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4315 VEC_DATA_TYPE(DATA_TYPE, 2)
4316 a1 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
4317#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4318#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4319 VEC_DATA_TYPE(DATA_TYPE, 2)
4320 a2 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
4321#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4322#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4323 VEC_DATA_TYPE(DATA_TYPE, 2)
4324 a3 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
4325#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004326#endif // defined(REINTERPRET_INPUT_AS_3D)
4327
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004328 // Load values from matrix B
4329 VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
4330 VECTOR_TYPE b1 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1 + src1_stride_y));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004331
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004332 // Accumulate
4333 acc0 += b0 * (VECTOR_TYPE)a0.s0;
4334 acc0 += b1 * (VECTOR_TYPE)a0.s1;
4335#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4336 acc1 += b0 * (VECTOR_TYPE)a1.s0;
4337 acc1 += b1 * (VECTOR_TYPE)a1.s1;
4338#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4339#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4340 acc2 += b0 * (VECTOR_TYPE)a2.s0;
4341 acc2 += b1 * (VECTOR_TYPE)a2.s1;
4342#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4343#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4344 acc3 += b0 * (VECTOR_TYPE)a3.s0;
4345 acc3 += b1 * (VECTOR_TYPE)a3.s1;
4346#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004347 }
4348
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004349 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(DATA_TYPE), src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004350 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004351#if defined(REINTERPRET_INPUT_AS_3D)
4352 // Load values from matrix A
4353 DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
4354#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4355 DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
4356#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4357#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4358 DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
4359#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4360#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4361 DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
4362#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4363#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004364 // Load values from matrix A
4365 DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
4366#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4367 DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
4368#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4369#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4370 DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
4371#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4372#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4373 DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
4374#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004375#endif // defined(REINTERPRET_INPUT_AS_3D)
4376
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004377 // Load values from matrix B
4378 VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004379
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004380 // Accumulate
4381 acc0 += b0 * (VECTOR_TYPE)a0;
4382#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4383 acc1 += b0 * (VECTOR_TYPE)a1;
4384#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4385#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4386 acc2 += b0 * (VECTOR_TYPE)a2;
4387#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4388#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4389 acc3 += b0 * (VECTOR_TYPE)a3;
4390#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004391 }
4392
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004393 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004394 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
4395
Gian Marcoae2af742018-02-15 12:35:44 +00004396 // Compute dst address
4397 __global uchar *dst_addr = offset(&dst, 0, 0);
4398
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004399 // Multiply by the weight of matrix-matrix product and store the result
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004400#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004401 acc0 = acc0 * (VECTOR_TYPE)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004402#endif // defined(ALPHA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004403#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
4404 acc1 = acc1 * (VECTOR_TYPE)ALPHA;
4405#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
4406#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
4407 acc2 = acc2 * (VECTOR_TYPE)ALPHA;
4408#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
4409#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
4410 acc3 = acc3 * (VECTOR_TYPE)ALPHA;
4411#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
4412
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004413#if defined(ADD_VEC_C)
4414 // *INDENT-OFF*
4415 // clang-format off
4416 __global DATA_TYPE *src2_addr = (__global DATA_TYPE *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
4417 VECTOR_TYPE c0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, src2_addr);
4418 // clang-format on
4419 // *INDENT-ON*
4420
4421 acc0 += c0;
4422#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4423 acc1 += c0;
4424#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4425#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4426 acc2 += c0;
4427#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4428#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4429 acc3 += c0;
4430#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4431#endif /* defined(ADD_VEC_C) */
4432
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004433 int z = get_global_id(2);
4434
4435#if defined(REINTERPRET_OUTPUT_AS_3D)
4436 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004437 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004438 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004439 // | |
4440 // | plane0 |
4441 // | |
4442 // |__________________|
4443 // |******************|
4444 // | cross_plane_pad |
4445 // |******************|
4446 // | |
4447 // | plane1 |
4448 // | |
4449 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004450
4451 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
4452 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
4453 zout = min(DEPTH_GEMM3D - 1, zout);
4454
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004455 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004456 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004457
4458 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4459 // multiply dst_stride_z by DEPTH_GEMM3D
4460 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
4461
4462 // Store output block
4463 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
4464 (acc0, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y + zout.s0));
4465#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4466 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
4467 (acc1, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y + zout.s1));
4468#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4469#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4470 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
4471 (acc2, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y + zout.s2));
4472#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4473#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4474 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
4475 (acc3, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y + zout.s3));
4476#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4477
4478#else // defined(REINTERPRET_OUTPUT_AS_3D)
4479 // Add offset for batched GEMM
4480 dst_addr += z * dst_stride_z;
4481
4482 // Store output block
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004483 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00004484 (acc0, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004485#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004486 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00004487 (acc1, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004488#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4489#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004490 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00004491 (acc2, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004492#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4493#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004494 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00004495 (acc3, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004496#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004497#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004498}
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004499#endif // defined(DATA_TYPE)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01004500
Michele Di Giorgiof6f08da2018-04-26 10:24:30 +01004501/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004502 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004503 * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
4504 *
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004505 * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
4506 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
4507 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
4508 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
4509 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00004510 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
4511 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004512 *
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004513 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
4514 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004515 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
4516 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
4517 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
4518 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
4519 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004520 * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
4521 *
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004522 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
4523 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
4524 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4525 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
4526 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4527 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
4528 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
4529 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
4530 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4531 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
4532 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4533 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004534 * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
4535 * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes)
4536 * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
4537 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004538 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
4539 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
4540 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
4541 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
4542 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
4543 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004544 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
4545 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
4546 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004547 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
4548 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004549 */
4550__kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
4551 IMAGE_DECLARATION(src1),
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004552#if defined(ADD_VEC_C)
4553 VECTOR_DECLARATION(src2),
4554#endif /* defined(ADD_VEC_C) */
Gian Marcoae2af742018-02-15 12:35:44 +00004555 IMAGE_DECLARATION(dst),
4556 uint src0_stride_z,
4557 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004558 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004559#if defined(REINTERPRET_INPUT_AS_3D)
4560 ,
4561 uint src_cross_plane_pad
4562#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004563#if defined(REINTERPRET_OUTPUT_AS_3D)
4564 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004565 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004566#endif // REINTERPRET_OUTPUT_AS_3D
4567 )
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004568{
4569 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
4570
4571 // Compute starting address for matrix A and matrix B
4572 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
4573
4574 // Update address for matrix A
4575 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
4576
4577 // Update address for matrix B
4578 src_addr.s1 += idx * sizeof(float);
4579
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004580#if defined(REINTERPRET_INPUT_AS_3D)
4581 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
4582 // in order to take into account the presence of possible cross plane paddings
4583 //
4584 // | |
4585 // | plane0 |
4586 // | |
4587 // |__________________|
4588 // |******************|
4589 // | cross_plane_pad |
4590 // |******************|
4591 // | |
4592 // | plane1 |
4593 // | |
4594 // |__________________|
4595
4596 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
4597 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
4598 zin = min(DEPTH_GEMM3D - 1, zin);
4599
4600 // Add offset due to the cross plane paddings
4601 zin *= (src_cross_plane_pad * src0_stride_y);
4602
4603 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4604 // multiply src0_stride_z by DEPTH_GEMM3D
4605 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
4606
4607#else // defined(REINTERPRET_INPUT_AS_3D)
4608
Gian Marcoae2af742018-02-15 12:35:44 +00004609 // Add offset for batched GEMM
4610 src_addr.s0 += get_global_id(2) * src0_stride_z;
4611
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004612#endif // defined(REINTERPRET_INPUT_AS_3D)
4613
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00004614#if defined(MATRIX_B_DEPTH)
4615 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
4616 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
4617#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00004618 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00004619#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00004620
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004621 // Initialize accumulators
4622 float acc00 = 0.0f;
4623 float acc01 = 0.0f;
4624 float acc02 = 0.0f;
4625 float acc03 = 0.0f;
4626
4627#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4628 float acc10 = 0.0f;
4629 float acc11 = 0.0f;
4630 float acc12 = 0.0f;
4631 float acc13 = 0.0f;
4632#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4633
4634#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4635 float acc20 = 0.0f;
4636 float acc21 = 0.0f;
4637 float acc22 = 0.0f;
4638 float acc23 = 0.0f;
4639#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4640
4641#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4642 float acc30 = 0.0f;
4643 float acc31 = 0.0f;
4644 float acc32 = 0.0f;
4645 float acc33 = 0.0f;
4646#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4647
4648 // A and B src indices get incremented at the same time.
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004649 int i = 0;
4650 for(; i <= ((int)COLS_A - 4); i += 4)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004651 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004652#if defined(REINTERPRET_INPUT_AS_3D)
4653 // Load values from matrix A and matrix B
4654 float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
4655#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4656 float4 a1 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
4657#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4658#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4659 float4 a2 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
4660#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4661#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4662 float4 a3 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
4663#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4664#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004665 // Load values from matrix A and matrix B
4666 float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004667#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004668 float4 a1 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004669#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4670#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004671 float4 a2 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004672#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4673#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004674 float4 a3 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004675#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004676#endif // defined(REINTERPRET_INPUT_AS_3D)
4677
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004678 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
4679 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004680
4681 // Multiply and accumulate
4682 acc00 = fma(a0.s0, b0.s0, acc00);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004683 acc01 = fma(a0.s0, b0.s1, acc01);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004684 acc02 = fma(a0.s0, b0.s2, acc02);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004685 acc03 = fma(a0.s0, b0.s3, acc03);
4686
4687#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004688
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004689 acc10 = fma(a1.s0, b0.s0, acc10);
4690 acc11 = fma(a1.s0, b0.s1, acc11);
4691 acc12 = fma(a1.s0, b0.s2, acc12);
4692 acc13 = fma(a1.s0, b0.s3, acc13);
4693
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004694#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4695#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004696
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004697 acc20 = fma(a2.s0, b0.s0, acc20);
4698 acc21 = fma(a2.s0, b0.s1, acc21);
4699 acc22 = fma(a2.s0, b0.s2, acc22);
4700 acc23 = fma(a2.s0, b0.s3, acc23);
4701
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004702#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4703#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004704
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004705 acc30 = fma(a3.s0, b0.s0, acc30);
4706 acc31 = fma(a3.s0, b0.s1, acc31);
4707 acc32 = fma(a3.s0, b0.s2, acc32);
4708 acc33 = fma(a3.s0, b0.s3, acc33);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004709#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004710
4711 // Load values from matrix A and matrix B
4712 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
4713 src_addr.s1 += src1_stride_y;
4714
4715 // Multiply and accumulate
4716 acc00 = fma(a0.s1, b0.s0, acc00);
4717 acc01 = fma(a0.s1, b0.s1, acc01);
4718 acc02 = fma(a0.s1, b0.s2, acc02);
4719 acc03 = fma(a0.s1, b0.s3, acc03);
4720
4721#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4722
4723 acc10 = fma(a1.s1, b0.s0, acc10);
4724 acc11 = fma(a1.s1, b0.s1, acc11);
4725 acc12 = fma(a1.s1, b0.s2, acc12);
4726 acc13 = fma(a1.s1, b0.s3, acc13);
4727
4728#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4729#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4730
4731 acc20 = fma(a2.s1, b0.s0, acc20);
4732 acc21 = fma(a2.s1, b0.s1, acc21);
4733 acc22 = fma(a2.s1, b0.s2, acc22);
4734 acc23 = fma(a2.s1, b0.s3, acc23);
4735
4736#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4737#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4738
4739 acc30 = fma(a3.s1, b0.s0, acc30);
4740 acc31 = fma(a3.s1, b0.s1, acc31);
4741 acc32 = fma(a3.s1, b0.s2, acc32);
4742 acc33 = fma(a3.s1, b0.s3, acc33);
4743#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4744
4745 // Load values from matrix A and matrix B
4746 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
4747 src_addr.s1 += src1_stride_y;
4748
4749 // Multiply and accumulate
4750 acc00 = fma(a0.s2, b0.s0, acc00);
4751 acc01 = fma(a0.s2, b0.s1, acc01);
4752 acc02 = fma(a0.s2, b0.s2, acc02);
4753 acc03 = fma(a0.s2, b0.s3, acc03);
4754
4755#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4756
4757 acc10 = fma(a1.s2, b0.s0, acc10);
4758 acc11 = fma(a1.s2, b0.s1, acc11);
4759 acc12 = fma(a1.s2, b0.s2, acc12);
4760 acc13 = fma(a1.s2, b0.s3, acc13);
4761
4762#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4763#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4764
4765 acc20 = fma(a2.s2, b0.s0, acc20);
4766 acc21 = fma(a2.s2, b0.s1, acc21);
4767 acc22 = fma(a2.s2, b0.s2, acc22);
4768 acc23 = fma(a2.s2, b0.s3, acc23);
4769
4770#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4771#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4772
4773 acc30 = fma(a3.s2, b0.s0, acc30);
4774 acc31 = fma(a3.s2, b0.s1, acc31);
4775 acc32 = fma(a3.s2, b0.s2, acc32);
4776 acc33 = fma(a3.s2, b0.s3, acc33);
4777#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4778
4779 // Load values from matrix A and matrix B
4780 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
4781 src_addr.s1 += src1_stride_y;
4782
4783 // Multiply and accumulate
4784 acc00 = fma(a0.s3, b0.s0, acc00);
4785 acc01 = fma(a0.s3, b0.s1, acc01);
4786 acc02 = fma(a0.s3, b0.s2, acc02);
4787 acc03 = fma(a0.s3, b0.s3, acc03);
4788
4789#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4790
4791 acc10 = fma(a1.s3, b0.s0, acc10);
4792 acc11 = fma(a1.s3, b0.s1, acc11);
4793 acc12 = fma(a1.s3, b0.s2, acc12);
4794 acc13 = fma(a1.s3, b0.s3, acc13);
4795
4796#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4797#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4798
4799 acc20 = fma(a2.s3, b0.s0, acc20);
4800 acc21 = fma(a2.s3, b0.s1, acc21);
4801 acc22 = fma(a2.s3, b0.s2, acc22);
4802 acc23 = fma(a2.s3, b0.s3, acc23);
4803
4804#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4805#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4806
4807 acc30 = fma(a3.s3, b0.s0, acc30);
4808 acc31 = fma(a3.s3, b0.s1, acc31);
4809 acc32 = fma(a3.s3, b0.s2, acc32);
4810 acc33 = fma(a3.s3, b0.s3, acc33);
4811#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4812
4813 src_addr.s0 += 4 * sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004814 }
4815
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004816 for(; i < (int)COLS_A; ++i)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004817 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004818#if defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004819 // Load values from matrix A
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004820 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
4821#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4822 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
4823#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4824#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4825 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
4826#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4827#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4828 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
4829#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4830#else // defined(REINTERPRET_INPUT_AS_3D)
4831 // Load values from matrix A
4832 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004833#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4834 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
4835#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4836#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4837 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
4838#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4839#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4840 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
4841#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004842#endif // defined(REINTERPRET_INPUT_AS_3D)
4843
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004844 // Load values from matrix B
4845 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004846 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004847
4848 // Multiply and accumulate
4849 acc00 = fma(a0, b0.s0, acc00);
4850 acc01 = fma(a0, b0.s1, acc01);
4851 acc02 = fma(a0, b0.s2, acc02);
4852 acc03 = fma(a0, b0.s3, acc03);
4853#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4854 acc10 = fma(a1, b0.s0, acc10);
4855 acc11 = fma(a1, b0.s1, acc11);
4856 acc12 = fma(a1, b0.s2, acc12);
4857 acc13 = fma(a1, b0.s3, acc13);
4858#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4859#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4860 acc20 = fma(a2, b0.s0, acc20);
4861 acc21 = fma(a2, b0.s1, acc21);
4862 acc22 = fma(a2, b0.s2, acc22);
4863 acc23 = fma(a2, b0.s3, acc23);
4864#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4865#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4866 acc30 = fma(a3, b0.s0, acc30);
4867 acc31 = fma(a3, b0.s1, acc31);
4868 acc32 = fma(a3, b0.s2, acc32);
4869 acc33 = fma(a3, b0.s3, acc33);
4870#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004871
4872 src_addr.s0 += sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004873 }
4874
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004875 int z = get_global_id(2);
4876
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004877 // Compute destination address
4878 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
4879
4880 // Multiply by the weight of matrix-matrix product and store the result
4881#if defined(ALPHA)
4882 acc00 = acc00 * ALPHA;
4883 acc01 = acc01 * ALPHA;
4884 acc02 = acc02 * ALPHA;
4885 acc03 = acc03 * ALPHA;
4886#endif // defined(ALPHA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004887#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004888 acc10 = acc10 * ALPHA;
4889 acc11 = acc11 * ALPHA;
4890 acc12 = acc12 * ALPHA;
4891 acc13 = acc13 * ALPHA;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004892#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
4893#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004894 acc20 = acc20 * ALPHA;
4895 acc21 = acc21 * ALPHA;
4896 acc22 = acc22 * ALPHA;
4897 acc23 = acc23 * ALPHA;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004898#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
4899#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004900 acc30 = acc30 * ALPHA;
4901 acc31 = acc31 * ALPHA;
4902 acc32 = acc32 * ALPHA;
4903 acc33 = acc33 * ALPHA;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004904#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
4905
4906 // Compute dst address
4907 __global uchar *dst_addr = offset(&dst, 0, 0);
4908
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004909#if defined(ADD_VEC_C)
4910 __global float *src2_addr = (__global float *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
4911 float4 c0 = vload4(0, src2_addr);
4912
4913 acc00 += c0.s0;
4914 acc01 += c0.s1;
4915 acc02 += c0.s2;
4916 acc03 += c0.s3;
4917#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4918 acc10 += c0.s0;
4919 acc11 += c0.s1;
4920 acc12 += c0.s2;
4921 acc13 += c0.s3;
4922#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4923#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4924 acc20 += c0.s0;
4925 acc21 += c0.s1;
4926 acc22 += c0.s2;
4927 acc23 += c0.s3;
4928#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4929#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4930 acc30 += c0.s0;
4931 acc31 += c0.s1;
4932 acc32 += c0.s2;
4933 acc33 += c0.s3;
4934#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4935#endif /* defined(ADD_VEC_C) */
4936
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004937#if defined(REINTERPRET_OUTPUT_AS_3D)
4938 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004939 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004940 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004941 // | |
4942 // | plane0 |
4943 // | |
4944 // |__________________|
4945 // |******************|
4946 // | cross_plane_pad |
4947 // |******************|
4948 // | |
4949 // | plane1 |
4950 // | |
4951 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004952
4953 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
4954 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
4955 zout = min(DEPTH_GEMM3D - 1, zout);
4956
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004957 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004958 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004959
4960 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4961 // multiply dst_stride_z by DEPTH_GEMM3D
4962 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
4963
4964 // Store the output block
4965 vstore4((float4)(acc00, acc01, acc02, acc03), 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
4966#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4967 vstore4((float4)(acc10, acc11, acc12, acc13), 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
4968#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4969#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4970 vstore4((float4)(acc20, acc21, acc22, acc23), 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
4971#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4972#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4973 vstore4((float4)(acc30, acc31, acc32, acc33), 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004974#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004975
4976#else // defined(REINTERPRET_OUTPUT_AS_3D)
4977 // Add offset for batched GEMM
4978 dst_addr += z * dst_stride_z;
4979
4980 // Store the output block
4981 vstore4((float4)(acc00, acc01, acc02, acc03), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
4982#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4983 vstore4((float4)(acc10, acc11, acc12, acc13), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
4984#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4985#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4986 vstore4((float4)(acc20, acc21, acc22, acc23), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
4987#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4988#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4989 vstore4((float4)(acc30, acc31, acc32, acc33), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
4990#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4991#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004992}
4993
4994/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
4995 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004996 * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
4997 *
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004998 * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
4999 * This OpenCL kernel is optimized for Bifrost when the number of matrix B columns is less or equal to 1000.
5000 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
5001 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=2.
5002 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
5003 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha if alpha!=1.0f.
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00005004 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
5005 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005006 *
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005007 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
5008 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005009 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
5010 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
5011 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
5012 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
5013 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00005014 * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
5015 *
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005016 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
5017 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
5018 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5019 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
5020 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
5021 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
5022 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
5023 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
5024 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5025 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
5026 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
5027 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00005028 * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
5029 * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes)
5030 * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
5031 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005032 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
5033 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
5034 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
5035 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
5036 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
5037 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005038 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
5039 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
5040 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005041 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
5042 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005043 */
5044__kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
5045 IMAGE_DECLARATION(src1),
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00005046#if defined(ADD_VEC_C)
5047 VECTOR_DECLARATION(src2),
5048#endif /* defined(ADD_VEC_C) */
Gian Marcoae2af742018-02-15 12:35:44 +00005049 IMAGE_DECLARATION(dst),
5050 uint src0_stride_z,
5051 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005052 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005053#if defined(REINTERPRET_INPUT_AS_3D)
5054 ,
5055 uint src_cross_plane_pad
5056#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005057#if defined(REINTERPRET_OUTPUT_AS_3D)
5058 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005059 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005060#endif // REINTERPRET_OUTPUT_AS_3D
5061 )
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005062{
5063 // Requires 2 NUM_ELEMS_PROCESSED_PER_THREAD_X, C vect2, A vect4, B (2 vload2) // to fix for NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5064 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
5065
5066 // Compute starting address for matrix A and Matrix B
5067 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
5068
5069 // Update address for the matrix A
5070 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
5071
5072 // Update address for the matrix B
5073 src_addr.s1 += idx * sizeof(float);
5074
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005075#if defined(REINTERPRET_INPUT_AS_3D)
5076 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
5077 // in order to take into account the presence of possible cross plane paddings
5078 //
5079 // | |
5080 // | plane0 |
5081 // | |
5082 // |__________________|
5083 // |******************|
5084 // | cross_plane_pad |
5085 // |******************|
5086 // | |
5087 // | plane1 |
5088 // | |
5089 // |__________________|
5090
5091 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
5092 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
5093 zin = min(DEPTH_GEMM3D - 1, zin);
5094
5095 // Add offset due to the cross plane paddings
5096 zin *= (src_cross_plane_pad * src0_stride_y);
5097
5098 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
5099 // multiply src0_stride_z by DEPTH_GEMM3D
5100 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
5101
5102#else // defined(REINTERPRET_INPUT_AS_3D)
5103
Gian Marcoae2af742018-02-15 12:35:44 +00005104 // Add offset for batched GEMM
5105 src_addr.s0 += get_global_id(2) * src0_stride_z;
5106
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005107#endif // defined(REINTERPRET_INPUT_AS_3D)
5108
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00005109#if defined(MATRIX_B_DEPTH)
5110 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
5111 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
5112#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00005113 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00005114#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00005115
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005116 // Initialize accumulators
5117 float acc00 = 0.0f;
5118 float acc01 = 0.0f;
5119
5120#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5121 float acc10 = 0.0f;
5122 float acc11 = 0.0f;
5123#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5124#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5125 float acc20 = 0.0f;
5126 float acc21 = 0.0f;
5127#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5128#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5129 float acc30 = 0.0f;
5130 float acc31 = 0.0f;
5131#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5132
5133 // A and B src indices get incremented at the same time.
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005134 int i = 0;
5135 for(; i <= ((int)COLS_A - 8); i += 8)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005136 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005137#if defined(REINTERPRET_INPUT_AS_3D)
5138 // Load values from matrix A
5139 float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + zin.s0));
5140#else // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005141 // Load values from matrix A
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005142 float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0));
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005143#endif // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005144
5145 // Load values from matrix B
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005146 float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
5147 src_addr.s1 += src1_stride_y;
5148 float2 b1 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
5149 src_addr.s1 += src1_stride_y;
5150 float2 b2 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
5151 src_addr.s1 += src1_stride_y;
5152 float2 b3 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
5153 src_addr.s1 += src1_stride_y;
5154 float2 b4 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
5155 src_addr.s1 += src1_stride_y;
5156 float2 b5 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
5157 src_addr.s1 += src1_stride_y;
5158 float2 b6 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
5159 src_addr.s1 += src1_stride_y;
5160 float2 b7 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
5161 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005162
5163 // Multiply and accumulate
5164 acc00 = fma(a0.s0, b0.s0, acc00);
5165 acc00 = fma(a0.s1, b1.s0, acc00);
5166 acc00 = fma(a0.s2, b2.s0, acc00);
5167 acc00 = fma(a0.s3, b3.s0, acc00);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005168 acc00 = fma(a0.s4, b4.s0, acc00);
5169 acc00 = fma(a0.s5, b5.s0, acc00);
5170 acc00 = fma(a0.s6, b6.s0, acc00);
5171 acc00 = fma(a0.s7, b7.s0, acc00);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005172
5173 acc01 = fma(a0.s0, b0.s1, acc01);
5174 acc01 = fma(a0.s1, b1.s1, acc01);
5175 acc01 = fma(a0.s2, b2.s1, acc01);
5176 acc01 = fma(a0.s3, b3.s1, acc01);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005177 acc01 = fma(a0.s4, b4.s1, acc01);
5178 acc01 = fma(a0.s5, b5.s1, acc01);
5179 acc01 = fma(a0.s6, b6.s1, acc01);
5180 acc01 = fma(a0.s7, b7.s1, acc01);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005181
5182#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005183#if defined(REINTERPRET_INPUT_AS_3D)
5184 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
5185#else // defined(REINTERPRET_INPUT_AS_3D)
5186 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
5187#endif // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005188 acc10 = fma(a0.s0, b0.s0, acc10);
5189 acc10 = fma(a0.s1, b1.s0, acc10);
5190 acc10 = fma(a0.s2, b2.s0, acc10);
5191 acc10 = fma(a0.s3, b3.s0, acc10);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005192 acc10 = fma(a0.s4, b4.s0, acc10);
5193 acc10 = fma(a0.s5, b5.s0, acc10);
5194 acc10 = fma(a0.s6, b6.s0, acc10);
5195 acc10 = fma(a0.s7, b7.s0, acc10);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005196
5197 acc11 = fma(a0.s0, b0.s1, acc11);
5198 acc11 = fma(a0.s1, b1.s1, acc11);
5199 acc11 = fma(a0.s2, b2.s1, acc11);
5200 acc11 = fma(a0.s3, b3.s1, acc11);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005201 acc11 = fma(a0.s4, b4.s1, acc11);
5202 acc11 = fma(a0.s5, b5.s1, acc11);
5203 acc11 = fma(a0.s6, b6.s1, acc11);
5204 acc11 = fma(a0.s7, b7.s1, acc11);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005205#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5206#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005207#if defined(REINTERPRET_INPUT_AS_3D)
5208 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
5209#else // defined(REINTERPRET_INPUT_AS_3D)
5210 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
5211#endif // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005212 acc20 = fma(a0.s0, b0.s0, acc20);
5213 acc20 = fma(a0.s1, b1.s0, acc20);
5214 acc20 = fma(a0.s2, b2.s0, acc20);
5215 acc20 = fma(a0.s3, b3.s0, acc20);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005216 acc20 = fma(a0.s4, b4.s0, acc20);
5217 acc20 = fma(a0.s5, b5.s0, acc20);
5218 acc20 = fma(a0.s6, b6.s0, acc20);
5219 acc20 = fma(a0.s7, b7.s0, acc20);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005220
5221 acc21 = fma(a0.s0, b0.s1, acc21);
5222 acc21 = fma(a0.s1, b1.s1, acc21);
5223 acc21 = fma(a0.s2, b2.s1, acc21);
5224 acc21 = fma(a0.s3, b3.s1, acc21);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005225 acc21 = fma(a0.s4, b4.s1, acc21);
5226 acc21 = fma(a0.s5, b5.s1, acc21);
5227 acc21 = fma(a0.s6, b6.s1, acc21);
5228 acc21 = fma(a0.s7, b7.s1, acc21);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005229#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5230#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005231#if defined(REINTERPRET_INPUT_AS_3D)
5232 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
5233#else // defined(REINTERPRET_INPUT_AS_3D)
5234 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
5235#endif // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005236 acc30 = fma(a0.s0, b0.s0, acc30);
5237 acc30 = fma(a0.s1, b1.s0, acc30);
5238 acc30 = fma(a0.s2, b2.s0, acc30);
5239 acc30 = fma(a0.s3, b3.s0, acc30);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005240 acc30 = fma(a0.s4, b4.s0, acc30);
5241 acc30 = fma(a0.s5, b5.s0, acc30);
5242 acc30 = fma(a0.s6, b6.s0, acc30);
5243 acc30 = fma(a0.s7, b7.s0, acc30);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005244
5245 acc31 = fma(a0.s0, b0.s1, acc31);
5246 acc31 = fma(a0.s1, b1.s1, acc31);
5247 acc31 = fma(a0.s2, b2.s1, acc31);
5248 acc31 = fma(a0.s3, b3.s1, acc31);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005249 acc31 = fma(a0.s4, b4.s1, acc31);
5250 acc31 = fma(a0.s5, b5.s1, acc31);
5251 acc31 = fma(a0.s6, b6.s1, acc31);
5252 acc31 = fma(a0.s7, b7.s1, acc31);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005253#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005254
5255 src_addr.s0 += sizeof(float) * 8;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005256 }
5257 // float size increment
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005258 for(; i < (int)COLS_A; ++i)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005259 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005260#if defined(REINTERPRET_INPUT_AS_3D)
5261 // Load values from matrix A
5262 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
5263#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5264 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
5265#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5266#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5267 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
5268#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5269#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5270 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
5271#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5272#else // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005273 // Load values from matrix A
5274 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
5275#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5276 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
5277#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5278#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5279 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
5280#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5281#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5282 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
5283#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005284#endif // defined(REINTERPRET_INPUT_AS_3D)
5285
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005286 // Load values from matrix B
5287 float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005288 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005289
5290 // Multiply and accumulate
5291 acc00 = fma(a0, b0.s0, acc00);
5292 acc01 = fma(a0, b0.s1, acc01);
5293#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5294 acc10 = fma(a1, b0.s0, acc10);
5295 acc11 = fma(a1, b0.s1, acc11);
5296#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5297#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5298 acc20 = fma(a2, b0.s0, acc20);
5299 acc21 = fma(a2, b0.s1, acc21);
5300#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5301#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5302 acc30 = fma(a3, b0.s0, acc30);
5303 acc31 = fma(a3, b0.s1, acc31);
5304#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005305
5306 src_addr.s0 += sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005307 }
5308
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005309 // Multiply by the weight of matrix-matrix product and store the result
5310#if defined(ALPHA)
5311 acc00 = acc00 * ALPHA;
5312 acc01 = acc01 * ALPHA;
5313#endif // defined(ALPHA)
5314#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
5315 acc10 = acc10 * ALPHA;
5316 acc11 = acc11 * ALPHA;
5317#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
5318#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
5319 acc20 = acc20 * ALPHA;
5320 acc21 = acc21 * ALPHA;
5321#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
5322#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
5323 acc30 = acc30 * ALPHA;
5324 acc31 = acc31 * ALPHA;
5325#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
5326
5327 int z = get_global_id(2);
5328
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005329 // Compute destination address
5330 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
5331
Gian Marcoae2af742018-02-15 12:35:44 +00005332 // Compute dst address
5333 __global uchar *dst_addr = offset(&dst, 0, 0);
5334
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00005335#if defined(ADD_VEC_C)
5336 __global float *src2_addr = (__global float *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
5337 float2 c0 = vload2(0, src2_addr);
5338
5339 acc00 += c0.s0;
5340 acc01 += c0.s1;
5341#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5342 acc10 += c0.s0;
5343 acc11 += c0.s1;
5344#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5345#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5346 acc20 += c0.s0;
5347 acc21 += c0.s1;
5348#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5349#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5350 acc30 += c0.s0;
5351 acc31 += c0.s1;
5352#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5353#endif /* defined(ADD_VEC_C) */
5354
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005355#if defined(REINTERPRET_OUTPUT_AS_3D)
5356 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01005357 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005358 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01005359 // | |
5360 // | plane0 |
5361 // | |
5362 // |__________________|
5363 // |******************|
5364 // | cross_plane_pad |
5365 // |******************|
5366 // | |
5367 // | plane1 |
5368 // | |
5369 // |__________________|
Gian Marcoae2af742018-02-15 12:35:44 +00005370
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005371 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
5372 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
5373 zout = min(DEPTH_GEMM3D - 1, zout);
5374
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01005375 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005376 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005377
5378 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
5379 // multiply dst_stride_z by DEPTH_GEMM3D
5380 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
5381
5382 // Store the output block
5383 vstore2((float2)(acc00, acc01), 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005384#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005385 vstore2((float2)(acc10, acc11), 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005386#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5387#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005388 vstore2((float2)(acc20, acc21), 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005389#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5390#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005391 vstore2((float2)(acc30, acc31), 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005392#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005393
5394#else // defined(REINTERPRET_OUTPUT_AS_3D)
5395 // Add offset for batched GEMM
5396 dst_addr += z * dst_stride_z;
5397
5398 // Store the output block
5399 vstore2((float2)(acc00, acc01), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
5400#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5401 vstore2((float2)(acc10, acc11), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
5402#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5403#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5404 vstore2((float2)(acc20, acc21), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
5405#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5406#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5407 vstore2((float2)(acc30, acc31), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
5408#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5409#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005410}
5411
Vidhya Sudhan Loganathanbdff4912018-05-22 15:03:09 +01005412#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005413/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
5414 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00005415 * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
5416 *
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005417 * @note This OpenCL kernel works with the 16-bit floating point data type (half) and accumulating the result in a 32 floating point variable.
5418 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
5419 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
5420 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
5421 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
5422 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
5423 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
5424 *
5425 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
5426 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
5427 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
5428 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
5429 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
5430 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
5431 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00005432 * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
5433 *
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005434 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
5435 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
5436 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5437 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
5438 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
5439 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
5440 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
5441 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
5442 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5443 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
5444 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
5445 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00005446 * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
5447 * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes)
5448 * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
5449 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005450 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
5451 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
5452 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
5453 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
5454 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
5455 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
5456 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
5457 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
5458 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
5459 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
5460 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
5461 */
5462__kernel void gemm_mm_floating_point_f16_bifrost_acc32(IMAGE_DECLARATION(src0),
5463 IMAGE_DECLARATION(src1),
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00005464#if defined(ADD_VEC_C)
5465 VECTOR_DECLARATION(src2),
5466#endif /* defined(ADD_VEC_C) */
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005467 IMAGE_DECLARATION(dst),
5468 uint src0_stride_z,
5469 uint src1_stride_z,
5470 uint dst_stride_z
5471#if defined(REINTERPRET_INPUT_AS_3D)
5472 ,
5473 uint src_cross_plane_pad
5474#endif // REINTERPRET_INPUT_AS_3D
5475#if defined(REINTERPRET_OUTPUT_AS_3D)
5476 ,
5477 uint dst_cross_plane_pad
5478#endif // REINTERPRET_OUTPUT_AS_3D
5479 )
5480{
5481 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
5482
5483 // Compute starting address for matrix A and Matrix B
5484 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
5485
5486 // Update address for the matrix A
5487 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
5488
5489 // Update address for the matrix B
5490 src_addr.s1 += idx * sizeof(half);
5491
5492#if defined(REINTERPRET_INPUT_AS_3D)
5493 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
5494 // in order to take into account the presence of possible cross plane paddings
5495 //
5496 // | |
5497 // | plane0 |
5498 // | |
5499 // |__________________|
5500 // |******************|
5501 // | cross_plane_pad |
5502 // |******************|
5503 // | |
5504 // | plane1 |
5505 // | |
5506 // |__________________|
5507
5508 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
5509 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
5510 zin = min(DEPTH_GEMM3D - 1, zin);
5511
5512 // Add offset due to the cross plane paddings
5513 zin *= (src_cross_plane_pad * src0_stride_y);
5514
5515 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
5516 // multiply src0_stride_z by DEPTH_GEMM3D
5517 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
5518
5519#else // defined(REINTERPRET_INPUT_AS_3D)
5520
5521 // Add offset for batched GEMM
5522 src_addr.s0 += get_global_id(2) * src0_stride_z;
5523
5524#endif // defined(REINTERPRET_INPUT_AS_3D)
5525
5526#if defined(MATRIX_B_DEPTH)
5527 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
5528 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
5529#else // defined(MATRIX_B_DEPTH)
5530 src_addr.s1 += get_global_id(2) * src1_stride_z;
5531#endif // defined(MATRIX_B_DEPTH)
5532
5533 float8 acc0 = 0.0h;
5534#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5535 float8 acc1 = 0.0h;
5536#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5537#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5538 float8 acc2 = 0.0h;
5539#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5540#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5541 float8 acc3 = 0.0h;
5542#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5543
5544 int i = 0;
5545 for(; i <= ((int)COLS_A - 4); i += 4)
5546 {
5547#if defined(REINTERPRET_INPUT_AS_3D)
5548 // Load values from matrix A
5549 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
5550#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5551 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
5552#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5553#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5554 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
5555#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5556#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5557 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
5558#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5559#else // defined(REINTERPRET_INPUT_AS_3D)
5560 // Load values from matrix A
5561 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
5562#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5563 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
5564#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5565#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5566 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
5567#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5568#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5569 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
5570#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5571#endif // defined(REINTERPRET_INPUT_AS_3D)
5572
5573 // Load values from matrix B
5574 float8 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
5575 src_addr.s1 += src1_stride_y;
5576
5577 // Accumulate
5578 acc0 = fma(b0, (float8)a0.s0, acc0);
5579#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5580 acc1 = fma(b0, (float8)a1.s0, acc1);
5581#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5582#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5583 acc2 = fma(b0, (float8)a2.s0, acc2);
5584#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5585#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5586 acc3 = fma(b0, (float8)a3.s0, acc3);
5587#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5588
5589 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
5590 src_addr.s1 += src1_stride_y;
5591 acc0 = fma(b0, (float8)a0.s1, acc0);
5592#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5593 acc1 = fma(b0, (float8)a1.s1, acc1);
5594#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5595#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5596 acc2 = fma(b0, (float8)a2.s1, acc2);
5597#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5598#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5599 acc3 = fma(b0, (float8)a3.s1, acc3);
5600#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5601
5602 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
5603 src_addr.s1 += src1_stride_y;
5604 acc0 = fma(b0, (float8)a0.s2, acc0);
5605#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5606 acc1 = fma(b0, (float8)a1.s2, acc1);
5607#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5608#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5609 acc2 = fma(b0, (float8)a2.s2, acc2);
5610#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5611#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5612 acc3 = fma(b0, (float8)a3.s2, acc3);
5613#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5614
5615 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
5616 src_addr.s1 += src1_stride_y;
5617 acc0 = fma(b0, (float8)a0.s3, acc0);
5618#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5619 acc1 = fma(b0, (float8)a1.s3, acc1);
5620#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5621#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5622 acc2 = fma(b0, (float8)a2.s3, acc2);
5623#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5624#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5625 acc3 = fma(b0, (float8)a3.s3, acc3);
5626#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5627
5628 src_addr.s0 += 4 * sizeof(half);
5629 }
5630
5631 for(; i < (int)COLS_A; ++i)
5632 {
5633#if defined(REINTERPRET_INPUT_AS_3D)
5634 // Load values from matrix A
5635 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
5636#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5637 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
5638#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5639#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5640 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
5641#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5642#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5643 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
5644#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5645#else // defined(REINTERPRET_INPUT_AS_3D)
5646 // Load values from matrix A
5647 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
5648#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5649 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
5650#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5651#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5652 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
5653#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5654#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5655 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
5656#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5657#endif // defined(REINTERPRET_INPUT_AS_3D)
5658
5659 // Load values from matrix B
5660 float8 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
5661
5662 src_addr += (int2)(sizeof(half), src1_stride_y);
5663
5664 // Accumulate
5665 acc0 = fma(b0, (float8)a0, acc0); // b0 * (half8)a0;
5666#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5667 acc1 = fma(b0, (float8)a1, acc1); // b0 * (half8)a1;
5668#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5669#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5670 acc2 = fma(b0, (float8)a2, acc2); // b0 * (half8)a2;
5671#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5672#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5673 acc3 = fma(b0, (float8)a3, acc3); // b0 * (half8)a3;
5674#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5675 }
5676
5677 // Multiply by the weight of matrix-matrix product and store the result
5678#if defined(ALPHA)
5679 half8 hacc0 = convert_half8(acc0) * (half8)ALPHA;
5680#else //defined(ALPHA)
5681 half8 hacc0 = convert_half8(acc0);
5682#endif // defined(ALPHA)
5683#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5684#if defined(ALPHA)
5685 half8 hacc1 = convert_half8(acc1) * (half8)ALPHA;
5686#else //defined(ALPHA)
5687 half8 hacc1 = convert_half8(acc1);
5688#endif //defined(ALPHA)
5689#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y
5690
5691#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5692#if defined(ALPHA)
5693 half8 hacc2 = convert_half8(acc2) * (half8)ALPHA;
5694#else //defined(ALPHA)
5695 half8 hacc2 = convert_half8(acc2);
5696#endif //defined(ALPHA)
5697#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5698
5699#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5700#if defined(ALPHA)
5701 half8 hacc3 = convert_half8(acc3) * (half8)ALPHA;
5702#else //defined(ALPHA)
5703 half8 hacc3 = convert_half8(acc3);
5704#endif // defined(ALPHA)
5705#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5706
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00005707#if defined(ADD_VEC_C)
5708 // *INDENT-OFF*
5709 // clang-format off
5710 __global half *src2_addr = (__global half *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
5711 half8 c0 = vload8(0, src2_addr);
5712 // clang-format on
5713 // *INDENT-ON*
5714
5715 hacc0 += c0;
5716#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5717 hacc1 += c0;
5718#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5719#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5720 hacc2 += c0;
5721#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5722#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5723 hacc3 += c0;
5724#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5725#endif /* defined(ADD_VEC_C) */
5726
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005727 int z = get_global_id(2);
5728
5729 // Compute destination address
5730 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
5731
5732 // Compute dst address
5733 __global uchar *dst_addr = offset(&dst, 0, 0);
5734
5735#if defined(REINTERPRET_OUTPUT_AS_3D)
5736 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
5737 // in order to take into account the presence of possible cross plane paddings
5738 //
5739 // | |
5740 // | plane0 |
5741 // | |
5742 // |__________________|
5743 // |******************|
5744 // | cross_plane_pad |
5745 // |******************|
5746 // | |
5747 // | plane1 |
5748 // | |
5749 // |__________________|
5750
5751 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
5752 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
5753 zout = min(DEPTH_GEMM3D - 1, zout);
5754
5755 // Add offset due to the cross plane paddings
5756 zout *= (dst_cross_plane_pad * dst_stride_y);
5757
5758 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
5759 // multiply dst_stride_z by DEPTH_GEMM3D
5760 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
5761
5762 // Store the output block
5763 vstore8(hacc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
5764#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5765 vstore8(hacc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
5766#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5767#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5768 vstore8(hacc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
5769#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5770#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5771 vstore8(hacc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
5772#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5773
5774#else // defined(REINTERPRET_OUTPUT_AS_3D)
5775 // Add offset for batched GEMM
5776 dst_addr += z * dst_stride_z;
5777
5778 // Store the output block
5779 vstore8(hacc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
5780#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5781 vstore8(hacc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
5782#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5783#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5784 vstore8(hacc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
5785#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5786#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5787 vstore8(hacc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
5788#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5789#endif // REINTERPRET_OUTPUT_AS_3D
5790}
5791
5792/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
5793 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00005794 * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
5795 *
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005796 * @note This OpenCL kernel works with the 16-bit floating point data type (half) and uses the fma units.
5797 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
5798 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
5799 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
5800 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
5801 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
5802 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
5803 *
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005804 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
5805 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005806 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
5807 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
5808 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
5809 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
5810 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00005811 * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
5812 *
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005813 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
5814 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
5815 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5816 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
5817 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
5818 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
5819 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
5820 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
5821 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5822 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
5823 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
5824 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00005825 * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
5826 * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes)
5827 * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
5828 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005829 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
5830 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
5831 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
5832 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
5833 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
5834 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005835 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
5836 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
5837 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005838 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
5839 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005840 */
5841__kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0),
5842 IMAGE_DECLARATION(src1),
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00005843#if defined(ADD_VEC_C)
5844 VECTOR_DECLARATION(src2),
5845#endif /* defined(ADD_VEC_C) */
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005846 IMAGE_DECLARATION(dst),
5847 uint src0_stride_z,
5848 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005849 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005850#if defined(REINTERPRET_INPUT_AS_3D)
5851 ,
5852 uint src_cross_plane_pad
5853#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005854#if defined(REINTERPRET_OUTPUT_AS_3D)
5855 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005856 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005857#endif // REINTERPRET_OUTPUT_AS_3D
5858 )
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005859{
5860 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
5861
5862 // Compute starting address for matrix A and Matrix B
5863 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
5864
5865 // Update address for the matrix A
5866 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
5867
5868 // Update address for the matrix B
5869 src_addr.s1 += idx * sizeof(half);
5870
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005871#if defined(REINTERPRET_INPUT_AS_3D)
5872 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
5873 // in order to take into account the presence of possible cross plane paddings
5874 //
5875 // | |
5876 // | plane0 |
5877 // | |
5878 // |__________________|
5879 // |******************|
5880 // | cross_plane_pad |
5881 // |******************|
5882 // | |
5883 // | plane1 |
5884 // | |
5885 // |__________________|
5886
5887 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
5888 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
5889 zin = min(DEPTH_GEMM3D - 1, zin);
5890
5891 // Add offset due to the cross plane paddings
5892 zin *= (src_cross_plane_pad * src0_stride_y);
5893
5894 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
5895 // multiply src0_stride_z by DEPTH_GEMM3D
5896 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
5897
5898#else // defined(REINTERPRET_INPUT_AS_3D)
5899
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005900 // Add offset for batched GEMM
5901 src_addr.s0 += get_global_id(2) * src0_stride_z;
5902
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005903#endif // defined(REINTERPRET_INPUT_AS_3D)
5904
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005905#if defined(MATRIX_B_DEPTH)
5906 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
5907 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
5908#else // defined(MATRIX_B_DEPTH)
5909 src_addr.s1 += get_global_id(2) * src1_stride_z;
5910#endif // defined(MATRIX_B_DEPTH)
5911
5912 half8 acc0 = 0.0h;
5913#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5914 half8 acc1 = 0.0h;
5915#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5916#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5917 half8 acc2 = 0.0h;
5918#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5919#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5920 half8 acc3 = 0.0h;
5921#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5922
5923 int i = 0;
5924 for(; i <= ((int)COLS_A - 4); i += 4)
5925 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005926#if defined(REINTERPRET_INPUT_AS_3D)
5927 // Load values from matrix A
5928 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
5929#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5930 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
5931#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5932#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5933 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
5934#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5935#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5936 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
5937#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5938#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005939 // Load values from matrix A
5940 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
5941#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5942 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
5943#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5944#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5945 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
5946#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5947#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5948 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
5949#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005950#endif // defined(REINTERPRET_INPUT_AS_3D)
5951
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005952 // Load values from matrix B
5953 half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
5954 src_addr.s1 += src1_stride_y;
5955
5956 // Accumulate
5957 acc0 = fma(b0, (half8)a0.s0, acc0);
5958#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5959 acc1 = fma(b0, (half8)a1.s0, acc1);
5960#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5961#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5962 acc2 = fma(b0, (half8)a2.s0, acc2);
5963#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5964#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5965 acc3 = fma(b0, (half8)a3.s0, acc3);
5966#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5967
5968 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
5969 src_addr.s1 += src1_stride_y;
5970 acc0 = fma(b0, (half8)a0.s1, acc0);
5971#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5972 acc1 = fma(b0, (half8)a1.s1, acc1);
5973#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5974#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5975 acc2 = fma(b0, (half8)a2.s1, acc2);
5976#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5977#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5978 acc3 = fma(b0, (half8)a3.s1, acc3);
5979#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5980
5981 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
5982 src_addr.s1 += src1_stride_y;
5983 acc0 = fma(b0, (half8)a0.s2, acc0);
5984#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5985 acc1 = fma(b0, (half8)a1.s2, acc1);
5986#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5987#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5988 acc2 = fma(b0, (half8)a2.s2, acc2);
5989#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5990#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5991 acc3 = fma(b0, (half8)a3.s2, acc3);
5992#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5993
5994 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
5995 src_addr.s1 += src1_stride_y;
5996 acc0 = fma(b0, (half8)a0.s3, acc0);
5997#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5998 acc1 = fma(b0, (half8)a1.s3, acc1);
5999#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6000#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6001 acc2 = fma(b0, (half8)a2.s3, acc2);
6002#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6003#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6004 acc3 = fma(b0, (half8)a3.s3, acc3);
6005#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6006
6007 src_addr.s0 += 4 * sizeof(half);
6008 }
6009
6010 for(; i < (int)COLS_A; ++i)
6011 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01006012#if defined(REINTERPRET_INPUT_AS_3D)
6013 // Load values from matrix A
6014 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
6015#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6016 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
6017#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6018#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6019 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
6020#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6021#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6022 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
6023#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6024#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006025 // Load values from matrix A
6026 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
6027#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6028 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
6029#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6030#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6031 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
6032#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6033#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6034 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
6035#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01006036#endif // defined(REINTERPRET_INPUT_AS_3D)
6037
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006038 // Load values from matrix B
6039 half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
6040
6041 src_addr += (int2)(sizeof(half), src1_stride_y);
6042
6043 // Accumulate
6044 acc0 = fma(b0, (half8)a0, acc0); // b0 * (half8)a0;
6045#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6046 acc1 = fma(b0, (half8)a1, acc1); // b0 * (half8)a1;
6047#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6048#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6049 acc2 = fma(b0, (half8)a2, acc2); // b0 * (half8)a2;
6050#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6051#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6052 acc3 = fma(b0, (half8)a3, acc3); // b0 * (half8)a3;
6053#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6054 }
6055
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006056 // Multiply by the weight of matrix-matrix product and store the result
6057#if defined(ALPHA)
6058 acc0 = acc0 * (half8)ALPHA;
6059#endif // defined(ALPHA)
6060#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
6061 acc1 = acc1 * (half8)ALPHA;
6062#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
6063#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
6064 acc2 = acc2 * (half8)ALPHA;
6065#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
6066#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
6067 acc3 = acc3 * (half8)ALPHA;
6068#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
6069
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00006070#if defined(ADD_VEC_C)
6071 // *INDENT-OFF*
6072 // clang-format off
6073 __global half *src2_addr = (__global half *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
6074 half8 c0 = vload8(0, src2_addr);
6075 // clang-format on
6076 // *INDENT-ON*
6077
6078 acc0 += c0;
6079#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6080 acc1 += c0;
6081#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6082#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6083 acc2 += c0;
6084#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6085#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6086 acc3 += c0;
6087#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6088#endif /* defined(ADD_VEC_C) */
6089
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006090 int z = get_global_id(2);
6091
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006092 // Compute destination address
6093 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
6094
6095 // Compute dst address
6096 __global uchar *dst_addr = offset(&dst, 0, 0);
6097
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006098#if defined(REINTERPRET_OUTPUT_AS_3D)
6099 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01006100 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006101 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01006102 // | |
6103 // | plane0 |
6104 // | |
6105 // |__________________|
6106 // |******************|
6107 // | cross_plane_pad |
6108 // |******************|
6109 // | |
6110 // | plane1 |
6111 // | |
6112 // |__________________|
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006113
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006114 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
6115 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
6116 zout = min(DEPTH_GEMM3D - 1, zout);
6117
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01006118 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01006119 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006120
6121 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
6122 // multiply dst_stride_z by DEPTH_GEMM3D
6123 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
6124
6125 // Store the output block
6126 vstore8(acc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
6127#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6128 vstore8(acc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
6129#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6130#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6131 vstore8(acc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
6132#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6133#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6134 vstore8(acc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
6135#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6136
6137#else // defined(REINTERPRET_OUTPUT_AS_3D)
6138 // Add offset for batched GEMM
6139 dst_addr += z * dst_stride_z;
6140
6141 // Store the output block
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006142 vstore8(acc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
6143#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006144 vstore8(acc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
6145#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6146#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006147 vstore8(acc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
6148#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6149#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006150 vstore8(acc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
6151#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006152#endif // REINTERPRET_OUTPUT_AS_3D
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006153}
Vidhya Sudhan Loganathanbdff4912018-05-22 15:03:09 +01006154#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006155
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01006156#endif // defined(COLS_A) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && (NUM_ELEMS_PROCESSED_PER_THREAD_Y)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006157
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006158#if defined(BETA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006159/** This OpenCL kernel performs the in-place matrix addition between 2 matrices taking into account that the second matrix might be weighted by a scalar value beta:
6160 *
Gian Marco19835e52018-01-30 13:35:54 +00006161 * @note The beta's value need to be passed at compile time using -DBETA
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006162 *
6163 * @param[in] src_ptr Pointer to the source matrix. Supported data types: F32
6164 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
6165 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
6166 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
6167 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006168 * @param[in] src_stride_z Stride of the destination tensor in Z dimension (in bytes)
6169 * @param[in] src_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006170 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01006171 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006172 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
6173 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
6174 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
6175 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006176 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
6177 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006178 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
6179 */
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006180__kernel void gemm_ma_f32(TENSOR3D_DECLARATION(src),
6181 TENSOR3D_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006182{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006183 // Compute source and destination addresses
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006184 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
6185 Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006186
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006187 // Load values from A x B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006188 float4 alpha_ab = vload4(0, (__global float *)dst.ptr);
6189
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006190 // Load values from Matrix C
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006191 float4 c = vload4(0, (__global float *)src.ptr);
6192
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006193 // Computes alpha * axb + beta * c
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006194 float4 out = alpha_ab + (float4)BETA * c;
6195
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006196 // Store final result in axb matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006197 vstore4(out, 0, (__global float *)dst.ptr);
6198}
6199
Vidhya Sudhan Loganathan76c85642018-05-25 13:53:02 +01006200#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006201/** This OpenCL kernel performs the in-place matrix addition between 2 matrices taking into account that the second matrix might be weighted by a scalar value beta:
6202 *
Gian Marco19835e52018-01-30 13:35:54 +00006203 * @note The beta's value need to be passed at compile time using -DBETA
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01006204 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006205 * @param[in] src_ptr Pointer to the source matrix. Supported data types: F16
6206 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
6207 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
6208 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
6209 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006210 * @param[in] src_stride_z Stride of the destination tensor in Z dimension (in bytes)
6211 * @param[in] src_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006212 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01006213 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006214 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
6215 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
6216 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
6217 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006218 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
6219 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006220 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
6221 */
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006222__kernel void gemm_ma_f16(TENSOR3D_DECLARATION(src),
6223 TENSOR3D_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006224{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006225 // Compute source and destination addresses
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006226 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
6227 Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006228
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006229 // Load values from A x B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006230 half8 alpha_ab = vload8(0, (__global half *)dst.ptr);
6231
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006232 // Load values from Matrix C
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006233 half8 c = vload8(0, (__global half *)src.ptr);
6234
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006235 // Computes alpha * axb + beta * c
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006236 half8 out = alpha_ab + (half8)BETA * c;
6237
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006238 // Store final result in axb matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006239 vstore8(out, 0, (__global half *)dst.ptr);
6240}
Vidhya Sudhan Loganathan76c85642018-05-25 13:53:02 +01006241#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006242#endif // defined(BETA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006243
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006244#if defined(WIDTH_VECTOR_A)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006245/** This OpenCL kernel computes the vector by matrix multiplication between each row of A (src0) and matrix B (src1) used for locally connected layer
6246 *
Gian Marco19835e52018-01-30 13:35:54 +00006247 * @note The width of A need to be passed at compile time using -DWIDTH_VECTOR_A
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006248 *
Gian Marco19835e52018-01-30 13:35:54 +00006249 * @note The input A and matrix B must not be reshaped
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006250 *
6251 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
6252 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
6253 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
6254 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
6255 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
6256 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01006257 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006258 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
6259 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
6260 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
6261 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
6262 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
6263 * @param[in] src1_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
6264 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01006265 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006266 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
6267 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
6268 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
6269 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
6270 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
6271 */
6272__kernel void gemm_lc_vm_f32(IMAGE_DECLARATION(src0),
6273 TENSOR3D_DECLARATION(src1),
6274 IMAGE_DECLARATION(dst))
6275{
6276 int idx = get_global_id(0) * 4;
6277 int idy = get_global_id(1);
6278
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006279 // Compute the address for the vector A and matrix B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006280 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes + src0_stride_y * idy, src1_offset_first_element_in_bytes + src1_stride_z * idy));
6281 src_addr.s1 += idx * sizeof(float);
6282
6283 int end_row_vec_a = src_addr.s0 + (WIDTH_VECTOR_A * sizeof(float));
6284
6285 float4 acc = 0.0f;
6286
Georgios Pinitas96880cf2017-10-20 18:52:20 +01006287 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(float)); src_addr += (int2)(2 * sizeof(float), 2 * src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006288 {
6289 float2 a0 = vload2(0, (__global float *)(src0_ptr + src_addr.s0));
6290 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
6291 float4 b1 = vload4(0, (__global float *)(src1_ptr + src_addr.s1 + src1_stride_y));
6292
6293 acc += b0 * (float4)a0.s0;
6294 acc += b1 * (float4)a0.s1;
6295 }
6296
6297 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(float), src1_stride_y))
6298 {
6299 float a0 = *((__global float *)(src0_ptr + src_addr.s0));
6300 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
6301
6302 acc += b0 * (float4)a0;
6303 }
6304
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006305 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006306 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
6307
6308 vstore4(acc, 0, (__global float *)(offset(&dst, 0, 0)));
6309}
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006310#endif // defined(WIDTH_VECTOR_A)
6311
6312/** This kernel accumulates each row with the biases vector.
6313 *
6314 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=short.
6315 * @note The vector size must be passed at compile time using -DVECTOR_SIZE e.g. -DVECTOR_SIZE=16.
6316 *
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01006317 * @param[in, out] accum_ptr Pointer to the accumulate tensor. Supported data type: U8/S8/U16/S16/F16/U32/S32/F32
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006318 * @param[in] accum_stride_x Stride of the accmulate tensor in X dimension (in bytes)
6319 * @param[in] accum_step_x accum_stride_x * number of elements along X processed per workitem(in bytes)
6320 * @param[in] accum_stride_y Stride of the accumlulate tensor in Y dimension (in bytes)
6321 * @param[in] accum_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
6322 * @param[in] accum_offset_first_element_in_bytes The offset of the first element in the accumulate tensor
6323 * @param[in] biases_ptr Pointer to the biases vector. Same as @p accum_ptr
6324 * @param[in] biases_stride_x Stride of the destination tensor in X dimension (in bytes)
6325 * @param[in] biases_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
6326 * @param[in] biases_offset_first_element_in_bytes The offset of the first element in the destination tensor
6327 */
6328#if defined(DATA_TYPE) && defined(VECTOR_SIZE)
6329__kernel void gemm_accumulate_biases(
6330 IMAGE_DECLARATION(accum),
6331 VECTOR_DECLARATION(biases))
6332{
6333 Image accum = CONVERT_TO_IMAGE_STRUCT(accum);
6334 Vector biases = CONVERT_TO_VECTOR_STRUCT(biases);
6335
6336 // Vector size, i.e. number of vector elements.
6337 VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
6338 accum_value = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)accum.ptr);
6339 VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
6340 biases_value = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)biases.ptr);
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01006341 accum_value = biases_value + accum_value;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006342 // Store result in the accumulate buffer
6343 VSTORE(VECTOR_SIZE)
6344 (accum_value, 0, (__global DATA_TYPE *)accum.ptr);
6345}
6346#endif // defined(DATA_TYPE) && defined(VECTOR_SIZE)