blob: 7ada14c77463d1ce22a367784038806b5f346e7e [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +00002 * Copyright (c) 2017-2019 ARM Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Usama Arif0681e3b2019-04-25 14:28:07 +010024#include "gemm_helpers.h"
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +000025#include "repeat.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010026
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +000027#if defined(M0) && defined(K0) && defined(V0) && defined(DATA_TYPE) && defined(SRC_WIDTH)
28#define INC2 (VEC_DATA_TYPE(uint, 2))(0, 1)
29#define INC3 (VEC_DATA_TYPE(uint, 3))(0, 1, 2)
30#define INC4 (VEC_DATA_TYPE(uint, 4))(0, 1, 2, 3)
31#define INC8 (VEC_DATA_TYPE(uint, 8))(0, 1, 2, 3, 4, 5, 6, 7)
32#define INC16 (VEC_DATA_TYPE(uint, 16))(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)
33#define CONCAT_INC(K0) INC##K0
34#define INC(K0) CONCAT_INC(K0)
35
36#if(SRC_WIDTH % K0)
37#define BOUNDARY_CONDITION_X(x, a) \
38 ({ \
39 a = select(0, a, CONVERT(((x * (VEC_DATA_TYPE(uint, K0))K0 + INC(K0)) < (VEC_DATA_TYPE(uint, K0))SRC_WIDTH), VEC_DATA_TYPE(DATA_TYPE, K0))); \
40 })
41#else // (SRC_WIDTH % K0)
42#define BOUNDARY_CONDITION_X(x, a) \
43 ({})
44#endif // (SRC_WIDTH % K0)
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000045
46/** This OpenCL kernel reshapes the lhs input matrix. The kernel splits the input matrix in blocks of size M0xK0 and stores each one (not transposed) in
47 * the output matrix unrolling the values.
48 *
49 * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float)
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +000050 * @note The width of the input tensor must be passed at compile time using -DSRC_WIDTH (i.e. -DSRC_WIDTH=16)
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000051 * @note The block's dimensions (M0 and K0) must be passed at compile time using -DM0 and -DK0 (i.e. -DM0=2, -DK0=2).
52 * @note The number of M0xK0 vertical blocks to store on the same output row must be passed at compile time using -DV0 (i.e. -DV0=2)
53 * @note Only the following values for M0, K0 and V0 are supported:
54 * M0: 2,3,4,5,6,7,8
Gian Marco Iodicebacfec52019-01-11 11:30:55 +000055 * K0: 2,3,4,8,16
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000056 * V0: greater than 0
57 * @note In case the input has to be reinterpreted as a 3D tensor (i.e. input of convolution layer 1x1), the following information must be passed at compile time:
58 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
59 * -# HEIGHT_GEMM3D: The height of the input in case it has to be reinterpreted as a 3D tensor.
60 * -# DEPTH_GEMM3D: The depth of the input in case it has to be reinterpreted as a 3D tensor
61 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
62 * @note If the M0xK0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
63 *
64 * @param[in] src_ptr Pointer to the source LHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
65 * @param[in] src_stride_x Stride of the source LHS tensor in X dimension (in bytes)
66 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
67 * @param[in] src_stride_y Stride of the source LHS tensor in Y dimension (in bytes)
68 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
69 * @param[in] src_stride_z Stride of the source LHS tensor in Z dimension (in bytes)
70 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
71 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source LHS tensor
72 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
73 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
74 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
75 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
76 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
77 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
78 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
79 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
80 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
81 */
82__kernel void gemm_reshape_lhs_matrix_nt(TENSOR3D_DECLARATION(src),
83 TENSOR3D_DECLARATION(dst)
84#if defined(REINTERPRET_INPUT_AS_3D)
85 ,
86 uint cross_plane_pad
87#endif // REINTERPRET_INPUT_AS_3D
88 )
89{
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +000090 // Block size
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000091#define BLOCK_SIZE ((M0) * (K0))
92
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +000093 // Output offset X
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000094#if defined(INTERLEAVE)
95#define OUTPUT_OFFSET_X (K0)
96#else // defined(INTERLEAVE)
97#define OUTPUT_OFFSET_X (BLOCK_SIZE)
98#endif // defined(INTERLEAVE)
99
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000100 // Output step X
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000101#if defined(INTERLEAVE)
102#define OUTPUT_STEP_X (K0) * (V0)
103#else // Do not interleave
104#define OUTPUT_STEP_X (K0)
105#endif // defined(INTERLEAVE)
106
107 // Compute source and destination addresses
108 uint x = get_global_id(0);
109 uint y = get_global_id(1);
110 uint z = get_global_id(2);
111
112 // ------------------ Compute input/output addresses ---------------------------
113
114 // Compute the input address
115 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)K0 * sizeof(DATA_TYPE) + y * (uint)M0 * src_stride_y;
116
117 // Compute the output address
118 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)BLOCK_SIZE * (uint)V0 * sizeof(DATA_TYPE)) + ((y / (uint)V0) * (uint)dst_stride_y) + ((y % V0) *
119 (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE));
120
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000121 // Create variables: uint zin0=0, zin1=0, zin2=0...zin(M0-1)=0;
122 REPEAT_VAR_INIT_TO_CONST(M0, uint, zin, 0);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000123
124#if defined(REINTERPRET_INPUT_AS_3D)
125 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
126 // multiply src_stride_z by DEPTH_GEMM3D
127
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000128 input_ptr += z * (uint)src_stride_z * DEPTH_GEMM3D;
129
130 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
Usama Arif0681e3b2019-04-25 14:28:07 +0100131 CALCULATE_Z_OFFSET(M0, uint, zin, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, cross_plane_pad, src_stride_y);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000132
133#else // defined(REINTERPRET_INPUT_AS_3D)
134
135 input_ptr += z * (uint)src_stride_z;
136
137#endif // defined(REINTERPRET_INPUT_AS_3D)
138
139 // Add offset for batched GEMM
140 output_ptr += z * (uint)dst_stride_z;
141
142 // ---------------------------Load input values --------------------------------
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000143 // Load values from the LHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +0100144 LOAD_BLOCK(M0, K0, DATA_TYPE, a, input_ptr, 0, src_stride_y, zin);
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000145 BOUNDARY_CONDITION_X(x, a0);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000146#if M0 > 1
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000147 BOUNDARY_CONDITION_X(x, a1);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000148#endif // M0 > 1
149#if M0 > 2
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000150 BOUNDARY_CONDITION_X(x, a2);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000151#endif // M0 > 2
152#if M0 > 3
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000153 BOUNDARY_CONDITION_X(x, a3);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000154#endif // M0 > 3
155#if M0 > 4
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000156 BOUNDARY_CONDITION_X(x, a4);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000157#endif // M0 > 4
158#if M0 > 5
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000159 BOUNDARY_CONDITION_X(x, a5);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000160#endif // M0 > 5
161#if M0 > 6
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000162 BOUNDARY_CONDITION_X(x, a6);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000163#endif // M0 > 6
164#if M0 > 7
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000165 BOUNDARY_CONDITION_X(x, a7);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000166#endif // M0 > 7
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000167 // ---------------------------Store output values ------------------------------
Usama Arif0681e3b2019-04-25 14:28:07 +0100168 REPEAT_VAR_INIT_TO_CONST(16, uint, zout, 0);
169 STORE_BLOCK(M0, K0, DATA_TYPE, a, output_ptr, OUTPUT_STEP_X * sizeof(DATA_TYPE), zout);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000170
171#undef BLOCK_SIZE
172#undef OUTPUT_OFFSET_X
173#undef OUTPUT_STEP_X
174}
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000175
176#if M0 == 2
177#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
178 ({ \
179 VEC_DATA_TYPE(DATA_TYPE, M0) \
180 res = (VEC_DATA_TYPE(DATA_TYPE, M0))(a0.s##i, a1.s##i); \
181 VSTORE(M0) \
182 (res, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
183 })
184#elif M0 == 3 // M0 == 3
185#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
186 ({ \
187 VEC_DATA_TYPE(DATA_TYPE, M0) \
188 res = (VEC_DATA_TYPE(DATA_TYPE, M0))(a0.s##i, a1.s##i, a2.s##i); \
189 VSTORE(M0) \
190 (res, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
191 })
192#elif M0 == 4 // M0 == 4
193#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
194 ({ \
195 VEC_DATA_TYPE(DATA_TYPE, M0) \
196 res = (VEC_DATA_TYPE(DATA_TYPE, M0))(a0.s##i, a1.s##i, a2.s##i, a3.s##i); \
197 VSTORE(M0) \
198 (res, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
199 })
200#elif M0 == 5 // M0 == 5
201#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
202 ({ \
203 VEC_DATA_TYPE(DATA_TYPE, 4) \
204 res0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s##i, a1.s##i, a2.s##i, a3.s##i); \
205 DATA_TYPE res1 = a4.s##i; \
206 VSTORE(4) \
207 (res0, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
208 *((__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE)) + 4) = res1; \
209 })
210#elif M0 == 6 // M0 == 6
211#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
212 ({ \
213 VEC_DATA_TYPE(DATA_TYPE, 4) \
214 res0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s##i, a1.s##i, a2.s##i, a3.s##i); \
215 VEC_DATA_TYPE(DATA_TYPE, 2) \
216 res1 = (VEC_DATA_TYPE(DATA_TYPE, 2))(a4.s##i, a5.s##i); \
217 VSTORE(4) \
218 (res0, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
219 VSTORE(2) \
220 (res1, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE)) + 4); \
221 })
222#elif M0 == 7 // M0 == 7
223#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
224 ({ \
225 VEC_DATA_TYPE(DATA_TYPE, 4) \
226 res0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s##i, a1.s##i, a2.s##i, a3.s##i); \
227 VEC_DATA_TYPE(DATA_TYPE, 3) \
228 res1 = (VEC_DATA_TYPE(DATA_TYPE, 3))(a4.s##i, a5.s##i, a6.s##i); \
229 VSTORE(4) \
230 (res0, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
231 VSTORE(3) \
232 (res1, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE)) + 4); \
233 })
234#elif M0 == 8 // M0 == 8
235#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
236 ({ \
237 VEC_DATA_TYPE(DATA_TYPE, M0) \
238 res = (VEC_DATA_TYPE(DATA_TYPE, M0))(a0.s##i, a1.s##i, a2.s##i, a3.s##i, a4.s##i, a5.s##i, a6.s##i, a7.s##i); \
239 VSTORE(M0) \
240 (res, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
241 })
242#else // M0 not supported
243#error "M0 value not supported"
244#endif // N0 conditions
245
246/** This OpenCL kernel reshapes the lhs input matrix. The kernel splits the input matrix in blocks of size M0xK0 and stores each one (transposed) in
247 * the output matrix unrolling the values.
248 *
249 * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float)
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000250 * @note The width of the input tensor must be passed at compile time using -DSRC_WIDTH (i.e. -DSRC_WIDTH=16)
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000251 * @note The block's dimensions (M0 and K0) must be passed at compile time using -DM0 and -DK0 (i.e. -DM0=2, -DK0=2).
252 * @note The number of M0xK0 vertical blocks to store on the same output row must be passed at compile time using -DV0 (i.e. -DV0=2)
253 * @note Only the following values for M0, K0 and V0 are supported:
254 * M0: 2,3,4,5,6,7,8
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000255 * K0: 2,3,4,8,16
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000256 * V0: greater than 0
257 * @note In case the input has to be reinterpreted as a 3D tensor (i.e. input of convolution layer 1x1), the following information must be passed at compile time:
258 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
259 * -# HEIGHT_GEMM3D: The height of the input in case it has to be reinterpreted as a 3D tensor.
260 * -# DEPTH_GEMM3D: The depth of the input in case it has to be reinterpreted as a 3D tensor
261 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
262 * @note If the M0xK0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
263 *
264 * @param[in] src_ptr Pointer to the source LHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
265 * @param[in] src_stride_x Stride of the source LHS tensor in X dimension (in bytes)
266 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
267 * @param[in] src_stride_y Stride of the source LHS tensor in Y dimension (in bytes)
268 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
269 * @param[in] src_stride_z Stride of the source LHS tensor in Z dimension (in bytes)
270 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
271 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source LHS tensor
272 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
273 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
274 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
275 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
276 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
277 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
278 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
279 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
280 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
281 */
282__kernel void gemm_reshape_lhs_matrix_t(TENSOR3D_DECLARATION(src),
283 TENSOR3D_DECLARATION(dst)
284#if defined(REINTERPRET_INPUT_AS_3D)
285 ,
286 uint cross_plane_pad
287#endif // REINTERPRET_INPUT_AS_3D
288 )
289{
290 // Block size
291#define BLOCK_SIZE ((M0) * (K0))
292
293 // Output offset X
294#if defined(INTERLEAVE)
295#define OUTPUT_OFFSET_X (M0)
296#else // defined(INTERLEAVE)
297#define OUTPUT_OFFSET_X (BLOCK_SIZE)
298#endif // defined(INTERLEAVE)
299
300 // Output step X
301#if defined(INTERLEAVE)
302#define OUTPUT_STEP_X (M0) * (V0)
303#else // Do not interleave
304#define OUTPUT_STEP_X (M0)
305#endif // defined(INTERLEAVE)
306
307 // Compute source and destination addresses
308 uint x = get_global_id(0);
309 uint y = get_global_id(1);
310 uint z = get_global_id(2);
311
312 // ------------------ Compute input/output addresses ---------------------------
313
314 // Compute the input address
315 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)K0 * sizeof(DATA_TYPE) + y * (uint)M0 * src_stride_y;
316
317 // Compute the output address
318 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)BLOCK_SIZE * (uint)V0 * sizeof(DATA_TYPE)) + ((y / (uint)V0) * (uint)dst_stride_y) + ((y % V0) *
319 (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE));
320
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000321 // Create variables: uint zin0=0, zin1=0, zin2=0...zin(M0-1)=0;
322 REPEAT_VAR_INIT_TO_CONST(M0, uint, zin, 0);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000323
324#if defined(REINTERPRET_INPUT_AS_3D)
325 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
326 // multiply src_stride_z by DEPTH_GEMM3D
327
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000328 input_ptr += z * (uint)src_stride_z * DEPTH_GEMM3D;
329
330 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
Usama Arif0681e3b2019-04-25 14:28:07 +0100331 CALCULATE_Z_OFFSET(M0, uint, zin, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, cross_plane_pad, src_stride_y);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000332
333#else // defined(REINTERPRET_INPUT_AS_3D)
334
335 input_ptr += z * (uint)src_stride_z;
336
337#endif // defined(REINTERPRET_INPUT_AS_3D)
338
339 // Add offset for batched GEMM
340 output_ptr += z * (uint)dst_stride_z;
341
342 // ---------------------------Load input values --------------------------------
343
344 // Load values from the LHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +0100345 LOAD_BLOCK(M0, K0, DATA_TYPE, a, input_ptr, 0, src_stride_y, zin);
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000346 BOUNDARY_CONDITION_X(x, a0);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000347#if M0 > 1
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000348 BOUNDARY_CONDITION_X(x, a1);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000349#endif // M0 > 1
350#if M0 > 2
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000351 BOUNDARY_CONDITION_X(x, a2);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000352#endif // M0 > 2
353#if M0 > 3
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000354 BOUNDARY_CONDITION_X(x, a3);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000355#endif // M0 > 3
356#if M0 > 4
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000357 BOUNDARY_CONDITION_X(x, a4);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000358#endif // M0 > 4
359#if M0 > 5
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000360 BOUNDARY_CONDITION_X(x, a5);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000361#endif // M0 > 5
362#if M0 > 6
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000363 BOUNDARY_CONDITION_X(x, a6);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000364#endif // M0 > 6
365#if M0 > 7
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000366 BOUNDARY_CONDITION_X(x, a7);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000367#endif // M0 > 7
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000368 // ---------------------------Transpose and store block -----------------------
369
370 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 0);
371 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 1);
372#if K0 > 2
373 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 2);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000374#endif // K0 > 2
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000375#if K0 > 3
376 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 3);
377#endif // K0 > 3
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000378#if K0 > 4
379 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 4);
380 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 5);
381 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 6);
382 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 7);
383#endif // K0 > 4
384#if K0 > 8
385 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 8);
386 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 9);
387 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, A);
388 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, B);
389 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, C);
390 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, D);
391 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, E);
392 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, F);
393#endif // K0 > 8
394
395#undef BLOCK_SIZE
396#undef OUTPUT_OFFSET_X
397#undef OUTPUT_STEP_X
398}
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000399#endif // defined(M0) && defined(K0) && defined(V0) && defined(DATA_TYPE) && defined(SRC_WIDTH)
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000400
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000401#if defined(K0) && defined(N0) && defined(H0) && defined(DATA_TYPE) && defined(SRC_HEIGHT)
402/** This OpenCL kernel reshapes the rhs input matrix. The kernel splits the input matrix in blocks of size K0xN0 and stores each one (not transposed) in
403 * the output matrix unrolling the values.
404 *
405 * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float)
406 * @note The height of the input tensor must be passed at compile time using -DSRC_HEIGHT (i.e. -DSRC_HEIGHT=16)
407 * @note The block's dimensions (K0 and N0) must be passed at compile time using -DK0 and -DN0 (i.e. -DK0=2, -DN0=2).
408 * @note The number of K0xN0 vertical blocks to store on the same output row must be passed at compile time using -DH0 (i.e. -DH0=2)
409 * @note If the K0xN0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
410 * @note Only the following values for K0, N0 and H0 are supported:
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000411 * N0: 2,3,4,8,16
412 * K0: 1,2,3,4,8,16
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000413 * H0: greater than 0
414 *
415 * @param[in] src_ptr Pointer to the source RHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
416 * @param[in] src_stride_x Stride of the source RHS tensor in X dimension (in bytes)
417 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
418 * @param[in] src_stride_y Stride of the source RHS tensor in Y dimension (in bytes)
419 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
420 * @param[in] src_stride_z Stride of the source RHS tensor in Z dimension (in bytes)
421 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
422 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source RHS tensor
423 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
424 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
425 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
426 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
427 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
428 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
429 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
430 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
431 */
432__kernel void gemm_reshape_rhs_matrix_nt(TENSOR3D_DECLARATION(src),
433 TENSOR3D_DECLARATION(dst))
434{
435 // Block size
436#define BLOCK_SIZE ((K0) * (N0))
437
438 // Output offset X
439#if defined(INTERLEAVE)
440#define OUTPUT_OFFSET_X (N0)
441#else // defined(INTERLEAVE)
442#define OUTPUT_OFFSET_X (BLOCK_SIZE)
443#endif // defined(INTERLEAVE)
444
445 // Output step X
446#if defined(INTERLEAVE)
447#define OUTPUT_STEP_X (N0) * (H0)
448#else // Do not interleave
449#define OUTPUT_STEP_X (N0)
450#endif // defined(INTERLEAVE)
451
452 // Compute source and destination addresses
453 uint x = get_global_id(0);
454 uint y = get_global_id(1);
455 uint z = get_global_id(2);
456
457 // ------------------ Compute input/output addresses ---------------------------
458
459 // Compute the input address
460 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)N0 * sizeof(DATA_TYPE) + y * (uint)K0 * src_stride_y + z * (uint)src_stride_z;
461
462 // Compute the output address
463 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (y * (uint)BLOCK_SIZE * (uint)H0 * sizeof(DATA_TYPE)) + ((x % (uint)H0) * (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE)) + ((
464 x / (uint)H0)
465 * (uint)dst_stride_y)
466 + z * (uint)dst_stride_z;
467
468 // ---------------------------Load input values --------------------------------
469
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +0000470 REPEAT_VAR_INIT_TO_CONST(K0, VEC_DATA_TYPE(DATA_TYPE, N0), a, 0); ////uint a0=0, a1=0, a2=0...a(M0-1)=0;
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000471
472 // Load values from the RHS matrix
473 a0 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y));
474#if K0 > 1
475 if(y * (uint)K0 + 1 < SRC_HEIGHT)
476 {
477 a1 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y));
478 }
479#endif // K0 > 1
480#if K0 > 2
481 if(y * (uint)K0 + 2 < SRC_HEIGHT)
482 {
483 a2 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y));
484 }
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000485#endif // K0 > 2
486#if K0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000487 if(y * (uint)K0 + 3 < SRC_HEIGHT)
488 {
489 a3 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y));
490 }
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000491#endif // K0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000492#if K0 > 4
493 if(y * (uint)K0 + 4 < SRC_HEIGHT)
494 {
495 a4 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 4 * src_stride_y));
496 }
497 if(y * (uint)K0 + 5 < SRC_HEIGHT)
498 {
499 a5 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 5 * src_stride_y));
500 }
501 if(y * (uint)K0 + 6 < SRC_HEIGHT)
502 {
503 a6 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 6 * src_stride_y));
504 }
505 if(y * (uint)K0 + 7 < SRC_HEIGHT)
506 {
507 a7 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 7 * src_stride_y));
508 }
509#endif // K0 > 4
510#if K0 > 8
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000511 if(y * (uint)K0 + 8 < SRC_HEIGHT)
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000512 {
513 a8 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 8 * src_stride_y));
514 }
515 if(y * (uint)K0 + 9 < SRC_HEIGHT)
516 {
517 a9 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 9 * src_stride_y));
518 }
519 if(y * (uint)K0 + 10 < SRC_HEIGHT)
520 {
521 aA = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 10 * src_stride_y));
522 }
523 if(y * (uint)K0 + 11 < SRC_HEIGHT)
524 {
525 aB = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 11 * src_stride_y));
526 }
527 if(y * (uint)K0 + 12 < SRC_HEIGHT)
528 {
529 aC = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 12 * src_stride_y));
530 }
531 if(y * (uint)K0 + 13 < SRC_HEIGHT)
532 {
533 aD = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 13 * src_stride_y));
534 }
535 if(y * (uint)K0 + 14 < SRC_HEIGHT)
536 {
537 aE = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 14 * src_stride_y));
538 }
539 if(y * (uint)K0 + 15 < SRC_HEIGHT)
540 {
541 aF = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 15 * src_stride_y));
542 }
543#endif // K0 > 8
544
545 // ---------------------------Store output values ------------------------------
Usama Arif0681e3b2019-04-25 14:28:07 +0100546 REPEAT_VAR_INIT_TO_CONST(16, uint, zout, 0);
547 STORE_BLOCK(K0, N0, DATA_TYPE, a, output_ptr, OUTPUT_STEP_X * sizeof(DATA_TYPE), zout);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000548
549#undef BLOCK_SIZE
550#undef OUTPUT_OFFSET_X
551#undef OUTPUT_STEP_X
552}
553
554#if defined(TRANSPOSE)
555/** This OpenCL kernel reshapes the rhs input matrix. The kernel splits the input matrix in blocks of size K0xN0 and stores each one (transposed) in
556 * the output matrix unrolling the values.
557 *
558 * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float)
559 * @note The height of the input tensor must be passed at compile time using -DSRC_HEIGHT (i.e. -DSRC_HEIGHT=16)
560 * @note The block's dimensions (K0 and N0) must be passed at compile time using -DK0 and -DN0 (i.e. -DK0=2, -DN0=2).
561 * @note The number of K0xN0 vertical blocks to store on the same output row must be passed at compile time using -DH0 (i.e. -DH0=2)
562 * @note If the K0xN0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
563 * @note The option -DTRANSPOSE must passed at compile time.
564 * @note Only the following values for K0, N0 and H0 are supported:
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000565 * N0: 2,3,4,8,16
566 * K0: 2,3,4,8,16
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000567 * H0: greater than 0
568 *
569 * @param[in] src_ptr Pointer to the source RHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
570 * @param[in] src_stride_x Stride of the source RHS tensor in X dimension (in bytes)
571 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
572 * @param[in] src_stride_y Stride of the source RHS tensor in Y dimension (in bytes)
573 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
574 * @param[in] src_stride_z Stride of the source RHS tensor in Z dimension (in bytes)
575 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
576 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source RHS tensor
577 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
578 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
579 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
580 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
581 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
582 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
583 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
584 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
585 */
586__kernel void gemm_reshape_rhs_matrix_t(TENSOR3D_DECLARATION(src),
587 TENSOR3D_DECLARATION(dst))
588{
589 // Block size
590#define BLOCK_SIZE ((K0) * (N0))
591
592 // Output offset X
593#if defined(INTERLEAVE)
594#define OUTPUT_OFFSET_X (K0)
595#else // defined(INTERLEAVE)
596#define OUTPUT_OFFSET_X (BLOCK_SIZE)
597#endif // defined(INTERLEAVE)
598
599 // Output step X
600#if defined(INTERLEAVE)
601#define OUTPUT_STEP_X (K0) * (H0)
602#else // Do not interleave
603#define OUTPUT_STEP_X (K0)
604#endif // defined(INTERLEAVE)
605
606 // Compute source and destination addresses
607 uint x = get_global_id(0);
608 uint y = get_global_id(1);
609 uint z = get_global_id(2);
610
611 // ------------------ Compute input/output addresses ---------------------------
612
613 // Compute the input address
614 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)N0 * sizeof(DATA_TYPE) + y * (uint)K0 * src_stride_y + z * (uint)src_stride_z;
615
616 // Compute the output address
617 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (y * (uint)BLOCK_SIZE * (uint)H0 * sizeof(DATA_TYPE)) + ((x % H0) * (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE)) + ((x /
618 (uint)H0) * (uint)dst_stride_y) + z * (uint)dst_stride_z;
619
620 // ---------------------------Load input values --------------------------------
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +0000621 REPEAT_VAR_INIT_TO_CONST(K0, VEC_DATA_TYPE(DATA_TYPE, N0), a, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) a0=0, a1=0, ... a(K0-1)=0;
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000622
623 // Load values from the RHS matrix
624 a0 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y));
625 if(y * (uint)K0 + 1 < SRC_HEIGHT)
626 {
627 a1 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y));
628 }
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000629#if K0 > 2
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000630 if(y * (uint)K0 + 2 < SRC_HEIGHT)
631 {
632 a2 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y));
633 }
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000634#endif // K0 > 2
635#if K0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000636 if(y * (uint)K0 + 3 < SRC_HEIGHT)
637 {
638 a3 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y));
639 }
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000640#endif // K0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000641#if K0 > 4
642 if(y * (uint)K0 + 4 < SRC_HEIGHT)
643 {
644 a4 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 4 * src_stride_y));
645 }
646 if(y * (uint)K0 + 5 < SRC_HEIGHT)
647 {
648 a5 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 5 * src_stride_y));
649 }
650 if(y * (uint)K0 + 6 < SRC_HEIGHT)
651 {
652 a6 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 6 * src_stride_y));
653 }
654 if(y * (uint)K0 + 7 < SRC_HEIGHT)
655 {
656 a7 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 7 * src_stride_y));
657 }
658#endif // K0 > 4
659#if K0 > 8
Gian Marco Iodice89124342018-12-19 14:17:22 +0000660 if(y * (uint)K0 + 8 < SRC_HEIGHT)
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000661 {
662 a8 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 8 * src_stride_y));
663 }
664 if(y * (uint)K0 + 9 < SRC_HEIGHT)
665 {
666 a9 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 9 * src_stride_y));
667 }
668 if(y * (uint)K0 + 10 < SRC_HEIGHT)
669 {
670 aA = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 10 * src_stride_y));
671 }
672 if(y * (uint)K0 + 11 < SRC_HEIGHT)
673 {
674 aB = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 11 * src_stride_y));
675 }
676 if(y * (uint)K0 + 12 < SRC_HEIGHT)
677 {
678 aC = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 12 * src_stride_y));
679 }
680 if(y * (uint)K0 + 13 < SRC_HEIGHT)
681 {
682 aD = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 13 * src_stride_y));
683 }
684 if(y * (uint)K0 + 14 < SRC_HEIGHT)
685 {
686 aE = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 14 * src_stride_y));
687 }
688 if(y * (uint)K0 + 15 < SRC_HEIGHT)
689 {
690 aF = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 15 * src_stride_y));
691 }
692#endif // K0 > 8
693
694 // ---------------------------Transpose the block ------------------------------
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +0000695 REPEAT_VAR_INIT_TO_CONST(N0, VEC_DATA_TYPE(DATA_TYPE, K0), res, 0); //VEC_DATA_TYPE(DATA_TYPE, K0) res0=0, res1=0, res2=0,... res(N0-1)=0;
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000696
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000697#if K0 == 2
698 // This part computes the following transpositions:
699 // 2x2 -> 2x2
700 // 2x4 -> 4x2
701 // 2x8 -> 8x2
702 // 2x16 -> 16x2
703 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0);
704 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1);
705#if N0 > 2
706 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2);
707#endif // N0 > 2
708#if N0 > 3
709 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3);
710#endif // N0 > 3
711#if N0 > 4
712 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4);
713 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5);
714 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6);
715 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7);
716#endif // N0 > 4
717#if N0 > 8
718 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8);
719 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9);
720 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA);
721 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB);
722 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC);
723 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD);
724 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE);
725 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF);
726#endif // N0 > 8
727
728#elif K0 == 3 // K0 == 2
729 // This part computes the following transpositions:
730 // 3x2 -> 2x3
731 // 3x4 -> 4x3
732 // 3x8 -> 8x3
733 // 3x16 -> 16x3
Georgios Pinitasb0f342e2019-05-21 13:32:43 +0100734 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0);
735 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000736#if N0 > 2
Georgios Pinitasb0f342e2019-05-21 13:32:43 +0100737 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000738#endif // N0 > 2
739#if N0 > 3
Georgios Pinitasb0f342e2019-05-21 13:32:43 +0100740 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000741#endif // N0 > 3
742#if N0 > 4
Georgios Pinitasb0f342e2019-05-21 13:32:43 +0100743 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4);
744 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5);
745 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6);
746 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000747#endif // N0 > 4
748#if N0 > 8
Georgios Pinitasb0f342e2019-05-21 13:32:43 +0100749 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8);
750 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9);
751 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA);
752 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB);
753 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC);
754 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD);
755 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE);
756 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000757#endif // N0 > 8
758
759#elif K0 == 4 // K0 == 4
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000760 // This part computes the following transpositions:
761 // 4x2 -> 2x4
762 // 4x4 -> 4x4
763 // 4x8 -> 8x4
764 // 4x16 -> 16x4
765 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0, a3.s0);
766 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1, a3.s1);
767#if N0 > 2
768 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2, a3.s2);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000769#endif // N0 > 2
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000770#if N0 > 3
771 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3, a3.s3);
772#endif // N0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000773#if N0 > 4
774 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4, a3.s4);
775 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5, a3.s5);
776 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6, a3.s6);
777 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7, a3.s7);
778#endif // N0 > 4
779#if N0 > 8
780 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8, a3.s8);
781 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9, a3.s9);
782 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA, a3.sA);
783 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB, a3.sB);
784 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC, a3.sC);
785 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD, a3.sD);
786 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE, a3.sE);
787 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF, a3.sF);
788#endif // N0 > 8
789
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000790#elif K0 == 8 // K0 == 8
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000791 // This part computes the following transpositions:
792 // 8x2 -> 2x8
793 // 8x4 -> 4x8
794 // 8x8 -> 8x8
795 // 8x16 -> 16x8
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000796 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0, a3.s0, a4.s0, a5.s0, a6.s0, a7.s0);
797 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1, a3.s1, a4.s1, a5.s1, a6.s1, a7.s1);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000798#if N0 > 2
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000799 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2, a3.s2, a4.s2, a5.s2, a6.s2, a7.s2);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000800#endif // N0 > 2
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000801#if N0 > 3
802 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3, a3.s3, a4.s3, a5.s3, a6.s3, a7.s3);
803#endif // N0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000804#if N0 > 4
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000805 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4, a3.s4, a4.s4, a5.s4, a6.s4, a7.s4);
806 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5, a3.s5, a4.s5, a5.s5, a6.s5, a7.s5);
807 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6, a3.s6, a4.s6, a5.s6, a6.s6, a7.s6);
808 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7, a3.s7, a4.s7, a5.s7, a6.s7, a7.s7);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000809#endif // N0 > 4
810#if N0 > 8
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000811 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8, a3.s8, a4.s8, a5.s8, a6.s8, a7.s8);
812 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9, a3.s9, a4.s9, a5.s9, a6.s9, a7.s9);
813 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA, a3.sA, a4.sA, a5.sA, a6.sA, a7.sA);
814 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB, a3.sB, a4.sB, a5.sB, a6.sB, a7.sB);
815 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC, a3.sC, a4.sC, a5.sC, a6.sC, a7.sC);
816 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD, a3.sD, a4.sD, a5.sD, a6.sD, a7.sD);
817 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE, a3.sE, a4.sE, a5.sE, a6.sE, a7.sE);
818 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF, a3.sF, a4.sF, a5.sF, a6.sF, a7.sF);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000819#endif // N0 > 8
820
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000821#elif K0 == 16 // K0 == 16
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000822
823 // This part computes the following transpositions:
824 // 16x2 -> 2x16
825 // 16x4 -> 4x16
826 // 16x8 -> 8x16
827 // 16x16 -> 16x16
828 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0, a3.s0, a4.s0, a5.s0, a6.s0, a7.s0,
829 a8.s0, a9.s0, aA.s0, aB.s0, aC.s0, aD.s0, aE.s0, aF.s0);
830 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1, a3.s1, a4.s1, a5.s1, a6.s1, a7.s1,
831 a8.s1, a9.s1, aA.s1, aB.s1, aC.s1, aD.s1, aE.s1, aF.s1);
832#if N0 > 2
833 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2, a3.s2, a4.s2, a5.s2, a6.s2, a7.s2,
834 a8.s2, a9.s2, aA.s2, aB.s2, aC.s2, aD.s2, aE.s2, aF.s2);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000835#endif // N0 > 2
836#if N0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000837 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3, a3.s3, a4.s3, a5.s3, a6.s3, a7.s3,
838 a8.s3, a9.s3, aA.s3, aB.s3, aC.s3, aD.s3, aE.s3, aF.s3);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000839#endif // N0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000840#if N0 > 4
841 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4, a3.s4, a4.s4, a5.s4, a6.s4, a7.s4,
842 a8.s4, a9.s4, aA.s4, aB.s4, aC.s4, aD.s4, aE.s4, aF.s4);
843 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5, a3.s5, a4.s5, a5.s5, a6.s5, a7.s5,
844 a8.s5, a9.s5, aA.s5, aB.s5, aC.s5, aD.s5, aE.s5, aF.s5);
845 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6, a3.s6, a4.s6, a5.s6, a6.s6, a7.s6,
846 a8.s6, a9.s6, aA.s6, aB.s6, aC.s6, aD.s6, aE.s6, aF.s6);
847 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7, a3.s7, a4.s7, a5.s7, a6.s7, a7.s7,
848 a8.s7, a9.s7, aA.s7, aB.s7, aC.s7, aD.s7, aE.s7, aF.s7);
849#endif // N0 > 4
850#if N0 > 8
851 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8, a3.s8, a4.s8, a5.s8, a6.s8, a7.s8,
852 a8.s8, a9.s8, aA.s8, aB.s8, aC.s8, aD.s8, aE.s8, aF.s8);
853 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9, a3.s9, a4.s9, a5.s9, a6.s9, a7.s9,
854 a8.s9, a9.s9, aA.s9, aB.s9, aC.s9, aD.s9, aE.s9, aF.s9);
855 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA, a3.sA, a4.sA, a5.sA, a6.sA, a7.sA,
856 a8.sA, a9.sA, aA.sA, aB.sA, aC.sA, aD.sA, aE.sA, aF.sA);
857 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB, a3.sB, a4.sB, a5.sB, a6.sB, a7.sB,
858 a8.sB, a9.sB, aA.sB, aB.sB, aC.sB, aD.sB, aE.sB, aF.sB);
859 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC, a3.sC, a4.sC, a5.sC, a6.sC, a7.sC,
860 a8.sC, a9.sC, aA.sC, aB.sC, aC.sC, aD.sC, aE.sC, aF.sC);
861 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD, a3.sD, a4.sD, a5.sD, a6.sD, a7.sD,
862 a8.sD, a9.sD, aA.sD, aB.sD, aC.sD, aD.sD, aE.sD, aF.sD);
863 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE, a3.sE, a4.sE, a5.sE, a6.sE, a7.sE,
864 a8.sE, a9.sE, aA.sE, aB.sE, aC.sE, aD.sE, aE.sE, aF.sE);
865 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF, a3.sF, a4.sF, a5.sF, a6.sF, a7.sF,
866 a8.sF, a9.sF, aA.sF, aB.sF, aC.sF, aD.sF, aE.sF, aF.sF);
867#endif // N0 > 8
868
869#else // N0 == 16
870#error "Not supported N0 value"
871#endif // N0 > 2
872
873 // ---------------------------Store the output values ------------------------------
Usama Arif0681e3b2019-04-25 14:28:07 +0100874 REPEAT_VAR_INIT_TO_CONST(16, uint, zout, 0);
875 STORE_BLOCK(N0, K0, DATA_TYPE, res, output_ptr, OUTPUT_STEP_X * sizeof(DATA_TYPE), zout);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000876
877#undef BLOCK_SIZE
878#undef OUTPUT_OFFSET_X
879#undef OUTPUT_STEP_X
880}
881#endif // defined(TRANSPOSE)
882#endif // defined(K0) && defined(N0) && defined(H0) && defined(DATA_TYPE) && defined(SRC_HEIGHT)
883
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +0000884#if defined(M0) && defined(N0) && defined(K0) && defined(H0) && defined(DATA_TYPE) && defined(M) && defined(N) && defined(K)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +0000885
886#define CONCAT(a, b) a##b
887
888#define ARM_DOT1(a, b, c) \
889 ({ \
890 c = fma(a, b, c); \
891 })
892#define ARM_DOT2(a, b, c) \
893 ({ \
894 c = fma(a.s0, b.s0, c); \
895 c = fma(a.s1, b.s1, c); \
896 })
897#define ARM_DOT3(a, b, c) \
898 ({ \
899 ARM_DOT2(a, b, c); \
900 c = fma((a.s2), (b.s2), c); \
901 })
902#define ARM_DOT4(a, b, c) \
903 ({ \
904 ARM_DOT3(a, b, c); \
905 c = fma((a.s3), (b.s3), c); \
906 })
907#define ARM_DOT8(a, b, c) \
908 ({ \
909 ARM_DOT4((a.lo), (b.lo), c); \
910 ARM_DOT4((a.hi), (b.hi), c); \
911 })
912#define ARM_DOT16(a, b, c) \
913 ({ \
914 ARM_DOT8((a.lo), (b.lo), c); \
915 ARM_DOT8((a.hi), (b.hi), c); \
916 })
917
918#if N0 == 2
919#define ARM_DOT_K0XN0(k0, a, b, c) \
920 ({ \
921 CONCAT(ARM_DOT, k0) \
922 ((a), (b##0), (c.s0)); \
923 CONCAT(ARM_DOT, k0) \
924 ((a), (b##1), (c.s1)); \
925 })
926#elif N0 == 3 // N0 == 3
927#define ARM_DOT_K0XN0(k0, a, b, c) \
928 ({ \
929 CONCAT(ARM_DOT, k0) \
930 ((a), (b##0), (c.s0)); \
931 CONCAT(ARM_DOT, k0) \
932 ((a), (b##1), (c.s1)); \
933 CONCAT(ARM_DOT, k0) \
934 ((a), (b##2), (c.s2)); \
935 })
936#elif N0 == 4 // N0 == 4
937#define ARM_DOT_K0XN0(k0, a, b, c) \
938 ({ \
939 CONCAT(ARM_DOT, k0) \
940 ((a), (b##0), (c.s0)); \
941 CONCAT(ARM_DOT, k0) \
942 ((a), (b##1), (c.s1)); \
943 CONCAT(ARM_DOT, k0) \
944 ((a), (b##2), (c.s2)); \
945 CONCAT(ARM_DOT, k0) \
946 ((a), (b##3), (c.s3)); \
947 })
948#elif N0 == 8 // N0 == 8
949#define ARM_DOT_K0XN0(k0, a, b, c) \
950 ({ \
951 CONCAT(ARM_DOT, k0) \
952 ((a), (b##0), (c.s0)); \
953 CONCAT(ARM_DOT, k0) \
954 ((a), (b##1), (c.s1)); \
955 CONCAT(ARM_DOT, k0) \
956 ((a), (b##2), (c.s2)); \
957 CONCAT(ARM_DOT, k0) \
958 ((a), (b##3), (c.s3)); \
959 CONCAT(ARM_DOT, k0) \
960 ((a), (b##4), (c.s4)); \
961 CONCAT(ARM_DOT, k0) \
962 ((a), (b##5), (c.s5)); \
963 CONCAT(ARM_DOT, k0) \
964 ((a), (b##6), (c.s6)); \
965 CONCAT(ARM_DOT, k0) \
966 ((a), (b##7), (c.s7)); \
967 })
968#elif N0 == 16 // N0 == 16
969#define ARM_DOT_K0XN0(k0, a, b, c) \
970 ({ \
971 CONCAT(ARM_DOT, k0) \
972 ((a), (b##0), (c.s0)); \
973 CONCAT(ARM_DOT, k0) \
974 ((a), (b##1), (c.s1)); \
975 CONCAT(ARM_DOT, k0) \
976 ((a), (b##2), (c.s2)); \
977 CONCAT(ARM_DOT, k0) \
978 ((a), (b##3), (c.s3)); \
979 CONCAT(ARM_DOT, k0) \
980 ((a), (b##4), (c.s4)); \
981 CONCAT(ARM_DOT, k0) \
982 ((a), (b##5), (c.s5)); \
983 CONCAT(ARM_DOT, k0) \
984 ((a), (b##6), (c.s6)); \
985 CONCAT(ARM_DOT, k0) \
986 ((a), (b##7), (c.s7)); \
987 CONCAT(ARM_DOT, k0) \
988 ((a), (b##8), (c.s8)); \
989 CONCAT(ARM_DOT, k0) \
990 ((a), (b##9), (c.s9)); \
991 CONCAT(ARM_DOT, k0) \
992 ((a), (b##A), (c.sA)); \
993 CONCAT(ARM_DOT, k0) \
994 ((a), (b##B), (c.sB)); \
995 CONCAT(ARM_DOT, k0) \
996 ((a), (b##C), (c.sC)); \
997 CONCAT(ARM_DOT, k0) \
998 ((a), (b##D), (c.sD)); \
999 CONCAT(ARM_DOT, k0) \
1000 ((a), (b##E), (c.sE)); \
1001 CONCAT(ARM_DOT, k0) \
1002 ((a), (b##F), (c.sF)); \
1003 })
1004#else // N0 not supported
1005#error "N0 value not supported"
1006#endif // N0 conditions
1007
1008/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
1009 * The LHS matrix is NOT reshaped
1010 * The RHS is reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the block K0xN0 is transposed
1011 *
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001012 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
1013 * @note The GEMM's dimensions (M,N and K) must be passed at compile time using -DM, -DN and and -DK (i.e. -DM=52, -DN=30 and -DK=90)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001014 * @note The number of columns of LHS matrix must be passed at compile time using -DK (i.e. -DK=64)
1015 * @note The block's dimensions used for reshaping the RHS matrix (N0 and K0) must be passed at compile time using -DN0 and -DK0 (i.e. -DN0=8, -DK0=4).
1016 * @note The number of M0 rows to process must be passed at compile time using -DM0 (i.e. -DM0=2)
1017 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (i.e. -DH0=2)
1018 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
1019 * @note Only the following configurations of M0, N0 and K0 are currently supported:
1020 * - M0 = 1, 2, 3, 4, 5, 6, 7, 8
1021 * - N0 = 2, 3, 4, 8, 16
1022 * - K0 = 2, 3, 4, 8, 16
Gian Marco Iodice62251f72019-03-11 16:07:12 +00001023 * - H0 >= 1
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001024 *
1025 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
1026 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
1027 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1028 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1029 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1030 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix
1031 *
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001032 * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: F16/F32
1033 * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes)
1034 * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1035 * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes)
1036 * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1037 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
1038 * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
1039 * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes)
1040 * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1041 * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
1042 * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1043 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
1044 * @param[in] bias_ptr (Optional)Pointer to the bias reshaped matrix. Supported data type: same as @p lhs_ptr
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001045 * @param[in] bias_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
1046 * @param[in] bias_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
1047 * @param[in] bias_step_x (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes)
1048 * @param[in] bias_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
1049 * @param[in] bias_step_y (Optional) bias_stride_y * number of elements along Y processed per workitem(in bytes)
1050 * @param[in] bias_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001051 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
1052 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1053 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1054 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1055 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1056 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1057 * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes)
1058 * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001059 * @param[in] bias_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001060 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1061 * @param[in] lhs_cross_plane_pad (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
1062 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001063 */
1064__kernel void gemm_mm_reshaped_only_rhs_t(IMAGE_DECLARATION(lhs),
1065 IMAGE_DECLARATION(rhs),
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001066#if defined(BETA)
1067 IMAGE_DECLARATION(bias),
1068#endif // defined(BETA)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001069 IMAGE_DECLARATION(dst),
1070 uint lhs_stride_z,
1071 uint rhs_stride_z,
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001072#if defined(BETA)
1073 uint bias_stride_z,
1074#endif //defined(BETA)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001075 uint dst_stride_z
1076#if defined(REINTERPRET_INPUT_AS_3D)
1077 ,
1078 uint lhs_cross_plane_pad
1079#endif // REINTERPRET_INPUT_AS_3D
1080#if defined(REINTERPRET_OUTPUT_AS_3D)
1081 ,
1082 uint dst_cross_plane_pad
1083#endif // REINTERPRET_OUTPUT_AS_3D
1084 )
1085{
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001086 // Block size
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001087#define RHS_BLOCK_SIZE ((K0) * (N0))
1088
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001089 // RHS offset and step X
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001090#if defined(RHS_INTERLEAVE)
1091#define RHS_OFFSET_X (K0)
1092#define RHS_STEP_X ((K0) * (H0))
1093#define RHS_STEP_LOOP (1)
1094#else // defined(RHS_INTERLEAVE)
1095#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
1096#define RHS_STEP_X (K0)
1097#define RHS_STEP_LOOP (H0)
1098#endif // defined(RHS_INTERLEAVE)
1099
1100 uint x = get_global_id(0);
1101 uint y = get_global_id(1);
1102 uint z = get_global_id(2);
1103
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001104#if defined(DUMMY_WORK_ITEMS)
1105 if((x * N0 >= N) || (y * M0 >= M))
1106 {
1107 return;
1108 }
1109#endif // defined(DUMMY_WORK_ITEMS)
1110
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001111 // Compute LHS matrix address
1112 uint lhs_offset = lhs_offset_first_element_in_bytes + y * M0 * (uint)lhs_stride_y;
1113
1114 // Compute RHS matrix address
1115 uint rhs_offset = rhs_offset_first_element_in_bytes + (x % H0) * (uint)RHS_OFFSET_X * sizeof(DATA_TYPE) + (x / (uint)H0) * rhs_stride_y;
1116
1117#if defined(MATRIX_B_DEPTH)
1118 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1119 rhs_offset += (z % MATRIX_B_DEPTH) * rhs_stride_z;
1120#else // defined(MATRIX_B_DEPTH)
1121 rhs_offset += z * rhs_stride_z;
1122#endif // defined(MATRIX_B_DEPTH)
1123
Usama Arif0681e3b2019-04-25 14:28:07 +01001124 REPEAT_VAR_INIT_TO_CONST(8, uint, zlhs, 0); //uint zlhs0=0,zlhs1=0,zlhs2=0,... zlhs7=0;
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001125 REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001126
1127#if defined(REINTERPRET_INPUT_AS_3D)
Usama Arif0681e3b2019-04-25 14:28:07 +01001128 // The plane (zlhs) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
1129 CALCULATE_Z_OFFSET(M0, uint, zlhs, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001130
1131 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1132 // multiply lhs_stride_z by DEPTH_GEMM3D
1133 lhs_offset += z * lhs_stride_z * DEPTH_GEMM3D;
1134
1135#else // defined(REINTERPRET_INPUT_AS_3D)
1136
1137 // Add offset for batched GEMM
1138 lhs_offset += z * lhs_stride_z;
1139
1140#endif // defined(REINTERPRET_INPUT_AS_3D)
1141
1142 // Initialize the accumulators
1143 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(M0-1)=0;
1144
1145 int i = 0;
1146 for(; i <= (K - K0); i += K0)
1147 {
1148 // Supported cases (M0, K0):
1149 // 1,2 - 1,3 - 1,4 - 1,8 - 1,16
1150 // 2,2 - 2,3 - 2,4 - 2,8 - 2,16
1151 // 3,2 - 3,3 - 3,4 - 3,8 - 3,16
1152 // 4,2 - 4,3 - 4,4 - 4,8 - 4,16
1153 // 5,2 - 5,3 - 5,4 - 5,8 - 5,16
1154 // 6,2 - 6,3 - 6,4 - 6,8 - 6,16
1155 // 7,2 - 7,3 - 7,4 - 7,8 - 7,16
1156 // 8,2 - 8,3 - 8,4 - 8,8 - 8,16
1157 // Load values from LHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +01001158 LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001159
1160 // Load values from RHS matrix
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001161 LOAD_BLOCK(N0, K0, DATA_TYPE, b, rhs_ptr, rhs_offset, RHS_STEP_X * sizeof(DATA_TYPE), zero);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001162
1163 // Accumulate
1164 ARM_DOT_K0XN0(K0, a0, b, c0);
1165#if M0 > 1
1166 ARM_DOT_K0XN0(K0, a1, b, c1);
1167#endif // M0 > 1
1168#if M0 > 2
1169 ARM_DOT_K0XN0(K0, a2, b, c2);
1170#endif // M0 > 2
1171#if M0 > 3
1172 ARM_DOT_K0XN0(K0, a3, b, c3);
1173#endif // M0 > 3
1174#if M0 > 4
1175 ARM_DOT_K0XN0(K0, a4, b, c4);
1176#endif // M0 > 4
1177#if M0 > 5
1178 ARM_DOT_K0XN0(K0, a5, b, c5);
1179#endif // M0 > 5
1180#if M0 > 6
1181 ARM_DOT_K0XN0(K0, a6, b, c6);
1182#endif // M0 > 6
1183#if M0 > 7
1184 ARM_DOT_K0XN0(K0, a7, b, c7);
1185#endif // M0 > 7
1186
1187 lhs_offset += K0 * sizeof(DATA_TYPE);
1188 rhs_offset += (N0 * RHS_STEP_X * RHS_STEP_LOOP) * sizeof(DATA_TYPE);
1189 }
1190
1191 // Left-over accumulations
1192 for(; i < K; ++i)
1193 {
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001194 // Load values from LHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +01001195 LOAD_BLOCK(M0, 1, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001196
1197 // Load values from RHS matrix
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001198 LOAD_BLOCK(N0, 1, DATA_TYPE, b, rhs_ptr, rhs_offset, RHS_STEP_X * sizeof(DATA_TYPE), zero);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001199
1200 // Accumulate
1201 ARM_DOT_K0XN0(1, a0, b, c0);
1202#if M0 > 1
1203 ARM_DOT_K0XN0(1, a1, b, c1);
1204#endif // M0 > 1
1205#if M0 > 2
1206 ARM_DOT_K0XN0(1, a2, b, c2);
1207#endif // M0 > 2
1208#if M0 > 3
1209 ARM_DOT_K0XN0(1, a3, b, c3);
1210#endif // M0 > 3
1211#if M0 > 4
1212 ARM_DOT_K0XN0(1, a4, b, c4);
1213#endif // M0 > 4
1214#if M0 > 5
1215 ARM_DOT_K0XN0(1, a5, b, c5);
1216#endif // M0 > 5
1217#if M0 > 6
1218 ARM_DOT_K0XN0(1, a6, b, c6);
1219#endif // M0 > 6
1220#if M0 > 7
1221 ARM_DOT_K0XN0(1, a7, b, c7);
1222#endif // M0 > 7
1223
1224 lhs_offset += sizeof(DATA_TYPE);
1225 rhs_offset += sizeof(DATA_TYPE);
1226 }
1227
1228 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * dst_stride_y);
1229
1230 REPEAT_VAR_INIT_TO_CONST(8, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
1231
1232#if defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001233
1234 // The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
Usama Arif0681e3b2019-04-25 14:28:07 +01001235 CALCULATE_Z_OFFSET(M0, uint, zout, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001236
1237 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1238 // multiply dst_stride_z by DEPTH_GEMM3D
1239 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1240
1241#else // defined(REINTERPRET_OUTPUT_AS_3D)
1242
1243 // Add offset for batched GEMM
1244 dst_addr += z * dst_stride_z;
1245
1246#endif // defined(REINTERPRET_OUTPUT_AS_3D)
1247
1248 // Multiply by the weight of matrix-matrix product and store the result
1249#if defined(ALPHA)
Usama Arif0681e3b2019-04-25 14:28:07 +01001250 SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001251#endif // defined(ALPHA)
1252
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001253 // Add beta*bias
1254#if defined(BETA)
1255#if defined(BROADCAST_BIAS)
1256 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
1257
1258 LOAD_BLOCK(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
1259
1260#ifndef UNIT_BETA
1261 SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
1262#endif // UNIT_BIAS
1263
1264 // c = c + bias[broadcasted]
1265 ADD_BLOCK_BROADCAST(M0, c, bias0);
1266
1267#else // defined(BROADCAST_BIAS)
1268 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id(
1269 2) * bias_stride_z;
1270
1271 LOAD_BLOCK(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
1272
1273#ifndef UNIT_BETA
1274 SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
1275#endif // UNIT_BIAS
1276
1277 // c = c + bias
1278 ADD_BLOCK(M0, c, bias);
1279
1280#endif // defined(BROADCAST_BIAS)
1281#endif // defined(BETA)
1282
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001283 // Store output block
Usama Arif0681e3b2019-04-25 14:28:07 +01001284 STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001285
1286#undef RHS_BLOCK_SIZE
1287#undef RHS_OFFSET_X
1288#undef RHS_STEP_X
1289}
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001290
1291#define VFMA(a, b, c) \
1292 ({ \
1293 c = fma(a, b, c); \
1294 })
1295
1296#if M0 == 1
1297#define LD_RHS_VFMA_M0xN0(i, a, c) \
1298 ({ \
1299 VEC_DATA_TYPE(DATA_TYPE, N0) \
1300 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1301 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1302 })
1303#elif M0 == 2 // M0 == 2
1304#define LD_RHS_VFMA_M0xN0(i, a, c) \
1305 ({ \
1306 VEC_DATA_TYPE(DATA_TYPE, N0) \
1307 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1308 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1309 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1310 })
1311#elif M0 == 3 // M0 == 3
1312#define LD_RHS_VFMA_M0xN0(i, a, c) \
1313 ({ \
1314 VEC_DATA_TYPE(DATA_TYPE, N0) \
1315 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1316 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1317 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1318 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1319 })
1320#elif M0 == 4 // M0 == 4
1321#define LD_RHS_VFMA_M0xN0(i, a, c) \
1322 ({ \
1323 VEC_DATA_TYPE(DATA_TYPE, N0) \
1324 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1325 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1326 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1327 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1328 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1329 })
1330#elif M0 == 5 // M0 == 5
1331#define LD_RHS_VFMA_M0xN0(i, a, c) \
1332 ({ \
1333 VEC_DATA_TYPE(DATA_TYPE, N0) \
1334 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1335 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1336 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1337 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1338 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1339 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
1340 })
1341#elif M0 == 6 // M0 == 6
1342#define LD_RHS_VFMA_M0xN0(i, a, c) \
1343 ({ \
1344 VEC_DATA_TYPE(DATA_TYPE, N0) \
1345 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1346 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1347 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1348 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1349 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1350 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
1351 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
1352 })
1353#elif M0 == 7 // M0 == 7
1354#define LD_RHS_VFMA_M0xN0(i, a, c) \
1355 ({ \
1356 VEC_DATA_TYPE(DATA_TYPE, N0) \
1357 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1358 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1359 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1360 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1361 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1362 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
1363 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
1364 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##6).s##i), b, (c##6)); \
1365 })
1366#elif M0 == 8 // M0 == 8
1367#define LD_RHS_VFMA_M0xN0(i, a, c) \
1368 ({ \
1369 VEC_DATA_TYPE(DATA_TYPE, N0) \
1370 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1371 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1372 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1373 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1374 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1375 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
1376 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
1377 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##6).s##i), b, (c##6)); \
1378 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##7).s##i), b, (c##7)); \
1379 })
1380#else // M0 not supported
1381#error "M0 not supported"
1382#endif // M0 not supported
1383
1384/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
1385 * The LHS matrix is NOT reshaped
1386 * The RHS is reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the block K0xN0 is NOT transposed
1387 *
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001388 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
1389 * @note The GEMM's dimensions (M,N and K) must be passed at compile time using -DM, -DN and and -DK (i.e. -DM=52, -DN=30 and -DK=90).
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001390 * @note The block's dimensions used for reshaping the RHS matrix (N0 and K0) must be passed at compile time using -DN0 and -DK0 (i.e. -DN0=8, -DK0=4).
1391 * @note The number of M0 rows to process must be passed at compile time using -DM0 (i.e. -DM0=2)
1392 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (i.e. -DH0=2)
1393 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
1394 * @note Only the following configurations of M0, N0 and K0 are currently supported:
1395 * - M0 = 1, 2, 3, 4, 5, 6, 7, 8
1396 * - N0 = 2, 3, 4, 8, 16
1397 * - K0 = 2, 3, 4, 8, 16
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001398 * - H0 >= 1
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001399 *
1400 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
1401 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
1402 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1403 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1404 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1405 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix
1406 *
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001407 * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: F16/F32
1408 * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes)
1409 * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1410 * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes)
1411 * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1412 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
1413 * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
1414 * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes)
1415 * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1416 * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
1417 * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1418 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001419 * @param[in] bias_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
1420 * @param[in] bias_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001421 * @param[in] bias_step_x (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes)
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001422 * @param[in] bias_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001423 * @param[in] bias_step_y (Optional) bias_stride_y * number of elements along Y processed per workitem(in bytes)
1424 * @param[in] bias_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
1425 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
1426 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1427 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1428 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1429 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1430 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1431 * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes)
1432 * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001433 * @param[in] bias_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001434 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1435 * @param[in] lhs_cross_plane_pad (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
1436 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001437 */
1438__kernel void gemm_mm_reshaped_only_rhs_nt(IMAGE_DECLARATION(lhs),
1439 IMAGE_DECLARATION(rhs),
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001440#if defined(BETA)
1441 IMAGE_DECLARATION(bias),
1442#endif // defined(BETA)
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001443 IMAGE_DECLARATION(dst),
1444 uint lhs_stride_z,
1445 uint rhs_stride_z,
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001446#if defined(BETA)
1447 uint bias_stride_z,
1448#endif //defined(BETA)
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001449 uint dst_stride_z
1450#if defined(REINTERPRET_INPUT_AS_3D)
1451 ,
1452 uint lhs_cross_plane_pad
1453#endif // REINTERPRET_INPUT_AS_3D
1454#if defined(REINTERPRET_OUTPUT_AS_3D)
1455 ,
1456 uint dst_cross_plane_pad
1457#endif // REINTERPRET_OUTPUT_AS_3D
1458 )
1459{
1460 // Block size
1461#define RHS_BLOCK_SIZE ((K0) * (N0))
1462
1463 // RHS offset and step X
1464#if defined(RHS_INTERLEAVE)
1465#define RHS_OFFSET_X (N0)
1466#define RHS_STEP_X ((N0) * (H0))
1467#define RHS_STEP_LOOP (1)
1468#else // defined(RHS_INTERLEAVE)
1469#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
1470#define RHS_STEP_X (N0)
1471#define RHS_STEP_LOOP (H0)
1472#endif // defined(RHS_INTERLEAVE)
1473
1474 uint x = get_global_id(0);
1475 uint y = get_global_id(1);
1476 uint z = get_global_id(2);
1477
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001478#if defined(DUMMY_WORK_ITEMS)
1479 if((x * N0 >= N) || (y * M0 >= M))
1480 {
1481 return;
1482 }
1483#endif // defined(DUMMY_WORK_ITEMS)
1484
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001485 // Compute LHS matrix address
1486 uint lhs_offset = lhs_offset_first_element_in_bytes + y * M0 * (uint)lhs_stride_y;
1487
1488 // Compute RHS matrix address
1489 uint rhs_offset = rhs_offset_first_element_in_bytes + (x % H0) * (uint)RHS_OFFSET_X * sizeof(DATA_TYPE) + (x / (uint)H0) * rhs_stride_y;
1490
1491#if defined(MATRIX_B_DEPTH)
1492 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1493 rhs_offset += (z % MATRIX_B_DEPTH) * rhs_stride_z;
1494#else // defined(MATRIX_B_DEPTH)
1495 rhs_offset += z * rhs_stride_z;
1496#endif // defined(MATRIX_B_DEPTH)
1497
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001498 REPEAT_VAR_INIT_TO_CONST(8, uint, zin, 0); //uint zin0=0,zin1=0,zin2=0,... zin7=0;
1499 REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0); //uint zero0=0,zero1=0,zero2=0,... zero7=0;
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001500
1501#if defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001502
1503 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
Usama Arif0681e3b2019-04-25 14:28:07 +01001504 CALCULATE_Z_OFFSET(M0, uint, zin, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001505
1506 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1507 // multiply lhs_stride_z by DEPTH_GEMM3D
1508 lhs_offset += z * lhs_stride_z * DEPTH_GEMM3D;
1509
1510#else // defined(REINTERPRET_INPUT_AS_3D)
1511
1512 // Add offset for batched GEMM
1513 lhs_offset += z * lhs_stride_z;
1514
1515#endif // defined(REINTERPRET_INPUT_AS_3D)
1516
1517 // Initialize the accumulators
1518 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(N0-1)=0;
1519
1520 int i = 0;
1521 for(; i <= (K - K0); i += K0)
1522 {
1523 // Supported cases (M0, K0):
1524 // 1,2 - 1,3 - 1,4 - 1,8 - 1,16
1525 // 2,2 - 2,3 - 2,4 - 2,8 - 2,16
1526 // 3,2 - 3,3 - 3,4 - 3,8 - 3,16
1527 // 4,2 - 4,3 - 4,4 - 4,8 - 4,16
1528 // 5,2 - 5,3 - 5,4 - 5,8 - 5,16
1529 // 6,2 - 6,3 - 6,4 - 6,8 - 6,16
1530 // 7,2 - 7,3 - 7,4 - 7,8 - 7,16
1531 // 8,2 - 8,3 - 8,4 - 8,8 - 8,16
1532 // Load values from LHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +01001533 LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zin);
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001534
1535 LD_RHS_VFMA_M0xN0(0, a, c);
1536 LD_RHS_VFMA_M0xN0(1, a, c);
1537#if K0 > 2
1538 LD_RHS_VFMA_M0xN0(2, a, c);
1539#endif // K0 > 2
1540#if K0 > 3
1541 LD_RHS_VFMA_M0xN0(3, a, c);
1542#endif // K0 > 3
1543#if K0 > 4
1544 LD_RHS_VFMA_M0xN0(4, a, c);
1545 LD_RHS_VFMA_M0xN0(5, a, c);
1546 LD_RHS_VFMA_M0xN0(6, a, c);
1547 LD_RHS_VFMA_M0xN0(7, a, c);
1548#endif // K0 > 4
1549#if K0 > 8
1550 LD_RHS_VFMA_M0xN0(8, a, c);
1551 LD_RHS_VFMA_M0xN0(9, a, c);
1552 LD_RHS_VFMA_M0xN0(A, a, c);
1553 LD_RHS_VFMA_M0xN0(B, a, c);
1554 LD_RHS_VFMA_M0xN0(C, a, c);
1555 LD_RHS_VFMA_M0xN0(D, a, c);
1556 LD_RHS_VFMA_M0xN0(E, a, c);
1557 LD_RHS_VFMA_M0xN0(F, a, c);
1558#endif // K0 > 8
1559
1560 lhs_offset += K0 * sizeof(DATA_TYPE);
1561 rhs_offset += K0 * RHS_STEP_X * RHS_STEP_LOOP * sizeof(DATA_TYPE);
1562 }
1563
1564 // Left-over accumulations
1565 for(; i < K; ++i)
1566 {
1567 // Load values from LHS matrix
1568 VEC_DATA_TYPE(DATA_TYPE, 2)
1569 a0 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 0 * lhs_stride_y + zin0));
1570#if M0 > 1
1571 VEC_DATA_TYPE(DATA_TYPE, 2)
1572 a1 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 1 * lhs_stride_y + zin1));
1573#endif // M0 > 1
1574#if M0 > 2
1575 VEC_DATA_TYPE(DATA_TYPE, 2)
1576 a2 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 2 * lhs_stride_y + zin2));
1577#endif // M0 > 2
1578#if M0 > 3
1579 VEC_DATA_TYPE(DATA_TYPE, 2)
1580 a3 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 3 * lhs_stride_y + zin3));
1581#endif // M0 > 3
1582#if M0 > 4
1583 VEC_DATA_TYPE(DATA_TYPE, 2)
1584 a4 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 4 * lhs_stride_y + zin4));
1585#endif // M0 > 4
1586#if M0 > 5
1587 VEC_DATA_TYPE(DATA_TYPE, 2)
1588 a5 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 5 * lhs_stride_y + zin5));
1589#endif // M0 > 5
1590#if M0 > 6
1591 VEC_DATA_TYPE(DATA_TYPE, 2)
1592 a6 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 6 * lhs_stride_y + zin6));
1593#endif // M0 > 6
1594#if M0 > 7
1595 VEC_DATA_TYPE(DATA_TYPE, 2)
giuros01b3204e72019-04-01 13:50:22 +01001596 a7 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 7 * lhs_stride_y + zin7));
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001597#endif // M0 > 7
1598
1599 LD_RHS_VFMA_M0xN0(0, a, c);
1600
1601 lhs_offset += sizeof(DATA_TYPE);
1602 rhs_offset += RHS_STEP_X * sizeof(DATA_TYPE);
1603 }
1604
1605 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * dst_stride_y);
1606
1607 REPEAT_VAR_INIT_TO_CONST(8, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
1608
1609#if defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001610 // The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
Usama Arif0681e3b2019-04-25 14:28:07 +01001611 CALCULATE_Z_OFFSET(M0, uint, zout, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001612
1613 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1614 // multiply dst_stride_z by DEPTH_GEMM3D
1615 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1616
1617#else // defined(REINTERPRET_OUTPUT_AS_3D)
1618
1619 // Add offset for batched GEMM
1620 dst_addr += z * dst_stride_z;
1621
1622#endif // defined(REINTERPRET_OUTPUT_AS_3D)
1623
1624 // Multiply by the weight of matrix-matrix product and store the result
1625#if defined(ALPHA)
Usama Arif0681e3b2019-04-25 14:28:07 +01001626 SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001627#endif // defined(ALPHA)
1628
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001629 // Add beta*bias
1630#if defined(BETA)
1631#if defined(BROADCAST_BIAS)
1632 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
1633
1634 LOAD_BLOCK(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
1635
1636#ifndef UNIT_BETA
1637 SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
1638#endif // UNIT_BIAS
1639
1640 // c = c + bias[broadcasted]
1641 ADD_BLOCK_BROADCAST(M0, c, bias0);
1642
1643#else // defined(BROADCAST_BIAS)
1644 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id(
1645 2) * bias_stride_z;
1646
1647 LOAD_BLOCK(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
1648
1649#ifndef UNIT_BETA
1650 SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
1651#endif // UNIT_BIAS
1652
1653 // c = c + bias
1654 ADD_BLOCK(M0, c, bias);
1655
1656#endif // defined(BROADCAST_BIAS)
1657#endif // defined(BETA)
1658
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001659 // Store output block
Usama Arif0681e3b2019-04-25 14:28:07 +01001660 STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001661
1662#undef RHS_BLOCK_SIZE
1663#undef RHS_OFFSET_X
1664#undef RHS_STEP_X
1665}
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001666#endif // defined(M0) && defined(N0) && defined(K0) && defined(H0) && defined(DATA_TYPE) && defined(M) && defined(N) && defined(K)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001667
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001668#if defined(M0) && defined(N0) && defined(K0) && defined(V0) && defined(H0) && defined(DATA_TYPE) && defined(M) && defined(N)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001669
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001670#if K0 == 2
1671#define ARM_DOT_K0(a, b, c) \
1672 ({ \
1673 c = fma(a.s0, b.s0, c); \
1674 c = fma(a.s1, b.s1, c); \
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001675 })
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001676#elif K0 == 3 // K0 == 3
1677#define ARM_DOT_K0(a, b, c) \
1678 ({ \
1679 c = fma(a.s0, b.s0, c); \
1680 c = fma(a.s1, b.s1, c); \
1681 c = fma(a.s2, b.s2, c); \
1682 })
1683#elif K0 == 4 // K0 == 4
1684#define ARM_DOT_K0(a, b, c) \
1685 ({ \
1686 c = fma(a.s0, b.s0, c); \
1687 c = fma(a.s1, b.s1, c); \
1688 c = fma(a.s2, b.s2, c); \
1689 c = fma(a.s3, b.s3, c); \
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001690 })
1691#elif K0 == 8 // K0 == 8
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001692#define ARM_DOT_K0(a, b, c) \
1693 ({ \
1694 c = fma(a.s0, b.s0, c); \
1695 c = fma(a.s1, b.s1, c); \
1696 c = fma(a.s2, b.s2, c); \
1697 c = fma(a.s3, b.s3, c); \
1698 c = fma(a.s4, b.s4, c); \
1699 c = fma(a.s5, b.s5, c); \
1700 c = fma(a.s6, b.s6, c); \
1701 c = fma(a.s7, b.s7, c); \
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001702 })
1703#elif K0 == 16 // K0 == 16
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001704#define ARM_DOT_K0(a, b, c) \
1705 ({ \
1706 c = fma(a.s0, b.s0, c); \
1707 c = fma(a.s1, b.s1, c); \
1708 c = fma(a.s2, b.s2, c); \
1709 c = fma(a.s3, b.s3, c); \
1710 c = fma(a.s4, b.s4, c); \
1711 c = fma(a.s5, b.s5, c); \
1712 c = fma(a.s6, b.s6, c); \
1713 c = fma(a.s7, b.s7, c); \
1714 c = fma(a.s8, b.s8, c); \
1715 c = fma(a.s9, b.s9, c); \
1716 c = fma(a.sA, b.sA, c); \
1717 c = fma(a.sB, b.sB, c); \
1718 c = fma(a.sC, b.sC, c); \
1719 c = fma(a.sD, b.sD, c); \
1720 c = fma(a.sE, b.sE, c); \
1721 c = fma(a.sF, b.sF, c); \
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001722 })
1723#else // K0 not supported
1724#error "K0 value not supported"
1725#endif // K0 conditions
1726
1727#if N0 == 2
1728#define ARM_DOT_K0XN0(a, b, c) \
1729 ({ \
1730 ARM_DOT_K0((a), (b##0), (c.s0)); \
1731 ARM_DOT_K0((a), (b##1), (c.s1)); \
1732 })
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001733#elif N0 == 3 // N0 == 3
1734#define ARM_DOT_K0XN0(a, b, c) \
1735 ({ \
1736 ARM_DOT_K0((a), (b##0), (c.s0)); \
1737 ARM_DOT_K0((a), (b##1), (c.s1)); \
1738 ARM_DOT_K0((a), (b##2), (c.s2)); \
1739 })
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001740#elif N0 == 4 // N0 == 4
1741#define ARM_DOT_K0XN0(a, b, c) \
1742 ({ \
1743 ARM_DOT_K0((a), (b##0), (c.s0)); \
1744 ARM_DOT_K0((a), (b##1), (c.s1)); \
1745 ARM_DOT_K0((a), (b##2), (c.s2)); \
1746 ARM_DOT_K0((a), (b##3), (c.s3)); \
1747 })
1748#elif N0 == 8 // N0 == 8
1749#define ARM_DOT_K0XN0(a, b, c) \
1750 ({ \
1751 ARM_DOT_K0((a), (b##0), (c.s0)); \
1752 ARM_DOT_K0((a), (b##1), (c.s1)); \
1753 ARM_DOT_K0((a), (b##2), (c.s2)); \
1754 ARM_DOT_K0((a), (b##3), (c.s3)); \
1755 ARM_DOT_K0((a), (b##4), (c.s4)); \
1756 ARM_DOT_K0((a), (b##5), (c.s5)); \
1757 ARM_DOT_K0((a), (b##6), (c.s6)); \
1758 ARM_DOT_K0((a), (b##7), (c.s7)); \
1759 })
1760#elif N0 == 16 // N0 == 16
1761#define ARM_DOT_K0XN0(a, b, c) \
1762 ({ \
1763 ARM_DOT_K0((a), (b##0), (c.s0)); \
1764 ARM_DOT_K0((a), (b##1), (c.s1)); \
1765 ARM_DOT_K0((a), (b##2), (c.s2)); \
1766 ARM_DOT_K0((a), (b##3), (c.s3)); \
1767 ARM_DOT_K0((a), (b##4), (c.s4)); \
1768 ARM_DOT_K0((a), (b##5), (c.s5)); \
1769 ARM_DOT_K0((a), (b##6), (c.s6)); \
1770 ARM_DOT_K0((a), (b##7), (c.s7)); \
1771 ARM_DOT_K0((a), (b##8), (c.s8)); \
1772 ARM_DOT_K0((a), (b##9), (c.s9)); \
1773 ARM_DOT_K0((a), (b##A), (c.sA)); \
1774 ARM_DOT_K0((a), (b##B), (c.sB)); \
1775 ARM_DOT_K0((a), (b##C), (c.sC)); \
1776 ARM_DOT_K0((a), (b##D), (c.sD)); \
1777 ARM_DOT_K0((a), (b##E), (c.sE)); \
1778 ARM_DOT_K0((a), (b##F), (c.sF)); \
1779 })
1780#else // N0 not supported
1781#error "N0 value not supported"
1782#endif // N0 conditions
1783
1784/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
1785 * The LHS matrix must be reshaped with @ref CLGEMMReshapeLHSMatrixKernel and the M0xK0 must be NOT transposed
1786 * The RHS matrix must be reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the K0xN0 must be transposed
1787 *
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001788 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
1789 * @note The GEMM's dimensions M and N must be passed at compile time using -DM and -DN (i.e. -DM=52 and -DN=90).
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001790 * @note The block's dimensions used for reshaping the LHS matrix and the RHS matrix (M0, N0 and K0) must be passed at compile time using -DM0, -DN0 and -DK0 (i.e. -DM0=4, -DN0=8, -DK0=4).
1791 * @note The number of M0xK0 vertical blocks stored on the same output row of the reshaped LHS matrix must be passed at compile time using -DV0 (i.e. -DV0=2)
1792 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (i.e. -DH0=2)
1793 * @note If the M0xK0 blocks in the reshaped LHS matrix have been interleaved, the option -DLHS_INTERLEAVE must passed at compile time.
1794 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
1795 * @note Only the following configurations of M0, N0 and K0 are currently supported:
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001796 * - M0 = 1, 2, 3, 4, 5, 6, 7, 8
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001797 * - N0 = 2, 3, 4, 8, 16
1798 * - K0 = 2, 3, 4, 8, 16
Gian Marco Iodice62251f72019-03-11 16:07:12 +00001799 * - V0 >= 1
1800 * - H0 >= 1
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001801 *
1802 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
1803 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1804 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1805 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1806 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix NOT reshaped
1807 *
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001808 * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: F16/F32
1809 * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes)
1810 * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1811 * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes)
1812 * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1813 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
1814 * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
1815 * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes)
1816 * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1817 * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
1818 * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1819 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
1820 * @param[in] bias_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
1821 * @param[in] bias_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
1822 * @param[in] bias_step_x (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes)
1823 * @param[in] bias_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
1824 * @param[in] bias_step_y (Optional) bias_stride_y * number of elements along Y processed per workitem(in bytes)
1825 * @param[in] bias_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
1826 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
1827 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1828 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1829 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1830 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1831 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1832 * @param[in] k Number of columns in LHS matrix and rows in RHS matrix not reshaped.
1833 * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes)
1834 * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
1835 * @param[in] bias_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
1836 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1837 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001838 */
1839__kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs),
1840 IMAGE_DECLARATION(rhs),
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001841#if defined(BETA)
1842 IMAGE_DECLARATION(bias),
1843#endif // defined(BETA)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001844 IMAGE_DECLARATION(dst),
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001845 uint k,
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001846 uint lhs_stride_z,
1847 uint rhs_stride_z,
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001848#if defined(BETA)
1849 uint bias_stride_z,
1850#endif //defined(BETA)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001851 uint dst_stride_z
1852#if defined(REINTERPRET_OUTPUT_AS_3D)
1853 ,
1854 uint dst_cross_plane_pad
1855#endif // REINTERPRET_OUTPUT_AS_3D
1856 )
1857{
1858 // Block size
1859#define LHS_BLOCK_SIZE ((K0) * (M0))
1860
1861#if defined(LHS_INTERLEAVE)
1862#define LHS_OFFSET_X (K0)
1863#define LHS_STEP_X ((K0) * (V0))
1864#define LHS_STEP_LOOP (1)
1865#else // defined(INTERLEAVE)
1866#define LHS_OFFSET_X (LHS_BLOCK_SIZE)
1867#define LHS_STEP_X (K0)
1868#define LHS_STEP_LOOP (V0)
1869#endif // defined(INTERLEAVE)
1870
1871 // Block size
1872#define RHS_BLOCK_SIZE ((K0) * (N0))
1873
1874 // RHS offset and step X
1875#if defined(RHS_INTERLEAVE)
1876#define RHS_OFFSET_X (K0)
1877#define RHS_STEP_X ((K0) * (H0))
1878#define RHS_STEP_LOOP (1)
1879#else // defined(RHS_INTERLEAVE)
1880#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
1881#define RHS_STEP_X (K0)
1882#define RHS_STEP_LOOP (H0)
1883#endif // defined(RHS_INTERLEAVE)
1884
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001885#if defined(DUMMY_WORK_ITEMS)
1886 if((get_global_id(0) * N0 >= N) || (get_global_id(1) * M0 >= M))
1887 {
1888 return;
1889 }
1890#endif // defined(DUMMY_WORK_ITEMS)
1891
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001892 // Compute LHS matrix address
1893 __global uchar *lhs_addr = lhs_ptr + lhs_offset_first_element_in_bytes + (get_global_id(1) % V0) * (uint)LHS_OFFSET_X * sizeof(DATA_TYPE) + (get_global_id(1) / V0) * (uint)lhs_stride_y +
1894 (get_global_id(2) * lhs_stride_z);
1895
1896 // Compute RHS matrix address
1897 __global uchar *rhs_addr = rhs_ptr + rhs_offset_first_element_in_bytes + (get_global_id(0) % H0) * (uint)RHS_OFFSET_X * sizeof(DATA_TYPE) + (get_global_id(0) / (uint)H0) * rhs_stride_y;
1898
1899#if defined(MATRIX_B_DEPTH)
1900 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1901 rhs_addr += (get_global_id(2) % MATRIX_B_DEPTH) * rhs_stride_z;
1902#else // defined(MATRIX_B_DEPTH)
1903 rhs_addr += get_global_id(2) * rhs_stride_z;
1904#endif // defined(MATRIX_B_DEPTH)
1905
1906 // Initialize the accumulators
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +00001907 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(M0-1)=0;
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001908
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001909 REPEAT_VAR_INIT_TO_CONST(M0, uint, zlhs, 0); //uint zlhs0=0,zlhs1=0,zlhs2=0,... zlhs7=0;
1910 REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0);
Usama Arif0681e3b2019-04-25 14:28:07 +01001911
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001912 for(int i = 0; i < k; i += K0)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001913 {
1914 // Supported cases (M0, K0):
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001915 // 1,2 - 1,3 - 1,4 - 1,8 - 1,16
1916 // 2,2 - 2,3 - 2,4 - 2,8 - 2,16
1917 // 3,2 - 3,3 - 3,4 - 3,8 - 3,16
1918 // 4,2 - 4,3 - 4,4 - 4,8 - 4,16
1919 // 5,2 - 5,3 - 5,4 - 5,8 - 5,16
1920 // 6,2 - 6,3 - 6,4 - 6,8 - 6,16
1921 // 7,2 - 7,3 - 7,4 - 7,8 - 7,16
1922 // 8,2 - 8,3 - 8,4 - 8,8 - 8,16
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001923 // Load values from LHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +01001924 LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_addr, 0, LHS_STEP_X * sizeof(DATA_TYPE), zlhs);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001925
1926 // Load values from RHS matrix
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001927 LOAD_BLOCK(N0, K0, DATA_TYPE, b, rhs_addr, 0, RHS_STEP_X * sizeof(DATA_TYPE), zero);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001928
1929 // Accumulate
1930 ARM_DOT_K0XN0(a0, b, c0);
1931#if M0 > 1
1932 ARM_DOT_K0XN0(a1, b, c1);
1933#endif // M0 > 1
1934#if M0 > 2
1935 ARM_DOT_K0XN0(a2, b, c2);
1936#endif // M0 > 2
1937#if M0 > 3
1938 ARM_DOT_K0XN0(a3, b, c3);
1939#endif // M0 > 3
1940#if M0 > 4
1941 ARM_DOT_K0XN0(a4, b, c4);
1942#endif // M0 > 4
1943#if M0 > 5
1944 ARM_DOT_K0XN0(a5, b, c5);
1945#endif // M0 > 5
1946#if M0 > 6
1947 ARM_DOT_K0XN0(a6, b, c6);
1948#endif // M0 > 6
1949#if M0 > 7
1950 ARM_DOT_K0XN0(a7, b, c7);
1951#endif // M0 > 7
1952
1953 lhs_addr += (M0 * LHS_STEP_X * LHS_STEP_LOOP) * sizeof(DATA_TYPE);
1954 rhs_addr += (N0 * RHS_STEP_X * RHS_STEP_LOOP) * sizeof(DATA_TYPE);
1955 }
1956
1957 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * dst_stride_y);
1958
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001959 REPEAT_VAR_INIT_TO_CONST(M0, uint, zout, 0);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001960
1961#if defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001962
1963 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
Usama Arif0681e3b2019-04-25 14:28:07 +01001964 CALCULATE_Z_OFFSET(M0, uint, zout, get_global_id(1), HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001965 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1966 // multiply dst_stride_z by DEPTH_GEMM3D
1967 dst_addr += get_global_id(2) * dst_stride_z * DEPTH_GEMM3D;
1968
1969#else // defined(REINTERPRET_OUTPUT_AS_3D)
1970
1971 // Add offset for batched GEMM
1972 dst_addr += get_global_id(2) * dst_stride_z;
1973
1974#endif // defined(REINTERPRET_OUTPUT_AS_3D)
1975
1976 // Multiply by the weight of matrix-matrix product and store the result
1977#if defined(ALPHA)
Usama Arif0681e3b2019-04-25 14:28:07 +01001978 SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001979#endif // defined(ALPHA)
1980
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001981 // Add beta*bias
1982#if defined(BETA)
1983#if defined(BROADCAST_BIAS)
1984 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
1985
1986 LOAD_BLOCK(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
1987
1988#ifndef UNIT_BETA
1989 SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
1990#endif // UNIT_BIAS
1991
1992 // c = c + bias[broadcasted]
1993 ADD_BLOCK_BROADCAST(M0, c, bias0);
1994
1995#else // defined(BROADCAST_BIAS)
1996 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id(
1997 2) * bias_stride_z;
1998
1999 LOAD_BLOCK(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
2000
2001#ifndef UNIT_BETA
2002 SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
2003#endif // UNIT_BIAS
2004
2005 // c = c + bias
2006 ADD_BLOCK(M0, c, bias);
2007
2008#endif // defined(BROADCAST_BIAS)
2009#endif // defined(BETA)
2010
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002011 // Store output block
Usama Arif0681e3b2019-04-25 14:28:07 +01002012 STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01002013
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002014#undef LHS_BLOCK_SIZE
2015#undef LHS_OFFSET_X
2016#undef LHS_STEP_X
2017#undef RHS_BLOCK_SIZE
2018#undef RHS_OFFSET_X
2019#undef RHS_STEP_X
2020}
giuros01b3204e72019-04-01 13:50:22 +01002021
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002022#endif // defined(M0) && defined(N0) && defined(K0) && defined(V0) && defined(H0) && defined(K) && defined(DATA_TYPE)
2023
giuros01b3204e72019-04-01 13:50:22 +01002024#if defined(M0) && defined(N0) && defined(K0) && defined(K) && defined(DATA_TYPE)
2025
2026#define VFMA(a, b, c) \
2027 ({ \
2028 c = fma(a, b, c); \
2029 })
2030
2031#if M0 == 1
2032#define RHS_VFMA_M0xN0(i, a, b, c) \
2033 ({ \
2034 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
2035 })
2036#elif M0 == 2 // M0 == 2
2037#define RHS_VFMA_M0xN0(i, a, b, c) \
2038 ({ \
2039 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
2040 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
2041 })
2042#elif M0 == 3 // M0 == 3
2043#define RHS_VFMA_M0xN0(i, a, b, c) \
2044 ({ \
2045 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
2046 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
2047 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
2048 })
2049#elif M0 == 4 // M0 == 4
2050#define RHS_VFMA_M0xN0(i, a, b, c) \
2051 ({ \
2052 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
2053 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
2054 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
2055 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
2056 })
2057#elif M0 == 5 // M0 == 5
2058#define RHS_VFMA_M0xN0(i, a, b, c) \
2059 ({ \
2060 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
2061 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
2062 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
2063 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
2064 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
2065 })
2066#elif M0 == 6 // M0 == 6
2067#define RHS_VFMA_M0xN0(i, a, b, c) \
2068 ({ \
2069 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
2070 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
2071 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
2072 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
2073 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
2074 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
2075 })
2076#elif M0 == 7 // M0 == 7
2077#define RHS_VFMA_M0xN0(i, a, b, c) \
2078 ({ \
2079 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
2080 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
2081 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
2082 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
2083 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
2084 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
2085 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##6).s##i), b, (c##6)); \
2086 })
2087#elif M0 == 8 // M0 == 8
2088#define RHS_VFMA_M0xN0(i, a, b, c) \
2089 ({ \
2090 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
2091 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
2092 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
2093 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
2094 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
2095 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
2096 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##6).s##i), b, (c##6)); \
2097 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##7).s##i), b, (c##7)); \
2098 })
2099#else // M0 not supported
2100#error "M0 not supported"
2101#endif // M0 not supported
2102
2103/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
2104 * The LHS matrix is NOT reshaped
2105 * The RHS matrix is NOT reshaped
2106 *
2107 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
2108 * @note The GEMM's dimensions (M,N and K) must be passed at compile time using -DM, -DN and and -DK (i.e. -DM=52, -DN=30 and -DK=90)
2109 * @note The number of columns of LHS matrix must be passed at compile time using -DK (i.e. -DK=64)
2110 * @note The number of M0 rows to process must be passed at compile time using -DM0 (i.e. -DM0=2)
2111 * @note The number of K0 partial accumulations must be passed at compile time using -DK0 (i.e., -DK0=2)
2112 * @note The number of N0 columns to process must be passed at compile time using -DN0 (i.e. -DN0=2)
2113 * @note Only the following configurations of M0, N0 and K0 are currently supported:
2114 * - M0 = 1, 2, 3, 4, 5, 6, 7, 8
2115 * - N0 = 2, 3, 4, 8, 16
2116 * - K0 = 2, 3, 4, 8, 16
2117 *
2118 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
2119 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
2120 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2121 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2122 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2123 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix
2124 *
2125 * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: F16/F32
2126 * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes)
2127 * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2128 * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes)
2129 * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2130 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
2131 * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
2132 * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes)
2133 * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2134 * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
2135 * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2136 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
2137 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
2138 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2139 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2140 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2141 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2142 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2143 * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes)
2144 * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
2145 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2146 * @param[in] lhs_cross_plane_pad (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
2147 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
2148 */
2149__kernel void gemm_mm_native(IMAGE_DECLARATION(lhs),
2150 IMAGE_DECLARATION(rhs),
2151 IMAGE_DECLARATION(dst),
2152 uint lhs_stride_z,
2153 uint rhs_stride_z,
2154 uint dst_stride_z
2155#if defined(REINTERPRET_INPUT_AS_3D)
2156 ,
2157 uint lhs_cross_plane_pad
2158#endif // REINTERPRET_INPUT_AS_3D
2159#if defined(REINTERPRET_OUTPUT_AS_3D)
2160 ,
2161 uint dst_cross_plane_pad
2162#endif // REINTERPRET_OUTPUT_AS_3D
2163 )
2164{
2165 // Block size
2166#define RHS_BLOCK_SIZE ((K0) * (N0))
2167
2168 // RHS offset and step X
2169#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
2170
2171 uint x = get_global_id(0);
2172 uint y = get_global_id(1);
2173 uint z = get_global_id(2);
2174
2175#if defined(DUMMY_WORK_ITEMS)
2176 if((x * N0 >= N) || (y * M0 >= M))
2177 {
2178 return;
2179 }
2180#endif // defined(DUMMY_WORK_ITEMS)
2181
2182 // Compute LHS matrix address
2183 uint lhs_offset = lhs_offset_first_element_in_bytes + y * M0 * (uint)lhs_stride_y;
2184
2185 // Compute RHS matrix address
2186 uint rhs_offset = rhs_offset_first_element_in_bytes + x * N0 * sizeof(DATA_TYPE);
2187
2188#if defined(MATRIX_B_DEPTH)
2189 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2190 rhs_offset += (z % MATRIX_B_DEPTH) * rhs_stride_z;
2191#else // defined(MATRIX_B_DEPTH)
2192 rhs_offset += z * rhs_stride_z;
2193#endif // defined(MATRIX_B_DEPTH)
2194
2195 REPEAT_VAR_INIT_TO_CONST(8, uint, zlhs, 0); //uint zlhs0=0,zlhs1=0,zlhs2=0,... zlhs7=0;
2196 REPEAT_VAR_INIT_TO_CONST(16, uint, zrhs, 0);
2197
2198#if defined(REINTERPRET_INPUT_AS_3D)
2199 // The plane (zlhs) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
2200 CALCULATE_Z_OFFSET(M0, uint, zlhs, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
2201
2202 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2203 // multiply lhs_stride_z by DEPTH_GEMM3D
2204 lhs_offset += z * lhs_stride_z * DEPTH_GEMM3D;
2205
2206#else // defined(REINTERPRET_INPUT_AS_3D)
2207
2208 // Add offset for batched GEMM
2209 lhs_offset += z * lhs_stride_z;
2210
2211#endif // defined(REINTERPRET_INPUT_AS_3D)
2212
2213 // Initialize the accumulators
2214 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(N0-1)=0;
2215
2216 int i = 0;
2217 for(; i <= (K - K0); i += K0)
2218 {
2219 // Supported cases (M0, K0):
2220 // 1,2 - 1,3 - 1,4 - 1,8 - 1,16
2221 // 2,2 - 2,3 - 2,4 - 2,8 - 2,16
2222 // 3,2 - 3,3 - 3,4 - 3,8 - 3,16
2223 // 4,2 - 4,3 - 4,4 - 4,8 - 4,16
2224 // 5,2 - 5,3 - 5,4 - 5,8 - 5,16
2225 // 6,2 - 6,3 - 6,4 - 6,8 - 6,16
2226 // 7,2 - 7,3 - 7,4 - 7,8 - 7,16
2227 // 8,2 - 8,3 - 8,4 - 8,8 - 8,16
2228 // Load values from LHS matrix
2229 LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs);
2230
2231 // Load values from RHS matrix
2232 LOAD_BLOCK(K0, N0, DATA_TYPE, b, rhs_ptr, rhs_offset, rhs_stride_y, zrhs);
2233
2234 RHS_VFMA_M0xN0(0, a, b0, c);
2235 RHS_VFMA_M0xN0(1, a, b1, c);
2236#if K0 > 2
2237 RHS_VFMA_M0xN0(2, a, b2, c);
2238#endif // K0 > 2
2239#if K0 > 3
2240 RHS_VFMA_M0xN0(3, a, b3, c);
2241#endif // K0 > 3
2242#if K0 > 4
2243 RHS_VFMA_M0xN0(4, a, b4, c);
2244 RHS_VFMA_M0xN0(5, a, b5, c);
2245 RHS_VFMA_M0xN0(6, a, b6, c);
2246 RHS_VFMA_M0xN0(7, a, b7, c);
2247#endif // K0 > 4
2248#if K0 > 8
2249 RHS_VFMA_M0xN0(8, a, b8, c);
2250 RHS_VFMA_M0xN0(9, a, b9, c);
2251 RHS_VFMA_M0xN0(A, a, b10, c);
2252 RHS_VFMA_M0xN0(B, a, b11, c);
2253 RHS_VFMA_M0xN0(C, a, b12, c);
2254 RHS_VFMA_M0xN0(D, a, b13, c);
2255 RHS_VFMA_M0xN0(E, a, b14, c);
2256 RHS_VFMA_M0xN0(F, a, b15, c);
2257#endif // K0 > 8
2258
2259 lhs_offset += K0 * sizeof(DATA_TYPE);
2260 rhs_offset += K0 * rhs_stride_y;
2261 }
2262
2263 // Left-over accumulations
2264 for(; i < K; ++i)
2265 {
2266 // Load values from LHS matrix
2267 VEC_DATA_TYPE(DATA_TYPE, 2)
2268 a0 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 0 * lhs_stride_y + zlhs0));
2269#if M0 > 1
2270 VEC_DATA_TYPE(DATA_TYPE, 2)
2271 a1 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 1 * lhs_stride_y + zlhs1));
2272#endif // M0 > 1
2273#if M0 > 2
2274 VEC_DATA_TYPE(DATA_TYPE, 2)
2275 a2 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 2 * lhs_stride_y + zlhs2));
2276#endif // M0 > 2
2277#if M0 > 3
2278 VEC_DATA_TYPE(DATA_TYPE, 2)
2279 a3 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 3 * lhs_stride_y + zlhs3));
2280#endif // M0 > 3
2281#if M0 > 4
2282 VEC_DATA_TYPE(DATA_TYPE, 2)
2283 a4 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 4 * lhs_stride_y + zlhs4));
2284#endif // M0 > 4
2285#if M0 > 5
2286 VEC_DATA_TYPE(DATA_TYPE, 2)
2287 a5 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 5 * lhs_stride_y + zlhs5));
2288#endif // M0 > 5
2289#if M0 > 6
2290 VEC_DATA_TYPE(DATA_TYPE, 2)
2291 a6 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 6 * lhs_stride_y + zlhs6));
2292#endif // M0 > 6
2293#if M0 > 7
2294 VEC_DATA_TYPE(DATA_TYPE, 2)
2295 a7 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 7 * lhs_stride_y + zlhs7));
2296#endif // M0 > 7
2297
2298 VEC_DATA_TYPE(DATA_TYPE, N0)
2299 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0 * rhs_stride_y));
2300 RHS_VFMA_M0xN0(0, a, b, c);
2301
2302 lhs_offset += sizeof(DATA_TYPE);
2303 rhs_offset += rhs_stride_y;
2304 }
2305
2306 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * dst_stride_y);
2307
2308 REPEAT_VAR_INIT_TO_CONST(8, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
2309
2310#if defined(REINTERPRET_OUTPUT_AS_3D)
2311 // The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
2312 CALCULATE_Z_OFFSET(M0, uint, zout, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
2313
2314 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2315 // multiply dst_stride_z by DEPTH_GEMM3D
2316 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
2317
2318#else // defined(REINTERPRET_OUTPUT_AS_3D)
2319
2320 // Add offset for batched GEMM
2321 dst_addr += z * dst_stride_z;
2322
2323#endif // defined(REINTERPRET_OUTPUT_AS_3D)
2324
2325 // Multiply by the weight of matrix-matrix product and store the result
2326 // Multiply by the weight of matrix-matrix product and store the result
2327#if defined(ALPHA)
2328 SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
2329#endif // defined(ALPHA)
2330
2331 // Store output block
2332 STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
2333
2334#undef RHS_BLOCK_SIZE
2335#undef RHS_OFFSET_X
2336#undef RHS_STEP_X
2337}
2338#endif // defined(M0) && defined(N0) && defined(K0) && defined(K) && defined(DATA_TYPE)
2339
Gian Marco36a0a462018-01-12 10:21:40 +00002340#if defined(COLS_B) && defined(MULT_TRANSPOSE1XW_WIDTH) && defined(MULT_INTERLEAVE4X4_HEIGHT)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002341/** This OpenCL kernel is optimised for Midgard. It computes the matrix multiplication between matrix A (src0) and matrix B (src1)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002342 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_32bit and @ref gemm_transpose1x4 before running the matrix multiplication
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002343 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002344 * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
2345 *
Gian Marco19835e52018-01-30 13:35:54 +00002346 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
2347 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
2348 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002349 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
2350 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002351 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002352 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
2353 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2354 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2355 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2356 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
2357 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002358 * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
2359 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002360 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
2361 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
2362 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2363 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
2364 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2365 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002366 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002367 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
2368 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2369 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
2370 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2371 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002372 * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
2373 * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes)
2374 * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
2375 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002376 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002377 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00002378 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002379 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00002380 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002381 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002382 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
2383 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
2384 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002385 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002386 */
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01002387__kernel void gemm_mm_interleaved_transposed_f32(IMAGE_DECLARATION(src0),
2388 IMAGE_DECLARATION(src1),
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002389#if defined(ADD_VEC_C)
2390 VECTOR_DECLARATION(src2),
2391#endif /* defined(ADD_VEC_C) */
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01002392 IMAGE_DECLARATION(dst),
2393 uint src0_stride_z,
2394 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002395 uint dst_stride_z
2396#if defined(REINTERPRET_OUTPUT_AS_3D)
2397 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002398 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002399#endif // REINTERPRET_OUTPUT_AS_3D
2400 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002401{
Gian Marco36a0a462018-01-12 10:21:40 +00002402 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
2403 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +00002404 int z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002405
Gian Marco36a0a462018-01-12 10:21:40 +00002406 // Offset
2407 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
2408 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 4;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002409
Gian Marco36a0a462018-01-12 10:21:40 +00002410 // src_addr_a = address of matrix A
2411 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002412 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
2413 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
2414
2415#if defined(MATRIX_B_DEPTH)
2416 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2417 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
2418#else // defined(MATRIX_B_DEPTH)
2419 src1_addr_in_bytes += z * src1_stride_z;
2420#endif // defined(MATRIX_B_DEPTH)
2421
2422 __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes);
2423 __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002424
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002425 // Compute end row address for matrix B
Gian Marco36a0a462018-01-12 10:21:40 +00002426 __global float *src_end_addr_b = src_addr_b + COLS_B;
2427
2428 src_addr_a += offset_row_a;
2429 src_addr_b += offset_row_b;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002430
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002431 // Reset accumulators
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002432 float4 c00 = 0.0f;
2433 float4 c10 = 0.0f;
2434 float4 c20 = 0.0f;
2435 float4 c30 = 0.0f;
2436
Gian Marco36a0a462018-01-12 10:21:40 +00002437 for(; src_addr_b <= (src_end_addr_b - (int)(8 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002438 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002439 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00002440 float4 a0 = vload4(0, src_addr_a);
2441 float4 b0 = vload4(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002442
2443 c00 += (float4)a0.s0 * b0;
2444 c10 += (float4)a0.s1 * b0;
2445 c20 += (float4)a0.s2 * b0;
2446 c30 += (float4)a0.s3 * b0;
2447
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002448 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00002449 a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
2450 b0 = vload4(0, src_addr_b + 4 * MULT_TRANSPOSE1XW_WIDTH);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002451
2452 c00 += (float4)a0.s0 * b0;
2453 c10 += (float4)a0.s1 * b0;
2454 c20 += (float4)a0.s2 * b0;
2455 c30 += (float4)a0.s3 * b0;
2456 }
2457
Gian Marco36a0a462018-01-12 10:21:40 +00002458 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002459 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002460 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00002461 float4 a0 = vload4(0, src_addr_a);
2462 float4 b0 = vload4(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002463
2464 c00 += (float4)a0.s0 * b0;
2465 c10 += (float4)a0.s1 * b0;
2466 c20 += (float4)a0.s2 * b0;
2467 c30 += (float4)a0.s3 * b0;
2468 }
2469
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002470 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002471 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2472
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002473#if defined(ALPHA)
2474 // Multiply by the weight of matrix product
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002475 c00 = c00 * (float4)ALPHA;
2476 c10 = c10 * (float4)ALPHA;
2477 c20 = c20 * (float4)ALPHA;
2478 c30 = c30 * (float4)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002479#endif // defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002480
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002481#if defined(ADD_VEC_C)
2482 __global float *src2_addr = (__global float *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
2483 float4 c0 = vload4(0, src2_addr);
2484
2485 c00 += c0;
2486 c10 += c0;
2487 c20 += c0;
2488 c30 += c0;
2489#endif /* defined(ADD_VEC_C) */
2490
Gian Marcoae2af742018-02-15 12:35:44 +00002491 // Compute dst address
2492 __global uchar *dst_addr = offset(&dst, 0, 0);
2493
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002494#if defined(REINTERPRET_OUTPUT_AS_3D)
2495 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002496 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002497 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002498 // | |
2499 // | plane0 |
2500 // | |
2501 // |__________________|
2502 // |******************|
2503 // | cross_plane_pad |
2504 // |******************|
2505 // | |
2506 // | plane1 |
2507 // | |
2508 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002509
2510 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
2511 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
2512 zout = min(DEPTH_GEMM3D - 1, zout);
2513
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002514 // Add offset due to the cross plane paddings
2515 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002516
2517 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2518 // multiply dst_stride_z by DEPTH_GEMM3D
2519 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
2520
2521 // Store 4x4 block
2522 vstore4(c00, 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
2523 vstore4(c10, 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
2524 vstore4(c20, 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
2525 vstore4(c30, 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
2526
2527#else // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +00002528 // Add offset for batched GEMM
2529 dst_addr += z * dst_stride_z;
2530
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002531 // Store 4x4 block
Gian Marcoae2af742018-02-15 12:35:44 +00002532 vstore4(c00, 0, (__global float *)(dst_addr + 0 * dst_stride_y));
2533 vstore4(c10, 0, (__global float *)(dst_addr + 1 * dst_stride_y));
2534 vstore4(c20, 0, (__global float *)(dst_addr + 2 * dst_stride_y));
2535 vstore4(c30, 0, (__global float *)(dst_addr + 3 * dst_stride_y));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002536#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002537}
2538
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002539/** This OpenCL kernel is optimized for Bifrost. It computes the matrix multiplication between matrix A (src0) and matrix B (src1)
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002540 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_32bit and @ref gemm_transpose1x4 before running the matrix multiplication.
2541 *
2542 * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002543 *
Gian Marco19835e52018-01-30 13:35:54 +00002544 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
2545 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
2546 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002547 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
2548 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
2549 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002550 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002551 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
2552 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2553 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2554 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2555 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
2556 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002557 * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
2558 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002559 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
2560 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
2561 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2562 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
2563 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2564 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002565 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002566 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
2567 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2568 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
2569 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2570 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002571 * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
2572 * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes)
2573 * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
2574 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002575 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002576 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00002577 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002578 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00002579 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002580 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002581 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
2582 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
2583 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002584 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002585 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002586__kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0),
2587 IMAGE_DECLARATION(src1),
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002588#if defined(ADD_VEC_C)
2589 VECTOR_DECLARATION(src2),
2590#endif /* defined(ADD_VEC_C) */
Gian Marcoae2af742018-02-15 12:35:44 +00002591 IMAGE_DECLARATION(dst),
2592 uint src0_stride_z,
2593 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002594 uint dst_stride_z
2595#if defined(REINTERPRET_OUTPUT_AS_3D)
2596 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002597 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002598#endif // REINTERPRET_OUTPUT_AS_3D
2599 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002600{
Gian Marco36a0a462018-01-12 10:21:40 +00002601 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
2602 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +00002603 int z = get_global_id(2);
Gian Marco36a0a462018-01-12 10:21:40 +00002604
2605 // Offset
2606 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
2607 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 4;
2608
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002609 // src_addr_a = address of matrix A
2610 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002611 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
2612 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
2613
2614#if defined(MATRIX_B_DEPTH)
2615 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2616 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
2617#else // defined(MATRIX_B_DEPTH)
2618 src1_addr_in_bytes += z * src1_stride_z;
2619#endif // defined(MATRIX_B_DEPTH)
2620
2621 __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes);
2622 __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002623
Gian Marco36a0a462018-01-12 10:21:40 +00002624 src_addr_a += offset_row_a;
2625 src_addr_b += offset_row_b;
2626
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002627 // Reset accumulators
2628 float c00 = 0.0f;
2629 float c01 = 0.0f;
2630 float c02 = 0.0f;
2631 float c03 = 0.0f;
2632 float c10 = 0.0f;
2633 float c11 = 0.0f;
2634 float c12 = 0.0f;
2635 float c13 = 0.0f;
2636 float c20 = 0.0f;
2637 float c21 = 0.0f;
2638 float c22 = 0.0f;
2639 float c23 = 0.0f;
2640 float c30 = 0.0f;
2641 float c31 = 0.0f;
2642 float c32 = 0.0f;
2643 float c33 = 0.0f;
2644
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002645#define COLS_MTX_B (COLS_B / (4 * MULT_TRANSPOSE1XW_WIDTH))
2646
2647 int i = 0;
2648 for(; i <= (int)(COLS_MTX_B - 4); i += 4)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002649 {
2650 // Load values from matrix A (interleaved) and matrix B (transposed)
2651 float4 a0 = vload4(0, src_addr_a);
2652 float4 b0 = vload4(0, src_addr_b);
2653
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002654 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
2655 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002656
2657 c00 = fma(a0.s0, b0.s0, c00);
2658 c01 = fma(a0.s0, b0.s1, c01);
2659 c02 = fma(a0.s0, b0.s2, c02);
2660 c03 = fma(a0.s0, b0.s3, c03);
2661
2662 c10 = fma(a0.s1, b0.s0, c10);
2663 c11 = fma(a0.s1, b0.s1, c11);
2664 c12 = fma(a0.s1, b0.s2, c12);
2665 c13 = fma(a0.s1, b0.s3, c13);
2666
2667 c20 = fma(a0.s2, b0.s0, c20);
2668 c21 = fma(a0.s2, b0.s1, c21);
2669 c22 = fma(a0.s2, b0.s2, c22);
2670 c23 = fma(a0.s2, b0.s3, c23);
2671
2672 c30 = fma(a0.s3, b0.s0, c30);
2673 c31 = fma(a0.s3, b0.s1, c31);
2674 c32 = fma(a0.s3, b0.s2, c32);
2675 c33 = fma(a0.s3, b0.s3, c33);
2676
2677 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002678 a0 = vload4(0, src_addr_a);
2679 b0 = vload4(0, src_addr_b);
2680
2681 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
2682 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002683
2684 c00 = fma(a0.s0, b0.s0, c00);
2685 c01 = fma(a0.s0, b0.s1, c01);
2686 c02 = fma(a0.s0, b0.s2, c02);
2687 c03 = fma(a0.s0, b0.s3, c03);
2688
2689 c10 = fma(a0.s1, b0.s0, c10);
2690 c11 = fma(a0.s1, b0.s1, c11);
2691 c12 = fma(a0.s1, b0.s2, c12);
2692 c13 = fma(a0.s1, b0.s3, c13);
2693
2694 c20 = fma(a0.s2, b0.s0, c20);
2695 c21 = fma(a0.s2, b0.s1, c21);
2696 c22 = fma(a0.s2, b0.s2, c22);
2697 c23 = fma(a0.s2, b0.s3, c23);
2698
2699 c30 = fma(a0.s3, b0.s0, c30);
2700 c31 = fma(a0.s3, b0.s1, c31);
2701 c32 = fma(a0.s3, b0.s2, c32);
2702 c33 = fma(a0.s3, b0.s3, c33);
2703
2704 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002705 a0 = vload4(0, src_addr_a);
2706 b0 = vload4(0, src_addr_b);
2707
2708 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
2709 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
2710
2711 c00 = fma(a0.s0, b0.s0, c00);
2712 c01 = fma(a0.s0, b0.s1, c01);
2713 c02 = fma(a0.s0, b0.s2, c02);
2714 c03 = fma(a0.s0, b0.s3, c03);
2715
2716 c10 = fma(a0.s1, b0.s0, c10);
2717 c11 = fma(a0.s1, b0.s1, c11);
2718 c12 = fma(a0.s1, b0.s2, c12);
2719 c13 = fma(a0.s1, b0.s3, c13);
2720
2721 c20 = fma(a0.s2, b0.s0, c20);
2722 c21 = fma(a0.s2, b0.s1, c21);
2723 c22 = fma(a0.s2, b0.s2, c22);
2724 c23 = fma(a0.s2, b0.s3, c23);
2725
2726 c30 = fma(a0.s3, b0.s0, c30);
2727 c31 = fma(a0.s3, b0.s1, c31);
2728 c32 = fma(a0.s3, b0.s2, c32);
2729 c33 = fma(a0.s3, b0.s3, c33);
2730
2731 // Load values from matrix A (interleaved) and matrix B (transposed)
2732 a0 = vload4(0, src_addr_a);
2733 b0 = vload4(0, src_addr_b);
2734
2735 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
2736 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002737
2738 c00 = fma(a0.s0, b0.s0, c00);
2739 c01 = fma(a0.s0, b0.s1, c01);
2740 c02 = fma(a0.s0, b0.s2, c02);
2741 c03 = fma(a0.s0, b0.s3, c03);
2742
2743 c10 = fma(a0.s1, b0.s0, c10);
2744 c11 = fma(a0.s1, b0.s1, c11);
2745 c12 = fma(a0.s1, b0.s2, c12);
2746 c13 = fma(a0.s1, b0.s3, c13);
2747
2748 c20 = fma(a0.s2, b0.s0, c20);
2749 c21 = fma(a0.s2, b0.s1, c21);
2750 c22 = fma(a0.s2, b0.s2, c22);
2751 c23 = fma(a0.s2, b0.s3, c23);
2752
2753 c30 = fma(a0.s3, b0.s0, c30);
2754 c31 = fma(a0.s3, b0.s1, c31);
2755 c32 = fma(a0.s3, b0.s2, c32);
2756 c33 = fma(a0.s3, b0.s3, c33);
2757 }
2758
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002759 for(; i < (int)(COLS_MTX_B); ++i)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002760 {
2761 // Load values from matrix A (interleaved) and matrix B (transposed)
2762 float4 a0 = vload4(0, src_addr_a);
2763 float4 b0 = vload4(0, src_addr_b);
2764
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002765 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
2766 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
2767
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002768 c00 = fma(a0.s0, b0.s0, c00);
2769 c01 = fma(a0.s0, b0.s1, c01);
2770 c02 = fma(a0.s0, b0.s2, c02);
2771 c03 = fma(a0.s0, b0.s3, c03);
2772
2773 c10 = fma(a0.s1, b0.s0, c10);
2774 c11 = fma(a0.s1, b0.s1, c11);
2775 c12 = fma(a0.s1, b0.s2, c12);
2776 c13 = fma(a0.s1, b0.s3, c13);
2777
2778 c20 = fma(a0.s2, b0.s0, c20);
2779 c21 = fma(a0.s2, b0.s1, c21);
2780 c22 = fma(a0.s2, b0.s2, c22);
2781 c23 = fma(a0.s2, b0.s3, c23);
2782
2783 c30 = fma(a0.s3, b0.s0, c30);
2784 c31 = fma(a0.s3, b0.s1, c31);
2785 c32 = fma(a0.s3, b0.s2, c32);
2786 c33 = fma(a0.s3, b0.s3, c33);
2787 }
2788
2789 // Compute destination address
2790 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2791
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002792#if defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002793 // Multiply by the weight of matrix product
2794 c00 = c00 * ALPHA;
2795 c01 = c01 * ALPHA;
2796 c02 = c02 * ALPHA;
2797 c03 = c03 * ALPHA;
2798 c10 = c10 * ALPHA;
2799 c11 = c11 * ALPHA;
2800 c12 = c12 * ALPHA;
2801 c13 = c13 * ALPHA;
2802 c20 = c20 * ALPHA;
2803 c21 = c21 * ALPHA;
2804 c22 = c22 * ALPHA;
2805 c23 = c23 * ALPHA;
2806 c30 = c30 * ALPHA;
2807 c31 = c31 * ALPHA;
2808 c32 = c32 * ALPHA;
2809 c33 = c33 * ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002810#endif // defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002811
Gian Marcoae2af742018-02-15 12:35:44 +00002812 // Compute dst address
2813 __global uchar *dst_addr = offset(&dst, 0, 0);
2814
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002815#if defined(ADD_VEC_C)
2816 __global float *src2_addr = (__global float *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
2817 float4 c0 = vload4(0, src2_addr);
2818
2819 c00 += c0.s0;
2820 c01 += c0.s1;
2821 c02 += c0.s2;
2822 c03 += c0.s3;
2823 c10 += c0.s0;
2824 c11 += c0.s1;
2825 c12 += c0.s2;
2826 c13 += c0.s3;
2827 c20 += c0.s0;
2828 c21 += c0.s1;
2829 c22 += c0.s2;
2830 c23 += c0.s3;
2831 c30 += c0.s0;
2832 c31 += c0.s1;
2833 c32 += c0.s2;
2834 c33 += c0.s3;
2835#endif /* defined(ADD_VEC_C) */
2836
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002837#if defined(REINTERPRET_OUTPUT_AS_3D)
2838 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002839 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002840 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002841 // | |
2842 // | plane0 |
2843 // | |
2844 // |__________________|
2845 // |******************|
2846 // | cross_plane_pad |
2847 // |******************|
2848 // | |
2849 // | plane1 |
2850 // | |
2851 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002852
2853 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
2854 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
2855 zout = min(DEPTH_GEMM3D - 1, zout);
2856
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002857 // Add offset due to the cross plane paddings
2858 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002859
2860 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2861 // multiply dst_stride_z by DEPTH_GEMM3D
2862 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
2863
2864 // Store 4x4 block
2865 vstore4((float4)(c00, c01, c02, c03), 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
2866 vstore4((float4)(c10, c11, c12, c13), 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
2867 vstore4((float4)(c20, c21, c22, c23), 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
2868 vstore4((float4)(c30, c31, c32, c33), 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
2869
2870#else // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +00002871 // Add offset for batched GEMM
2872 dst_addr += z * dst_stride_z;
2873
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002874 // Store 4x4 block
Gian Marcoae2af742018-02-15 12:35:44 +00002875 vstore4((float4)(c00, c01, c02, c03), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
2876 vstore4((float4)(c10, c11, c12, c13), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
2877 vstore4((float4)(c20, c21, c22, c23), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
2878 vstore4((float4)(c30, c31, c32, c33), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002879#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002880}
2881
Georgios Pinitas84225582018-05-14 12:00:05 +01002882// Undefine local defines
2883#undef COLS_MTX_B
2884
Matthew Bentham6f31f8c2017-10-27 11:50:06 +01002885#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002886/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002887 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002888 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002889 * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
2890 *
Gian Marco19835e52018-01-30 13:35:54 +00002891 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
2892 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
2893 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002894 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
2895 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002896 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002897 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
2898 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2899 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2900 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2901 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
2902 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002903 * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
2904 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002905 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
2906 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
2907 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2908 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
2909 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2910 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002911 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002912 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
2913 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2914 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
2915 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2916 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002917 * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
2918 * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes)
2919 * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
2920 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002921 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002922 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00002923 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002924 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00002925 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002926 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002927 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
2928 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
2929 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002930 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002931 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01002932__kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0),
2933 IMAGE_DECLARATION(src1),
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002934#if defined(ADD_VEC_C)
2935 VECTOR_DECLARATION(src2),
2936#endif /* defined(ADD_VEC_C) */
Gian Marcoae2af742018-02-15 12:35:44 +00002937 IMAGE_DECLARATION(dst),
2938 uint src0_stride_z,
2939 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002940 uint dst_stride_z
2941#if defined(REINTERPRET_OUTPUT_AS_3D)
2942 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002943 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002944#endif // REINTERPRET_OUTPUT_AS_3D
2945 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002946{
Gian Marco36a0a462018-01-12 10:21:40 +00002947 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
2948 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +00002949 int z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002950
Gian Marco36a0a462018-01-12 10:21:40 +00002951 // Offset
2952 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
2953 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002954
Gian Marco36a0a462018-01-12 10:21:40 +00002955 // src_addr_a = address of matrix A
2956 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002957 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
2958 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
2959
2960#if defined(MATRIX_B_DEPTH)
2961 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2962 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
2963#else // defined(MATRIX_B_DEPTH)
2964 src1_addr_in_bytes += z * src1_stride_z;
2965#endif // defined(MATRIX_B_DEPTH)
2966
2967 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
2968 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002969
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002970 // Compute end row address for matrix B
Gian Marco36a0a462018-01-12 10:21:40 +00002971 __global half *src_end_addr_b = src_addr_b + COLS_B;
2972
2973 src_addr_a += offset_row_a;
2974 src_addr_b += offset_row_b;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002975
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002976 // Reset accumulators
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002977 half8 c00 = 0.0f;
2978 half8 c10 = 0.0f;
2979 half8 c20 = 0.0f;
2980 half8 c30 = 0.0f;
2981
Gian Marco36a0a462018-01-12 10:21:40 +00002982 for(; src_addr_b <= (src_end_addr_b - (int)(16 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 16 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002983 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002984 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00002985 half4 a0 = vload4(0, src_addr_a);
2986 half8 b0 = vload8(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002987
2988 c00 += (half8)a0.s0 * b0;
2989 c10 += (half8)a0.s1 * b0;
2990 c20 += (half8)a0.s2 * b0;
2991 c30 += (half8)a0.s3 * b0;
2992
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002993 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00002994 a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
2995 b0 = vload8(0, src_addr_b + 8 * MULT_TRANSPOSE1XW_WIDTH);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002996
2997 c00 += (half8)a0.s0 * b0;
2998 c10 += (half8)a0.s1 * b0;
2999 c20 += (half8)a0.s2 * b0;
3000 c30 += (half8)a0.s3 * b0;
3001 }
3002
Gian Marco36a0a462018-01-12 10:21:40 +00003003 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003004 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003005 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00003006 half4 a0 = vload4(0, src_addr_a);
3007 half8 b0 = vload8(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003008
3009 c00 += (half8)a0.s0 * b0;
3010 c10 += (half8)a0.s1 * b0;
3011 c20 += (half8)a0.s2 * b0;
3012 c30 += (half8)a0.s3 * b0;
3013 }
3014
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003015 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003016 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
3017
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003018#if defined(ALPHA)
3019 // Multiply by the weight of matrix product
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003020 c00 = c00 * (half8)ALPHA;
3021 c10 = c10 * (half8)ALPHA;
3022 c20 = c20 * (half8)ALPHA;
3023 c30 = c30 * (half8)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003024#endif // defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003025
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003026#if defined(ADD_VEC_C)
3027 // *INDENT-OFF*
3028 // clang-format off
3029 __global half *src2_addr = (__global half *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
3030 half8 c0 = vload8(0, src2_addr);
3031 // clang-format on
3032 // *INDENT-ON*
3033
3034 c00 += c0;
3035 c10 += c0;
3036 c20 += c0;
3037 c30 += c0;
3038#endif /* defined(ADD_VEC_C) */
3039
Gian Marcoae2af742018-02-15 12:35:44 +00003040 // Compute dst address
3041 __global uchar *dst_addr = offset(&dst, 0, 0);
3042
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003043#if defined(REINTERPRET_OUTPUT_AS_3D)
3044 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003045 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003046 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003047 // | |
3048 // | plane0 |
3049 // | |
3050 // |__________________|
3051 // |******************|
3052 // | cross_plane_pad |
3053 // |******************|
3054 // | |
3055 // | plane1 |
3056 // | |
3057 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003058
3059 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
3060 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
3061 zout = min(DEPTH_GEMM3D - 1, zout);
3062
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003063 // Add offset due to the cross plane paddings
3064 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003065
3066 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3067 // multiply dst_stride_z by DEPTH_GEMM3D
3068 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
3069
3070 // Store 4x8 block
3071 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
3072 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
3073 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
3074 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
3075
3076#else // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +00003077 // Add offset for batched GEMM
3078 dst_addr += z * dst_stride_z;
3079
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003080 // Store 4x8 block
Gian Marcoae2af742018-02-15 12:35:44 +00003081 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
3082 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
3083 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
3084 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003085#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003086}
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003087
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003088/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1) while accumulating the result in a 32 floating point variable.
3089 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
3090 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003091 * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
3092 *
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003093 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
3094 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
3095 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
3096 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
3097 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
3098 *
3099 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
3100 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
3101 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
3102 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
3103 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
3104 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003105 * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
3106 *
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003107 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
3108 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
3109 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3110 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
3111 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3112 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
3113 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
3114 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
3115 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3116 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
3117 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3118 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003119 * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
3120 * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes)
3121 * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
3122 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003123 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
3124 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
3125 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
3126 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
3127 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
3128 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
3129 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
3130 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
3131 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
3132 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
3133 */
3134__kernel void gemm_mm_interleaved_transposed_f16_acc32(IMAGE_DECLARATION(src0),
3135 IMAGE_DECLARATION(src1),
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003136#if defined(ADD_VEC_C)
3137 VECTOR_DECLARATION(src2),
3138#endif /* defined(ADD_VEC_C) */
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003139 IMAGE_DECLARATION(dst),
3140 uint src0_stride_z,
3141 uint src1_stride_z,
3142 uint dst_stride_z
3143#if defined(REINTERPRET_OUTPUT_AS_3D)
3144 ,
3145 uint cross_plane_pad
3146#endif // REINTERPRET_OUTPUT_AS_3D
3147 )
3148{
3149 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
3150 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
3151 int z = get_global_id(2);
3152
3153 // Offset
3154 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
3155 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
3156
3157 // src_addr_a = address of matrix A
3158 // src_addr_b = address of matrix B
3159 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
3160 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
3161
3162#if defined(MATRIX_B_DEPTH)
3163 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
3164 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
3165#else // defined(MATRIX_B_DEPTH)
3166 src1_addr_in_bytes += z * src1_stride_z;
3167#endif // defined(MATRIX_B_DEPTH)
3168
3169 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
3170 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
3171
3172 // Compute end row address for matrix B
3173 __global half *src_end_addr_b = src_addr_b + COLS_B;
3174
3175 src_addr_a += offset_row_a;
3176 src_addr_b += offset_row_b;
3177
3178 // Reset accumulators
3179 float8 c00 = 0.0f;
3180 float8 c10 = 0.0f;
3181 float8 c20 = 0.0f;
3182 float8 c30 = 0.0f;
3183
3184 for(; src_addr_b <= (src_end_addr_b - (int)(16 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 16 * MULT_TRANSPOSE1XW_WIDTH)
3185 {
3186 // Load values from matrix A (interleaved) and matrix B (transposed)
3187 float4 a0 = convert_float4(vload4(0, src_addr_a));
3188 float8 b0 = convert_float8(vload8(0, src_addr_b));
3189
3190 c00 += (float8)a0.s0 * b0;
3191 c10 += (float8)a0.s1 * b0;
3192 c20 += (float8)a0.s2 * b0;
3193 c30 += (float8)a0.s3 * b0;
3194
3195 // Load values from matrix A (interleaved) and matrix B (transposed)
3196 a0 = convert_float4(vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT));
3197 b0 = convert_float8(vload8(0, src_addr_b + 8 * MULT_TRANSPOSE1XW_WIDTH));
3198
3199 c00 += (float8)a0.s0 * b0;
3200 c10 += (float8)a0.s1 * b0;
3201 c20 += (float8)a0.s2 * b0;
3202 c30 += (float8)a0.s3 * b0;
3203 }
3204
3205 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
3206 {
3207 // Load values from matrix A (interleaved) and matrix B (transposed)
3208 float4 a0 = convert_float4(vload4(0, src_addr_a));
3209 float8 b0 = convert_float8(vload8(0, src_addr_b));
3210
3211 c00 += (float8)a0.s0 * b0;
3212 c10 += (float8)a0.s1 * b0;
3213 c20 += (float8)a0.s2 * b0;
3214 c30 += (float8)a0.s3 * b0;
3215 }
3216
3217 // Compute destination address
3218 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
3219
3220#if defined(ALPHA)
3221 // Multiply by the weight of matrix product
3222 c00 = c00 * (float8)ALPHA;
3223 c10 = c10 * (float8)ALPHA;
3224 c20 = c20 * (float8)ALPHA;
3225 c30 = c30 * (float8)ALPHA;
3226#endif // defined(ALPHA)
3227
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003228#if defined(ADD_VEC_C)
3229 // *INDENT-OFF*
3230 // clang-format off
3231 __global half *src2_addr = (__global half *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
3232 float8 c0 = convert_float8(vload8(0, src2_addr));
3233 // clang-format on
3234 // *INDENT-ON*
3235
3236 c00 += c0;
3237 c10 += c0;
3238 c20 += c0;
3239 c30 += c0;
3240#endif /* defined(ADD_VEC_C) */
3241
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003242 // Compute dst address
3243 __global uchar *dst_addr = offset(&dst, 0, 0);
3244
3245#if defined(REINTERPRET_OUTPUT_AS_3D)
3246 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
3247 // in order to take into account the presence of possible cross plane paddings
3248 //
3249 // | |
3250 // | plane0 |
3251 // | |
3252 // |__________________|
3253 // |******************|
3254 // | cross_plane_pad |
3255 // |******************|
3256 // | |
3257 // | plane1 |
3258 // | |
3259 // |__________________|
3260
3261 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
3262 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
3263 zout = min(DEPTH_GEMM3D - 1, zout);
3264
3265 // Add offset due to the cross plane paddings
3266 zout *= (cross_plane_pad * dst_stride_y);
3267
3268 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3269 // multiply dst_stride_z by DEPTH_GEMM3D
3270 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
3271
3272 // Store 4x8 block
3273 vstore8(convert_half8(c00), 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
3274 vstore8(convert_half8(c10), 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
3275 vstore8(convert_half8(c20), 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
3276 vstore8(convert_half8(c30), 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
3277
3278#else // defined(REINTERPRET_OUTPUT_AS_3D)
3279 // Add offset for batched GEMM
3280 dst_addr += z * dst_stride_z;
3281
3282 // Store 4x8 block
3283 vstore8(convert_half8(c00), 0, (__global half *)(dst_addr + 0 * dst_stride_y));
3284 vstore8(convert_half8(c10), 0, (__global half *)(dst_addr + 1 * dst_stride_y));
3285 vstore8(convert_half8(c20), 0, (__global half *)(dst_addr + 2 * dst_stride_y));
3286 vstore8(convert_half8(c30), 0, (__global half *)(dst_addr + 3 * dst_stride_y));
3287#endif // defined(REINTERPRET_OUTPUT_AS_3D)
3288}
3289
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003290/** This OpenCL kernel optimized for Bifrost architectures computes the matrix multiplication between matrix A (src0) and matrix B (src1)
3291 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
3292 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003293 * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
3294 *
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003295 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
3296 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
3297 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
3298 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
3299 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
3300 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003301 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
3302 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
3303 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
3304 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
3305 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
3306 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003307 * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
3308 *
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003309 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
3310 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
3311 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3312 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
3313 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3314 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
3315 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
3316 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
3317 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3318 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
3319 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3320 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003321 * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
3322 * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes)
3323 * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
3324 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003325 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
3326 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
3327 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
3328 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
3329 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
3330 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003331 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003332 */
3333__kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0),
3334 IMAGE_DECLARATION(src1),
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003335#if defined(ADD_VEC_C)
3336 VECTOR_DECLARATION(src2),
3337#endif /* defined(ADD_VEC_C) */
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003338 IMAGE_DECLARATION(dst),
3339 uint src0_stride_z,
3340 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003341 uint dst_stride_z
3342#if defined(REINTERPRET_OUTPUT_AS_3D)
3343 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003344 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003345#endif // REINTERPRET_OUTPUT_AS_3D
3346 )
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003347{
3348 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
3349 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
3350 int z = get_global_id(2);
3351
3352 // Offset
3353 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
3354 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
3355
3356 // src_addr_a = address of matrix A
3357 // src_addr_b = address of matrix B
3358 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
3359 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
3360
3361#if defined(MATRIX_B_DEPTH)
3362 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
3363 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
3364#else // defined(MATRIX_B_DEPTH)
3365 src1_addr_in_bytes += z * src1_stride_z;
3366#endif // defined(MATRIX_B_DEPTH)
3367
3368 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
3369 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
3370
3371 // Compute end row address for matrix B
3372 __global half *src_end_addr_b = src_addr_b + COLS_B;
3373
3374 src_addr_a += offset_row_a;
3375 src_addr_b += offset_row_b;
3376
3377 // Reset accumulators
3378 half8 c00 = 0.0f;
3379 half8 c10 = 0.0f;
3380 half8 c20 = 0.0f;
3381 half8 c30 = 0.0f;
3382
3383#define COLS_MTX_B (COLS_B / (8 * MULT_TRANSPOSE1XW_WIDTH))
3384
3385 int i = 0;
3386 for(; i <= (int)(COLS_MTX_B - 4); i += 4)
3387 {
3388#if MULT_INTERLEAVE4X4_HEIGHT == 1
3389 // Load values from matrix A (interleaved) and matrix B (transposed)
3390 half8 a0 = vload8(0, src_addr_a);
3391 half8 b0 = vload8(0, src_addr_b);
3392
3393 src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT;
3394 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3395
3396 c00 = fma((half8)a0.s0, b0, c00);
3397 c10 = fma((half8)a0.s1, b0, c10);
3398 c20 = fma((half8)a0.s2, b0, c20);
3399 c30 = fma((half8)a0.s3, b0, c30);
3400
3401 // Load values from matrix B (transposed)
3402 b0 = vload8(0, src_addr_b);
3403
3404 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3405
3406 c00 = fma((half8)a0.s4, b0, c00);
3407 c10 = fma((half8)a0.s5, b0, c10);
3408 c20 = fma((half8)a0.s6, b0, c20);
3409 c30 = fma((half8)a0.s7, b0, c30);
3410
3411 // Load values from matrix A (interleaved) and matrix B (transposed)
3412 a0 = vload8(0, src_addr_a);
3413 b0 = vload8(0, src_addr_b);
3414
3415 src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT;
3416 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3417
3418 c00 = fma((half8)a0.s0, b0, c00);
3419 c10 = fma((half8)a0.s1, b0, c10);
3420 c20 = fma((half8)a0.s2, b0, c20);
3421 c30 = fma((half8)a0.s3, b0, c30);
3422
3423 // Load values from matrix B (transposed)
3424 b0 = vload8(0, src_addr_b);
3425
3426 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3427
3428 c00 = fma((half8)a0.s4, b0, c00);
3429 c10 = fma((half8)a0.s5, b0, c10);
3430 c20 = fma((half8)a0.s6, b0, c20);
3431 c30 = fma((half8)a0.s7, b0, c30);
3432#else // MULT_INTERLEAVE4X4_HEIGHT == 1
3433 // Load values from matrix A (interleaved) and matrix B (transposed)
3434 half4 a0 = vload4(0, src_addr_a);
3435 half8 b0 = vload8(0, src_addr_b);
3436
3437 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3438 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3439
3440 c00 = fma((half8)a0.s0, b0, c00);
3441 c10 = fma((half8)a0.s1, b0, c10);
3442 c20 = fma((half8)a0.s2, b0, c20);
3443 c30 = fma((half8)a0.s3, b0, c30);
3444
3445 // Load values from matrix A (interleaved) and matrix B (transposed)
3446 a0 = vload4(0, src_addr_a);
3447 b0 = vload8(0, src_addr_b);
3448
3449 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3450 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3451
3452 c00 = fma((half8)a0.s0, b0, c00);
3453 c10 = fma((half8)a0.s1, b0, c10);
3454 c20 = fma((half8)a0.s2, b0, c20);
3455 c30 = fma((half8)a0.s3, b0, c30);
3456
3457 // Load values from matrix A (interleaved) and matrix B (transposed)
3458 a0 = vload4(0, src_addr_a);
3459 b0 = vload8(0, src_addr_b);
3460
3461 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3462 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3463
3464 c00 = fma((half8)a0.s0, b0, c00);
3465 c10 = fma((half8)a0.s1, b0, c10);
3466 c20 = fma((half8)a0.s2, b0, c20);
3467 c30 = fma((half8)a0.s3, b0, c30);
3468
3469 // Load values from matrix A (interleaved) and matrix B (transposed)
3470 a0 = vload4(0, src_addr_a);
3471 b0 = vload8(0, src_addr_b);
3472
3473 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3474 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3475
3476 c00 = fma((half8)a0.s0, b0, c00);
3477 c10 = fma((half8)a0.s1, b0, c10);
3478 c20 = fma((half8)a0.s2, b0, c20);
3479 c30 = fma((half8)a0.s3, b0, c30);
3480#endif // MULT_INTERLEAVE4X4_HEIGHT == 1
3481 }
3482
3483 for(; i < (int)(COLS_MTX_B); ++i)
3484 {
3485 // Load values from matrix A (interleaved) and matrix B (transposed)
3486 half4 a0 = vload4(0, src_addr_a);
3487 half8 b0 = vload8(0, src_addr_b);
3488
3489 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3490 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
3491
3492 c00 = fma((half8)a0.s0, b0, c00);
3493 c10 = fma((half8)a0.s1, b0, c10);
3494 c20 = fma((half8)a0.s2, b0, c20);
3495 c30 = fma((half8)a0.s3, b0, c30);
3496 }
3497
3498 // Compute destination address
3499 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
3500
3501#if defined(ALPHA)
3502 // Multiply by the weight of matrix product
3503 c00 = c00 * (half8)ALPHA;
3504 c10 = c10 * (half8)ALPHA;
3505 c20 = c20 * (half8)ALPHA;
3506 c30 = c30 * (half8)ALPHA;
3507#endif // defined(ALPHA)
3508
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003509#if defined(ADD_VEC_C)
3510 // *INDENT-OFF*
3511 // clang-format off
3512 __global half *src2_addr = (__global half *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
3513 half8 c0 = vload8(0, src2_addr);
3514 // clang-format on
3515 // *INDENT-ON*
3516
3517 c00 += c0;
3518 c10 += c0;
3519 c20 += c0;
3520 c30 += c0;
3521#endif /* defined(ADD_VEC_C) */
3522
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003523 // Compute dst address
3524 __global uchar *dst_addr = offset(&dst, 0, 0);
3525
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003526#if defined(REINTERPRET_OUTPUT_AS_3D)
3527 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003528 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003529 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003530 // | |
3531 // | plane0 |
3532 // | |
3533 // |__________________|
3534 // |******************|
3535 // | cross_plane_pad |
3536 // |******************|
3537 // | |
3538 // | plane1 |
3539 // | |
3540 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003541
3542 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
3543 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
3544 zout = min(DEPTH_GEMM3D - 1, zout);
3545
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003546 // Add offset due to the cross plane paddings
3547 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003548
3549 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3550 // multiply dst_stride_z by DEPTH_GEMM3D
3551 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
3552
3553 // Store 4x8 block
3554 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
3555 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
3556 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
3557 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
3558
3559#else // defined(REINTERPRET_OUTPUT_AS_3D)
3560 // Add offset for batched GEMM
3561 dst_addr += z * dst_stride_z;
3562
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003563 // Store 4x8 block
3564 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
3565 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
3566 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
3567 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003568#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003569}
Georgios Pinitas84225582018-05-14 12:00:05 +01003570
3571// Undefine local defines
3572#undef COLS_MTX_B
3573
Matthew Bentham6f31f8c2017-10-27 11:50:06 +01003574#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003575
Gian Marco36a0a462018-01-12 10:21:40 +00003576#endif // defined(COLS_B) && defined(MULT_TRANSPOSE1XW_WIDTH) && defined(MULT_INTERLEAVE4X4_HEIGHT)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003577
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003578#if defined(COLS_A) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && (NUM_ELEMS_PROCESSED_PER_THREAD_Y)
3579#if defined(DATA_TYPE)
3580#define VECTOR_TYPE VEC_DATA_TYPE(DATA_TYPE, NUM_ELEMS_PROCESSED_PER_THREAD_X)
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003581/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped.
3582 *
3583 * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003584 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003585 * @note This OpenCL kernel works with floating point data types (F16/F32)
3586 * @note The floating point data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
3587 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003588 * @note The number of matrix A columns and the optional alpha's value need to be passed at compile time using -DCOLS_A and -DALPHA
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003589 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
3590 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003591 *
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003592 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
3593 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003594 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
3595 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
3596 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
3597 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
3598 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003599 * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
3600 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003601 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003602 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
3603 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3604 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
3605 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3606 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003607 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003608 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
3609 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3610 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
3611 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3612 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003613 * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
3614 * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes)
3615 * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
3616 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003617 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003618 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
3619 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
3620 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
3621 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
3622 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003623 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
3624 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
3625 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003626 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
3627 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements for the output tensor (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003628 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003629__kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
3630 IMAGE_DECLARATION(src1),
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003631#if defined(ADD_VEC_C)
3632 VECTOR_DECLARATION(src2),
3633#endif /* defined(ADD_VEC_C) */
Gian Marcoae2af742018-02-15 12:35:44 +00003634 IMAGE_DECLARATION(dst),
3635 uint src0_stride_z,
3636 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003637 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003638#if defined(REINTERPRET_INPUT_AS_3D)
3639 ,
3640 uint src_cross_plane_pad
3641#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003642#if defined(REINTERPRET_OUTPUT_AS_3D)
3643 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003644 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003645#endif // REINTERPRET_OUTPUT_AS_3D
3646 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003647{
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003648 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003649
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003650 // Compute starting address for matrix A and Matrix B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003651 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003652
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003653 // Update address for the matrix A
3654 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003655
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003656 // Update address for the matrix B
3657 src_addr.s1 += idx * sizeof(DATA_TYPE);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003658
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003659#if defined(REINTERPRET_INPUT_AS_3D)
3660 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
3661 // in order to take into account the presence of possible cross plane paddings
3662 //
3663 // | |
3664 // | plane0 |
3665 // | |
3666 // |__________________|
3667 // |******************|
3668 // | cross_plane_pad |
3669 // |******************|
3670 // | |
3671 // | plane1 |
3672 // | |
3673 // |__________________|
3674
3675 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
3676 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
3677 zin = min(DEPTH_GEMM3D - 1, zin);
3678
3679 // Add offset due to the cross plane paddings
3680 zin *= (src_cross_plane_pad * src0_stride_y);
3681
3682 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3683 // multiply src0_stride_z by DEPTH_GEMM3D
3684 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
3685
3686#else // defined(REINTERPRET_INPUT_AS_3D)
3687
Gian Marcoae2af742018-02-15 12:35:44 +00003688 // Add offset for batched GEMM
3689 src_addr.s0 += get_global_id(2) * src0_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003690
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003691#endif // defined(REINTERPRET_INPUT_AS_3D)
3692
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003693#if defined(MATRIX_B_DEPTH)
3694 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
3695 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
3696#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00003697 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003698#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00003699
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003700 int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(DATA_TYPE));
3701
3702 VECTOR_TYPE acc0 = 0.0f;
3703#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3704 VECTOR_TYPE acc1 = 0.0f;
3705#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3706#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3707 VECTOR_TYPE acc2 = 0.0f;
3708#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3709#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3710 VECTOR_TYPE acc3 = 0.0f;
3711#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3712
Georgios Pinitas96880cf2017-10-20 18:52:20 +01003713 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(DATA_TYPE)); src_addr += (int2)(2 * sizeof(DATA_TYPE), 2 * src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003714 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003715#if defined(REINTERPRET_INPUT_AS_3D)
3716 // Load values from matrix A
Usama Arif0681e3b2019-04-25 14:28:07 +01003717 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 2, DATA_TYPE, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s);
3718#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003719 // Load values from matrix A
3720 VEC_DATA_TYPE(DATA_TYPE, 2)
3721 a0 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
3722#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3723 VEC_DATA_TYPE(DATA_TYPE, 2)
3724 a1 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
3725#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3726#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3727 VEC_DATA_TYPE(DATA_TYPE, 2)
3728 a2 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
3729#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3730#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3731 VEC_DATA_TYPE(DATA_TYPE, 2)
3732 a3 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
3733#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003734#endif // defined(REINTERPRET_INPUT_AS_3D)
3735
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003736 // Load values from matrix B
3737 VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
3738 VECTOR_TYPE b1 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1 + src1_stride_y));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003739
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003740 // Accumulate
3741 acc0 += b0 * (VECTOR_TYPE)a0.s0;
3742 acc0 += b1 * (VECTOR_TYPE)a0.s1;
3743#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3744 acc1 += b0 * (VECTOR_TYPE)a1.s0;
3745 acc1 += b1 * (VECTOR_TYPE)a1.s1;
3746#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3747#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3748 acc2 += b0 * (VECTOR_TYPE)a2.s0;
3749 acc2 += b1 * (VECTOR_TYPE)a2.s1;
3750#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3751#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3752 acc3 += b0 * (VECTOR_TYPE)a3.s0;
3753 acc3 += b1 * (VECTOR_TYPE)a3.s1;
3754#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003755 }
3756
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003757 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(DATA_TYPE), src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003758 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003759#if defined(REINTERPRET_INPUT_AS_3D)
3760 // Load values from matrix A
3761 DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
3762#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3763 DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
3764#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3765#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3766 DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
3767#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3768#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3769 DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
3770#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3771#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003772 // Load values from matrix A
3773 DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
3774#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3775 DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
3776#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3777#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3778 DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
3779#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3780#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3781 DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
3782#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003783#endif // defined(REINTERPRET_INPUT_AS_3D)
3784
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003785 // Load values from matrix B
3786 VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003787
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003788 // Accumulate
3789 acc0 += b0 * (VECTOR_TYPE)a0;
3790#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3791 acc1 += b0 * (VECTOR_TYPE)a1;
3792#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3793#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3794 acc2 += b0 * (VECTOR_TYPE)a2;
3795#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3796#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3797 acc3 += b0 * (VECTOR_TYPE)a3;
3798#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003799 }
3800
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003801 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003802 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
3803
Gian Marcoae2af742018-02-15 12:35:44 +00003804 // Compute dst address
3805 __global uchar *dst_addr = offset(&dst, 0, 0);
3806
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003807 // Multiply by the weight of matrix-matrix product and store the result
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003808#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003809 acc0 = acc0 * (VECTOR_TYPE)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003810#endif // defined(ALPHA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003811#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
3812 acc1 = acc1 * (VECTOR_TYPE)ALPHA;
3813#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
3814#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
3815 acc2 = acc2 * (VECTOR_TYPE)ALPHA;
3816#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
3817#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
3818 acc3 = acc3 * (VECTOR_TYPE)ALPHA;
3819#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
3820
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003821#if defined(ADD_VEC_C)
3822 // *INDENT-OFF*
3823 // clang-format off
3824 __global DATA_TYPE *src2_addr = (__global DATA_TYPE *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
3825 VECTOR_TYPE c0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, src2_addr);
3826 // clang-format on
3827 // *INDENT-ON*
3828
3829 acc0 += c0;
3830#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3831 acc1 += c0;
3832#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3833#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3834 acc2 += c0;
3835#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3836#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3837 acc3 += c0;
3838#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3839#endif /* defined(ADD_VEC_C) */
3840
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003841 int z = get_global_id(2);
3842
3843#if defined(REINTERPRET_OUTPUT_AS_3D)
3844 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003845 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003846 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003847 // | |
3848 // | plane0 |
3849 // | |
3850 // |__________________|
3851 // |******************|
3852 // | cross_plane_pad |
3853 // |******************|
3854 // | |
3855 // | plane1 |
3856 // | |
3857 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003858
3859 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
3860 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
3861 zout = min(DEPTH_GEMM3D - 1, zout);
3862
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003863 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003864 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003865
3866 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3867 // multiply dst_stride_z by DEPTH_GEMM3D
3868 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
3869
3870 // Store output block
Usama Arif0681e3b2019-04-25 14:28:07 +01003871 STORE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, NUM_ELEMS_PROCESSED_PER_THREAD_X, DATA_TYPE, acc, dst_addr, dst_stride_y, zout.s);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003872#else // defined(REINTERPRET_OUTPUT_AS_3D)
3873 // Add offset for batched GEMM
3874 dst_addr += z * dst_stride_z;
3875
3876 // Store output block
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003877 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00003878 (acc0, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003879#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003880 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00003881 (acc1, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003882#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3883#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003884 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00003885 (acc2, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003886#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3887#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003888 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00003889 (acc3, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003890#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003891#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003892}
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003893#endif // defined(DATA_TYPE)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003894
Michele Di Giorgiof6f08da2018-04-26 10:24:30 +01003895/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003896 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003897 * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
3898 *
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003899 * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
3900 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
3901 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
3902 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
3903 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003904 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
3905 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003906 *
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003907 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
3908 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003909 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
3910 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
3911 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
3912 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
3913 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003914 * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
3915 *
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003916 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
3917 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
3918 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3919 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
3920 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3921 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
3922 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
3923 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
3924 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3925 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
3926 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3927 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003928 * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
3929 * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes)
3930 * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
3931 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003932 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
3933 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
3934 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
3935 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
3936 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
3937 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003938 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
3939 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
3940 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003941 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
3942 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003943 */
3944__kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
3945 IMAGE_DECLARATION(src1),
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003946#if defined(ADD_VEC_C)
3947 VECTOR_DECLARATION(src2),
3948#endif /* defined(ADD_VEC_C) */
Gian Marcoae2af742018-02-15 12:35:44 +00003949 IMAGE_DECLARATION(dst),
3950 uint src0_stride_z,
3951 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003952 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003953#if defined(REINTERPRET_INPUT_AS_3D)
3954 ,
3955 uint src_cross_plane_pad
3956#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003957#if defined(REINTERPRET_OUTPUT_AS_3D)
3958 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003959 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003960#endif // REINTERPRET_OUTPUT_AS_3D
3961 )
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003962{
3963 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
3964
3965 // Compute starting address for matrix A and matrix B
3966 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
3967
3968 // Update address for matrix A
3969 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
3970
3971 // Update address for matrix B
3972 src_addr.s1 += idx * sizeof(float);
3973
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003974#if defined(REINTERPRET_INPUT_AS_3D)
3975 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
3976 // in order to take into account the presence of possible cross plane paddings
3977 //
3978 // | |
3979 // | plane0 |
3980 // | |
3981 // |__________________|
3982 // |******************|
3983 // | cross_plane_pad |
3984 // |******************|
3985 // | |
3986 // | plane1 |
3987 // | |
3988 // |__________________|
3989
3990 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
3991 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
3992 zin = min(DEPTH_GEMM3D - 1, zin);
3993
3994 // Add offset due to the cross plane paddings
3995 zin *= (src_cross_plane_pad * src0_stride_y);
3996
3997 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3998 // multiply src0_stride_z by DEPTH_GEMM3D
3999 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
4000
4001#else // defined(REINTERPRET_INPUT_AS_3D)
4002
Gian Marcoae2af742018-02-15 12:35:44 +00004003 // Add offset for batched GEMM
4004 src_addr.s0 += get_global_id(2) * src0_stride_z;
4005
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004006#endif // defined(REINTERPRET_INPUT_AS_3D)
4007
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00004008#if defined(MATRIX_B_DEPTH)
4009 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
4010 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
4011#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00004012 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00004013#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00004014
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004015 // Initialize accumulators
4016 float acc00 = 0.0f;
4017 float acc01 = 0.0f;
4018 float acc02 = 0.0f;
4019 float acc03 = 0.0f;
4020
4021#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4022 float acc10 = 0.0f;
4023 float acc11 = 0.0f;
4024 float acc12 = 0.0f;
4025 float acc13 = 0.0f;
4026#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4027
4028#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4029 float acc20 = 0.0f;
4030 float acc21 = 0.0f;
4031 float acc22 = 0.0f;
4032 float acc23 = 0.0f;
4033#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4034
4035#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4036 float acc30 = 0.0f;
4037 float acc31 = 0.0f;
4038 float acc32 = 0.0f;
4039 float acc33 = 0.0f;
4040#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4041
4042 // A and B src indices get incremented at the same time.
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004043 int i = 0;
4044 for(; i <= ((int)COLS_A - 4); i += 4)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004045 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004046#if defined(REINTERPRET_INPUT_AS_3D)
4047 // Load values from matrix A and matrix B
Usama Arif0681e3b2019-04-25 14:28:07 +01004048 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 4, float, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s);
4049#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004050 // Load values from matrix A and matrix B
4051 float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004052#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004053 float4 a1 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004054#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4055#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004056 float4 a2 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004057#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4058#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004059 float4 a3 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004060#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004061#endif // defined(REINTERPRET_INPUT_AS_3D)
4062
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004063 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
4064 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004065
4066 // Multiply and accumulate
4067 acc00 = fma(a0.s0, b0.s0, acc00);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004068 acc01 = fma(a0.s0, b0.s1, acc01);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004069 acc02 = fma(a0.s0, b0.s2, acc02);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004070 acc03 = fma(a0.s0, b0.s3, acc03);
4071
4072#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004073
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004074 acc10 = fma(a1.s0, b0.s0, acc10);
4075 acc11 = fma(a1.s0, b0.s1, acc11);
4076 acc12 = fma(a1.s0, b0.s2, acc12);
4077 acc13 = fma(a1.s0, b0.s3, acc13);
4078
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004079#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4080#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004081
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004082 acc20 = fma(a2.s0, b0.s0, acc20);
4083 acc21 = fma(a2.s0, b0.s1, acc21);
4084 acc22 = fma(a2.s0, b0.s2, acc22);
4085 acc23 = fma(a2.s0, b0.s3, acc23);
4086
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004087#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4088#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004089
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004090 acc30 = fma(a3.s0, b0.s0, acc30);
4091 acc31 = fma(a3.s0, b0.s1, acc31);
4092 acc32 = fma(a3.s0, b0.s2, acc32);
4093 acc33 = fma(a3.s0, b0.s3, acc33);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004094#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004095
4096 // Load values from matrix A and matrix B
4097 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
4098 src_addr.s1 += src1_stride_y;
4099
4100 // Multiply and accumulate
4101 acc00 = fma(a0.s1, b0.s0, acc00);
4102 acc01 = fma(a0.s1, b0.s1, acc01);
4103 acc02 = fma(a0.s1, b0.s2, acc02);
4104 acc03 = fma(a0.s1, b0.s3, acc03);
4105
4106#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4107
4108 acc10 = fma(a1.s1, b0.s0, acc10);
4109 acc11 = fma(a1.s1, b0.s1, acc11);
4110 acc12 = fma(a1.s1, b0.s2, acc12);
4111 acc13 = fma(a1.s1, b0.s3, acc13);
4112
4113#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4114#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4115
4116 acc20 = fma(a2.s1, b0.s0, acc20);
4117 acc21 = fma(a2.s1, b0.s1, acc21);
4118 acc22 = fma(a2.s1, b0.s2, acc22);
4119 acc23 = fma(a2.s1, b0.s3, acc23);
4120
4121#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4122#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4123
4124 acc30 = fma(a3.s1, b0.s0, acc30);
4125 acc31 = fma(a3.s1, b0.s1, acc31);
4126 acc32 = fma(a3.s1, b0.s2, acc32);
4127 acc33 = fma(a3.s1, b0.s3, acc33);
4128#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4129
4130 // Load values from matrix A and matrix B
4131 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
4132 src_addr.s1 += src1_stride_y;
4133
4134 // Multiply and accumulate
4135 acc00 = fma(a0.s2, b0.s0, acc00);
4136 acc01 = fma(a0.s2, b0.s1, acc01);
4137 acc02 = fma(a0.s2, b0.s2, acc02);
4138 acc03 = fma(a0.s2, b0.s3, acc03);
4139
4140#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4141
4142 acc10 = fma(a1.s2, b0.s0, acc10);
4143 acc11 = fma(a1.s2, b0.s1, acc11);
4144 acc12 = fma(a1.s2, b0.s2, acc12);
4145 acc13 = fma(a1.s2, b0.s3, acc13);
4146
4147#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4148#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4149
4150 acc20 = fma(a2.s2, b0.s0, acc20);
4151 acc21 = fma(a2.s2, b0.s1, acc21);
4152 acc22 = fma(a2.s2, b0.s2, acc22);
4153 acc23 = fma(a2.s2, b0.s3, acc23);
4154
4155#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4156#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4157
4158 acc30 = fma(a3.s2, b0.s0, acc30);
4159 acc31 = fma(a3.s2, b0.s1, acc31);
4160 acc32 = fma(a3.s2, b0.s2, acc32);
4161 acc33 = fma(a3.s2, b0.s3, acc33);
4162#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4163
4164 // Load values from matrix A and matrix B
4165 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
4166 src_addr.s1 += src1_stride_y;
4167
4168 // Multiply and accumulate
4169 acc00 = fma(a0.s3, b0.s0, acc00);
4170 acc01 = fma(a0.s3, b0.s1, acc01);
4171 acc02 = fma(a0.s3, b0.s2, acc02);
4172 acc03 = fma(a0.s3, b0.s3, acc03);
4173
4174#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4175
4176 acc10 = fma(a1.s3, b0.s0, acc10);
4177 acc11 = fma(a1.s3, b0.s1, acc11);
4178 acc12 = fma(a1.s3, b0.s2, acc12);
4179 acc13 = fma(a1.s3, b0.s3, acc13);
4180
4181#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4182#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4183
4184 acc20 = fma(a2.s3, b0.s0, acc20);
4185 acc21 = fma(a2.s3, b0.s1, acc21);
4186 acc22 = fma(a2.s3, b0.s2, acc22);
4187 acc23 = fma(a2.s3, b0.s3, acc23);
4188
4189#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4190#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4191
4192 acc30 = fma(a3.s3, b0.s0, acc30);
4193 acc31 = fma(a3.s3, b0.s1, acc31);
4194 acc32 = fma(a3.s3, b0.s2, acc32);
4195 acc33 = fma(a3.s3, b0.s3, acc33);
4196#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4197
4198 src_addr.s0 += 4 * sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004199 }
4200
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004201 for(; i < (int)COLS_A; ++i)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004202 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004203#if defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004204 // Load values from matrix A
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004205 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
4206#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4207 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
4208#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4209#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4210 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
4211#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4212#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4213 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
4214#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4215#else // defined(REINTERPRET_INPUT_AS_3D)
4216 // Load values from matrix A
4217 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004218#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4219 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
4220#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4221#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4222 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
4223#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4224#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4225 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
4226#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004227#endif // defined(REINTERPRET_INPUT_AS_3D)
4228
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004229 // Load values from matrix B
4230 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004231 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004232
4233 // Multiply and accumulate
4234 acc00 = fma(a0, b0.s0, acc00);
4235 acc01 = fma(a0, b0.s1, acc01);
4236 acc02 = fma(a0, b0.s2, acc02);
4237 acc03 = fma(a0, b0.s3, acc03);
4238#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4239 acc10 = fma(a1, b0.s0, acc10);
4240 acc11 = fma(a1, b0.s1, acc11);
4241 acc12 = fma(a1, b0.s2, acc12);
4242 acc13 = fma(a1, b0.s3, acc13);
4243#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4244#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4245 acc20 = fma(a2, b0.s0, acc20);
4246 acc21 = fma(a2, b0.s1, acc21);
4247 acc22 = fma(a2, b0.s2, acc22);
4248 acc23 = fma(a2, b0.s3, acc23);
4249#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4250#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4251 acc30 = fma(a3, b0.s0, acc30);
4252 acc31 = fma(a3, b0.s1, acc31);
4253 acc32 = fma(a3, b0.s2, acc32);
4254 acc33 = fma(a3, b0.s3, acc33);
4255#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004256
4257 src_addr.s0 += sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004258 }
4259
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004260 int z = get_global_id(2);
4261
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004262 // Compute destination address
4263 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
4264
4265 // Multiply by the weight of matrix-matrix product and store the result
4266#if defined(ALPHA)
4267 acc00 = acc00 * ALPHA;
4268 acc01 = acc01 * ALPHA;
4269 acc02 = acc02 * ALPHA;
4270 acc03 = acc03 * ALPHA;
4271#endif // defined(ALPHA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004272#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004273 acc10 = acc10 * ALPHA;
4274 acc11 = acc11 * ALPHA;
4275 acc12 = acc12 * ALPHA;
4276 acc13 = acc13 * ALPHA;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004277#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
4278#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004279 acc20 = acc20 * ALPHA;
4280 acc21 = acc21 * ALPHA;
4281 acc22 = acc22 * ALPHA;
4282 acc23 = acc23 * ALPHA;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004283#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
4284#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004285 acc30 = acc30 * ALPHA;
4286 acc31 = acc31 * ALPHA;
4287 acc32 = acc32 * ALPHA;
4288 acc33 = acc33 * ALPHA;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004289#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
4290
4291 // Compute dst address
4292 __global uchar *dst_addr = offset(&dst, 0, 0);
4293
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004294#if defined(ADD_VEC_C)
4295 __global float *src2_addr = (__global float *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
4296 float4 c0 = vload4(0, src2_addr);
4297
4298 acc00 += c0.s0;
4299 acc01 += c0.s1;
4300 acc02 += c0.s2;
4301 acc03 += c0.s3;
4302#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4303 acc10 += c0.s0;
4304 acc11 += c0.s1;
4305 acc12 += c0.s2;
4306 acc13 += c0.s3;
4307#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4308#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4309 acc20 += c0.s0;
4310 acc21 += c0.s1;
4311 acc22 += c0.s2;
4312 acc23 += c0.s3;
4313#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4314#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4315 acc30 += c0.s0;
4316 acc31 += c0.s1;
4317 acc32 += c0.s2;
4318 acc33 += c0.s3;
4319#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4320#endif /* defined(ADD_VEC_C) */
4321
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004322#if defined(REINTERPRET_OUTPUT_AS_3D)
4323 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004324 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004325 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004326 // | |
4327 // | plane0 |
4328 // | |
4329 // |__________________|
4330 // |******************|
4331 // | cross_plane_pad |
4332 // |******************|
4333 // | |
4334 // | plane1 |
4335 // | |
4336 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004337
4338 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
4339 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
4340 zout = min(DEPTH_GEMM3D - 1, zout);
4341
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004342 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004343 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004344
4345 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4346 // multiply dst_stride_z by DEPTH_GEMM3D
4347 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
4348
4349 // Store the output block
4350 vstore4((float4)(acc00, acc01, acc02, acc03), 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
4351#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4352 vstore4((float4)(acc10, acc11, acc12, acc13), 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
4353#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4354#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4355 vstore4((float4)(acc20, acc21, acc22, acc23), 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
4356#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4357#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4358 vstore4((float4)(acc30, acc31, acc32, acc33), 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004359#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004360
4361#else // defined(REINTERPRET_OUTPUT_AS_3D)
4362 // Add offset for batched GEMM
4363 dst_addr += z * dst_stride_z;
4364
4365 // Store the output block
4366 vstore4((float4)(acc00, acc01, acc02, acc03), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
4367#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4368 vstore4((float4)(acc10, acc11, acc12, acc13), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
4369#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4370#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4371 vstore4((float4)(acc20, acc21, acc22, acc23), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
4372#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4373#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4374 vstore4((float4)(acc30, acc31, acc32, acc33), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
4375#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4376#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004377}
4378
4379/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
4380 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004381 * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
4382 *
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004383 * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
4384 * This OpenCL kernel is optimized for Bifrost when the number of matrix B columns is less or equal to 1000.
4385 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
4386 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=2.
4387 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
4388 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha if alpha!=1.0f.
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00004389 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
4390 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004391 *
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004392 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
4393 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004394 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
4395 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
4396 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
4397 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
4398 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004399 * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
4400 *
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004401 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
4402 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
4403 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4404 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
4405 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4406 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
4407 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
4408 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
4409 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4410 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
4411 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4412 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004413 * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
4414 * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes)
4415 * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
4416 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004417 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
4418 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
4419 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
4420 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
4421 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
4422 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004423 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
4424 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
4425 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004426 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
4427 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004428 */
4429__kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
4430 IMAGE_DECLARATION(src1),
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004431#if defined(ADD_VEC_C)
4432 VECTOR_DECLARATION(src2),
4433#endif /* defined(ADD_VEC_C) */
Gian Marcoae2af742018-02-15 12:35:44 +00004434 IMAGE_DECLARATION(dst),
4435 uint src0_stride_z,
4436 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004437 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004438#if defined(REINTERPRET_INPUT_AS_3D)
4439 ,
4440 uint src_cross_plane_pad
4441#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004442#if defined(REINTERPRET_OUTPUT_AS_3D)
4443 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004444 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004445#endif // REINTERPRET_OUTPUT_AS_3D
4446 )
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004447{
4448 // Requires 2 NUM_ELEMS_PROCESSED_PER_THREAD_X, C vect2, A vect4, B (2 vload2) // to fix for NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4449 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
4450
4451 // Compute starting address for matrix A and Matrix B
4452 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
4453
4454 // Update address for the matrix A
4455 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
4456
4457 // Update address for the matrix B
4458 src_addr.s1 += idx * sizeof(float);
4459
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004460#if defined(REINTERPRET_INPUT_AS_3D)
4461 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
4462 // in order to take into account the presence of possible cross plane paddings
4463 //
4464 // | |
4465 // | plane0 |
4466 // | |
4467 // |__________________|
4468 // |******************|
4469 // | cross_plane_pad |
4470 // |******************|
4471 // | |
4472 // | plane1 |
4473 // | |
4474 // |__________________|
4475
4476 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
4477 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
4478 zin = min(DEPTH_GEMM3D - 1, zin);
4479
4480 // Add offset due to the cross plane paddings
4481 zin *= (src_cross_plane_pad * src0_stride_y);
4482
4483 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4484 // multiply src0_stride_z by DEPTH_GEMM3D
4485 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
4486
4487#else // defined(REINTERPRET_INPUT_AS_3D)
4488
Gian Marcoae2af742018-02-15 12:35:44 +00004489 // Add offset for batched GEMM
4490 src_addr.s0 += get_global_id(2) * src0_stride_z;
4491
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004492#endif // defined(REINTERPRET_INPUT_AS_3D)
4493
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00004494#if defined(MATRIX_B_DEPTH)
4495 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
4496 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
4497#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00004498 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00004499#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00004500
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004501 // Initialize accumulators
4502 float acc00 = 0.0f;
4503 float acc01 = 0.0f;
4504
4505#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4506 float acc10 = 0.0f;
4507 float acc11 = 0.0f;
4508#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4509#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4510 float acc20 = 0.0f;
4511 float acc21 = 0.0f;
4512#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4513#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4514 float acc30 = 0.0f;
4515 float acc31 = 0.0f;
4516#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4517
4518 // A and B src indices get incremented at the same time.
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004519 int i = 0;
4520 for(; i <= ((int)COLS_A - 8); i += 8)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004521 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004522#if defined(REINTERPRET_INPUT_AS_3D)
4523 // Load values from matrix A
4524 float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + zin.s0));
4525#else // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004526 // Load values from matrix A
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004527 float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0));
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004528#endif // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004529
4530 // Load values from matrix B
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004531 float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
4532 src_addr.s1 += src1_stride_y;
4533 float2 b1 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
4534 src_addr.s1 += src1_stride_y;
4535 float2 b2 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
4536 src_addr.s1 += src1_stride_y;
4537 float2 b3 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
4538 src_addr.s1 += src1_stride_y;
4539 float2 b4 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
4540 src_addr.s1 += src1_stride_y;
4541 float2 b5 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
4542 src_addr.s1 += src1_stride_y;
4543 float2 b6 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
4544 src_addr.s1 += src1_stride_y;
4545 float2 b7 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
4546 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004547
4548 // Multiply and accumulate
4549 acc00 = fma(a0.s0, b0.s0, acc00);
4550 acc00 = fma(a0.s1, b1.s0, acc00);
4551 acc00 = fma(a0.s2, b2.s0, acc00);
4552 acc00 = fma(a0.s3, b3.s0, acc00);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004553 acc00 = fma(a0.s4, b4.s0, acc00);
4554 acc00 = fma(a0.s5, b5.s0, acc00);
4555 acc00 = fma(a0.s6, b6.s0, acc00);
4556 acc00 = fma(a0.s7, b7.s0, acc00);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004557
4558 acc01 = fma(a0.s0, b0.s1, acc01);
4559 acc01 = fma(a0.s1, b1.s1, acc01);
4560 acc01 = fma(a0.s2, b2.s1, acc01);
4561 acc01 = fma(a0.s3, b3.s1, acc01);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004562 acc01 = fma(a0.s4, b4.s1, acc01);
4563 acc01 = fma(a0.s5, b5.s1, acc01);
4564 acc01 = fma(a0.s6, b6.s1, acc01);
4565 acc01 = fma(a0.s7, b7.s1, acc01);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004566
4567#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004568#if defined(REINTERPRET_INPUT_AS_3D)
4569 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
4570#else // defined(REINTERPRET_INPUT_AS_3D)
4571 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
4572#endif // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004573 acc10 = fma(a0.s0, b0.s0, acc10);
4574 acc10 = fma(a0.s1, b1.s0, acc10);
4575 acc10 = fma(a0.s2, b2.s0, acc10);
4576 acc10 = fma(a0.s3, b3.s0, acc10);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004577 acc10 = fma(a0.s4, b4.s0, acc10);
4578 acc10 = fma(a0.s5, b5.s0, acc10);
4579 acc10 = fma(a0.s6, b6.s0, acc10);
4580 acc10 = fma(a0.s7, b7.s0, acc10);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004581
4582 acc11 = fma(a0.s0, b0.s1, acc11);
4583 acc11 = fma(a0.s1, b1.s1, acc11);
4584 acc11 = fma(a0.s2, b2.s1, acc11);
4585 acc11 = fma(a0.s3, b3.s1, acc11);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004586 acc11 = fma(a0.s4, b4.s1, acc11);
4587 acc11 = fma(a0.s5, b5.s1, acc11);
4588 acc11 = fma(a0.s6, b6.s1, acc11);
4589 acc11 = fma(a0.s7, b7.s1, acc11);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004590#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4591#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004592#if defined(REINTERPRET_INPUT_AS_3D)
4593 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
4594#else // defined(REINTERPRET_INPUT_AS_3D)
4595 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
4596#endif // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004597 acc20 = fma(a0.s0, b0.s0, acc20);
4598 acc20 = fma(a0.s1, b1.s0, acc20);
4599 acc20 = fma(a0.s2, b2.s0, acc20);
4600 acc20 = fma(a0.s3, b3.s0, acc20);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004601 acc20 = fma(a0.s4, b4.s0, acc20);
4602 acc20 = fma(a0.s5, b5.s0, acc20);
4603 acc20 = fma(a0.s6, b6.s0, acc20);
4604 acc20 = fma(a0.s7, b7.s0, acc20);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004605
4606 acc21 = fma(a0.s0, b0.s1, acc21);
4607 acc21 = fma(a0.s1, b1.s1, acc21);
4608 acc21 = fma(a0.s2, b2.s1, acc21);
4609 acc21 = fma(a0.s3, b3.s1, acc21);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004610 acc21 = fma(a0.s4, b4.s1, acc21);
4611 acc21 = fma(a0.s5, b5.s1, acc21);
4612 acc21 = fma(a0.s6, b6.s1, acc21);
4613 acc21 = fma(a0.s7, b7.s1, acc21);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004614#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4615#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004616#if defined(REINTERPRET_INPUT_AS_3D)
4617 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
4618#else // defined(REINTERPRET_INPUT_AS_3D)
4619 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
4620#endif // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004621 acc30 = fma(a0.s0, b0.s0, acc30);
4622 acc30 = fma(a0.s1, b1.s0, acc30);
4623 acc30 = fma(a0.s2, b2.s0, acc30);
4624 acc30 = fma(a0.s3, b3.s0, acc30);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004625 acc30 = fma(a0.s4, b4.s0, acc30);
4626 acc30 = fma(a0.s5, b5.s0, acc30);
4627 acc30 = fma(a0.s6, b6.s0, acc30);
4628 acc30 = fma(a0.s7, b7.s0, acc30);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004629
4630 acc31 = fma(a0.s0, b0.s1, acc31);
4631 acc31 = fma(a0.s1, b1.s1, acc31);
4632 acc31 = fma(a0.s2, b2.s1, acc31);
4633 acc31 = fma(a0.s3, b3.s1, acc31);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004634 acc31 = fma(a0.s4, b4.s1, acc31);
4635 acc31 = fma(a0.s5, b5.s1, acc31);
4636 acc31 = fma(a0.s6, b6.s1, acc31);
4637 acc31 = fma(a0.s7, b7.s1, acc31);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004638#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004639
4640 src_addr.s0 += sizeof(float) * 8;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004641 }
4642 // float size increment
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004643 for(; i < (int)COLS_A; ++i)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004644 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004645#if defined(REINTERPRET_INPUT_AS_3D)
4646 // Load values from matrix A
4647 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
4648#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4649 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
4650#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4651#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4652 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
4653#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4654#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4655 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
4656#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4657#else // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004658 // Load values from matrix A
4659 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
4660#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4661 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
4662#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4663#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4664 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
4665#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4666#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4667 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
4668#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004669#endif // defined(REINTERPRET_INPUT_AS_3D)
4670
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004671 // Load values from matrix B
4672 float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004673 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004674
4675 // Multiply and accumulate
4676 acc00 = fma(a0, b0.s0, acc00);
4677 acc01 = fma(a0, b0.s1, acc01);
4678#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4679 acc10 = fma(a1, b0.s0, acc10);
4680 acc11 = fma(a1, b0.s1, acc11);
4681#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4682#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4683 acc20 = fma(a2, b0.s0, acc20);
4684 acc21 = fma(a2, b0.s1, acc21);
4685#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4686#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4687 acc30 = fma(a3, b0.s0, acc30);
4688 acc31 = fma(a3, b0.s1, acc31);
4689#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004690
4691 src_addr.s0 += sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004692 }
4693
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004694 // Multiply by the weight of matrix-matrix product and store the result
4695#if defined(ALPHA)
4696 acc00 = acc00 * ALPHA;
4697 acc01 = acc01 * ALPHA;
4698#endif // defined(ALPHA)
4699#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
4700 acc10 = acc10 * ALPHA;
4701 acc11 = acc11 * ALPHA;
4702#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
4703#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
4704 acc20 = acc20 * ALPHA;
4705 acc21 = acc21 * ALPHA;
4706#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
4707#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
4708 acc30 = acc30 * ALPHA;
4709 acc31 = acc31 * ALPHA;
4710#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
4711
4712 int z = get_global_id(2);
4713
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004714 // Compute destination address
4715 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
4716
Gian Marcoae2af742018-02-15 12:35:44 +00004717 // Compute dst address
4718 __global uchar *dst_addr = offset(&dst, 0, 0);
4719
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004720#if defined(ADD_VEC_C)
4721 __global float *src2_addr = (__global float *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
4722 float2 c0 = vload2(0, src2_addr);
4723
4724 acc00 += c0.s0;
4725 acc01 += c0.s1;
4726#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4727 acc10 += c0.s0;
4728 acc11 += c0.s1;
4729#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4730#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4731 acc20 += c0.s0;
4732 acc21 += c0.s1;
4733#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4734#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4735 acc30 += c0.s0;
4736 acc31 += c0.s1;
4737#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4738#endif /* defined(ADD_VEC_C) */
4739
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004740#if defined(REINTERPRET_OUTPUT_AS_3D)
4741 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004742 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004743 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004744 // | |
4745 // | plane0 |
4746 // | |
4747 // |__________________|
4748 // |******************|
4749 // | cross_plane_pad |
4750 // |******************|
4751 // | |
4752 // | plane1 |
4753 // | |
4754 // |__________________|
Gian Marcoae2af742018-02-15 12:35:44 +00004755
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004756 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
4757 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
4758 zout = min(DEPTH_GEMM3D - 1, zout);
4759
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004760 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004761 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004762
4763 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4764 // multiply dst_stride_z by DEPTH_GEMM3D
4765 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
4766
4767 // Store the output block
4768 vstore2((float2)(acc00, acc01), 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004769#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004770 vstore2((float2)(acc10, acc11), 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004771#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4772#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004773 vstore2((float2)(acc20, acc21), 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004774#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4775#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004776 vstore2((float2)(acc30, acc31), 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004777#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004778
4779#else // defined(REINTERPRET_OUTPUT_AS_3D)
4780 // Add offset for batched GEMM
4781 dst_addr += z * dst_stride_z;
4782
4783 // Store the output block
4784 vstore2((float2)(acc00, acc01), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
4785#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4786 vstore2((float2)(acc10, acc11), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
4787#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4788#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4789 vstore2((float2)(acc20, acc21), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
4790#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4791#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4792 vstore2((float2)(acc30, acc31), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
4793#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4794#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004795}
4796
Vidhya Sudhan Loganathanbdff4912018-05-22 15:03:09 +01004797#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01004798/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
4799 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004800 * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
4801 *
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00004802 * @note This OpenCL kernel works with the 16-bit floating point data type (half) and accumulating the result in a 32 floating point variable.
4803 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
4804 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
4805 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
4806 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
4807 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
4808 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
4809 *
4810 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
4811 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
4812 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
4813 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
4814 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
4815 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
4816 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004817 * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
4818 *
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00004819 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
4820 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
4821 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4822 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
4823 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4824 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
4825 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
4826 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
4827 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4828 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
4829 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4830 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004831 * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
4832 * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes)
4833 * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
4834 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00004835 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
4836 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
4837 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
4838 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
4839 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
4840 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
4841 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
4842 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
4843 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
4844 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
4845 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
4846 */
4847__kernel void gemm_mm_floating_point_f16_bifrost_acc32(IMAGE_DECLARATION(src0),
4848 IMAGE_DECLARATION(src1),
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004849#if defined(ADD_VEC_C)
4850 VECTOR_DECLARATION(src2),
4851#endif /* defined(ADD_VEC_C) */
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00004852 IMAGE_DECLARATION(dst),
4853 uint src0_stride_z,
4854 uint src1_stride_z,
4855 uint dst_stride_z
4856#if defined(REINTERPRET_INPUT_AS_3D)
4857 ,
4858 uint src_cross_plane_pad
4859#endif // REINTERPRET_INPUT_AS_3D
4860#if defined(REINTERPRET_OUTPUT_AS_3D)
4861 ,
4862 uint dst_cross_plane_pad
4863#endif // REINTERPRET_OUTPUT_AS_3D
4864 )
4865{
4866 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
4867
4868 // Compute starting address for matrix A and Matrix B
4869 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
4870
4871 // Update address for the matrix A
4872 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
4873
4874 // Update address for the matrix B
4875 src_addr.s1 += idx * sizeof(half);
4876
4877#if defined(REINTERPRET_INPUT_AS_3D)
4878 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
4879 // in order to take into account the presence of possible cross plane paddings
4880 //
4881 // | |
4882 // | plane0 |
4883 // | |
4884 // |__________________|
4885 // |******************|
4886 // | cross_plane_pad |
4887 // |******************|
4888 // | |
4889 // | plane1 |
4890 // | |
4891 // |__________________|
4892
4893 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
4894 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
4895 zin = min(DEPTH_GEMM3D - 1, zin);
4896
4897 // Add offset due to the cross plane paddings
4898 zin *= (src_cross_plane_pad * src0_stride_y);
4899
4900 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4901 // multiply src0_stride_z by DEPTH_GEMM3D
4902 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
4903
4904#else // defined(REINTERPRET_INPUT_AS_3D)
4905
4906 // Add offset for batched GEMM
4907 src_addr.s0 += get_global_id(2) * src0_stride_z;
4908
4909#endif // defined(REINTERPRET_INPUT_AS_3D)
4910
4911#if defined(MATRIX_B_DEPTH)
4912 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
4913 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
4914#else // defined(MATRIX_B_DEPTH)
4915 src_addr.s1 += get_global_id(2) * src1_stride_z;
4916#endif // defined(MATRIX_B_DEPTH)
4917
4918 float8 acc0 = 0.0h;
4919#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4920 float8 acc1 = 0.0h;
4921#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4922#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4923 float8 acc2 = 0.0h;
4924#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4925#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4926 float8 acc3 = 0.0h;
4927#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4928
4929 int i = 0;
4930 for(; i <= ((int)COLS_A - 4); i += 4)
4931 {
4932#if defined(REINTERPRET_INPUT_AS_3D)
4933 // Load values from matrix A
Usama Arif0681e3b2019-04-25 14:28:07 +01004934 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 4, half, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s);
4935#else // defined(REINTERPRET_INPUT_AS_3D)
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00004936 // Load values from matrix A
4937 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
4938#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4939 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
4940#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4941#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4942 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
4943#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4944#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4945 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
4946#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4947#endif // defined(REINTERPRET_INPUT_AS_3D)
4948
4949 // Load values from matrix B
4950 float8 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
4951 src_addr.s1 += src1_stride_y;
4952
4953 // Accumulate
4954 acc0 = fma(b0, (float8)a0.s0, acc0);
4955#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4956 acc1 = fma(b0, (float8)a1.s0, acc1);
4957#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4958#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4959 acc2 = fma(b0, (float8)a2.s0, acc2);
4960#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4961#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4962 acc3 = fma(b0, (float8)a3.s0, acc3);
4963#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4964
4965 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
4966 src_addr.s1 += src1_stride_y;
4967 acc0 = fma(b0, (float8)a0.s1, acc0);
4968#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4969 acc1 = fma(b0, (float8)a1.s1, acc1);
4970#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4971#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4972 acc2 = fma(b0, (float8)a2.s1, acc2);
4973#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4974#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4975 acc3 = fma(b0, (float8)a3.s1, acc3);
4976#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4977
4978 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
4979 src_addr.s1 += src1_stride_y;
4980 acc0 = fma(b0, (float8)a0.s2, acc0);
4981#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4982 acc1 = fma(b0, (float8)a1.s2, acc1);
4983#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4984#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4985 acc2 = fma(b0, (float8)a2.s2, acc2);
4986#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4987#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4988 acc3 = fma(b0, (float8)a3.s2, acc3);
4989#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4990
4991 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
4992 src_addr.s1 += src1_stride_y;
4993 acc0 = fma(b0, (float8)a0.s3, acc0);
4994#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4995 acc1 = fma(b0, (float8)a1.s3, acc1);
4996#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4997#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4998 acc2 = fma(b0, (float8)a2.s3, acc2);
4999#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5000#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5001 acc3 = fma(b0, (float8)a3.s3, acc3);
5002#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5003
5004 src_addr.s0 += 4 * sizeof(half);
5005 }
5006
5007 for(; i < (int)COLS_A; ++i)
5008 {
5009#if defined(REINTERPRET_INPUT_AS_3D)
5010 // Load values from matrix A
5011 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
5012#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5013 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
5014#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5015#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5016 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
5017#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5018#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5019 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
5020#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5021#else // defined(REINTERPRET_INPUT_AS_3D)
5022 // Load values from matrix A
5023 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
5024#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5025 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
5026#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5027#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5028 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
5029#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5030#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5031 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
5032#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5033#endif // defined(REINTERPRET_INPUT_AS_3D)
5034
5035 // Load values from matrix B
5036 float8 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
5037
5038 src_addr += (int2)(sizeof(half), src1_stride_y);
5039
5040 // Accumulate
5041 acc0 = fma(b0, (float8)a0, acc0); // b0 * (half8)a0;
5042#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5043 acc1 = fma(b0, (float8)a1, acc1); // b0 * (half8)a1;
5044#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5045#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5046 acc2 = fma(b0, (float8)a2, acc2); // b0 * (half8)a2;
5047#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5048#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5049 acc3 = fma(b0, (float8)a3, acc3); // b0 * (half8)a3;
5050#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5051 }
5052
5053 // Multiply by the weight of matrix-matrix product and store the result
5054#if defined(ALPHA)
5055 half8 hacc0 = convert_half8(acc0) * (half8)ALPHA;
5056#else //defined(ALPHA)
5057 half8 hacc0 = convert_half8(acc0);
5058#endif // defined(ALPHA)
5059#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5060#if defined(ALPHA)
5061 half8 hacc1 = convert_half8(acc1) * (half8)ALPHA;
5062#else //defined(ALPHA)
5063 half8 hacc1 = convert_half8(acc1);
5064#endif //defined(ALPHA)
5065#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y
5066
5067#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5068#if defined(ALPHA)
5069 half8 hacc2 = convert_half8(acc2) * (half8)ALPHA;
5070#else //defined(ALPHA)
5071 half8 hacc2 = convert_half8(acc2);
5072#endif //defined(ALPHA)
5073#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5074
5075#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5076#if defined(ALPHA)
5077 half8 hacc3 = convert_half8(acc3) * (half8)ALPHA;
5078#else //defined(ALPHA)
5079 half8 hacc3 = convert_half8(acc3);
5080#endif // defined(ALPHA)
5081#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5082
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00005083#if defined(ADD_VEC_C)
5084 // *INDENT-OFF*
5085 // clang-format off
5086 __global half *src2_addr = (__global half *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
5087 half8 c0 = vload8(0, src2_addr);
5088 // clang-format on
5089 // *INDENT-ON*
5090
5091 hacc0 += c0;
5092#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5093 hacc1 += c0;
5094#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5095#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5096 hacc2 += c0;
5097#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5098#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5099 hacc3 += c0;
5100#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5101#endif /* defined(ADD_VEC_C) */
5102
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005103 int z = get_global_id(2);
5104
5105 // Compute destination address
5106 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
5107
5108 // Compute dst address
5109 __global uchar *dst_addr = offset(&dst, 0, 0);
5110
5111#if defined(REINTERPRET_OUTPUT_AS_3D)
5112 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
5113 // in order to take into account the presence of possible cross plane paddings
5114 //
5115 // | |
5116 // | plane0 |
5117 // | |
5118 // |__________________|
5119 // |******************|
5120 // | cross_plane_pad |
5121 // |******************|
5122 // | |
5123 // | plane1 |
5124 // | |
5125 // |__________________|
5126
5127 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
5128 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
5129 zout = min(DEPTH_GEMM3D - 1, zout);
5130
5131 // Add offset due to the cross plane paddings
5132 zout *= (dst_cross_plane_pad * dst_stride_y);
5133
5134 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
5135 // multiply dst_stride_z by DEPTH_GEMM3D
5136 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005137 // Store the output block
Usama Arif0681e3b2019-04-25 14:28:07 +01005138 STORE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 8, half, hacc, dst_addr, dst_stride_y, zout.s);
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005139#else // defined(REINTERPRET_OUTPUT_AS_3D)
5140 // Add offset for batched GEMM
5141 dst_addr += z * dst_stride_z;
5142
5143 // Store the output block
5144 vstore8(hacc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
5145#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5146 vstore8(hacc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
5147#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5148#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5149 vstore8(hacc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
5150#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5151#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5152 vstore8(hacc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
5153#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5154#endif // REINTERPRET_OUTPUT_AS_3D
5155}
5156
5157/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
5158 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00005159 * Moreover, it can add a vector (src2) if the ADD_VEC_C parameter is passed at compile time.
5160 *
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005161 * @note This OpenCL kernel works with the 16-bit floating point data type (half) and uses the fma units.
5162 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
5163 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
5164 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
5165 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
5166 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
5167 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
5168 *
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005169 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
5170 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005171 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
5172 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
5173 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
5174 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
5175 *
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00005176 * @note In case a 3rd input (src2) needs to be added, the ADD_VEC_C parameter has to be passed at compile time as -DADD_VEC_C
5177 *
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005178 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
5179 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
5180 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5181 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
5182 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
5183 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
5184 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
5185 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
5186 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5187 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
5188 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
5189 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00005190 * @param[in] src2_ptr (Optional) Pointer to the source matrix. Supported data types: same as @p src0_ptr
5191 * @param[in] src2_stride_x (Optional) Stride of the source vector in X dimension (in bytes)
5192 * @param[in] src2_step_x (Optional) src_stride_x * number of elements along X processed per workitem(in bytes)
5193 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the source matrix
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005194 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
5195 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
5196 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
5197 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
5198 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
5199 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005200 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
5201 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
5202 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005203 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
5204 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005205 */
5206__kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0),
5207 IMAGE_DECLARATION(src1),
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00005208#if defined(ADD_VEC_C)
5209 VECTOR_DECLARATION(src2),
5210#endif /* defined(ADD_VEC_C) */
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005211 IMAGE_DECLARATION(dst),
5212 uint src0_stride_z,
5213 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005214 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005215#if defined(REINTERPRET_INPUT_AS_3D)
5216 ,
5217 uint src_cross_plane_pad
5218#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005219#if defined(REINTERPRET_OUTPUT_AS_3D)
5220 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005221 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005222#endif // REINTERPRET_OUTPUT_AS_3D
5223 )
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005224{
5225 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
5226
5227 // Compute starting address for matrix A and Matrix B
5228 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
5229
5230 // Update address for the matrix A
5231 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
5232
5233 // Update address for the matrix B
5234 src_addr.s1 += idx * sizeof(half);
5235
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005236#if defined(REINTERPRET_INPUT_AS_3D)
5237 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
5238 // in order to take into account the presence of possible cross plane paddings
5239 //
5240 // | |
5241 // | plane0 |
5242 // | |
5243 // |__________________|
5244 // |******************|
5245 // | cross_plane_pad |
5246 // |******************|
5247 // | |
5248 // | plane1 |
5249 // | |
5250 // |__________________|
5251
5252 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
5253 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
5254 zin = min(DEPTH_GEMM3D - 1, zin);
5255
5256 // Add offset due to the cross plane paddings
5257 zin *= (src_cross_plane_pad * src0_stride_y);
5258
5259 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
5260 // multiply src0_stride_z by DEPTH_GEMM3D
5261 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
5262
5263#else // defined(REINTERPRET_INPUT_AS_3D)
5264
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005265 // Add offset for batched GEMM
5266 src_addr.s0 += get_global_id(2) * src0_stride_z;
5267
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005268#endif // defined(REINTERPRET_INPUT_AS_3D)
5269
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005270#if defined(MATRIX_B_DEPTH)
5271 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
5272 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
5273#else // defined(MATRIX_B_DEPTH)
5274 src_addr.s1 += get_global_id(2) * src1_stride_z;
5275#endif // defined(MATRIX_B_DEPTH)
5276
5277 half8 acc0 = 0.0h;
5278#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5279 half8 acc1 = 0.0h;
5280#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5281#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5282 half8 acc2 = 0.0h;
5283#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5284#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5285 half8 acc3 = 0.0h;
5286#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5287
5288 int i = 0;
5289 for(; i <= ((int)COLS_A - 4); i += 4)
5290 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005291#if defined(REINTERPRET_INPUT_AS_3D)
5292 // Load values from matrix A
Usama Arif0681e3b2019-04-25 14:28:07 +01005293 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 4, half, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s);
5294#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005295 // Load values from matrix A
5296 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
5297#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5298 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
5299#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5300#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5301 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
5302#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5303#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5304 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
5305#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005306#endif // defined(REINTERPRET_INPUT_AS_3D)
5307
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005308 // Load values from matrix B
5309 half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
5310 src_addr.s1 += src1_stride_y;
5311
5312 // Accumulate
5313 acc0 = fma(b0, (half8)a0.s0, acc0);
5314#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5315 acc1 = fma(b0, (half8)a1.s0, acc1);
5316#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5317#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5318 acc2 = fma(b0, (half8)a2.s0, acc2);
5319#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5320#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5321 acc3 = fma(b0, (half8)a3.s0, acc3);
5322#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5323
5324 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
5325 src_addr.s1 += src1_stride_y;
5326 acc0 = fma(b0, (half8)a0.s1, acc0);
5327#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5328 acc1 = fma(b0, (half8)a1.s1, acc1);
5329#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5330#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5331 acc2 = fma(b0, (half8)a2.s1, acc2);
5332#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5333#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5334 acc3 = fma(b0, (half8)a3.s1, acc3);
5335#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5336
5337 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
5338 src_addr.s1 += src1_stride_y;
5339 acc0 = fma(b0, (half8)a0.s2, acc0);
5340#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5341 acc1 = fma(b0, (half8)a1.s2, acc1);
5342#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5343#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5344 acc2 = fma(b0, (half8)a2.s2, acc2);
5345#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5346#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5347 acc3 = fma(b0, (half8)a3.s2, acc3);
5348#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5349
5350 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
5351 src_addr.s1 += src1_stride_y;
5352 acc0 = fma(b0, (half8)a0.s3, acc0);
5353#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5354 acc1 = fma(b0, (half8)a1.s3, acc1);
5355#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5356#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5357 acc2 = fma(b0, (half8)a2.s3, acc2);
5358#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5359#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5360 acc3 = fma(b0, (half8)a3.s3, acc3);
5361#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5362
5363 src_addr.s0 += 4 * sizeof(half);
5364 }
5365
5366 for(; i < (int)COLS_A; ++i)
5367 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005368#if defined(REINTERPRET_INPUT_AS_3D)
5369 // Load values from matrix A
5370 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
5371#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5372 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
5373#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5374#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5375 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
5376#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5377#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5378 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
5379#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5380#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005381 // Load values from matrix A
5382 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
5383#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5384 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
5385#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5386#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5387 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
5388#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5389#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5390 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
5391#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005392#endif // defined(REINTERPRET_INPUT_AS_3D)
5393
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005394 // Load values from matrix B
5395 half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
5396
5397 src_addr += (int2)(sizeof(half), src1_stride_y);
5398
5399 // Accumulate
5400 acc0 = fma(b0, (half8)a0, acc0); // b0 * (half8)a0;
5401#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5402 acc1 = fma(b0, (half8)a1, acc1); // b0 * (half8)a1;
5403#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5404#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5405 acc2 = fma(b0, (half8)a2, acc2); // b0 * (half8)a2;
5406#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5407#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5408 acc3 = fma(b0, (half8)a3, acc3); // b0 * (half8)a3;
5409#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5410 }
5411
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005412 // Multiply by the weight of matrix-matrix product and store the result
5413#if defined(ALPHA)
5414 acc0 = acc0 * (half8)ALPHA;
5415#endif // defined(ALPHA)
5416#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
5417 acc1 = acc1 * (half8)ALPHA;
5418#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
5419#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
5420 acc2 = acc2 * (half8)ALPHA;
5421#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
5422#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
5423 acc3 = acc3 * (half8)ALPHA;
5424#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
5425
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00005426#if defined(ADD_VEC_C)
5427 // *INDENT-OFF*
5428 // clang-format off
5429 __global half *src2_addr = (__global half *)(src2_ptr + src2_offset_first_element_in_bytes + get_global_id(0) * src2_step_x);
5430 half8 c0 = vload8(0, src2_addr);
5431 // clang-format on
5432 // *INDENT-ON*
5433
5434 acc0 += c0;
5435#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5436 acc1 += c0;
5437#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5438#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5439 acc2 += c0;
5440#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5441#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5442 acc3 += c0;
5443#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5444#endif /* defined(ADD_VEC_C) */
5445
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005446 int z = get_global_id(2);
5447
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005448 // Compute destination address
5449 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
5450
5451 // Compute dst address
5452 __global uchar *dst_addr = offset(&dst, 0, 0);
5453
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005454#if defined(REINTERPRET_OUTPUT_AS_3D)
5455 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01005456 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005457 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01005458 // | |
5459 // | plane0 |
5460 // | |
5461 // |__________________|
5462 // |******************|
5463 // | cross_plane_pad |
5464 // |******************|
5465 // | |
5466 // | plane1 |
5467 // | |
5468 // |__________________|
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005469
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005470 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
5471 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
5472 zout = min(DEPTH_GEMM3D - 1, zout);
5473
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01005474 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005475 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005476
5477 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
5478 // multiply dst_stride_z by DEPTH_GEMM3D
5479 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
5480
5481 // Store the output block
Usama Arif0681e3b2019-04-25 14:28:07 +01005482 STORE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 8, half, acc, dst_addr, dst_stride_y, zout.s);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005483#else // defined(REINTERPRET_OUTPUT_AS_3D)
5484 // Add offset for batched GEMM
5485 dst_addr += z * dst_stride_z;
5486
5487 // Store the output block
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005488 vstore8(acc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
5489#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005490 vstore8(acc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
5491#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5492#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005493 vstore8(acc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
5494#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5495#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005496 vstore8(acc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
5497#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005498#endif // REINTERPRET_OUTPUT_AS_3D
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005499}
Vidhya Sudhan Loganathanbdff4912018-05-22 15:03:09 +01005500#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005501
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01005502#endif // defined(COLS_A) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && (NUM_ELEMS_PROCESSED_PER_THREAD_Y)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005503
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005504#if defined(BETA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005505/** This OpenCL kernel performs the in-place matrix addition between 2 matrices taking into account that the second matrix might be weighted by a scalar value beta:
5506 *
Gian Marco19835e52018-01-30 13:35:54 +00005507 * @note The beta's value need to be passed at compile time using -DBETA
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005508 *
5509 * @param[in] src_ptr Pointer to the source matrix. Supported data types: F32
5510 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
5511 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5512 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
5513 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005514 * @param[in] src_stride_z Stride of the destination tensor in Z dimension (in bytes)
5515 * @param[in] src_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005516 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01005517 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005518 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
5519 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
5520 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
5521 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005522 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
5523 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005524 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
5525 */
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005526__kernel void gemm_ma_f32(TENSOR3D_DECLARATION(src),
5527 TENSOR3D_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005528{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005529 // Compute source and destination addresses
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005530 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
5531 Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005532
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005533 // Load values from A x B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005534 float4 alpha_ab = vload4(0, (__global float *)dst.ptr);
5535
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005536 // Load values from Matrix C
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005537 float4 c = vload4(0, (__global float *)src.ptr);
5538
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005539 // Computes alpha * axb + beta * c
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005540 float4 out = alpha_ab + (float4)BETA * c;
5541
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005542 // Store final result in axb matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005543 vstore4(out, 0, (__global float *)dst.ptr);
5544}
5545
Vidhya Sudhan Loganathan76c85642018-05-25 13:53:02 +01005546#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005547/** This OpenCL kernel performs the in-place matrix addition between 2 matrices taking into account that the second matrix might be weighted by a scalar value beta:
5548 *
Gian Marco19835e52018-01-30 13:35:54 +00005549 * @note The beta's value need to be passed at compile time using -DBETA
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01005550 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005551 * @param[in] src_ptr Pointer to the source matrix. Supported data types: F16
5552 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
5553 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5554 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
5555 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005556 * @param[in] src_stride_z Stride of the destination tensor in Z dimension (in bytes)
5557 * @param[in] src_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005558 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01005559 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005560 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
5561 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
5562 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
5563 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005564 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
5565 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005566 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
5567 */
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005568__kernel void gemm_ma_f16(TENSOR3D_DECLARATION(src),
5569 TENSOR3D_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005570{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005571 // Compute source and destination addresses
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005572 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
5573 Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005574
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005575 // Load values from A x B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005576 half8 alpha_ab = vload8(0, (__global half *)dst.ptr);
5577
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005578 // Load values from Matrix C
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005579 half8 c = vload8(0, (__global half *)src.ptr);
5580
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005581 // Computes alpha * axb + beta * c
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005582 half8 out = alpha_ab + (half8)BETA * c;
5583
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005584 // Store final result in axb matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005585 vstore8(out, 0, (__global half *)dst.ptr);
5586}
Vidhya Sudhan Loganathan76c85642018-05-25 13:53:02 +01005587#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005588#endif // defined(BETA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005589
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005590#if defined(WIDTH_VECTOR_A)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005591/** This OpenCL kernel computes the vector by matrix multiplication between each row of A (src0) and matrix B (src1) used for locally connected layer
5592 *
Gian Marco19835e52018-01-30 13:35:54 +00005593 * @note The width of A need to be passed at compile time using -DWIDTH_VECTOR_A
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005594 *
Gian Marco19835e52018-01-30 13:35:54 +00005595 * @note The input A and matrix B must not be reshaped
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005596 *
5597 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
5598 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
5599 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5600 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
5601 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
5602 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01005603 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005604 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
5605 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5606 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
5607 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
5608 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
5609 * @param[in] src1_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
5610 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01005611 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005612 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
5613 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
5614 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
5615 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
5616 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
5617 */
5618__kernel void gemm_lc_vm_f32(IMAGE_DECLARATION(src0),
5619 TENSOR3D_DECLARATION(src1),
5620 IMAGE_DECLARATION(dst))
5621{
5622 int idx = get_global_id(0) * 4;
5623 int idy = get_global_id(1);
5624
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005625 // Compute the address for the vector A and matrix B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005626 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes + src0_stride_y * idy, src1_offset_first_element_in_bytes + src1_stride_z * idy));
5627 src_addr.s1 += idx * sizeof(float);
5628
5629 int end_row_vec_a = src_addr.s0 + (WIDTH_VECTOR_A * sizeof(float));
5630
5631 float4 acc = 0.0f;
5632
Georgios Pinitas96880cf2017-10-20 18:52:20 +01005633 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(float)); src_addr += (int2)(2 * sizeof(float), 2 * src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005634 {
5635 float2 a0 = vload2(0, (__global float *)(src0_ptr + src_addr.s0));
5636 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
5637 float4 b1 = vload4(0, (__global float *)(src1_ptr + src_addr.s1 + src1_stride_y));
5638
5639 acc += b0 * (float4)a0.s0;
5640 acc += b1 * (float4)a0.s1;
5641 }
5642
5643 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(float), src1_stride_y))
5644 {
5645 float a0 = *((__global float *)(src0_ptr + src_addr.s0));
5646 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
5647
5648 acc += b0 * (float4)a0;
5649 }
5650
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005651 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005652 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
5653
5654 vstore4(acc, 0, (__global float *)(offset(&dst, 0, 0)));
5655}
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005656#endif // defined(WIDTH_VECTOR_A)
5657
5658/** This kernel accumulates each row with the biases vector.
5659 *
5660 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=short.
5661 * @note The vector size must be passed at compile time using -DVECTOR_SIZE e.g. -DVECTOR_SIZE=16.
5662 *
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01005663 * @param[in, out] accum_ptr Pointer to the accumulate tensor. Supported data type: U8/S8/U16/S16/F16/U32/S32/F32
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005664 * @param[in] accum_stride_x Stride of the accmulate tensor in X dimension (in bytes)
5665 * @param[in] accum_step_x accum_stride_x * number of elements along X processed per workitem(in bytes)
5666 * @param[in] accum_stride_y Stride of the accumlulate tensor in Y dimension (in bytes)
5667 * @param[in] accum_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
5668 * @param[in] accum_offset_first_element_in_bytes The offset of the first element in the accumulate tensor
5669 * @param[in] biases_ptr Pointer to the biases vector. Same as @p accum_ptr
5670 * @param[in] biases_stride_x Stride of the destination tensor in X dimension (in bytes)
5671 * @param[in] biases_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
5672 * @param[in] biases_offset_first_element_in_bytes The offset of the first element in the destination tensor
5673 */
5674#if defined(DATA_TYPE) && defined(VECTOR_SIZE)
5675__kernel void gemm_accumulate_biases(
5676 IMAGE_DECLARATION(accum),
5677 VECTOR_DECLARATION(biases))
5678{
5679 Image accum = CONVERT_TO_IMAGE_STRUCT(accum);
5680 Vector biases = CONVERT_TO_VECTOR_STRUCT(biases);
5681
5682 // Vector size, i.e. number of vector elements.
5683 VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
5684 accum_value = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)accum.ptr);
5685 VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
5686 biases_value = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)biases.ptr);
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01005687 accum_value = biases_value + accum_value;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005688 // Store result in the accumulate buffer
5689 VSTORE(VECTOR_SIZE)
5690 (accum_value, 0, (__global DATA_TYPE *)accum.ptr);
5691}
5692#endif // defined(DATA_TYPE) && defined(VECTOR_SIZE)