blob: 8a956010e74f87dd69d8571f6fda650929827c88 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Sheri Zhang1a378102020-04-30 12:59:39 +01002 * Copyright (c) 2017-2020 ARM Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Usama Arif0681e3b2019-04-25 14:28:07 +010024#include "gemm_helpers.h"
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +000025#include "repeat.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010026
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +000027#if defined(M0) && defined(K0) && defined(V0) && defined(DATA_TYPE) && defined(SRC_WIDTH)
28#define INC2 (VEC_DATA_TYPE(uint, 2))(0, 1)
29#define INC3 (VEC_DATA_TYPE(uint, 3))(0, 1, 2)
30#define INC4 (VEC_DATA_TYPE(uint, 4))(0, 1, 2, 3)
31#define INC8 (VEC_DATA_TYPE(uint, 8))(0, 1, 2, 3, 4, 5, 6, 7)
32#define INC16 (VEC_DATA_TYPE(uint, 16))(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)
33#define CONCAT_INC(K0) INC##K0
34#define INC(K0) CONCAT_INC(K0)
35
36#if(SRC_WIDTH % K0)
37#define BOUNDARY_CONDITION_X(x, a) \
38 ({ \
39 a = select(0, a, CONVERT(((x * (VEC_DATA_TYPE(uint, K0))K0 + INC(K0)) < (VEC_DATA_TYPE(uint, K0))SRC_WIDTH), VEC_DATA_TYPE(DATA_TYPE, K0))); \
40 })
41#else // (SRC_WIDTH % K0)
42#define BOUNDARY_CONDITION_X(x, a) \
43 ({})
44#endif // (SRC_WIDTH % K0)
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000045
46/** This OpenCL kernel reshapes the lhs input matrix. The kernel splits the input matrix in blocks of size M0xK0 and stores each one (not transposed) in
47 * the output matrix unrolling the values.
48 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +010049 * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
50 * @note The width of the input tensor must be passed at compile time using -DSRC_WIDTH (e.g. -DSRC_WIDTH=16)
51 * @note The block's dimensions (M0 and K0) must be passed at compile time using -DM0 and -DK0 (e.g. -DM0=2, -DK0=2).
52 * @note The number of M0xK0 vertical blocks to store on the same output row must be passed at compile time using -DV0 (e.g. -DV0=2)
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000053 * @note Only the following values for M0, K0 and V0 are supported:
54 * M0: 2,3,4,5,6,7,8
Gian Marco Iodicebacfec52019-01-11 11:30:55 +000055 * K0: 2,3,4,8,16
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000056 * V0: greater than 0
Gian Marco Iodiced1f54762019-07-19 09:54:47 +010057 * @note In case the input has to be reinterpreted as a 3D tensor (e.g. input of convolution layer 1x1), the following information must be passed at compile time:
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000058 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
59 * -# HEIGHT_GEMM3D: The height of the input in case it has to be reinterpreted as a 3D tensor.
60 * -# DEPTH_GEMM3D: The depth of the input in case it has to be reinterpreted as a 3D tensor
61 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
62 * @note If the M0xK0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
63 *
64 * @param[in] src_ptr Pointer to the source LHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
65 * @param[in] src_stride_x Stride of the source LHS tensor in X dimension (in bytes)
66 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
67 * @param[in] src_stride_y Stride of the source LHS tensor in Y dimension (in bytes)
68 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
69 * @param[in] src_stride_z Stride of the source LHS tensor in Z dimension (in bytes)
70 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
71 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source LHS tensor
72 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
73 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
74 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
75 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
76 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
77 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
78 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
79 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
80 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
81 */
82__kernel void gemm_reshape_lhs_matrix_nt(TENSOR3D_DECLARATION(src),
83 TENSOR3D_DECLARATION(dst)
84#if defined(REINTERPRET_INPUT_AS_3D)
85 ,
86 uint cross_plane_pad
87#endif // REINTERPRET_INPUT_AS_3D
88 )
89{
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +000090 // Block size
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000091#define BLOCK_SIZE ((M0) * (K0))
92
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +000093 // Output offset X
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000094#if defined(INTERLEAVE)
95#define OUTPUT_OFFSET_X (K0)
96#else // defined(INTERLEAVE)
97#define OUTPUT_OFFSET_X (BLOCK_SIZE)
98#endif // defined(INTERLEAVE)
99
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000100 // Output step X
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000101#if defined(INTERLEAVE)
102#define OUTPUT_STEP_X (K0) * (V0)
103#else // Do not interleave
104#define OUTPUT_STEP_X (K0)
105#endif // defined(INTERLEAVE)
106
107 // Compute source and destination addresses
108 uint x = get_global_id(0);
109 uint y = get_global_id(1);
110 uint z = get_global_id(2);
111
112 // ------------------ Compute input/output addresses ---------------------------
113
114 // Compute the input address
115 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)K0 * sizeof(DATA_TYPE) + y * (uint)M0 * src_stride_y;
116
117 // Compute the output address
118 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)BLOCK_SIZE * (uint)V0 * sizeof(DATA_TYPE)) + ((y / (uint)V0) * (uint)dst_stride_y) + ((y % V0) *
119 (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE));
120
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000121 // Create variables: uint zin0=0, zin1=0, zin2=0...zin(M0-1)=0;
122 REPEAT_VAR_INIT_TO_CONST(M0, uint, zin, 0);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000123
124#if defined(REINTERPRET_INPUT_AS_3D)
125 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
126 // multiply src_stride_z by DEPTH_GEMM3D
127
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000128 input_ptr += z * (uint)src_stride_z * DEPTH_GEMM3D;
129
130 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
Usama Arif0681e3b2019-04-25 14:28:07 +0100131 CALCULATE_Z_OFFSET(M0, uint, zin, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, cross_plane_pad, src_stride_y);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000132
133#else // defined(REINTERPRET_INPUT_AS_3D)
134
135 input_ptr += z * (uint)src_stride_z;
136
137#endif // defined(REINTERPRET_INPUT_AS_3D)
138
139 // Add offset for batched GEMM
140 output_ptr += z * (uint)dst_stride_z;
141
142 // ---------------------------Load input values --------------------------------
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000143 // Load values from the LHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +0100144 LOAD_BLOCK(M0, K0, DATA_TYPE, a, input_ptr, 0, src_stride_y, zin);
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000145 BOUNDARY_CONDITION_X(x, a0);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000146#if M0 > 1
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000147 BOUNDARY_CONDITION_X(x, a1);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000148#endif // M0 > 1
149#if M0 > 2
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000150 BOUNDARY_CONDITION_X(x, a2);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000151#endif // M0 > 2
152#if M0 > 3
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000153 BOUNDARY_CONDITION_X(x, a3);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000154#endif // M0 > 3
155#if M0 > 4
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000156 BOUNDARY_CONDITION_X(x, a4);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000157#endif // M0 > 4
158#if M0 > 5
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000159 BOUNDARY_CONDITION_X(x, a5);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000160#endif // M0 > 5
161#if M0 > 6
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000162 BOUNDARY_CONDITION_X(x, a6);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000163#endif // M0 > 6
164#if M0 > 7
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000165 BOUNDARY_CONDITION_X(x, a7);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000166#endif // M0 > 7
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000167 // ---------------------------Store output values ------------------------------
Usama Arif0681e3b2019-04-25 14:28:07 +0100168 REPEAT_VAR_INIT_TO_CONST(16, uint, zout, 0);
169 STORE_BLOCK(M0, K0, DATA_TYPE, a, output_ptr, OUTPUT_STEP_X * sizeof(DATA_TYPE), zout);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000170
171#undef BLOCK_SIZE
172#undef OUTPUT_OFFSET_X
173#undef OUTPUT_STEP_X
174}
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000175
176#if M0 == 2
177#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
178 ({ \
179 VEC_DATA_TYPE(DATA_TYPE, M0) \
180 res = (VEC_DATA_TYPE(DATA_TYPE, M0))(a0.s##i, a1.s##i); \
181 VSTORE(M0) \
182 (res, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
183 })
184#elif M0 == 3 // M0 == 3
185#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
186 ({ \
187 VEC_DATA_TYPE(DATA_TYPE, M0) \
188 res = (VEC_DATA_TYPE(DATA_TYPE, M0))(a0.s##i, a1.s##i, a2.s##i); \
189 VSTORE(M0) \
190 (res, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
191 })
192#elif M0 == 4 // M0 == 4
193#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
194 ({ \
195 VEC_DATA_TYPE(DATA_TYPE, M0) \
196 res = (VEC_DATA_TYPE(DATA_TYPE, M0))(a0.s##i, a1.s##i, a2.s##i, a3.s##i); \
197 VSTORE(M0) \
198 (res, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
199 })
200#elif M0 == 5 // M0 == 5
201#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
202 ({ \
203 VEC_DATA_TYPE(DATA_TYPE, 4) \
204 res0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s##i, a1.s##i, a2.s##i, a3.s##i); \
205 DATA_TYPE res1 = a4.s##i; \
206 VSTORE(4) \
207 (res0, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
208 *((__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE)) + 4) = res1; \
209 })
210#elif M0 == 6 // M0 == 6
211#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
212 ({ \
213 VEC_DATA_TYPE(DATA_TYPE, 4) \
214 res0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s##i, a1.s##i, a2.s##i, a3.s##i); \
215 VEC_DATA_TYPE(DATA_TYPE, 2) \
216 res1 = (VEC_DATA_TYPE(DATA_TYPE, 2))(a4.s##i, a5.s##i); \
217 VSTORE(4) \
218 (res0, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
219 VSTORE(2) \
220 (res1, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE)) + 4); \
221 })
222#elif M0 == 7 // M0 == 7
223#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
224 ({ \
225 VEC_DATA_TYPE(DATA_TYPE, 4) \
226 res0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s##i, a1.s##i, a2.s##i, a3.s##i); \
227 VEC_DATA_TYPE(DATA_TYPE, 3) \
228 res1 = (VEC_DATA_TYPE(DATA_TYPE, 3))(a4.s##i, a5.s##i, a6.s##i); \
229 VSTORE(4) \
230 (res0, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
231 VSTORE(3) \
232 (res1, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE)) + 4); \
233 })
234#elif M0 == 8 // M0 == 8
235#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
236 ({ \
237 VEC_DATA_TYPE(DATA_TYPE, M0) \
238 res = (VEC_DATA_TYPE(DATA_TYPE, M0))(a0.s##i, a1.s##i, a2.s##i, a3.s##i, a4.s##i, a5.s##i, a6.s##i, a7.s##i); \
239 VSTORE(M0) \
240 (res, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
241 })
242#else // M0 not supported
243#error "M0 value not supported"
244#endif // N0 conditions
245
246/** This OpenCL kernel reshapes the lhs input matrix. The kernel splits the input matrix in blocks of size M0xK0 and stores each one (transposed) in
247 * the output matrix unrolling the values.
248 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +0100249 * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
250 * @note The width of the input tensor must be passed at compile time using -DSRC_WIDTH (e.g. -DSRC_WIDTH=16)
251 * @note The block's dimensions (M0 and K0) must be passed at compile time using -DM0 and -DK0 (e.g. -DM0=2, -DK0=2).
252 * @note The number of M0xK0 vertical blocks to store on the same output row must be passed at compile time using -DV0 (e.g. -DV0=2)
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000253 * @note Only the following values for M0, K0 and V0 are supported:
254 * M0: 2,3,4,5,6,7,8
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000255 * K0: 2,3,4,8,16
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000256 * V0: greater than 0
Gian Marco Iodiced1f54762019-07-19 09:54:47 +0100257 * @note In case the input has to be reinterpreted as a 3D tensor (e.g. input of convolution layer 1x1), the following information must be passed at compile time:
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000258 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
259 * -# HEIGHT_GEMM3D: The height of the input in case it has to be reinterpreted as a 3D tensor.
260 * -# DEPTH_GEMM3D: The depth of the input in case it has to be reinterpreted as a 3D tensor
261 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
262 * @note If the M0xK0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
263 *
264 * @param[in] src_ptr Pointer to the source LHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
265 * @param[in] src_stride_x Stride of the source LHS tensor in X dimension (in bytes)
266 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
267 * @param[in] src_stride_y Stride of the source LHS tensor in Y dimension (in bytes)
268 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
269 * @param[in] src_stride_z Stride of the source LHS tensor in Z dimension (in bytes)
270 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
271 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source LHS tensor
272 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
273 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
274 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
275 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
276 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
277 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
278 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
279 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
280 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
281 */
282__kernel void gemm_reshape_lhs_matrix_t(TENSOR3D_DECLARATION(src),
283 TENSOR3D_DECLARATION(dst)
284#if defined(REINTERPRET_INPUT_AS_3D)
285 ,
286 uint cross_plane_pad
287#endif // REINTERPRET_INPUT_AS_3D
288 )
289{
290 // Block size
291#define BLOCK_SIZE ((M0) * (K0))
292
293 // Output offset X
294#if defined(INTERLEAVE)
295#define OUTPUT_OFFSET_X (M0)
296#else // defined(INTERLEAVE)
297#define OUTPUT_OFFSET_X (BLOCK_SIZE)
298#endif // defined(INTERLEAVE)
299
300 // Output step X
301#if defined(INTERLEAVE)
302#define OUTPUT_STEP_X (M0) * (V0)
303#else // Do not interleave
304#define OUTPUT_STEP_X (M0)
305#endif // defined(INTERLEAVE)
306
307 // Compute source and destination addresses
308 uint x = get_global_id(0);
309 uint y = get_global_id(1);
310 uint z = get_global_id(2);
311
312 // ------------------ Compute input/output addresses ---------------------------
313
314 // Compute the input address
315 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)K0 * sizeof(DATA_TYPE) + y * (uint)M0 * src_stride_y;
316
317 // Compute the output address
318 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)BLOCK_SIZE * (uint)V0 * sizeof(DATA_TYPE)) + ((y / (uint)V0) * (uint)dst_stride_y) + ((y % V0) *
319 (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE));
320
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000321 // Create variables: uint zin0=0, zin1=0, zin2=0...zin(M0-1)=0;
322 REPEAT_VAR_INIT_TO_CONST(M0, uint, zin, 0);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000323
324#if defined(REINTERPRET_INPUT_AS_3D)
325 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
326 // multiply src_stride_z by DEPTH_GEMM3D
327
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000328 input_ptr += z * (uint)src_stride_z * DEPTH_GEMM3D;
329
330 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
Usama Arif0681e3b2019-04-25 14:28:07 +0100331 CALCULATE_Z_OFFSET(M0, uint, zin, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, cross_plane_pad, src_stride_y);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000332
333#else // defined(REINTERPRET_INPUT_AS_3D)
334
335 input_ptr += z * (uint)src_stride_z;
336
337#endif // defined(REINTERPRET_INPUT_AS_3D)
338
339 // Add offset for batched GEMM
340 output_ptr += z * (uint)dst_stride_z;
341
342 // ---------------------------Load input values --------------------------------
343
344 // Load values from the LHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +0100345 LOAD_BLOCK(M0, K0, DATA_TYPE, a, input_ptr, 0, src_stride_y, zin);
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000346 BOUNDARY_CONDITION_X(x, a0);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000347#if M0 > 1
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000348 BOUNDARY_CONDITION_X(x, a1);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000349#endif // M0 > 1
350#if M0 > 2
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000351 BOUNDARY_CONDITION_X(x, a2);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000352#endif // M0 > 2
353#if M0 > 3
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000354 BOUNDARY_CONDITION_X(x, a3);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000355#endif // M0 > 3
356#if M0 > 4
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000357 BOUNDARY_CONDITION_X(x, a4);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000358#endif // M0 > 4
359#if M0 > 5
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000360 BOUNDARY_CONDITION_X(x, a5);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000361#endif // M0 > 5
362#if M0 > 6
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000363 BOUNDARY_CONDITION_X(x, a6);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000364#endif // M0 > 6
365#if M0 > 7
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000366 BOUNDARY_CONDITION_X(x, a7);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000367#endif // M0 > 7
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000368 // ---------------------------Transpose and store block -----------------------
369
370 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 0);
371 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 1);
372#if K0 > 2
373 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 2);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000374#endif // K0 > 2
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000375#if K0 > 3
376 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 3);
377#endif // K0 > 3
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000378#if K0 > 4
379 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 4);
380 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 5);
381 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 6);
382 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 7);
383#endif // K0 > 4
384#if K0 > 8
385 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 8);
386 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 9);
387 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, A);
388 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, B);
389 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, C);
390 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, D);
391 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, E);
392 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, F);
393#endif // K0 > 8
394
395#undef BLOCK_SIZE
396#undef OUTPUT_OFFSET_X
397#undef OUTPUT_STEP_X
398}
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000399#endif // defined(M0) && defined(K0) && defined(V0) && defined(DATA_TYPE) && defined(SRC_WIDTH)
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000400
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000401#if defined(K0) && defined(N0) && defined(H0) && defined(DATA_TYPE) && defined(SRC_HEIGHT)
402/** This OpenCL kernel reshapes the rhs input matrix. The kernel splits the input matrix in blocks of size K0xN0 and stores each one (not transposed) in
403 * the output matrix unrolling the values.
404 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +0100405 * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
406 * @note The height of the input tensor must be passed at compile time using -DSRC_HEIGHT (e.g. -DSRC_HEIGHT=16)
407 * @note The block's dimensions (K0 and N0) must be passed at compile time using -DK0 and -DN0 (e.g. -DK0=2, -DN0=2).
408 * @note The number of K0xN0 vertical blocks to store on the same output row must be passed at compile time using -DH0 (e.g. -DH0=2)
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000409 * @note If the K0xN0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
410 * @note Only the following values for K0, N0 and H0 are supported:
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000411 * N0: 2,3,4,8,16
412 * K0: 1,2,3,4,8,16
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000413 * H0: greater than 0
414 *
415 * @param[in] src_ptr Pointer to the source RHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
416 * @param[in] src_stride_x Stride of the source RHS tensor in X dimension (in bytes)
417 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
418 * @param[in] src_stride_y Stride of the source RHS tensor in Y dimension (in bytes)
419 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
420 * @param[in] src_stride_z Stride of the source RHS tensor in Z dimension (in bytes)
421 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
422 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source RHS tensor
423 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
424 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
425 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
426 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
427 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
428 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
429 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
430 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
431 */
432__kernel void gemm_reshape_rhs_matrix_nt(TENSOR3D_DECLARATION(src),
433 TENSOR3D_DECLARATION(dst))
434{
435 // Block size
436#define BLOCK_SIZE ((K0) * (N0))
437
438 // Output offset X
439#if defined(INTERLEAVE)
440#define OUTPUT_OFFSET_X (N0)
441#else // defined(INTERLEAVE)
442#define OUTPUT_OFFSET_X (BLOCK_SIZE)
443#endif // defined(INTERLEAVE)
444
445 // Output step X
446#if defined(INTERLEAVE)
447#define OUTPUT_STEP_X (N0) * (H0)
448#else // Do not interleave
449#define OUTPUT_STEP_X (N0)
450#endif // defined(INTERLEAVE)
451
452 // Compute source and destination addresses
453 uint x = get_global_id(0);
454 uint y = get_global_id(1);
455 uint z = get_global_id(2);
456
457 // ------------------ Compute input/output addresses ---------------------------
458
459 // Compute the input address
460 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)N0 * sizeof(DATA_TYPE) + y * (uint)K0 * src_stride_y + z * (uint)src_stride_z;
461
462 // Compute the output address
463 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (y * (uint)BLOCK_SIZE * (uint)H0 * sizeof(DATA_TYPE)) + ((x % (uint)H0) * (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE)) + ((
464 x / (uint)H0)
465 * (uint)dst_stride_y)
466 + z * (uint)dst_stride_z;
467
468 // ---------------------------Load input values --------------------------------
469
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +0000470 REPEAT_VAR_INIT_TO_CONST(K0, VEC_DATA_TYPE(DATA_TYPE, N0), a, 0); ////uint a0=0, a1=0, a2=0...a(M0-1)=0;
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000471
472 // Load values from the RHS matrix
473 a0 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y));
474#if K0 > 1
475 if(y * (uint)K0 + 1 < SRC_HEIGHT)
476 {
477 a1 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y));
478 }
479#endif // K0 > 1
480#if K0 > 2
481 if(y * (uint)K0 + 2 < SRC_HEIGHT)
482 {
483 a2 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y));
484 }
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000485#endif // K0 > 2
486#if K0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000487 if(y * (uint)K0 + 3 < SRC_HEIGHT)
488 {
489 a3 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y));
490 }
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000491#endif // K0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000492#if K0 > 4
493 if(y * (uint)K0 + 4 < SRC_HEIGHT)
494 {
495 a4 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 4 * src_stride_y));
496 }
497 if(y * (uint)K0 + 5 < SRC_HEIGHT)
498 {
499 a5 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 5 * src_stride_y));
500 }
501 if(y * (uint)K0 + 6 < SRC_HEIGHT)
502 {
503 a6 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 6 * src_stride_y));
504 }
505 if(y * (uint)K0 + 7 < SRC_HEIGHT)
506 {
507 a7 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 7 * src_stride_y));
508 }
509#endif // K0 > 4
510#if K0 > 8
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000511 if(y * (uint)K0 + 8 < SRC_HEIGHT)
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000512 {
513 a8 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 8 * src_stride_y));
514 }
515 if(y * (uint)K0 + 9 < SRC_HEIGHT)
516 {
517 a9 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 9 * src_stride_y));
518 }
519 if(y * (uint)K0 + 10 < SRC_HEIGHT)
520 {
521 aA = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 10 * src_stride_y));
522 }
523 if(y * (uint)K0 + 11 < SRC_HEIGHT)
524 {
525 aB = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 11 * src_stride_y));
526 }
527 if(y * (uint)K0 + 12 < SRC_HEIGHT)
528 {
529 aC = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 12 * src_stride_y));
530 }
531 if(y * (uint)K0 + 13 < SRC_HEIGHT)
532 {
533 aD = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 13 * src_stride_y));
534 }
535 if(y * (uint)K0 + 14 < SRC_HEIGHT)
536 {
537 aE = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 14 * src_stride_y));
538 }
539 if(y * (uint)K0 + 15 < SRC_HEIGHT)
540 {
541 aF = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 15 * src_stride_y));
542 }
543#endif // K0 > 8
544
545 // ---------------------------Store output values ------------------------------
Usama Arif0681e3b2019-04-25 14:28:07 +0100546 REPEAT_VAR_INIT_TO_CONST(16, uint, zout, 0);
547 STORE_BLOCK(K0, N0, DATA_TYPE, a, output_ptr, OUTPUT_STEP_X * sizeof(DATA_TYPE), zout);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000548
549#undef BLOCK_SIZE
550#undef OUTPUT_OFFSET_X
551#undef OUTPUT_STEP_X
552}
553
554#if defined(TRANSPOSE)
555/** This OpenCL kernel reshapes the rhs input matrix. The kernel splits the input matrix in blocks of size K0xN0 and stores each one (transposed) in
556 * the output matrix unrolling the values.
557 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +0100558 * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
559 * @note The height of the input tensor must be passed at compile time using -DSRC_HEIGHT (e.g. -DSRC_HEIGHT=16)
560 * @note The block's dimensions (K0 and N0) must be passed at compile time using -DK0 and -DN0 (e.g. -DK0=2, -DN0=2).
561 * @note The number of K0xN0 vertical blocks to store on the same output row must be passed at compile time using -DH0 (e.g. -DH0=2)
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000562 * @note If the K0xN0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
563 * @note The option -DTRANSPOSE must passed at compile time.
564 * @note Only the following values for K0, N0 and H0 are supported:
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000565 * N0: 2,3,4,8,16
566 * K0: 2,3,4,8,16
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000567 * H0: greater than 0
568 *
569 * @param[in] src_ptr Pointer to the source RHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
570 * @param[in] src_stride_x Stride of the source RHS tensor in X dimension (in bytes)
571 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
572 * @param[in] src_stride_y Stride of the source RHS tensor in Y dimension (in bytes)
573 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
574 * @param[in] src_stride_z Stride of the source RHS tensor in Z dimension (in bytes)
575 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
576 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source RHS tensor
577 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
578 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
579 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
580 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
581 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
582 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
583 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
584 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
585 */
586__kernel void gemm_reshape_rhs_matrix_t(TENSOR3D_DECLARATION(src),
587 TENSOR3D_DECLARATION(dst))
588{
589 // Block size
590#define BLOCK_SIZE ((K0) * (N0))
591
592 // Output offset X
593#if defined(INTERLEAVE)
594#define OUTPUT_OFFSET_X (K0)
595#else // defined(INTERLEAVE)
596#define OUTPUT_OFFSET_X (BLOCK_SIZE)
597#endif // defined(INTERLEAVE)
598
599 // Output step X
600#if defined(INTERLEAVE)
601#define OUTPUT_STEP_X (K0) * (H0)
602#else // Do not interleave
603#define OUTPUT_STEP_X (K0)
604#endif // defined(INTERLEAVE)
605
606 // Compute source and destination addresses
607 uint x = get_global_id(0);
608 uint y = get_global_id(1);
609 uint z = get_global_id(2);
610
611 // ------------------ Compute input/output addresses ---------------------------
612
613 // Compute the input address
614 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)N0 * sizeof(DATA_TYPE) + y * (uint)K0 * src_stride_y + z * (uint)src_stride_z;
615
616 // Compute the output address
617 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (y * (uint)BLOCK_SIZE * (uint)H0 * sizeof(DATA_TYPE)) + ((x % H0) * (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE)) + ((x /
618 (uint)H0) * (uint)dst_stride_y) + z * (uint)dst_stride_z;
619
620 // ---------------------------Load input values --------------------------------
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +0000621 REPEAT_VAR_INIT_TO_CONST(K0, VEC_DATA_TYPE(DATA_TYPE, N0), a, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) a0=0, a1=0, ... a(K0-1)=0;
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000622
623 // Load values from the RHS matrix
624 a0 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y));
625 if(y * (uint)K0 + 1 < SRC_HEIGHT)
626 {
627 a1 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y));
628 }
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000629#if K0 > 2
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000630 if(y * (uint)K0 + 2 < SRC_HEIGHT)
631 {
632 a2 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y));
633 }
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000634#endif // K0 > 2
635#if K0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000636 if(y * (uint)K0 + 3 < SRC_HEIGHT)
637 {
638 a3 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y));
639 }
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000640#endif // K0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000641#if K0 > 4
642 if(y * (uint)K0 + 4 < SRC_HEIGHT)
643 {
644 a4 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 4 * src_stride_y));
645 }
646 if(y * (uint)K0 + 5 < SRC_HEIGHT)
647 {
648 a5 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 5 * src_stride_y));
649 }
650 if(y * (uint)K0 + 6 < SRC_HEIGHT)
651 {
652 a6 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 6 * src_stride_y));
653 }
654 if(y * (uint)K0 + 7 < SRC_HEIGHT)
655 {
656 a7 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 7 * src_stride_y));
657 }
658#endif // K0 > 4
659#if K0 > 8
Gian Marco Iodice89124342018-12-19 14:17:22 +0000660 if(y * (uint)K0 + 8 < SRC_HEIGHT)
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000661 {
662 a8 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 8 * src_stride_y));
663 }
664 if(y * (uint)K0 + 9 < SRC_HEIGHT)
665 {
666 a9 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 9 * src_stride_y));
667 }
668 if(y * (uint)K0 + 10 < SRC_HEIGHT)
669 {
670 aA = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 10 * src_stride_y));
671 }
672 if(y * (uint)K0 + 11 < SRC_HEIGHT)
673 {
674 aB = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 11 * src_stride_y));
675 }
676 if(y * (uint)K0 + 12 < SRC_HEIGHT)
677 {
678 aC = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 12 * src_stride_y));
679 }
680 if(y * (uint)K0 + 13 < SRC_HEIGHT)
681 {
682 aD = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 13 * src_stride_y));
683 }
684 if(y * (uint)K0 + 14 < SRC_HEIGHT)
685 {
686 aE = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 14 * src_stride_y));
687 }
688 if(y * (uint)K0 + 15 < SRC_HEIGHT)
689 {
690 aF = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 15 * src_stride_y));
691 }
692#endif // K0 > 8
693
694 // ---------------------------Transpose the block ------------------------------
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +0000695 REPEAT_VAR_INIT_TO_CONST(N0, VEC_DATA_TYPE(DATA_TYPE, K0), res, 0); //VEC_DATA_TYPE(DATA_TYPE, K0) res0=0, res1=0, res2=0,... res(N0-1)=0;
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000696
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000697#if K0 == 2
698 // This part computes the following transpositions:
699 // 2x2 -> 2x2
700 // 2x4 -> 4x2
701 // 2x8 -> 8x2
702 // 2x16 -> 16x2
703 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0);
704 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1);
705#if N0 > 2
706 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2);
707#endif // N0 > 2
708#if N0 > 3
709 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3);
710#endif // N0 > 3
711#if N0 > 4
712 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4);
713 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5);
714 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6);
715 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7);
716#endif // N0 > 4
717#if N0 > 8
718 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8);
719 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9);
720 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA);
721 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB);
722 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC);
723 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD);
724 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE);
725 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF);
726#endif // N0 > 8
727
728#elif K0 == 3 // K0 == 2
729 // This part computes the following transpositions:
730 // 3x2 -> 2x3
731 // 3x4 -> 4x3
732 // 3x8 -> 8x3
733 // 3x16 -> 16x3
Georgios Pinitasb0f342e2019-05-21 13:32:43 +0100734 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0);
735 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000736#if N0 > 2
Georgios Pinitasb0f342e2019-05-21 13:32:43 +0100737 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000738#endif // N0 > 2
739#if N0 > 3
Georgios Pinitasb0f342e2019-05-21 13:32:43 +0100740 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000741#endif // N0 > 3
742#if N0 > 4
Georgios Pinitasb0f342e2019-05-21 13:32:43 +0100743 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4);
744 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5);
745 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6);
746 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000747#endif // N0 > 4
748#if N0 > 8
Georgios Pinitasb0f342e2019-05-21 13:32:43 +0100749 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8);
750 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9);
751 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA);
752 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB);
753 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC);
754 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD);
755 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE);
756 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000757#endif // N0 > 8
758
759#elif K0 == 4 // K0 == 4
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000760 // This part computes the following transpositions:
761 // 4x2 -> 2x4
762 // 4x4 -> 4x4
763 // 4x8 -> 8x4
764 // 4x16 -> 16x4
765 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0, a3.s0);
766 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1, a3.s1);
767#if N0 > 2
768 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2, a3.s2);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000769#endif // N0 > 2
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000770#if N0 > 3
771 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3, a3.s3);
772#endif // N0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000773#if N0 > 4
774 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4, a3.s4);
775 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5, a3.s5);
776 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6, a3.s6);
777 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7, a3.s7);
778#endif // N0 > 4
779#if N0 > 8
780 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8, a3.s8);
781 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9, a3.s9);
782 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA, a3.sA);
783 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB, a3.sB);
784 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC, a3.sC);
785 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD, a3.sD);
786 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE, a3.sE);
787 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF, a3.sF);
788#endif // N0 > 8
789
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000790#elif K0 == 8 // K0 == 8
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000791 // This part computes the following transpositions:
792 // 8x2 -> 2x8
793 // 8x4 -> 4x8
794 // 8x8 -> 8x8
795 // 8x16 -> 16x8
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000796 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0, a3.s0, a4.s0, a5.s0, a6.s0, a7.s0);
797 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1, a3.s1, a4.s1, a5.s1, a6.s1, a7.s1);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000798#if N0 > 2
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000799 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2, a3.s2, a4.s2, a5.s2, a6.s2, a7.s2);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000800#endif // N0 > 2
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000801#if N0 > 3
802 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3, a3.s3, a4.s3, a5.s3, a6.s3, a7.s3);
803#endif // N0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000804#if N0 > 4
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000805 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4, a3.s4, a4.s4, a5.s4, a6.s4, a7.s4);
806 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5, a3.s5, a4.s5, a5.s5, a6.s5, a7.s5);
807 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6, a3.s6, a4.s6, a5.s6, a6.s6, a7.s6);
808 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7, a3.s7, a4.s7, a5.s7, a6.s7, a7.s7);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000809#endif // N0 > 4
810#if N0 > 8
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000811 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8, a3.s8, a4.s8, a5.s8, a6.s8, a7.s8);
812 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9, a3.s9, a4.s9, a5.s9, a6.s9, a7.s9);
813 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA, a3.sA, a4.sA, a5.sA, a6.sA, a7.sA);
814 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB, a3.sB, a4.sB, a5.sB, a6.sB, a7.sB);
815 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC, a3.sC, a4.sC, a5.sC, a6.sC, a7.sC);
816 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD, a3.sD, a4.sD, a5.sD, a6.sD, a7.sD);
817 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE, a3.sE, a4.sE, a5.sE, a6.sE, a7.sE);
818 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF, a3.sF, a4.sF, a5.sF, a6.sF, a7.sF);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000819#endif // N0 > 8
820
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000821#elif K0 == 16 // K0 == 16
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000822
823 // This part computes the following transpositions:
824 // 16x2 -> 2x16
825 // 16x4 -> 4x16
826 // 16x8 -> 8x16
827 // 16x16 -> 16x16
828 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0, a3.s0, a4.s0, a5.s0, a6.s0, a7.s0,
829 a8.s0, a9.s0, aA.s0, aB.s0, aC.s0, aD.s0, aE.s0, aF.s0);
830 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1, a3.s1, a4.s1, a5.s1, a6.s1, a7.s1,
831 a8.s1, a9.s1, aA.s1, aB.s1, aC.s1, aD.s1, aE.s1, aF.s1);
832#if N0 > 2
833 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2, a3.s2, a4.s2, a5.s2, a6.s2, a7.s2,
834 a8.s2, a9.s2, aA.s2, aB.s2, aC.s2, aD.s2, aE.s2, aF.s2);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000835#endif // N0 > 2
836#if N0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000837 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3, a3.s3, a4.s3, a5.s3, a6.s3, a7.s3,
838 a8.s3, a9.s3, aA.s3, aB.s3, aC.s3, aD.s3, aE.s3, aF.s3);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000839#endif // N0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000840#if N0 > 4
841 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4, a3.s4, a4.s4, a5.s4, a6.s4, a7.s4,
842 a8.s4, a9.s4, aA.s4, aB.s4, aC.s4, aD.s4, aE.s4, aF.s4);
843 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5, a3.s5, a4.s5, a5.s5, a6.s5, a7.s5,
844 a8.s5, a9.s5, aA.s5, aB.s5, aC.s5, aD.s5, aE.s5, aF.s5);
845 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6, a3.s6, a4.s6, a5.s6, a6.s6, a7.s6,
846 a8.s6, a9.s6, aA.s6, aB.s6, aC.s6, aD.s6, aE.s6, aF.s6);
847 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7, a3.s7, a4.s7, a5.s7, a6.s7, a7.s7,
848 a8.s7, a9.s7, aA.s7, aB.s7, aC.s7, aD.s7, aE.s7, aF.s7);
849#endif // N0 > 4
850#if N0 > 8
851 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8, a3.s8, a4.s8, a5.s8, a6.s8, a7.s8,
852 a8.s8, a9.s8, aA.s8, aB.s8, aC.s8, aD.s8, aE.s8, aF.s8);
853 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9, a3.s9, a4.s9, a5.s9, a6.s9, a7.s9,
854 a8.s9, a9.s9, aA.s9, aB.s9, aC.s9, aD.s9, aE.s9, aF.s9);
855 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA, a3.sA, a4.sA, a5.sA, a6.sA, a7.sA,
856 a8.sA, a9.sA, aA.sA, aB.sA, aC.sA, aD.sA, aE.sA, aF.sA);
857 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB, a3.sB, a4.sB, a5.sB, a6.sB, a7.sB,
858 a8.sB, a9.sB, aA.sB, aB.sB, aC.sB, aD.sB, aE.sB, aF.sB);
859 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC, a3.sC, a4.sC, a5.sC, a6.sC, a7.sC,
860 a8.sC, a9.sC, aA.sC, aB.sC, aC.sC, aD.sC, aE.sC, aF.sC);
861 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD, a3.sD, a4.sD, a5.sD, a6.sD, a7.sD,
862 a8.sD, a9.sD, aA.sD, aB.sD, aC.sD, aD.sD, aE.sD, aF.sD);
863 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE, a3.sE, a4.sE, a5.sE, a6.sE, a7.sE,
864 a8.sE, a9.sE, aA.sE, aB.sE, aC.sE, aD.sE, aE.sE, aF.sE);
865 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF, a3.sF, a4.sF, a5.sF, a6.sF, a7.sF,
866 a8.sF, a9.sF, aA.sF, aB.sF, aC.sF, aD.sF, aE.sF, aF.sF);
867#endif // N0 > 8
868
869#else // N0 == 16
870#error "Not supported N0 value"
871#endif // N0 > 2
872
873 // ---------------------------Store the output values ------------------------------
Usama Arif0681e3b2019-04-25 14:28:07 +0100874 REPEAT_VAR_INIT_TO_CONST(16, uint, zout, 0);
875 STORE_BLOCK(N0, K0, DATA_TYPE, res, output_ptr, OUTPUT_STEP_X * sizeof(DATA_TYPE), zout);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000876
877#undef BLOCK_SIZE
878#undef OUTPUT_OFFSET_X
879#undef OUTPUT_STEP_X
880}
881#endif // defined(TRANSPOSE)
882#endif // defined(K0) && defined(N0) && defined(H0) && defined(DATA_TYPE) && defined(SRC_HEIGHT)
883
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +0000884#if defined(M0) && defined(N0) && defined(K0) && defined(H0) && defined(DATA_TYPE) && defined(M) && defined(N) && defined(K)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +0000885
886#define CONCAT(a, b) a##b
887
888#define ARM_DOT1(a, b, c) \
889 ({ \
890 c = fma(a, b, c); \
891 })
892#define ARM_DOT2(a, b, c) \
893 ({ \
894 c = fma(a.s0, b.s0, c); \
895 c = fma(a.s1, b.s1, c); \
896 })
897#define ARM_DOT3(a, b, c) \
898 ({ \
899 ARM_DOT2(a, b, c); \
900 c = fma((a.s2), (b.s2), c); \
901 })
902#define ARM_DOT4(a, b, c) \
903 ({ \
904 ARM_DOT3(a, b, c); \
905 c = fma((a.s3), (b.s3), c); \
906 })
907#define ARM_DOT8(a, b, c) \
908 ({ \
909 ARM_DOT4((a.lo), (b.lo), c); \
910 ARM_DOT4((a.hi), (b.hi), c); \
911 })
912#define ARM_DOT16(a, b, c) \
913 ({ \
914 ARM_DOT8((a.lo), (b.lo), c); \
915 ARM_DOT8((a.hi), (b.hi), c); \
916 })
917
918#if N0 == 2
919#define ARM_DOT_K0XN0(k0, a, b, c) \
920 ({ \
921 CONCAT(ARM_DOT, k0) \
922 ((a), (b##0), (c.s0)); \
923 CONCAT(ARM_DOT, k0) \
924 ((a), (b##1), (c.s1)); \
925 })
926#elif N0 == 3 // N0 == 3
927#define ARM_DOT_K0XN0(k0, a, b, c) \
928 ({ \
929 CONCAT(ARM_DOT, k0) \
930 ((a), (b##0), (c.s0)); \
931 CONCAT(ARM_DOT, k0) \
932 ((a), (b##1), (c.s1)); \
933 CONCAT(ARM_DOT, k0) \
934 ((a), (b##2), (c.s2)); \
935 })
936#elif N0 == 4 // N0 == 4
937#define ARM_DOT_K0XN0(k0, a, b, c) \
938 ({ \
939 CONCAT(ARM_DOT, k0) \
940 ((a), (b##0), (c.s0)); \
941 CONCAT(ARM_DOT, k0) \
942 ((a), (b##1), (c.s1)); \
943 CONCAT(ARM_DOT, k0) \
944 ((a), (b##2), (c.s2)); \
945 CONCAT(ARM_DOT, k0) \
946 ((a), (b##3), (c.s3)); \
947 })
948#elif N0 == 8 // N0 == 8
949#define ARM_DOT_K0XN0(k0, a, b, c) \
950 ({ \
951 CONCAT(ARM_DOT, k0) \
952 ((a), (b##0), (c.s0)); \
953 CONCAT(ARM_DOT, k0) \
954 ((a), (b##1), (c.s1)); \
955 CONCAT(ARM_DOT, k0) \
956 ((a), (b##2), (c.s2)); \
957 CONCAT(ARM_DOT, k0) \
958 ((a), (b##3), (c.s3)); \
959 CONCAT(ARM_DOT, k0) \
960 ((a), (b##4), (c.s4)); \
961 CONCAT(ARM_DOT, k0) \
962 ((a), (b##5), (c.s5)); \
963 CONCAT(ARM_DOT, k0) \
964 ((a), (b##6), (c.s6)); \
965 CONCAT(ARM_DOT, k0) \
966 ((a), (b##7), (c.s7)); \
967 })
968#elif N0 == 16 // N0 == 16
969#define ARM_DOT_K0XN0(k0, a, b, c) \
970 ({ \
971 CONCAT(ARM_DOT, k0) \
972 ((a), (b##0), (c.s0)); \
973 CONCAT(ARM_DOT, k0) \
974 ((a), (b##1), (c.s1)); \
975 CONCAT(ARM_DOT, k0) \
976 ((a), (b##2), (c.s2)); \
977 CONCAT(ARM_DOT, k0) \
978 ((a), (b##3), (c.s3)); \
979 CONCAT(ARM_DOT, k0) \
980 ((a), (b##4), (c.s4)); \
981 CONCAT(ARM_DOT, k0) \
982 ((a), (b##5), (c.s5)); \
983 CONCAT(ARM_DOT, k0) \
984 ((a), (b##6), (c.s6)); \
985 CONCAT(ARM_DOT, k0) \
986 ((a), (b##7), (c.s7)); \
987 CONCAT(ARM_DOT, k0) \
988 ((a), (b##8), (c.s8)); \
989 CONCAT(ARM_DOT, k0) \
990 ((a), (b##9), (c.s9)); \
991 CONCAT(ARM_DOT, k0) \
992 ((a), (b##A), (c.sA)); \
993 CONCAT(ARM_DOT, k0) \
994 ((a), (b##B), (c.sB)); \
995 CONCAT(ARM_DOT, k0) \
996 ((a), (b##C), (c.sC)); \
997 CONCAT(ARM_DOT, k0) \
998 ((a), (b##D), (c.sD)); \
999 CONCAT(ARM_DOT, k0) \
1000 ((a), (b##E), (c.sE)); \
1001 CONCAT(ARM_DOT, k0) \
1002 ((a), (b##F), (c.sF)); \
1003 })
1004#else // N0 not supported
1005#error "N0 value not supported"
1006#endif // N0 conditions
1007
1008/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
1009 * The LHS matrix is NOT reshaped
1010 * The RHS is reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the block K0xN0 is transposed
1011 *
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001012 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01001013 * @note The GEMM's dimensions (M,N and K) must be passed at compile time using -DM, -DN and and -DK (e.g. -DM=52, -DN=30 and -DK=90)
1014 * @note The number of columns of LHS matrix must be passed at compile time using -DK (e.g. -DK=64)
1015 * @note The block's dimensions used for reshaping the RHS matrix (N0 and K0) must be passed at compile time using -DN0 and -DK0 (e.g. -DN0=8, -DK0=4).
1016 * @note The number of M0 rows to process must be passed at compile time using -DM0 (e.g. -DM0=2)
1017 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (e.g. -DH0=2)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001018 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
1019 * @note Only the following configurations of M0, N0 and K0 are currently supported:
1020 * - M0 = 1, 2, 3, 4, 5, 6, 7, 8
1021 * - N0 = 2, 3, 4, 8, 16
1022 * - K0 = 2, 3, 4, 8, 16
Gian Marco Iodice62251f72019-03-11 16:07:12 +00001023 * - H0 >= 1
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001024 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01001025 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +01001026 * The activation function is performed after the bias addition
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001027 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
1028 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
1029 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1030 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1031 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1032 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix
1033 *
Sheri Zhang1a378102020-04-30 12:59:39 +01001034 * @param[in] lhs_ptr Pointer to the LHS matrix. Supported data type: F16/F32
1035 * @param[in] lhs_stride_x Stride of the LHS matrix in X dimension (in bytes)
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001036 * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
Sheri Zhang1a378102020-04-30 12:59:39 +01001037 * @param[in] lhs_stride_y Stride of the LHS matrix in Y dimension (in bytes)
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001038 * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Sheri Zhang1a378102020-04-30 12:59:39 +01001039 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS matrix
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001040 * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
1041 * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes)
1042 * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1043 * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
1044 * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1045 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001046 * @param[in] bias_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
1047 * @param[in] bias_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
1048 * @param[in] bias_step_x (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes)
1049 * @param[in] bias_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
1050 * @param[in] bias_step_y (Optional) bias_stride_y * number of elements along Y processed per workitem(in bytes)
1051 * @param[in] bias_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001052 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
1053 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1054 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1055 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1056 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1057 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Sheri Zhang1a378102020-04-30 12:59:39 +01001058 * @param[in] lhs_stride_z Stride of the LHS matrix in Z dimension (in bytes)
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001059 * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001060 * @param[in] bias_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001061 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1062 * @param[in] lhs_cross_plane_pad (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
1063 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001064 */
1065__kernel void gemm_mm_reshaped_only_rhs_t(IMAGE_DECLARATION(lhs),
1066 IMAGE_DECLARATION(rhs),
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001067#if defined(BETA)
1068 IMAGE_DECLARATION(bias),
1069#endif // defined(BETA)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001070 IMAGE_DECLARATION(dst),
1071 uint lhs_stride_z,
1072 uint rhs_stride_z,
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001073#if defined(BETA)
1074 uint bias_stride_z,
1075#endif //defined(BETA)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001076 uint dst_stride_z
1077#if defined(REINTERPRET_INPUT_AS_3D)
1078 ,
1079 uint lhs_cross_plane_pad
1080#endif // REINTERPRET_INPUT_AS_3D
1081#if defined(REINTERPRET_OUTPUT_AS_3D)
1082 ,
1083 uint dst_cross_plane_pad
1084#endif // REINTERPRET_OUTPUT_AS_3D
1085 )
1086{
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001087 // Block size
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001088#define RHS_BLOCK_SIZE ((K0) * (N0))
1089
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001090 // RHS offset and step X
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001091#if defined(RHS_INTERLEAVE)
1092#define RHS_OFFSET_X (K0)
1093#define RHS_STEP_X ((K0) * (H0))
1094#define RHS_STEP_LOOP (1)
1095#else // defined(RHS_INTERLEAVE)
1096#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
1097#define RHS_STEP_X (K0)
1098#define RHS_STEP_LOOP (H0)
1099#endif // defined(RHS_INTERLEAVE)
1100
1101 uint x = get_global_id(0);
1102 uint y = get_global_id(1);
1103 uint z = get_global_id(2);
1104
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001105#if defined(DUMMY_WORK_ITEMS)
1106 if((x * N0 >= N) || (y * M0 >= M))
1107 {
1108 return;
1109 }
1110#endif // defined(DUMMY_WORK_ITEMS)
1111
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001112 // Compute LHS matrix address
1113 uint lhs_offset = lhs_offset_first_element_in_bytes + y * M0 * (uint)lhs_stride_y;
1114
Sheri Zhang1a378102020-04-30 12:59:39 +01001115 // Compute RHS reshaped matrix address
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001116 uint rhs_offset = rhs_offset_first_element_in_bytes + (x % H0) * (uint)RHS_OFFSET_X * sizeof(DATA_TYPE) + (x / (uint)H0) * rhs_stride_y;
1117
1118#if defined(MATRIX_B_DEPTH)
1119 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1120 rhs_offset += (z % MATRIX_B_DEPTH) * rhs_stride_z;
1121#else // defined(MATRIX_B_DEPTH)
1122 rhs_offset += z * rhs_stride_z;
1123#endif // defined(MATRIX_B_DEPTH)
1124
Usama Arif0681e3b2019-04-25 14:28:07 +01001125 REPEAT_VAR_INIT_TO_CONST(8, uint, zlhs, 0); //uint zlhs0=0,zlhs1=0,zlhs2=0,... zlhs7=0;
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001126 REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001127
1128#if defined(REINTERPRET_INPUT_AS_3D)
Usama Arif0681e3b2019-04-25 14:28:07 +01001129 // The plane (zlhs) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
1130 CALCULATE_Z_OFFSET(M0, uint, zlhs, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001131
1132 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1133 // multiply lhs_stride_z by DEPTH_GEMM3D
1134 lhs_offset += z * lhs_stride_z * DEPTH_GEMM3D;
1135
1136#else // defined(REINTERPRET_INPUT_AS_3D)
1137
1138 // Add offset for batched GEMM
1139 lhs_offset += z * lhs_stride_z;
1140
1141#endif // defined(REINTERPRET_INPUT_AS_3D)
1142
1143 // Initialize the accumulators
1144 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(M0-1)=0;
1145
1146 int i = 0;
1147 for(; i <= (K - K0); i += K0)
1148 {
1149 // Supported cases (M0, K0):
1150 // 1,2 - 1,3 - 1,4 - 1,8 - 1,16
1151 // 2,2 - 2,3 - 2,4 - 2,8 - 2,16
1152 // 3,2 - 3,3 - 3,4 - 3,8 - 3,16
1153 // 4,2 - 4,3 - 4,4 - 4,8 - 4,16
1154 // 5,2 - 5,3 - 5,4 - 5,8 - 5,16
1155 // 6,2 - 6,3 - 6,4 - 6,8 - 6,16
1156 // 7,2 - 7,3 - 7,4 - 7,8 - 7,16
1157 // 8,2 - 8,3 - 8,4 - 8,8 - 8,16
1158 // Load values from LHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +01001159 LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001160
Sheri Zhang1a378102020-04-30 12:59:39 +01001161 // Load values from RHS reshaped matrix
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001162 LOAD_BLOCK(N0, K0, DATA_TYPE, b, rhs_ptr, rhs_offset, RHS_STEP_X * sizeof(DATA_TYPE), zero);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001163
1164 // Accumulate
1165 ARM_DOT_K0XN0(K0, a0, b, c0);
1166#if M0 > 1
1167 ARM_DOT_K0XN0(K0, a1, b, c1);
1168#endif // M0 > 1
1169#if M0 > 2
1170 ARM_DOT_K0XN0(K0, a2, b, c2);
1171#endif // M0 > 2
1172#if M0 > 3
1173 ARM_DOT_K0XN0(K0, a3, b, c3);
1174#endif // M0 > 3
1175#if M0 > 4
1176 ARM_DOT_K0XN0(K0, a4, b, c4);
1177#endif // M0 > 4
1178#if M0 > 5
1179 ARM_DOT_K0XN0(K0, a5, b, c5);
1180#endif // M0 > 5
1181#if M0 > 6
1182 ARM_DOT_K0XN0(K0, a6, b, c6);
1183#endif // M0 > 6
1184#if M0 > 7
1185 ARM_DOT_K0XN0(K0, a7, b, c7);
1186#endif // M0 > 7
1187
1188 lhs_offset += K0 * sizeof(DATA_TYPE);
1189 rhs_offset += (N0 * RHS_STEP_X * RHS_STEP_LOOP) * sizeof(DATA_TYPE);
1190 }
1191
1192 // Left-over accumulations
1193 for(; i < K; ++i)
1194 {
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001195 // Load values from LHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +01001196 LOAD_BLOCK(M0, 1, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001197
Sheri Zhang1a378102020-04-30 12:59:39 +01001198 // Load values from RHS reshaped matrix
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001199 LOAD_BLOCK(N0, 1, DATA_TYPE, b, rhs_ptr, rhs_offset, RHS_STEP_X * sizeof(DATA_TYPE), zero);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001200
1201 // Accumulate
1202 ARM_DOT_K0XN0(1, a0, b, c0);
1203#if M0 > 1
1204 ARM_DOT_K0XN0(1, a1, b, c1);
1205#endif // M0 > 1
1206#if M0 > 2
1207 ARM_DOT_K0XN0(1, a2, b, c2);
1208#endif // M0 > 2
1209#if M0 > 3
1210 ARM_DOT_K0XN0(1, a3, b, c3);
1211#endif // M0 > 3
1212#if M0 > 4
1213 ARM_DOT_K0XN0(1, a4, b, c4);
1214#endif // M0 > 4
1215#if M0 > 5
1216 ARM_DOT_K0XN0(1, a5, b, c5);
1217#endif // M0 > 5
1218#if M0 > 6
1219 ARM_DOT_K0XN0(1, a6, b, c6);
1220#endif // M0 > 6
1221#if M0 > 7
1222 ARM_DOT_K0XN0(1, a7, b, c7);
1223#endif // M0 > 7
1224
1225 lhs_offset += sizeof(DATA_TYPE);
1226 rhs_offset += sizeof(DATA_TYPE);
1227 }
1228
1229 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * dst_stride_y);
1230
1231 REPEAT_VAR_INIT_TO_CONST(8, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
1232
1233#if defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001234
1235 // The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
Usama Arif0681e3b2019-04-25 14:28:07 +01001236 CALCULATE_Z_OFFSET(M0, uint, zout, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001237
1238 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1239 // multiply dst_stride_z by DEPTH_GEMM3D
1240 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1241
1242#else // defined(REINTERPRET_OUTPUT_AS_3D)
1243
1244 // Add offset for batched GEMM
1245 dst_addr += z * dst_stride_z;
1246
1247#endif // defined(REINTERPRET_OUTPUT_AS_3D)
1248
1249 // Multiply by the weight of matrix-matrix product and store the result
1250#if defined(ALPHA)
Usama Arif0681e3b2019-04-25 14:28:07 +01001251 SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001252#endif // defined(ALPHA)
1253
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001254 // Add beta*bias
1255#if defined(BETA)
1256#if defined(BROADCAST_BIAS)
1257 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
1258
1259 LOAD_BLOCK(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
1260
1261#ifndef UNIT_BETA
1262 SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
1263#endif // UNIT_BIAS
1264
1265 // c = c + bias[broadcasted]
1266 ADD_BLOCK_BROADCAST(M0, c, bias0);
1267
1268#else // defined(BROADCAST_BIAS)
1269 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id(
1270 2) * bias_stride_z;
1271
1272 LOAD_BLOCK(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
1273
1274#ifndef UNIT_BETA
1275 SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
1276#endif // UNIT_BIAS
1277
1278 // c = c + bias
1279 ADD_BLOCK(M0, c, bias);
1280
1281#endif // defined(BROADCAST_BIAS)
1282#endif // defined(BETA)
1283
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +01001284#if defined(ACTIVATION_TYPE)
1285 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE, c, A_VAL, B_VAL);
1286#endif // defined(ACTIVATION_TYPE)
1287
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001288 // Store output block
Usama Arif0681e3b2019-04-25 14:28:07 +01001289 STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001290
1291#undef RHS_BLOCK_SIZE
1292#undef RHS_OFFSET_X
1293#undef RHS_STEP_X
1294}
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001295
1296#define VFMA(a, b, c) \
1297 ({ \
1298 c = fma(a, b, c); \
1299 })
1300
1301#if M0 == 1
1302#define LD_RHS_VFMA_M0xN0(i, a, c) \
1303 ({ \
1304 VEC_DATA_TYPE(DATA_TYPE, N0) \
1305 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1306 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1307 })
1308#elif M0 == 2 // M0 == 2
1309#define LD_RHS_VFMA_M0xN0(i, a, c) \
1310 ({ \
1311 VEC_DATA_TYPE(DATA_TYPE, N0) \
1312 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1313 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1314 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1315 })
1316#elif M0 == 3 // M0 == 3
1317#define LD_RHS_VFMA_M0xN0(i, a, c) \
1318 ({ \
1319 VEC_DATA_TYPE(DATA_TYPE, N0) \
1320 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1321 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1322 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1323 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1324 })
1325#elif M0 == 4 // M0 == 4
1326#define LD_RHS_VFMA_M0xN0(i, a, c) \
1327 ({ \
1328 VEC_DATA_TYPE(DATA_TYPE, N0) \
1329 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1330 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1331 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1332 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1333 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1334 })
1335#elif M0 == 5 // M0 == 5
1336#define LD_RHS_VFMA_M0xN0(i, a, c) \
1337 ({ \
1338 VEC_DATA_TYPE(DATA_TYPE, N0) \
1339 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1340 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1341 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1342 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1343 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1344 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
1345 })
1346#elif M0 == 6 // M0 == 6
1347#define LD_RHS_VFMA_M0xN0(i, a, c) \
1348 ({ \
1349 VEC_DATA_TYPE(DATA_TYPE, N0) \
1350 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1351 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1352 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1353 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1354 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1355 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
1356 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
1357 })
1358#elif M0 == 7 // M0 == 7
1359#define LD_RHS_VFMA_M0xN0(i, a, c) \
1360 ({ \
1361 VEC_DATA_TYPE(DATA_TYPE, N0) \
1362 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1363 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1364 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1365 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1366 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1367 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
1368 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
1369 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##6).s##i), b, (c##6)); \
1370 })
1371#elif M0 == 8 // M0 == 8
1372#define LD_RHS_VFMA_M0xN0(i, a, c) \
1373 ({ \
1374 VEC_DATA_TYPE(DATA_TYPE, N0) \
1375 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1376 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1377 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1378 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1379 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1380 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
1381 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
1382 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##6).s##i), b, (c##6)); \
1383 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##7).s##i), b, (c##7)); \
1384 })
1385#else // M0 not supported
1386#error "M0 not supported"
1387#endif // M0 not supported
1388
1389/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
1390 * The LHS matrix is NOT reshaped
1391 * The RHS is reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the block K0xN0 is NOT transposed
1392 *
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001393 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01001394 * @note The GEMM's dimensions (M,N and K) must be passed at compile time using -DM, -DN and and -DK (e.g. -DM=52, -DN=30 and -DK=90).
1395 * @note The block's dimensions used for reshaping the RHS matrix (N0 and K0) must be passed at compile time using -DN0 and -DK0 (e.g. -DN0=8, -DK0=4).
1396 * @note The number of M0 rows to process must be passed at compile time using -DM0 (e.g. -DM0=2)
1397 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (e.g. -DH0=2)
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001398 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
1399 * @note Only the following configurations of M0, N0 and K0 are currently supported:
1400 * - M0 = 1, 2, 3, 4, 5, 6, 7, 8
1401 * - N0 = 2, 3, 4, 8, 16
1402 * - K0 = 2, 3, 4, 8, 16
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001403 * - H0 >= 1
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001404 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01001405 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +01001406 * The activation function is performed after the bias addition
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001407 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
1408 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
1409 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1410 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1411 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1412 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix
1413 *
Sheri Zhang1a378102020-04-30 12:59:39 +01001414 * @param[in] lhs_ptr Pointer to the LHS matrix. Supported data type: F16/F32
1415 * @param[in] lhs_stride_x Stride of the LHS matrix in X dimension (in bytes)
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001416 * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
Sheri Zhang1a378102020-04-30 12:59:39 +01001417 * @param[in] lhs_stride_y Stride of the LHS matrix in Y dimension (in bytes)
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001418 * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Sheri Zhang1a378102020-04-30 12:59:39 +01001419 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS matrix
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001420 * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
1421 * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes)
1422 * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1423 * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
1424 * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1425 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001426 * @param[in] bias_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
1427 * @param[in] bias_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001428 * @param[in] bias_step_x (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes)
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001429 * @param[in] bias_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001430 * @param[in] bias_step_y (Optional) bias_stride_y * number of elements along Y processed per workitem(in bytes)
1431 * @param[in] bias_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
1432 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
1433 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1434 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1435 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1436 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1437 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Sheri Zhang1a378102020-04-30 12:59:39 +01001438 * @param[in] lhs_stride_z Stride of the LHS matrix in Z dimension (in bytes)
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001439 * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001440 * @param[in] bias_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001441 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1442 * @param[in] lhs_cross_plane_pad (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
1443 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001444 */
1445__kernel void gemm_mm_reshaped_only_rhs_nt(IMAGE_DECLARATION(lhs),
1446 IMAGE_DECLARATION(rhs),
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001447#if defined(BETA)
1448 IMAGE_DECLARATION(bias),
1449#endif // defined(BETA)
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001450 IMAGE_DECLARATION(dst),
1451 uint lhs_stride_z,
1452 uint rhs_stride_z,
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001453#if defined(BETA)
1454 uint bias_stride_z,
1455#endif //defined(BETA)
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001456 uint dst_stride_z
1457#if defined(REINTERPRET_INPUT_AS_3D)
1458 ,
1459 uint lhs_cross_plane_pad
1460#endif // REINTERPRET_INPUT_AS_3D
1461#if defined(REINTERPRET_OUTPUT_AS_3D)
1462 ,
1463 uint dst_cross_plane_pad
1464#endif // REINTERPRET_OUTPUT_AS_3D
1465 )
1466{
1467 // Block size
1468#define RHS_BLOCK_SIZE ((K0) * (N0))
1469
1470 // RHS offset and step X
1471#if defined(RHS_INTERLEAVE)
1472#define RHS_OFFSET_X (N0)
1473#define RHS_STEP_X ((N0) * (H0))
1474#define RHS_STEP_LOOP (1)
1475#else // defined(RHS_INTERLEAVE)
1476#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
1477#define RHS_STEP_X (N0)
1478#define RHS_STEP_LOOP (H0)
1479#endif // defined(RHS_INTERLEAVE)
1480
1481 uint x = get_global_id(0);
1482 uint y = get_global_id(1);
1483 uint z = get_global_id(2);
1484
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001485#if defined(DUMMY_WORK_ITEMS)
1486 if((x * N0 >= N) || (y * M0 >= M))
1487 {
1488 return;
1489 }
1490#endif // defined(DUMMY_WORK_ITEMS)
1491
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001492 // Compute LHS matrix address
1493 uint lhs_offset = lhs_offset_first_element_in_bytes + y * M0 * (uint)lhs_stride_y;
1494
Sheri Zhang1a378102020-04-30 12:59:39 +01001495 // Compute RHS reshaped matrix address
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001496 uint rhs_offset = rhs_offset_first_element_in_bytes + (x % H0) * (uint)RHS_OFFSET_X * sizeof(DATA_TYPE) + (x / (uint)H0) * rhs_stride_y;
1497
1498#if defined(MATRIX_B_DEPTH)
1499 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1500 rhs_offset += (z % MATRIX_B_DEPTH) * rhs_stride_z;
1501#else // defined(MATRIX_B_DEPTH)
1502 rhs_offset += z * rhs_stride_z;
1503#endif // defined(MATRIX_B_DEPTH)
1504
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001505 REPEAT_VAR_INIT_TO_CONST(8, uint, zin, 0); //uint zin0=0,zin1=0,zin2=0,... zin7=0;
1506 REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0); //uint zero0=0,zero1=0,zero2=0,... zero7=0;
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001507
1508#if defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001509
1510 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
Usama Arif0681e3b2019-04-25 14:28:07 +01001511 CALCULATE_Z_OFFSET(M0, uint, zin, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001512
1513 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1514 // multiply lhs_stride_z by DEPTH_GEMM3D
1515 lhs_offset += z * lhs_stride_z * DEPTH_GEMM3D;
1516
1517#else // defined(REINTERPRET_INPUT_AS_3D)
1518
1519 // Add offset for batched GEMM
1520 lhs_offset += z * lhs_stride_z;
1521
1522#endif // defined(REINTERPRET_INPUT_AS_3D)
1523
1524 // Initialize the accumulators
1525 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(N0-1)=0;
1526
1527 int i = 0;
1528 for(; i <= (K - K0); i += K0)
1529 {
1530 // Supported cases (M0, K0):
1531 // 1,2 - 1,3 - 1,4 - 1,8 - 1,16
1532 // 2,2 - 2,3 - 2,4 - 2,8 - 2,16
1533 // 3,2 - 3,3 - 3,4 - 3,8 - 3,16
1534 // 4,2 - 4,3 - 4,4 - 4,8 - 4,16
1535 // 5,2 - 5,3 - 5,4 - 5,8 - 5,16
1536 // 6,2 - 6,3 - 6,4 - 6,8 - 6,16
1537 // 7,2 - 7,3 - 7,4 - 7,8 - 7,16
1538 // 8,2 - 8,3 - 8,4 - 8,8 - 8,16
1539 // Load values from LHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +01001540 LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zin);
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001541
1542 LD_RHS_VFMA_M0xN0(0, a, c);
1543 LD_RHS_VFMA_M0xN0(1, a, c);
1544#if K0 > 2
1545 LD_RHS_VFMA_M0xN0(2, a, c);
1546#endif // K0 > 2
1547#if K0 > 3
1548 LD_RHS_VFMA_M0xN0(3, a, c);
1549#endif // K0 > 3
1550#if K0 > 4
1551 LD_RHS_VFMA_M0xN0(4, a, c);
1552 LD_RHS_VFMA_M0xN0(5, a, c);
1553 LD_RHS_VFMA_M0xN0(6, a, c);
1554 LD_RHS_VFMA_M0xN0(7, a, c);
1555#endif // K0 > 4
1556#if K0 > 8
1557 LD_RHS_VFMA_M0xN0(8, a, c);
1558 LD_RHS_VFMA_M0xN0(9, a, c);
1559 LD_RHS_VFMA_M0xN0(A, a, c);
1560 LD_RHS_VFMA_M0xN0(B, a, c);
1561 LD_RHS_VFMA_M0xN0(C, a, c);
1562 LD_RHS_VFMA_M0xN0(D, a, c);
1563 LD_RHS_VFMA_M0xN0(E, a, c);
1564 LD_RHS_VFMA_M0xN0(F, a, c);
1565#endif // K0 > 8
1566
1567 lhs_offset += K0 * sizeof(DATA_TYPE);
1568 rhs_offset += K0 * RHS_STEP_X * RHS_STEP_LOOP * sizeof(DATA_TYPE);
1569 }
1570
1571 // Left-over accumulations
1572 for(; i < K; ++i)
1573 {
1574 // Load values from LHS matrix
1575 VEC_DATA_TYPE(DATA_TYPE, 2)
1576 a0 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 0 * lhs_stride_y + zin0));
1577#if M0 > 1
1578 VEC_DATA_TYPE(DATA_TYPE, 2)
1579 a1 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 1 * lhs_stride_y + zin1));
1580#endif // M0 > 1
1581#if M0 > 2
1582 VEC_DATA_TYPE(DATA_TYPE, 2)
1583 a2 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 2 * lhs_stride_y + zin2));
1584#endif // M0 > 2
1585#if M0 > 3
1586 VEC_DATA_TYPE(DATA_TYPE, 2)
1587 a3 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 3 * lhs_stride_y + zin3));
1588#endif // M0 > 3
1589#if M0 > 4
1590 VEC_DATA_TYPE(DATA_TYPE, 2)
1591 a4 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 4 * lhs_stride_y + zin4));
1592#endif // M0 > 4
1593#if M0 > 5
1594 VEC_DATA_TYPE(DATA_TYPE, 2)
1595 a5 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 5 * lhs_stride_y + zin5));
1596#endif // M0 > 5
1597#if M0 > 6
1598 VEC_DATA_TYPE(DATA_TYPE, 2)
1599 a6 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 6 * lhs_stride_y + zin6));
1600#endif // M0 > 6
1601#if M0 > 7
1602 VEC_DATA_TYPE(DATA_TYPE, 2)
giuros01b3204e72019-04-01 13:50:22 +01001603 a7 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 7 * lhs_stride_y + zin7));
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001604#endif // M0 > 7
1605
1606 LD_RHS_VFMA_M0xN0(0, a, c);
1607
1608 lhs_offset += sizeof(DATA_TYPE);
1609 rhs_offset += RHS_STEP_X * sizeof(DATA_TYPE);
1610 }
1611
1612 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * dst_stride_y);
1613
1614 REPEAT_VAR_INIT_TO_CONST(8, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
1615
1616#if defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001617 // The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
Usama Arif0681e3b2019-04-25 14:28:07 +01001618 CALCULATE_Z_OFFSET(M0, uint, zout, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001619
1620 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1621 // multiply dst_stride_z by DEPTH_GEMM3D
1622 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1623
1624#else // defined(REINTERPRET_OUTPUT_AS_3D)
1625
1626 // Add offset for batched GEMM
1627 dst_addr += z * dst_stride_z;
1628
1629#endif // defined(REINTERPRET_OUTPUT_AS_3D)
1630
1631 // Multiply by the weight of matrix-matrix product and store the result
1632#if defined(ALPHA)
Usama Arif0681e3b2019-04-25 14:28:07 +01001633 SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001634#endif // defined(ALPHA)
1635
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001636 // Add beta*bias
1637#if defined(BETA)
1638#if defined(BROADCAST_BIAS)
1639 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
1640
1641 LOAD_BLOCK(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
1642
1643#ifndef UNIT_BETA
1644 SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
1645#endif // UNIT_BIAS
1646
1647 // c = c + bias[broadcasted]
1648 ADD_BLOCK_BROADCAST(M0, c, bias0);
1649
1650#else // defined(BROADCAST_BIAS)
1651 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id(
1652 2) * bias_stride_z;
1653
1654 LOAD_BLOCK(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
1655
1656#ifndef UNIT_BETA
1657 SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
1658#endif // UNIT_BIAS
1659
1660 // c = c + bias
1661 ADD_BLOCK(M0, c, bias);
1662
1663#endif // defined(BROADCAST_BIAS)
1664#endif // defined(BETA)
1665
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +01001666#if defined(ACTIVATION_TYPE)
1667 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE, c, A_VAL, B_VAL);
1668#endif // defined(ACTIVATION_TYPE)
1669
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001670 // Store output block
Usama Arif0681e3b2019-04-25 14:28:07 +01001671 STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001672
1673#undef RHS_BLOCK_SIZE
1674#undef RHS_OFFSET_X
1675#undef RHS_STEP_X
1676}
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001677#endif // defined(M0) && defined(N0) && defined(K0) && defined(H0) && defined(DATA_TYPE) && defined(M) && defined(N) && defined(K)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001678
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01001679#if defined(M0) && defined(N0) && defined(K0) && defined(V0) && defined(H0) && defined(DATA_TYPE) && defined(DATA_TYPE_ACCUMULATOR) && defined(M) && defined(N)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001680
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01001681#if defined(MIXED_PRECISION)
1682#if K0 == 2
1683#define ARM_DOT_K0(a, b, c) \
1684 ({ \
1685 c += a.s0 * b.s0; \
1686 c += a.s1 * b.s1; \
1687 })
1688#elif K0 == 3 // K0 == 3
1689#define ARM_DOT_K0(a, b, c) \
1690 ({ \
1691 c += a.s0 * b.s0; \
1692 c += a.s1 * b.s1; \
1693 c += a.s2 * b.s2; \
1694 })
1695#elif K0 == 4 // K0 == 4
1696#define ARM_DOT_K0(a, b, c) \
1697 ({ \
1698 c += a.s0 * b.s0; \
1699 c += a.s1 * b.s1; \
1700 c += a.s2 * b.s2; \
1701 c += a.s3 * b.s3; \
1702 })
1703#elif K0 == 8 // K0 == 8
1704#define ARM_DOT_K0(a, b, c) \
1705 ({ \
1706 c += a.s0 * b.s0; \
1707 c += a.s1 * b.s1; \
1708 c += a.s2 * b.s2; \
1709 c += a.s3 * b.s3; \
1710 c += a.s4 * b.s4; \
1711 c += a.s5 * b.s5; \
1712 c += a.s6 * b.s6; \
1713 c += a.s7 * b.s7; \
1714 })
1715#elif K0 == 16 // K0 == 16
1716#define ARM_DOT_K0(a, b, c) \
1717 ({ \
1718 c += a.s0 * b.s0; \
1719 c += a.s1 * b.s1; \
1720 c += a.s2 * b.s2; \
1721 c += a.s3 * b.s3; \
1722 c += a.s4 * b.s4; \
1723 c += a.s5 * b.s5; \
1724 c += a.s6 * b.s6; \
1725 c += a.s7 * b.s7; \
1726 c += a.s8 * b.s8; \
1727 c += a.s9 * b.s9; \
1728 c += a.sA * b.sA; \
1729 c += a.sB * b.sB; \
1730 c += a.sC * b.sC; \
1731 c += a.sD * b.sD; \
1732 c += a.sE * b.sE; \
1733 c += a.sF * b.sF; \
1734 })
1735#else // K0 not supported
1736#error "K0 value not supported"
1737#endif // K0 conditions
1738#else // defined(MIXED_PRECISION)
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001739#if K0 == 2
1740#define ARM_DOT_K0(a, b, c) \
1741 ({ \
1742 c = fma(a.s0, b.s0, c); \
1743 c = fma(a.s1, b.s1, c); \
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001744 })
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001745#elif K0 == 3 // K0 == 3
1746#define ARM_DOT_K0(a, b, c) \
1747 ({ \
1748 c = fma(a.s0, b.s0, c); \
1749 c = fma(a.s1, b.s1, c); \
1750 c = fma(a.s2, b.s2, c); \
1751 })
1752#elif K0 == 4 // K0 == 4
1753#define ARM_DOT_K0(a, b, c) \
1754 ({ \
1755 c = fma(a.s0, b.s0, c); \
1756 c = fma(a.s1, b.s1, c); \
1757 c = fma(a.s2, b.s2, c); \
1758 c = fma(a.s3, b.s3, c); \
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001759 })
1760#elif K0 == 8 // K0 == 8
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001761#define ARM_DOT_K0(a, b, c) \
1762 ({ \
1763 c = fma(a.s0, b.s0, c); \
1764 c = fma(a.s1, b.s1, c); \
1765 c = fma(a.s2, b.s2, c); \
1766 c = fma(a.s3, b.s3, c); \
1767 c = fma(a.s4, b.s4, c); \
1768 c = fma(a.s5, b.s5, c); \
1769 c = fma(a.s6, b.s6, c); \
1770 c = fma(a.s7, b.s7, c); \
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001771 })
1772#elif K0 == 16 // K0 == 16
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001773#define ARM_DOT_K0(a, b, c) \
1774 ({ \
1775 c = fma(a.s0, b.s0, c); \
1776 c = fma(a.s1, b.s1, c); \
1777 c = fma(a.s2, b.s2, c); \
1778 c = fma(a.s3, b.s3, c); \
1779 c = fma(a.s4, b.s4, c); \
1780 c = fma(a.s5, b.s5, c); \
1781 c = fma(a.s6, b.s6, c); \
1782 c = fma(a.s7, b.s7, c); \
1783 c = fma(a.s8, b.s8, c); \
1784 c = fma(a.s9, b.s9, c); \
1785 c = fma(a.sA, b.sA, c); \
1786 c = fma(a.sB, b.sB, c); \
1787 c = fma(a.sC, b.sC, c); \
1788 c = fma(a.sD, b.sD, c); \
1789 c = fma(a.sE, b.sE, c); \
1790 c = fma(a.sF, b.sF, c); \
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001791 })
1792#else // K0 not supported
1793#error "K0 value not supported"
1794#endif // K0 conditions
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01001795#endif // defined(MIXED_PRECISION)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001796
1797#if N0 == 2
1798#define ARM_DOT_K0XN0(a, b, c) \
1799 ({ \
1800 ARM_DOT_K0((a), (b##0), (c.s0)); \
1801 ARM_DOT_K0((a), (b##1), (c.s1)); \
1802 })
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001803#elif N0 == 3 // N0 == 3
1804#define ARM_DOT_K0XN0(a, b, c) \
1805 ({ \
1806 ARM_DOT_K0((a), (b##0), (c.s0)); \
1807 ARM_DOT_K0((a), (b##1), (c.s1)); \
1808 ARM_DOT_K0((a), (b##2), (c.s2)); \
1809 })
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001810#elif N0 == 4 // N0 == 4
1811#define ARM_DOT_K0XN0(a, b, c) \
1812 ({ \
1813 ARM_DOT_K0((a), (b##0), (c.s0)); \
1814 ARM_DOT_K0((a), (b##1), (c.s1)); \
1815 ARM_DOT_K0((a), (b##2), (c.s2)); \
1816 ARM_DOT_K0((a), (b##3), (c.s3)); \
1817 })
1818#elif N0 == 8 // N0 == 8
1819#define ARM_DOT_K0XN0(a, b, c) \
1820 ({ \
1821 ARM_DOT_K0((a), (b##0), (c.s0)); \
1822 ARM_DOT_K0((a), (b##1), (c.s1)); \
1823 ARM_DOT_K0((a), (b##2), (c.s2)); \
1824 ARM_DOT_K0((a), (b##3), (c.s3)); \
1825 ARM_DOT_K0((a), (b##4), (c.s4)); \
1826 ARM_DOT_K0((a), (b##5), (c.s5)); \
1827 ARM_DOT_K0((a), (b##6), (c.s6)); \
1828 ARM_DOT_K0((a), (b##7), (c.s7)); \
1829 })
1830#elif N0 == 16 // N0 == 16
1831#define ARM_DOT_K0XN0(a, b, c) \
1832 ({ \
1833 ARM_DOT_K0((a), (b##0), (c.s0)); \
1834 ARM_DOT_K0((a), (b##1), (c.s1)); \
1835 ARM_DOT_K0((a), (b##2), (c.s2)); \
1836 ARM_DOT_K0((a), (b##3), (c.s3)); \
1837 ARM_DOT_K0((a), (b##4), (c.s4)); \
1838 ARM_DOT_K0((a), (b##5), (c.s5)); \
1839 ARM_DOT_K0((a), (b##6), (c.s6)); \
1840 ARM_DOT_K0((a), (b##7), (c.s7)); \
1841 ARM_DOT_K0((a), (b##8), (c.s8)); \
1842 ARM_DOT_K0((a), (b##9), (c.s9)); \
1843 ARM_DOT_K0((a), (b##A), (c.sA)); \
1844 ARM_DOT_K0((a), (b##B), (c.sB)); \
1845 ARM_DOT_K0((a), (b##C), (c.sC)); \
1846 ARM_DOT_K0((a), (b##D), (c.sD)); \
1847 ARM_DOT_K0((a), (b##E), (c.sE)); \
1848 ARM_DOT_K0((a), (b##F), (c.sF)); \
1849 })
1850#else // N0 not supported
1851#error "N0 value not supported"
1852#endif // N0 conditions
1853
1854/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
1855 * The LHS matrix must be reshaped with @ref CLGEMMReshapeLHSMatrixKernel and the M0xK0 must be NOT transposed
1856 * The RHS matrix must be reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the K0xN0 must be transposed
1857 *
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01001858 * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
1859 * @note The data type used for the accumulators must be passed at compile time using -DDATA_TYPE_ACCUMULATOR (e.g. -DDATA_TYPE_ACCUMULATOR=float)
1860 * @note The F16 computation also supports mixed precision through the option -DMIXED_PRECISION passed at compile time. If enabled, DATA_TYPE_ACCUMULATOR should be set to float
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001861 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01001862 * @note The GEMM's dimensions M and N must be passed at compile time using -DM and -DN (e.g. -DM=52 and -DN=90).
1863 * @note The block's dimensions used for reshaping the LHS matrix and the RHS matrix (M0, N0 and K0) must be passed at compile time using -DM0, -DN0 and -DK0 (e.g. -DM0=4, -DN0=8, -DK0=4).
1864 * @note The number of M0xK0 vertical blocks stored on the same output row of the reshaped LHS matrix must be passed at compile time using -DV0 (e.g. -DV0=2)
1865 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (e.g. -DH0=2)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001866 * @note If the M0xK0 blocks in the reshaped LHS matrix have been interleaved, the option -DLHS_INTERLEAVE must passed at compile time.
1867 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
1868 * @note Only the following configurations of M0, N0 and K0 are currently supported:
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01001869 * - M0 = 2, 3, 4, 5, 6, 7, 8
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001870 * - N0 = 2, 3, 4, 8, 16
1871 * - K0 = 2, 3, 4, 8, 16
Gian Marco Iodice62251f72019-03-11 16:07:12 +00001872 * - V0 >= 1
1873 * - H0 >= 1
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001874 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01001875 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +01001876 * The activation function is performed after the bias addition
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01001877 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001878 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1879 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1880 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1881 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix NOT reshaped
1882 *
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001883 * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: F16/F32
1884 * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes)
1885 * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1886 * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes)
1887 * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1888 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
1889 * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
1890 * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes)
1891 * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1892 * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
1893 * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1894 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
1895 * @param[in] bias_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
1896 * @param[in] bias_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
1897 * @param[in] bias_step_x (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes)
1898 * @param[in] bias_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
1899 * @param[in] bias_step_y (Optional) bias_stride_y * number of elements along Y processed per workitem(in bytes)
1900 * @param[in] bias_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
1901 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
1902 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1903 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1904 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1905 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1906 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1907 * @param[in] k Number of columns in LHS matrix and rows in RHS matrix not reshaped.
1908 * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes)
1909 * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
1910 * @param[in] bias_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
1911 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1912 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001913 */
1914__kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs),
1915 IMAGE_DECLARATION(rhs),
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001916#if defined(BETA)
1917 IMAGE_DECLARATION(bias),
1918#endif // defined(BETA)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001919 IMAGE_DECLARATION(dst),
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001920 uint k,
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001921 uint lhs_stride_z,
1922 uint rhs_stride_z,
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001923#if defined(BETA)
1924 uint bias_stride_z,
1925#endif //defined(BETA)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001926 uint dst_stride_z
1927#if defined(REINTERPRET_OUTPUT_AS_3D)
1928 ,
1929 uint dst_cross_plane_pad
1930#endif // REINTERPRET_OUTPUT_AS_3D
1931 )
1932{
1933 // Block size
1934#define LHS_BLOCK_SIZE ((K0) * (M0))
1935
1936#if defined(LHS_INTERLEAVE)
1937#define LHS_OFFSET_X (K0)
1938#define LHS_STEP_X ((K0) * (V0))
1939#define LHS_STEP_LOOP (1)
1940#else // defined(INTERLEAVE)
1941#define LHS_OFFSET_X (LHS_BLOCK_SIZE)
1942#define LHS_STEP_X (K0)
1943#define LHS_STEP_LOOP (V0)
1944#endif // defined(INTERLEAVE)
1945
1946 // Block size
1947#define RHS_BLOCK_SIZE ((K0) * (N0))
1948
1949 // RHS offset and step X
1950#if defined(RHS_INTERLEAVE)
1951#define RHS_OFFSET_X (K0)
1952#define RHS_STEP_X ((K0) * (H0))
1953#define RHS_STEP_LOOP (1)
1954#else // defined(RHS_INTERLEAVE)
1955#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
1956#define RHS_STEP_X (K0)
1957#define RHS_STEP_LOOP (H0)
1958#endif // defined(RHS_INTERLEAVE)
1959
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001960#if defined(DUMMY_WORK_ITEMS)
1961 if((get_global_id(0) * N0 >= N) || (get_global_id(1) * M0 >= M))
1962 {
1963 return;
1964 }
1965#endif // defined(DUMMY_WORK_ITEMS)
1966
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001967 // Compute LHS matrix address
1968 __global uchar *lhs_addr = lhs_ptr + lhs_offset_first_element_in_bytes + (get_global_id(1) % V0) * (uint)LHS_OFFSET_X * sizeof(DATA_TYPE) + (get_global_id(1) / V0) * (uint)lhs_stride_y +
1969 (get_global_id(2) * lhs_stride_z);
1970
1971 // Compute RHS matrix address
1972 __global uchar *rhs_addr = rhs_ptr + rhs_offset_first_element_in_bytes + (get_global_id(0) % H0) * (uint)RHS_OFFSET_X * sizeof(DATA_TYPE) + (get_global_id(0) / (uint)H0) * rhs_stride_y;
1973
1974#if defined(MATRIX_B_DEPTH)
1975 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1976 rhs_addr += (get_global_id(2) % MATRIX_B_DEPTH) * rhs_stride_z;
1977#else // defined(MATRIX_B_DEPTH)
1978 rhs_addr += get_global_id(2) * rhs_stride_z;
1979#endif // defined(MATRIX_B_DEPTH)
1980
1981 // Initialize the accumulators
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01001982 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE_ACCUMULATOR, N0), c, 0);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001983
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001984 REPEAT_VAR_INIT_TO_CONST(M0, uint, zlhs, 0); //uint zlhs0=0,zlhs1=0,zlhs2=0,... zlhs7=0;
1985 REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0);
Usama Arif0681e3b2019-04-25 14:28:07 +01001986
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001987 for(int i = 0; i < k; i += K0)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001988 {
1989 // Supported cases (M0, K0):
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001990 // 1,2 - 1,3 - 1,4 - 1,8 - 1,16
1991 // 2,2 - 2,3 - 2,4 - 2,8 - 2,16
1992 // 3,2 - 3,3 - 3,4 - 3,8 - 3,16
1993 // 4,2 - 4,3 - 4,4 - 4,8 - 4,16
1994 // 5,2 - 5,3 - 5,4 - 5,8 - 5,16
1995 // 6,2 - 6,3 - 6,4 - 6,8 - 6,16
1996 // 7,2 - 7,3 - 7,4 - 7,8 - 7,16
1997 // 8,2 - 8,3 - 8,4 - 8,8 - 8,16
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001998 // Load values from LHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +01001999 LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_addr, 0, LHS_STEP_X * sizeof(DATA_TYPE), zlhs);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002000
2001 // Load values from RHS matrix
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01002002 LOAD_BLOCK(N0, K0, DATA_TYPE, b, rhs_addr, 0, RHS_STEP_X * sizeof(DATA_TYPE), zero);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002003
2004 // Accumulate
2005 ARM_DOT_K0XN0(a0, b, c0);
2006#if M0 > 1
2007 ARM_DOT_K0XN0(a1, b, c1);
2008#endif // M0 > 1
2009#if M0 > 2
2010 ARM_DOT_K0XN0(a2, b, c2);
2011#endif // M0 > 2
2012#if M0 > 3
2013 ARM_DOT_K0XN0(a3, b, c3);
2014#endif // M0 > 3
2015#if M0 > 4
2016 ARM_DOT_K0XN0(a4, b, c4);
2017#endif // M0 > 4
2018#if M0 > 5
2019 ARM_DOT_K0XN0(a5, b, c5);
2020#endif // M0 > 5
2021#if M0 > 6
2022 ARM_DOT_K0XN0(a6, b, c6);
2023#endif // M0 > 6
2024#if M0 > 7
2025 ARM_DOT_K0XN0(a7, b, c7);
2026#endif // M0 > 7
2027
2028 lhs_addr += (M0 * LHS_STEP_X * LHS_STEP_LOOP) * sizeof(DATA_TYPE);
2029 rhs_addr += (N0 * RHS_STEP_X * RHS_STEP_LOOP) * sizeof(DATA_TYPE);
2030 }
2031
2032 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * dst_stride_y);
2033
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01002034 REPEAT_VAR_INIT_TO_CONST(M0, uint, zout, 0);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002035
2036#if defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002037
2038 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
Usama Arif0681e3b2019-04-25 14:28:07 +01002039 CALCULATE_Z_OFFSET(M0, uint, zout, get_global_id(1), HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002040 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2041 // multiply dst_stride_z by DEPTH_GEMM3D
2042 dst_addr += get_global_id(2) * dst_stride_z * DEPTH_GEMM3D;
2043
2044#else // defined(REINTERPRET_OUTPUT_AS_3D)
2045
2046 // Add offset for batched GEMM
2047 dst_addr += get_global_id(2) * dst_stride_z;
2048
2049#endif // defined(REINTERPRET_OUTPUT_AS_3D)
2050
2051 // Multiply by the weight of matrix-matrix product and store the result
2052#if defined(ALPHA)
Usama Arif0681e3b2019-04-25 14:28:07 +01002053 SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002054#endif // defined(ALPHA)
2055
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01002056 // Add beta*bias
2057#if defined(BETA)
2058#if defined(BROADCAST_BIAS)
2059 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
2060
2061 LOAD_BLOCK(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
2062
2063#ifndef UNIT_BETA
2064 SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
2065#endif // UNIT_BIAS
2066
2067 // c = c + bias[broadcasted]
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002068#if defined(MIXED_PRECISION)
2069 CONVERT_BLOCK(1, N0, DATA_TYPE_ACCUMULATOR, bias, bias_hp);
2070 ADD_BLOCK_BROADCAST(M0, c, bias_hp0);
2071#else // defined(MIXED_PRECISION)
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01002072 ADD_BLOCK_BROADCAST(M0, c, bias0);
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002073#endif // defined(MIXED_PRECISION)
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01002074
2075#else // defined(BROADCAST_BIAS)
2076 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id(
2077 2) * bias_stride_z;
2078
2079 LOAD_BLOCK(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
2080
2081#ifndef UNIT_BETA
2082 SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
2083#endif // UNIT_BIAS
2084
2085 // c = c + bias
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002086#if defined(MIXED_PRECISION)
2087 CONVERT_BLOCK(M0, N0, DATA_TYPE_ACCUMULATOR, bias, bias_hp);
2088 ADD_BLOCK(M0, c, bias_hp);
2089#else // defined(MIXED_PRECISION)
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01002090 ADD_BLOCK(M0, c, bias);
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002091#endif // defined(MIXED_PRECISION)
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01002092
2093#endif // defined(BROADCAST_BIAS)
2094#endif // defined(BETA)
2095
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +01002096#if defined(ACTIVATION_TYPE)
Georgios Pinitasa07ce152019-10-11 17:38:50 +01002097#if defined(MIXED_PRECISION)
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002098 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE_ACCUMULATOR, c, A_VAL, B_VAL);
Georgios Pinitasa07ce152019-10-11 17:38:50 +01002099#else // defined(MIXED_PRECISION)
2100 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE, c, A_VAL, B_VAL);
2101#endif // defined(MIXED_PRECISION)
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +01002102#endif // defined(ACTIVATION_TYPE)
2103
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002104 // Store output block
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002105#if defined(MIXED_PRECISION)
2106 CONVERT_STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
2107#else // defined(MIXED_PRECISION)
Usama Arif0681e3b2019-04-25 14:28:07 +01002108 STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002109#endif // defined(MIXED_PRECISION)
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01002110
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002111#undef LHS_BLOCK_SIZE
2112#undef LHS_OFFSET_X
2113#undef LHS_STEP_X
2114#undef RHS_BLOCK_SIZE
2115#undef RHS_OFFSET_X
2116#undef RHS_STEP_X
2117}
giuros01b3204e72019-04-01 13:50:22 +01002118
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002119#if defined(LHS_TRANSPOSE)
2120
2121#define VTYPE(TYPE, SIZE) VEC_DATA_TYPE(TYPE, SIZE)
2122
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002123#if defined(MIXED_PRECISION)
2124
2125#if(GPU_ARCH == GPU_ARCH_MIDGARD)
2126#define ARM_VFMA(N0, a, b, c) c += (CONVERT(a, VEC_DATA_TYPE(DATA_TYPE_ACCUMULATOR, N0))) * (CONVERT(b, VEC_DATA_TYPE(DATA_TYPE_ACCUMULATOR, N0)));
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002127#else // GPU_ARCH == GPU_ARCH_MIDGARD
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002128#define ARM_VFMA(N0, a, b, c) c = fma((CONVERT(a, VEC_DATA_TYPE(DATA_TYPE_ACCUMULATOR, N0))), (CONVERT(b, VEC_DATA_TYPE(DATA_TYPE_ACCUMULATOR, N0))), (c));
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002129#endif // GPU_ARCH == GPU_ARCH_MIDGARD
2130
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002131#else // defined(MIXED_PRECISION
2132
2133#if(GPU_ARCH == GPU_ARCH_MIDGARD)
2134#define ARM_VFMA(N0, a, b, c) c += (a) * (b);
2135#else // GPU_ARCH == GPU_ARCH_MIDGARD
2136#define ARM_VFMA(N0, a, b, c) c = fma((a), (b), (c));
2137#endif // GPU_ARCH == GPU_ARCH_MIDGARD
2138
2139#endif // defined(MIXED_PRECISION)
2140
2141#define ARM_VVM_T_NT_1xN0x1(N0, TYPE, a, b, C) \
2142 ({ \
2143 ARM_VFMA(N0, (VTYPE(TYPE, N0))(a), b, (C##0)); \
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002144 })
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002145#define ARM_VVM_T_NT_2xN0x1(N0, TYPE, a, b, C) \
2146 ({ \
2147 ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s0), b, (C##0)); \
2148 ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s1), b, (C##1)); \
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002149 })
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002150#define ARM_VVM_T_NT_3xN0x1(N0, TYPE, a, b, C) \
2151 ({ \
2152 ARM_VVM_T_NT_2xN0x1(N0, TYPE, a, b, C); \
2153 ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s2), b, (C##2)); \
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002154 })
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002155#define ARM_VVM_T_NT_4xN0x1(N0, TYPE, a, b, C) \
2156 ({ \
2157 ARM_VVM_T_NT_3xN0x1(N0, TYPE, a, b, C); \
2158 ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s3), b, (C##3)); \
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002159 })
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002160#define ARM_VVM_T_NT_8xN0x1(N0, TYPE, a, b, C) \
2161 ({ \
2162 ARM_VVM_T_NT_4xN0x1(N0, TYPE, a, b, C); \
2163 ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s4), b, (C##4)); \
2164 ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s5), b, (C##5)); \
2165 ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s6), b, (C##6)); \
2166 ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s7), b, (C##7)); \
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002167 })
2168
2169// Factory macro for the column-vector (transposed) by row-vector (not transposed) multiplication. K0 = 1
2170// a is the column-vector (transposed)
2171// b is the row-vector (not transposed)
2172// C is the output matrix
2173// Lower case is a vector (a, b)
2174// Upper case is a matrix (C)
2175#define ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, a, b, C) ARM_VVM_T_NT_##M0##xN0x1(N0, TYPE, a, b, C)
2176
2177#define ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, A, B, C) \
2178 ({ \
2179 ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##0), (B##0), C); \
2180 })
2181#define ARM_MM_T_NT_M0xN0x2(M0, N0, TYPE, A, B, C) \
2182 ({ \
2183 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, A, B, C); \
2184 ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##1), (B##1), C); \
2185 })
2186#define ARM_MM_T_NT_M0xN0x3(M0, N0, TYPE, A, B, C) \
2187 ({ \
2188 ARM_MM_T_NT_M0xN0x2(M0, N0, TYPE, A, B, C); \
2189 ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##2), (B##2), C); \
2190 })
2191#define ARM_MM_T_NT_M0xN0x4(M0, N0, TYPE, A, B, C) \
2192 ({ \
2193 ARM_MM_T_NT_M0xN0x3(M0, N0, TYPE, A, B, C); \
2194 ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##3), (B##3), C); \
2195 })
2196#define ARM_MM_T_NT_M0xN0x8(M0, N0, TYPE, A, B, C) \
2197 ({ \
2198 ARM_MM_T_NT_M0xN0x4(M0, N0, TYPE, A, B, C); \
2199 ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##4), (B##4), C); \
2200 ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##5), (B##5), C); \
2201 ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##6), (B##6), C); \
2202 ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##7), (B##7), C); \
2203 })
2204#define ARM_MM_T_NT_M0xN0x16(M0, N0, TYPE, A, B, C) \
2205 ({ \
2206 ARM_MM_T_NT_M0xN0x8(M0, N0, TYPE, A, B, C); \
2207 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##8), (B##8), C); \
2208 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##9), (B##9), C); \
2209 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##A), (B##A), C); \
2210 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##B), (B##B), C); \
2211 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##C), (B##C), C); \
2212 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##D), (B##D), C); \
2213 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##E), (B##E), C); \
2214 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##F), (B##F), C); \
2215 })
2216
2217// Factory macro for the matrix (transposed) by matrix (not transposed) multiplication.
2218// The dimensions for this matrix multiplications are defined through M0, N0 and K0
2219// The dimensions supported are:
2220// M0: 1, 2, 3, 4, 8
2221// N0: 1, 2, 3, 4, 8, 16
2222// K0: 1, 2, 3, 4, 8, 16
2223// This macro calls the vector-by-matrix macro K0 times
2224// A, B and C are matrices
Gian Marco Iodice05639f62019-09-24 12:05:06 +01002225#define ARM_MM_T_NT(M0, N0, K0, TYPE, A, B, C) \
2226 CONCAT(ARM_MM_T_NT_M0xN0x, K0) \
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002227 (M0, N0, TYPE, A, B, C)
2228
2229/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
2230 * The LHS matrix must be reshaped with @ref CLGEMMReshapeLHSMatrixKernel and the M0xK0 must be transposed
2231 * The RHS matrix must be reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the K0xN0 must be NOT transposed
2232 *
2233 * @note LHS_TRANSPOSE should be passed at compile time in order to compile this OpenCL kernel (e.g. -DLHS_TRANSPOSE).
2234 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
2235 * @note The GEMM's dimensions M and N must be passed at compile time using -DM and -DN (e.g. -DM=52 and -DN=90).
2236 * @note The block's dimensions used for reshaping the LHS matrix and the RHS matrix (M0, N0 and K0) must be passed at compile time using -DM0, -DN0 and -DK0 (e.g. -DM0=4, -DN0=8, -DK0=4).
2237 * @note The number of M0xK0 vertical blocks stored on the same output row of the reshaped LHS matrix must be passed at compile time using -DV0 (e.g. -DV0=2)
2238 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (e.g. -DH0=2)
2239 * @note If the M0xK0 blocks in the reshaped LHS matrix have been interleaved, the option -DLHS_INTERLEAVE must passed at compile time.
2240 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
2241 * @note Only the following configurations of M0, N0 and K0 are currently supported:
2242 * - M0 = 2, 3, 4, 8
2243 * - N0 = 2, 3, 4, 8, 16
2244 * - K0 = 2, 3, 4, 8, 16
2245 * - V0 >= 1
2246 * - H0 >= 1
2247 *
2248 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
2249 * The activation function is performed after the bias addition
2250 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
2251 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2252 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2253 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2254 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix NOT reshaped
2255 *
2256 * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: F16/F32
2257 * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes)
2258 * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2259 * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes)
2260 * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2261 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
2262 * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
2263 * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes)
2264 * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2265 * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
2266 * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2267 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
2268 * @param[in] bias_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
2269 * @param[in] bias_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
2270 * @param[in] bias_step_x (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes)
2271 * @param[in] bias_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
2272 * @param[in] bias_step_y (Optional) bias_stride_y * number of elements along Y processed per workitem(in bytes)
2273 * @param[in] bias_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
2274 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
2275 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2276 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2277 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2278 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2279 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2280 * @param[in] k Number of columns in LHS matrix and rows in RHS matrix not reshaped.
2281 * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes)
2282 * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
2283 * @param[in] bias_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
2284 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2285 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
2286 */
2287__kernel void gemm_mm_reshaped_lhs_t_rhs_nt(IMAGE_DECLARATION(lhs),
2288 IMAGE_DECLARATION(rhs),
2289#if defined(BETA)
2290 IMAGE_DECLARATION(bias),
2291#endif // defined(BETA)
2292 IMAGE_DECLARATION(dst),
2293 uint k,
2294 uint lhs_stride_z,
2295 uint rhs_stride_z,
2296#if defined(BETA)
2297 uint bias_stride_z,
2298#endif //defined(BETA)
2299 uint dst_stride_z
2300#if defined(REINTERPRET_OUTPUT_AS_3D)
2301 ,
2302 uint dst_cross_plane_pad
2303#endif // REINTERPRET_OUTPUT_AS_3D
2304 )
2305{
2306 // Block size
2307#define LHS_BLOCK_SIZE ((K0) * (M0))
2308
2309#if defined(LHS_INTERLEAVE)
2310#define LHS_OFFSET_X (M0)
2311#define LHS_STEP_X ((M0) * (V0))
2312#define LHS_STEP_LOOP (1)
2313#else // defined(INTERLEAVE)
2314#define LHS_OFFSET_X (LHS_BLOCK_SIZE)
2315#define LHS_STEP_X (M0)
2316#define LHS_STEP_LOOP (V0)
2317#endif // defined(INTERLEAVE)
2318
2319 // Block size
2320#define RHS_BLOCK_SIZE ((K0) * (N0))
2321
2322 // RHS offset and step X
2323#if defined(RHS_INTERLEAVE)
2324#define RHS_OFFSET_X (N0)
2325#define RHS_STEP_X ((N0) * (H0))
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002326#else // defined(RHS_INTERLEAVE)
2327#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
2328#define RHS_STEP_X (N0)
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002329#endif // defined(RHS_INTERLEAVE)
2330
2331 const uint x = get_global_id(0);
2332 const uint y = get_global_id(1);
2333 const uint z = get_global_id(2);
2334
2335#if defined(DUMMY_WORK_ITEMS)
2336 if((x * N0 >= N) || (y * M0 >= M))
2337 {
2338 return;
2339 }
2340#endif // defined(DUMMY_WORK_ITEMS)
2341
2342 // Compute LHS matrix address
2343 __global uchar *lhs_addr = lhs_ptr + lhs_offset_first_element_in_bytes + (y % V0) * (uint)LHS_OFFSET_X * sizeof(DATA_TYPE) + (y / V0) * (uint)lhs_stride_y + (z * lhs_stride_z);
2344
2345 // Compute RHS matrix address
2346 __global uchar *rhs_addr = rhs_ptr + rhs_offset_first_element_in_bytes + (x % H0) * (uint)RHS_OFFSET_X * sizeof(DATA_TYPE) + (x / (uint)H0) * rhs_stride_y;
2347
2348#if defined(MATRIX_B_DEPTH)
2349 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2350 rhs_addr += (z % MATRIX_B_DEPTH) * rhs_stride_z;
2351#else // defined(MATRIX_B_DEPTH)
2352 rhs_addr += z * rhs_stride_z;
2353#endif // defined(MATRIX_B_DEPTH)
2354
2355 // Initialize the accumulators
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002356 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE_ACCUMULATOR, N0), c, 0);
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002357
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002358 REPEAT_VAR_INIT_TO_CONST(M0, uint, zero, 0);
2359
Gian Marco Iodice05639f62019-09-24 12:05:06 +01002360 __global DATA_TYPE *lhs = (__global DATA_TYPE *)(lhs_addr);
2361 __global DATA_TYPE *rhs = (__global DATA_TYPE *)(rhs_addr);
2362
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002363 for(int i = 0; i < k; i += K0)
2364 {
Gian Marco Iodice05639f62019-09-24 12:05:06 +01002365 VEC_DATA_TYPE(DATA_TYPE, M0)
2366 a0 = VLOAD(M0)(0, lhs);
2367 VEC_DATA_TYPE(DATA_TYPE, N0)
2368 b0 = VLOAD(N0)(0, rhs);
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002369
Gian Marco Iodice05639f62019-09-24 12:05:06 +01002370 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002371
Gian Marco Iodice05639f62019-09-24 12:05:06 +01002372 lhs += LHS_STEP_X;
2373 rhs += RHS_STEP_X;
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002374
Gian Marco Iodice05639f62019-09-24 12:05:06 +01002375#if K0 > 1
2376 a0 = VLOAD(M0)(0, lhs);
2377 b0 = VLOAD(N0)(0, rhs);
2378
2379 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2380
2381 lhs += LHS_STEP_X;
2382 rhs += RHS_STEP_X;
2383#endif // K0 > 1
2384
2385#if K0 > 2
2386 a0 = VLOAD(M0)(0, lhs);
2387 b0 = VLOAD(N0)(0, rhs);
2388
2389 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2390
2391 lhs += LHS_STEP_X;
2392 rhs += RHS_STEP_X;
2393#endif // K0 > 2
2394
2395#if K0 > 3
2396 a0 = VLOAD(M0)(0, lhs);
2397 b0 = VLOAD(N0)(0, rhs);
2398
2399 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2400
2401 lhs += LHS_STEP_X;
2402 rhs += RHS_STEP_X;
2403#endif // K0 > 3
2404
2405#if K0 > 4
2406 a0 = VLOAD(M0)(0, lhs);
2407 b0 = VLOAD(N0)(0, rhs);
2408
2409 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2410
2411 lhs += LHS_STEP_X;
2412 rhs += RHS_STEP_X;
2413
2414 a0 = VLOAD(M0)(0, lhs);
2415 b0 = VLOAD(N0)(0, rhs);
2416
2417 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2418
2419 lhs += LHS_STEP_X;
2420 rhs += RHS_STEP_X;
2421
2422 a0 = VLOAD(M0)(0, lhs);
2423 b0 = VLOAD(N0)(0, rhs);
2424
2425 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2426
2427 lhs += LHS_STEP_X;
2428 rhs += RHS_STEP_X;
2429
2430 a0 = VLOAD(M0)(0, lhs);
2431 b0 = VLOAD(N0)(0, rhs);
2432
2433 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2434
2435 lhs += LHS_STEP_X;
2436 rhs += RHS_STEP_X;
2437#endif // K0 > 4
2438
2439#if K0 > 8
2440 a0 = VLOAD(M0)(0, lhs);
2441 b0 = VLOAD(N0)(0, rhs);
2442
2443 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2444
2445 lhs += LHS_STEP_X;
2446 rhs += RHS_STEP_X;
2447
2448 a0 = VLOAD(M0)(0, lhs);
2449 b0 = VLOAD(N0)(0, rhs);
2450
2451 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2452
2453 lhs += LHS_STEP_X;
2454 rhs += RHS_STEP_X;
2455
2456 a0 = VLOAD(M0)(0, lhs);
2457 b0 = VLOAD(N0)(0, rhs);
2458
2459 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2460
2461 lhs += LHS_STEP_X;
2462 rhs += RHS_STEP_X;
2463
2464 a0 = VLOAD(M0)(0, lhs);
2465 b0 = VLOAD(N0)(0, rhs);
2466
2467 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2468
2469 lhs += LHS_STEP_X;
2470 rhs += RHS_STEP_X;
2471
2472 a0 = VLOAD(M0)(0, lhs);
2473 b0 = VLOAD(N0)(0, rhs);
2474
2475 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2476
2477 lhs += LHS_STEP_X;
2478 rhs += RHS_STEP_X;
2479
2480 a0 = VLOAD(M0)(0, lhs);
2481 b0 = VLOAD(N0)(0, rhs);
2482
2483 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2484
2485 lhs += LHS_STEP_X;
2486 rhs += RHS_STEP_X;
2487
2488 a0 = VLOAD(M0)(0, lhs);
2489 b0 = VLOAD(N0)(0, rhs);
2490
2491 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2492
2493 lhs += LHS_STEP_X;
2494 rhs += RHS_STEP_X;
2495
2496 a0 = VLOAD(M0)(0, lhs);
2497 b0 = VLOAD(N0)(0, rhs);
2498
2499 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2500
2501 lhs += LHS_STEP_X;
2502 rhs += RHS_STEP_X;
2503#endif // K0 > 8
2504
2505#ifndef LHS_INTERLEAVE
2506 lhs += (M0 * K0 * (V0 - 1));
2507#endif // LHS_INTERLEAVE
2508
2509#ifndef RHS_INTERLEAVE
2510 rhs += (N0 * K0 * (H0 - 1));
2511#endif // RHS_INTERLEAVE
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002512 }
2513
2514 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * dst_stride_y);
2515
2516 REPEAT_VAR_INIT_TO_CONST(M0, uint, zout, 0);
2517
2518#if defined(REINTERPRET_OUTPUT_AS_3D)
2519
2520 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
2521 CALCULATE_Z_OFFSET(M0, uint, zout, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
2522 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2523 // multiply dst_stride_z by DEPTH_GEMM3D
2524 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
2525
2526#else // defined(REINTERPRET_OUTPUT_AS_3D)
2527
2528 // Add offset for batched GEMM
2529 dst_addr += z * dst_stride_z;
2530
2531#endif // defined(REINTERPRET_OUTPUT_AS_3D)
2532
2533 // Multiply by the weight of matrix-matrix product and store the result
2534#if defined(ALPHA)
2535 SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
2536#endif // defined(ALPHA)
2537
2538 // Add beta*bias
2539#if defined(BETA)
2540#if defined(BROADCAST_BIAS)
2541 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE));
2542
2543 LOAD_BLOCK(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
2544
2545#ifndef UNIT_BETA
2546 SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
2547#endif // UNIT_BIAS
2548
2549 // c = c + bias[broadcasted]
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002550#if defined(MIXED_PRECISION)
2551 CONVERT_BLOCK(1, N0, DATA_TYPE_ACCUMULATOR, bias, bias_hp);
2552 ADD_BLOCK_BROADCAST(M0, c, bias_hp0);
2553#else // defined(MIXED_PRECISION)
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002554 ADD_BLOCK_BROADCAST(M0, c, bias0);
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002555#endif // defined(MIXED_PRECISION)
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002556
2557#else // defined(BROADCAST_BIAS)
2558 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * bias_stride_y) + z * bias_stride_z;
2559
2560 LOAD_BLOCK(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
2561
2562#ifndef UNIT_BETA
2563 SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
2564#endif // UNIT_BIAS
2565
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002566#if defined(MIXED_PRECISION)
2567 CONVERT_BLOCK(M0, N0, DATA_TYPE_ACCUMULATOR, bias, bias_hp);
2568 ADD_BLOCK(M0, c, bias_hp);
2569#else // defined(MIXED_PRECISION)
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002570 ADD_BLOCK(M0, c, bias);
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002571#endif // defined(MIXED_PRECISION)
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002572
2573#endif // defined(BROADCAST_BIAS)
2574#endif // defined(BETA)
2575
2576#if defined(ACTIVATION_TYPE)
Georgios Pinitasa07ce152019-10-11 17:38:50 +01002577#if defined(MIXED_PRECISION)
2578 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE_ACCUMULATOR, c, A_VAL, B_VAL);
2579#else // defined(MIXED_PRECISION)
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002580 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE, c, A_VAL, B_VAL);
Georgios Pinitasa07ce152019-10-11 17:38:50 +01002581#endif // defined(MIXED_PRECISION)
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002582#endif // defined(ACTIVATION_TYPE)
2583
2584 // Store output block
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002585#if defined(MIXED_PRECISION)
2586 CONVERT_STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
2587#else // defined(MIXED_PRECISION)
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002588 STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002589#endif // defined(MIXED_PRECISION)
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002590
2591#undef LHS_BLOCK_SIZE
2592#undef LHS_OFFSET_X
2593#undef LHS_STEP_X
2594#undef RHS_BLOCK_SIZE
2595#undef RHS_OFFSET_X
2596#undef RHS_STEP_X
2597}
2598
2599#endif // defined(LHS_TRANSPOSE)
2600
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002601#endif // defined(M0) && defined(N0) && defined(K0) && defined(V0) && defined(H0) && defined(K) && defined(DATA_TYPE)
2602
giuros01b3204e72019-04-01 13:50:22 +01002603#if defined(M0) && defined(N0) && defined(K0) && defined(K) && defined(DATA_TYPE)
2604
2605#define VFMA(a, b, c) \
2606 ({ \
2607 c = fma(a, b, c); \
2608 })
2609
2610#if M0 == 1
2611#define RHS_VFMA_M0xN0(i, a, b, c) \
2612 ({ \
2613 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
2614 })
2615#elif M0 == 2 // M0 == 2
2616#define RHS_VFMA_M0xN0(i, a, b, c) \
2617 ({ \
2618 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
2619 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
2620 })
2621#elif M0 == 3 // M0 == 3
2622#define RHS_VFMA_M0xN0(i, a, b, c) \
2623 ({ \
2624 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
2625 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
2626 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
2627 })
2628#elif M0 == 4 // M0 == 4
2629#define RHS_VFMA_M0xN0(i, a, b, c) \
2630 ({ \
2631 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
2632 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
2633 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
2634 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
2635 })
2636#elif M0 == 5 // M0 == 5
2637#define RHS_VFMA_M0xN0(i, a, b, c) \
2638 ({ \
2639 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
2640 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
2641 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
2642 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
2643 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
2644 })
2645#elif M0 == 6 // M0 == 6
2646#define RHS_VFMA_M0xN0(i, a, b, c) \
2647 ({ \
2648 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
2649 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
2650 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
2651 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
2652 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
2653 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
2654 })
2655#elif M0 == 7 // M0 == 7
2656#define RHS_VFMA_M0xN0(i, a, b, c) \
2657 ({ \
2658 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
2659 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
2660 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
2661 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
2662 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
2663 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
2664 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##6).s##i), b, (c##6)); \
2665 })
2666#elif M0 == 8 // M0 == 8
2667#define RHS_VFMA_M0xN0(i, a, b, c) \
2668 ({ \
2669 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
2670 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
2671 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
2672 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
2673 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
2674 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
2675 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##6).s##i), b, (c##6)); \
2676 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##7).s##i), b, (c##7)); \
2677 })
2678#else // M0 not supported
2679#error "M0 not supported"
2680#endif // M0 not supported
2681
2682/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
2683 * The LHS matrix is NOT reshaped
2684 * The RHS matrix is NOT reshaped
2685 *
2686 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002687 * @note The GEMM's dimensions (M,N and K) must be passed at compile time using -DM, -DN and and -DK (e.g. -DM=52, -DN=30 and -DK=90)
2688 * @note The number of columns of LHS matrix must be passed at compile time using -DK (e.g. -DK=64)
2689 * @note The number of M0 rows to process must be passed at compile time using -DM0 (e.g. -DM0=2)
2690 * @note The number of K0 partial accumulations must be passed at compile time using -DK0 (e.g., -DK0=2)
2691 * @note The number of N0 columns to process must be passed at compile time using -DN0 (e.g. -DN0=2)
giuros01b3204e72019-04-01 13:50:22 +01002692 * @note Only the following configurations of M0, N0 and K0 are currently supported:
2693 * - M0 = 1, 2, 3, 4, 5, 6, 7, 8
2694 * - N0 = 2, 3, 4, 8, 16
2695 * - K0 = 2, 3, 4, 8, 16
2696 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002697 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +01002698 * The activation function is performed after the bias addition
giuros01b3204e72019-04-01 13:50:22 +01002699 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
2700 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
2701 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2702 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2703 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2704 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix
2705 *
Gian Marco Iodice944170e2019-06-24 14:40:30 +01002706 * @param[in] lhs_ptr Pointer to the LHS matrix. Supported data type: F16/F32
2707 * @param[in] lhs_stride_x Stride of the LHS matrix in X dimension (in bytes)
2708 * @param[in] lhs_step_x lhs_stride_x * number of elements along X processed per workitem(in bytes)
2709 * @param[in] lhs_stride_y Stride of the LHS matrix in Y dimension (in bytes)
2710 * @param[in] lhs_step_y lhs_stride_y * number of elements along Y processed per workitem(in bytes)
2711 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS matrix
2712 * @param[in] rhs_ptr Pointer to the RHS matrix. Supported data type: same as @p lhs_ptr
2713 * @param[in] rhs_stride_x Stride of the RHS matrix in X dimension (in bytes)
2714 * @param[in] rhs_step_x rhs_stride_x * number of elements along X processed per workitem(in bytes)
2715 * @param[in] rhs_stride_y Stride of the RHS matrix in Y dimension (in bytes)
2716 * @param[in] rhs_step_y rhs_stride_y * number of elements along Y processed per workitem(in bytes)
2717 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS matrix
Gian Marco Iodice944170e2019-06-24 14:40:30 +01002718 * @param[in] bias_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
2719 * @param[in] bias_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
2720 * @param[in] bias_step_x (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes)
2721 * @param[in] bias_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
2722 * @param[in] bias_step_y (Optional) bias_stride_y * number of elements along Y processed per workitem(in bytes)
2723 * @param[in] bias_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
2724 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
2725 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2726 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2727 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2728 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2729 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2730 * @param[in] lhs_stride_z Stride of the LHS matrix in Z dimension (in bytes)
2731 * @param[in] rhs_stride_z Stride of the RHS matrix in Z dimension (in bytes)
2732 * @param[in] bias_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
2733 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2734 * @param[in] lhs_cross_plane_pad (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
2735 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
giuros01b3204e72019-04-01 13:50:22 +01002736 */
2737__kernel void gemm_mm_native(IMAGE_DECLARATION(lhs),
2738 IMAGE_DECLARATION(rhs),
Gian Marco Iodice944170e2019-06-24 14:40:30 +01002739#if defined(BETA)
2740 IMAGE_DECLARATION(bias),
2741#endif // defined(BETA)
giuros01b3204e72019-04-01 13:50:22 +01002742 IMAGE_DECLARATION(dst),
2743 uint lhs_stride_z,
2744 uint rhs_stride_z,
Gian Marco Iodice944170e2019-06-24 14:40:30 +01002745#if defined(BETA)
2746 uint bias_stride_z,
2747#endif //defined(BETA)
giuros01b3204e72019-04-01 13:50:22 +01002748 uint dst_stride_z
2749#if defined(REINTERPRET_INPUT_AS_3D)
2750 ,
2751 uint lhs_cross_plane_pad
2752#endif // REINTERPRET_INPUT_AS_3D
2753#if defined(REINTERPRET_OUTPUT_AS_3D)
2754 ,
2755 uint dst_cross_plane_pad
2756#endif // REINTERPRET_OUTPUT_AS_3D
2757 )
2758{
2759 // Block size
2760#define RHS_BLOCK_SIZE ((K0) * (N0))
2761
2762 // RHS offset and step X
2763#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
2764
2765 uint x = get_global_id(0);
2766 uint y = get_global_id(1);
2767 uint z = get_global_id(2);
2768
2769#if defined(DUMMY_WORK_ITEMS)
2770 if((x * N0 >= N) || (y * M0 >= M))
2771 {
2772 return;
2773 }
2774#endif // defined(DUMMY_WORK_ITEMS)
2775
2776 // Compute LHS matrix address
2777 uint lhs_offset = lhs_offset_first_element_in_bytes + y * M0 * (uint)lhs_stride_y;
2778
2779 // Compute RHS matrix address
2780 uint rhs_offset = rhs_offset_first_element_in_bytes + x * N0 * sizeof(DATA_TYPE);
2781
2782#if defined(MATRIX_B_DEPTH)
2783 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2784 rhs_offset += (z % MATRIX_B_DEPTH) * rhs_stride_z;
2785#else // defined(MATRIX_B_DEPTH)
2786 rhs_offset += z * rhs_stride_z;
2787#endif // defined(MATRIX_B_DEPTH)
2788
Gian Marco Iodice944170e2019-06-24 14:40:30 +01002789 REPEAT_VAR_INIT_TO_CONST(M0, uint, zlhs, 0);
2790 REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0);
giuros01b3204e72019-04-01 13:50:22 +01002791
2792#if defined(REINTERPRET_INPUT_AS_3D)
2793 // The plane (zlhs) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
2794 CALCULATE_Z_OFFSET(M0, uint, zlhs, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
2795
2796 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2797 // multiply lhs_stride_z by DEPTH_GEMM3D
2798 lhs_offset += z * lhs_stride_z * DEPTH_GEMM3D;
2799
2800#else // defined(REINTERPRET_INPUT_AS_3D)
2801
2802 // Add offset for batched GEMM
2803 lhs_offset += z * lhs_stride_z;
2804
2805#endif // defined(REINTERPRET_INPUT_AS_3D)
2806
2807 // Initialize the accumulators
Gian Marco Iodice944170e2019-06-24 14:40:30 +01002808 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(M0-1)=0;
giuros01b3204e72019-04-01 13:50:22 +01002809
2810 int i = 0;
2811 for(; i <= (K - K0); i += K0)
2812 {
2813 // Supported cases (M0, K0):
2814 // 1,2 - 1,3 - 1,4 - 1,8 - 1,16
2815 // 2,2 - 2,3 - 2,4 - 2,8 - 2,16
2816 // 3,2 - 3,3 - 3,4 - 3,8 - 3,16
2817 // 4,2 - 4,3 - 4,4 - 4,8 - 4,16
2818 // 5,2 - 5,3 - 5,4 - 5,8 - 5,16
2819 // 6,2 - 6,3 - 6,4 - 6,8 - 6,16
2820 // 7,2 - 7,3 - 7,4 - 7,8 - 7,16
2821 // 8,2 - 8,3 - 8,4 - 8,8 - 8,16
2822 // Load values from LHS matrix
2823 LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs);
2824
2825 // Load values from RHS matrix
Gian Marco Iodice944170e2019-06-24 14:40:30 +01002826 LOAD_BLOCK(K0, N0, DATA_TYPE, b, rhs_ptr, rhs_offset, rhs_stride_y, zero);
giuros01b3204e72019-04-01 13:50:22 +01002827
2828 RHS_VFMA_M0xN0(0, a, b0, c);
2829 RHS_VFMA_M0xN0(1, a, b1, c);
2830#if K0 > 2
2831 RHS_VFMA_M0xN0(2, a, b2, c);
2832#endif // K0 > 2
2833#if K0 > 3
2834 RHS_VFMA_M0xN0(3, a, b3, c);
2835#endif // K0 > 3
2836#if K0 > 4
2837 RHS_VFMA_M0xN0(4, a, b4, c);
2838 RHS_VFMA_M0xN0(5, a, b5, c);
2839 RHS_VFMA_M0xN0(6, a, b6, c);
2840 RHS_VFMA_M0xN0(7, a, b7, c);
2841#endif // K0 > 4
2842#if K0 > 8
2843 RHS_VFMA_M0xN0(8, a, b8, c);
2844 RHS_VFMA_M0xN0(9, a, b9, c);
Gian Marco Iodice7b9d7ca2019-09-19 16:37:39 +01002845 RHS_VFMA_M0xN0(A, a, bA, c);
2846 RHS_VFMA_M0xN0(B, a, bB, c);
2847 RHS_VFMA_M0xN0(C, a, bC, c);
2848 RHS_VFMA_M0xN0(D, a, bD, c);
2849 RHS_VFMA_M0xN0(E, a, bE, c);
2850 RHS_VFMA_M0xN0(F, a, bF, c);
giuros01b3204e72019-04-01 13:50:22 +01002851#endif // K0 > 8
2852
2853 lhs_offset += K0 * sizeof(DATA_TYPE);
2854 rhs_offset += K0 * rhs_stride_y;
2855 }
2856
2857 // Left-over accumulations
2858 for(; i < K; ++i)
2859 {
2860 // Load values from LHS matrix
2861 VEC_DATA_TYPE(DATA_TYPE, 2)
2862 a0 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 0 * lhs_stride_y + zlhs0));
2863#if M0 > 1
2864 VEC_DATA_TYPE(DATA_TYPE, 2)
2865 a1 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 1 * lhs_stride_y + zlhs1));
2866#endif // M0 > 1
2867#if M0 > 2
2868 VEC_DATA_TYPE(DATA_TYPE, 2)
2869 a2 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 2 * lhs_stride_y + zlhs2));
2870#endif // M0 > 2
2871#if M0 > 3
2872 VEC_DATA_TYPE(DATA_TYPE, 2)
2873 a3 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 3 * lhs_stride_y + zlhs3));
2874#endif // M0 > 3
2875#if M0 > 4
2876 VEC_DATA_TYPE(DATA_TYPE, 2)
2877 a4 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 4 * lhs_stride_y + zlhs4));
2878#endif // M0 > 4
2879#if M0 > 5
2880 VEC_DATA_TYPE(DATA_TYPE, 2)
2881 a5 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 5 * lhs_stride_y + zlhs5));
2882#endif // M0 > 5
2883#if M0 > 6
2884 VEC_DATA_TYPE(DATA_TYPE, 2)
2885 a6 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 6 * lhs_stride_y + zlhs6));
2886#endif // M0 > 6
2887#if M0 > 7
2888 VEC_DATA_TYPE(DATA_TYPE, 2)
2889 a7 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 7 * lhs_stride_y + zlhs7));
2890#endif // M0 > 7
2891
2892 VEC_DATA_TYPE(DATA_TYPE, N0)
2893 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0 * rhs_stride_y));
2894 RHS_VFMA_M0xN0(0, a, b, c);
2895
2896 lhs_offset += sizeof(DATA_TYPE);
2897 rhs_offset += rhs_stride_y;
2898 }
2899
2900 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * dst_stride_y);
2901
Gian Marco Iodice944170e2019-06-24 14:40:30 +01002902 REPEAT_VAR_INIT_TO_CONST(M0, uint, zout, 0);
giuros01b3204e72019-04-01 13:50:22 +01002903
2904#if defined(REINTERPRET_OUTPUT_AS_3D)
2905 // The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
2906 CALCULATE_Z_OFFSET(M0, uint, zout, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
2907
2908 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2909 // multiply dst_stride_z by DEPTH_GEMM3D
2910 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
2911
2912#else // defined(REINTERPRET_OUTPUT_AS_3D)
2913
2914 // Add offset for batched GEMM
2915 dst_addr += z * dst_stride_z;
2916
2917#endif // defined(REINTERPRET_OUTPUT_AS_3D)
2918
2919 // Multiply by the weight of matrix-matrix product and store the result
giuros01b3204e72019-04-01 13:50:22 +01002920#if defined(ALPHA)
2921 SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
2922#endif // defined(ALPHA)
2923
Gian Marco Iodice944170e2019-06-24 14:40:30 +01002924 // Add beta*bias
2925#if defined(BETA)
2926#if defined(BROADCAST_BIAS)
2927 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
2928
2929 LOAD_BLOCK(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
2930
2931#ifndef UNIT_BETA
2932 SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
2933#endif // UNIT_BIAS
2934
2935 // c = c + bias[broadcasted]
2936 ADD_BLOCK_BROADCAST(M0, c, bias0);
2937
2938#else // defined(BROADCAST_BIAS)
2939 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id(
2940 2) * bias_stride_z;
2941
2942 LOAD_BLOCK(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
2943
2944#ifndef UNIT_BETA
2945 SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
2946#endif // UNIT_BIAS
2947
2948 // c = c + bias
2949 ADD_BLOCK(M0, c, bias);
2950
2951#endif // defined(BROADCAST_BIAS)
2952#endif // defined(BETA)
2953
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +01002954#if defined(ACTIVATION_TYPE)
2955 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE, c, A_VAL, B_VAL);
2956#endif // defined(ACTIVATION_TYPE)
2957
giuros01b3204e72019-04-01 13:50:22 +01002958 // Store output block
2959 STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
2960
2961#undef RHS_BLOCK_SIZE
2962#undef RHS_OFFSET_X
2963#undef RHS_STEP_X
2964}
2965#endif // defined(M0) && defined(N0) && defined(K0) && defined(K) && defined(DATA_TYPE)
2966
Gian Marco36a0a462018-01-12 10:21:40 +00002967#if defined(COLS_B) && defined(MULT_TRANSPOSE1XW_WIDTH) && defined(MULT_INTERLEAVE4X4_HEIGHT)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002968/** This OpenCL kernel is optimised for Midgard. It computes the matrix multiplication between matrix A reshaped (src0) and matrix B reshaped (src1)
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00002969 *
Gian Marco19835e52018-01-30 13:35:54 +00002970 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002971 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (e.g. -DMULT_TRANSPOSE1XW_WIDTH=2)
2972 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (e.g. -DMULT_INTERLEAVE4X4_HEIGHT=2)
2973 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
2974 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002975 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002976 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
2977 * The activation function is performed after the bias addition
2978 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002979 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2980 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2981 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2982 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
2983 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002984 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
2985 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
2986 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2987 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
2988 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2989 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01002990 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01002991 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
2992 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2993 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
2994 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2995 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01002996 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
2997 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
2998 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
2999 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
3000 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
3001 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003002 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003003 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00003004 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003005 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00003006 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003007 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003008 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
3009 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003010 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003011 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003012 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003013 */
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003014__kernel void gemm_mm_interleaved_transposed_f32(IMAGE_DECLARATION(src0),
3015 IMAGE_DECLARATION(src1),
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003016#if defined(BETA)
3017 IMAGE_DECLARATION(src2),
3018#endif // defined(BETA)
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003019 IMAGE_DECLARATION(dst),
3020 uint src0_stride_z,
3021 uint src1_stride_z,
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003022#if defined(BETA)
3023 uint src2_stride_z,
3024#endif //defined(BETA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003025 uint dst_stride_z
3026#if defined(REINTERPRET_OUTPUT_AS_3D)
3027 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003028 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003029#endif // REINTERPRET_OUTPUT_AS_3D
3030 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003031{
Gian Marco36a0a462018-01-12 10:21:40 +00003032 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
3033 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +00003034 int z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003035
Gian Marco36a0a462018-01-12 10:21:40 +00003036 // Offset
3037 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
3038 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 4;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003039
Gian Marco36a0a462018-01-12 10:21:40 +00003040 // src_addr_a = address of matrix A
3041 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003042 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
3043 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
3044
3045#if defined(MATRIX_B_DEPTH)
3046 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
3047 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
3048#else // defined(MATRIX_B_DEPTH)
3049 src1_addr_in_bytes += z * src1_stride_z;
3050#endif // defined(MATRIX_B_DEPTH)
3051
3052 __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes);
3053 __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003054
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003055 // Compute end row address for matrix B
Gian Marco36a0a462018-01-12 10:21:40 +00003056 __global float *src_end_addr_b = src_addr_b + COLS_B;
3057
3058 src_addr_a += offset_row_a;
3059 src_addr_b += offset_row_b;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003060
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003061 // Reset accumulators
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003062 float4 c0 = 0.0f;
3063 float4 c1 = 0.0f;
3064 float4 c2 = 0.0f;
3065 float4 c3 = 0.0f;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003066
Gian Marco36a0a462018-01-12 10:21:40 +00003067 for(; src_addr_b <= (src_end_addr_b - (int)(8 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003068 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003069 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00003070 float4 a0 = vload4(0, src_addr_a);
3071 float4 b0 = vload4(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003072
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003073 c0 += (float4)a0.s0 * b0;
3074 c1 += (float4)a0.s1 * b0;
3075 c2 += (float4)a0.s2 * b0;
3076 c3 += (float4)a0.s3 * b0;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003077
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003078 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00003079 a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
3080 b0 = vload4(0, src_addr_b + 4 * MULT_TRANSPOSE1XW_WIDTH);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003081
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003082 c0 += (float4)a0.s0 * b0;
3083 c1 += (float4)a0.s1 * b0;
3084 c2 += (float4)a0.s2 * b0;
3085 c3 += (float4)a0.s3 * b0;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003086 }
3087
Gian Marco36a0a462018-01-12 10:21:40 +00003088 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003089 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003090 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00003091 float4 a0 = vload4(0, src_addr_a);
3092 float4 b0 = vload4(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003093
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003094 c0 += (float4)a0.s0 * b0;
3095 c1 += (float4)a0.s1 * b0;
3096 c2 += (float4)a0.s2 * b0;
3097 c3 += (float4)a0.s3 * b0;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003098 }
3099
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003100 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003101 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
3102
Gian Marcoae2af742018-02-15 12:35:44 +00003103 // Compute dst address
3104 __global uchar *dst_addr = offset(&dst, 0, 0);
3105
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003106 uint4 zout = 0;
3107
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003108#if defined(REINTERPRET_OUTPUT_AS_3D)
3109 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003110 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003111 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003112 // | |
3113 // | plane0 |
3114 // | |
3115 // |__________________|
3116 // |******************|
3117 // | cross_plane_pad |
3118 // |******************|
3119 // | |
3120 // | plane1 |
3121 // | |
3122 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003123
3124 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003125 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
3126 zout = min(DEPTH_GEMM3D - 1, zout);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003127
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003128 // Add offset due to the cross plane paddings
3129 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003130
3131 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3132 // multiply dst_stride_z by DEPTH_GEMM3D
3133 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003134#else // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +00003135 // Add offset for batched GEMM
3136 dst_addr += z * dst_stride_z;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003137#endif // defined(REINTERPRET_OUTPUT_AS_3D)
3138
3139 // Multiply by the weight of matrix-matrix product and store the result
3140#if defined(ALPHA)
3141 SCALE_BLOCK(4, float, c, ALPHA);
3142#endif // defined(ALPHA)
3143
3144 // Add beta*bias
3145#if defined(BETA)
3146 REPEAT_VAR_INIT_TO_CONST(4, uint, zero, 0);
3147
3148#if defined(BROADCAST_BIAS)
3149 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float));
3150
3151 LOAD_BLOCK(1, 4, float, bias, src2_addr, 0, src2_stride_y, zero);
3152
3153#ifndef UNIT_BETA
3154 SCALE_BLOCK(1, float, bias, BETA);
3155#endif // UNIT_BIAS
3156
3157 // c = c + bias[broadcasted]
3158 ADD_BLOCK_BROADCAST(4, c, bias0);
3159
3160#else // defined(BROADCAST_BIAS)
3161 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id(
3162 2) * src2_stride_z;
3163
3164 LOAD_BLOCK(4, 4, float, bias, src2_addr, 0, src2_stride_y, zero);
3165
3166#ifndef UNIT_BETA
3167 SCALE_BLOCK(4, float, bias, BETA);
3168#endif // UNIT_BIAS
3169
3170 // c = c + bias
3171 ADD_BLOCK(4, c, bias);
3172
3173#endif // defined(BROADCAST_BIAS)
3174#endif // defined(BETA)
3175
3176#if defined(ACTIVATION_TYPE)
3177 ACTIVATION_BLOCK(4, ACTIVATION_TYPE, float, c, A_VAL, B_VAL);
3178#endif // defined(ACTIVATION_TYPE)
Gian Marcoae2af742018-02-15 12:35:44 +00003179
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003180 // Store 4x4 block
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003181 vstore4(c0, 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
3182 vstore4(c1, 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
3183 vstore4(c2, 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
3184 vstore4(c3, 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003185}
3186
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003187/** This OpenCL kernel is optimized for Bifrost and tt computes the matrix multiplication between matrix A reshaped (src0) and matrix B reshaped (src1)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003188 *
Gian Marco19835e52018-01-30 13:35:54 +00003189 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003190 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (e.g. -DMULT_TRANSPOSE1XW_WIDTH=2)
3191 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (e.g. -DMULT_INTERLEAVE4X4_HEIGHT=2)
3192 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (e.g. -DMULT_INTERLEAVE4X4_HEIGHT=2)
3193 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
3194 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003195 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003196 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
3197 * The activation function is performed after the bias addition
3198 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003199 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
3200 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
3201 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
3202 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
3203 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003204 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
3205 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
3206 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3207 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
3208 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3209 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003210 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003211 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
3212 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3213 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
3214 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3215 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003216 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
3217 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
3218 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
3219 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
3220 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
3221 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003222 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003223 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00003224 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003225 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00003226 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003227 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003228 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
3229 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003230 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003231 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003232 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003233 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003234__kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0),
3235 IMAGE_DECLARATION(src1),
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003236#if defined(BETA)
3237 IMAGE_DECLARATION(src2),
3238#endif // defined(BETA)
Gian Marcoae2af742018-02-15 12:35:44 +00003239 IMAGE_DECLARATION(dst),
3240 uint src0_stride_z,
3241 uint src1_stride_z,
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003242#if defined(BETA)
3243 uint src2_stride_z,
3244#endif //defined(BETA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003245 uint dst_stride_z
3246#if defined(REINTERPRET_OUTPUT_AS_3D)
3247 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003248 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003249#endif // REINTERPRET_OUTPUT_AS_3D
3250 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003251{
Gian Marco36a0a462018-01-12 10:21:40 +00003252 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
3253 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +00003254 int z = get_global_id(2);
Gian Marco36a0a462018-01-12 10:21:40 +00003255
3256 // Offset
3257 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
3258 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 4;
3259
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003260 // src_addr_a = address of matrix A
3261 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003262 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
3263 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
3264
3265#if defined(MATRIX_B_DEPTH)
3266 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
3267 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
3268#else // defined(MATRIX_B_DEPTH)
3269 src1_addr_in_bytes += z * src1_stride_z;
3270#endif // defined(MATRIX_B_DEPTH)
3271
3272 __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes);
3273 __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003274
Gian Marco36a0a462018-01-12 10:21:40 +00003275 src_addr_a += offset_row_a;
3276 src_addr_b += offset_row_b;
3277
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003278 // Reset accumulators
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003279 float4 c0 = 0.0f;
3280 float4 c1 = 0.0f;
3281 float4 c2 = 0.0f;
3282 float4 c3 = 0.0f;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003283
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003284#define COLS_MTX_B (COLS_B / (4 * MULT_TRANSPOSE1XW_WIDTH))
3285
3286 int i = 0;
3287 for(; i <= (int)(COLS_MTX_B - 4); i += 4)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003288 {
3289 // Load values from matrix A (interleaved) and matrix B (transposed)
3290 float4 a0 = vload4(0, src_addr_a);
3291 float4 b0 = vload4(0, src_addr_b);
3292
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003293 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3294 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003295
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003296 c0.s0 = fma(a0.s0, b0.s0, c0.s0);
3297 c0.s1 = fma(a0.s0, b0.s1, c0.s1);
3298 c0.s2 = fma(a0.s0, b0.s2, c0.s2);
3299 c0.s3 = fma(a0.s0, b0.s3, c0.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003300
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003301 c1.s0 = fma(a0.s1, b0.s0, c1.s0);
3302 c1.s1 = fma(a0.s1, b0.s1, c1.s1);
3303 c1.s2 = fma(a0.s1, b0.s2, c1.s2);
3304 c1.s3 = fma(a0.s1, b0.s3, c1.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003305
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003306 c2.s0 = fma(a0.s2, b0.s0, c2.s0);
3307 c2.s1 = fma(a0.s2, b0.s1, c2.s1);
3308 c2.s2 = fma(a0.s2, b0.s2, c2.s2);
3309 c2.s3 = fma(a0.s2, b0.s3, c2.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003310
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003311 c3.s0 = fma(a0.s3, b0.s0, c3.s0);
3312 c3.s1 = fma(a0.s3, b0.s1, c3.s1);
3313 c3.s2 = fma(a0.s3, b0.s2, c3.s2);
3314 c3.s3 = fma(a0.s3, b0.s3, c3.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003315
3316 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003317 a0 = vload4(0, src_addr_a);
3318 b0 = vload4(0, src_addr_b);
3319
3320 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3321 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003322
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003323 c0.s0 = fma(a0.s0, b0.s0, c0.s0);
3324 c0.s1 = fma(a0.s0, b0.s1, c0.s1);
3325 c0.s2 = fma(a0.s0, b0.s2, c0.s2);
3326 c0.s3 = fma(a0.s0, b0.s3, c0.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003327
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003328 c1.s0 = fma(a0.s1, b0.s0, c1.s0);
3329 c1.s1 = fma(a0.s1, b0.s1, c1.s1);
3330 c1.s2 = fma(a0.s1, b0.s2, c1.s2);
3331 c1.s3 = fma(a0.s1, b0.s3, c1.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003332
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003333 c2.s0 = fma(a0.s2, b0.s0, c2.s0);
3334 c2.s1 = fma(a0.s2, b0.s1, c2.s1);
3335 c2.s2 = fma(a0.s2, b0.s2, c2.s2);
3336 c2.s3 = fma(a0.s2, b0.s3, c2.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003337
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003338 c3.s0 = fma(a0.s3, b0.s0, c3.s0);
3339 c3.s1 = fma(a0.s3, b0.s1, c3.s1);
3340 c3.s2 = fma(a0.s3, b0.s2, c3.s2);
3341 c3.s3 = fma(a0.s3, b0.s3, c3.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003342
3343 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003344 a0 = vload4(0, src_addr_a);
3345 b0 = vload4(0, src_addr_b);
3346
3347 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3348 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
3349
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003350 c0.s0 = fma(a0.s0, b0.s0, c0.s0);
3351 c0.s1 = fma(a0.s0, b0.s1, c0.s1);
3352 c0.s2 = fma(a0.s0, b0.s2, c0.s2);
3353 c0.s3 = fma(a0.s0, b0.s3, c0.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003354
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003355 c1.s0 = fma(a0.s1, b0.s0, c1.s0);
3356 c1.s1 = fma(a0.s1, b0.s1, c1.s1);
3357 c1.s2 = fma(a0.s1, b0.s2, c1.s2);
3358 c1.s3 = fma(a0.s1, b0.s3, c1.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003359
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003360 c2.s0 = fma(a0.s2, b0.s0, c2.s0);
3361 c2.s1 = fma(a0.s2, b0.s1, c2.s1);
3362 c2.s2 = fma(a0.s2, b0.s2, c2.s2);
3363 c2.s3 = fma(a0.s2, b0.s3, c2.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003364
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003365 c3.s0 = fma(a0.s3, b0.s0, c3.s0);
3366 c3.s1 = fma(a0.s3, b0.s1, c3.s1);
3367 c3.s2 = fma(a0.s3, b0.s2, c3.s2);
3368 c3.s3 = fma(a0.s3, b0.s3, c3.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003369
3370 // Load values from matrix A (interleaved) and matrix B (transposed)
3371 a0 = vload4(0, src_addr_a);
3372 b0 = vload4(0, src_addr_b);
3373
3374 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3375 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003376
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003377 c0.s0 = fma(a0.s0, b0.s0, c0.s0);
3378 c0.s1 = fma(a0.s0, b0.s1, c0.s1);
3379 c0.s2 = fma(a0.s0, b0.s2, c0.s2);
3380 c0.s3 = fma(a0.s0, b0.s3, c0.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003381
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003382 c1.s0 = fma(a0.s1, b0.s0, c1.s0);
3383 c1.s1 = fma(a0.s1, b0.s1, c1.s1);
3384 c1.s2 = fma(a0.s1, b0.s2, c1.s2);
3385 c1.s3 = fma(a0.s1, b0.s3, c1.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003386
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003387 c2.s0 = fma(a0.s2, b0.s0, c2.s0);
3388 c2.s1 = fma(a0.s2, b0.s1, c2.s1);
3389 c2.s2 = fma(a0.s2, b0.s2, c2.s2);
3390 c2.s3 = fma(a0.s2, b0.s3, c2.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003391
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003392 c3.s0 = fma(a0.s3, b0.s0, c3.s0);
3393 c3.s1 = fma(a0.s3, b0.s1, c3.s1);
3394 c3.s2 = fma(a0.s3, b0.s2, c3.s2);
3395 c3.s3 = fma(a0.s3, b0.s3, c3.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003396 }
3397
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003398 for(; i < (int)(COLS_MTX_B); ++i)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003399 {
3400 // Load values from matrix A (interleaved) and matrix B (transposed)
3401 float4 a0 = vload4(0, src_addr_a);
3402 float4 b0 = vload4(0, src_addr_b);
3403
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003404 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3405 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
3406
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003407 c0.s0 = fma(a0.s0, b0.s0, c0.s0);
3408 c0.s1 = fma(a0.s0, b0.s1, c0.s1);
3409 c0.s2 = fma(a0.s0, b0.s2, c0.s2);
3410 c0.s3 = fma(a0.s0, b0.s3, c0.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003411
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003412 c1.s0 = fma(a0.s1, b0.s0, c1.s0);
3413 c1.s1 = fma(a0.s1, b0.s1, c1.s1);
3414 c1.s2 = fma(a0.s1, b0.s2, c1.s2);
3415 c1.s3 = fma(a0.s1, b0.s3, c1.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003416
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003417 c2.s0 = fma(a0.s2, b0.s0, c2.s0);
3418 c2.s1 = fma(a0.s2, b0.s1, c2.s1);
3419 c2.s2 = fma(a0.s2, b0.s2, c2.s2);
3420 c2.s3 = fma(a0.s2, b0.s3, c2.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003421
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003422 c3.s0 = fma(a0.s3, b0.s0, c3.s0);
3423 c3.s1 = fma(a0.s3, b0.s1, c3.s1);
3424 c3.s2 = fma(a0.s3, b0.s2, c3.s2);
3425 c3.s3 = fma(a0.s3, b0.s3, c3.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003426 }
3427
3428 // Compute destination address
3429 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
3430
Gian Marcoae2af742018-02-15 12:35:44 +00003431 // Compute dst address
3432 __global uchar *dst_addr = offset(&dst, 0, 0);
3433
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003434 uint4 zout = 0;
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003435
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003436#if defined(REINTERPRET_OUTPUT_AS_3D)
3437 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003438 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003439 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003440 // | |
3441 // | plane0 |
3442 // | |
3443 // |__________________|
3444 // |******************|
3445 // | cross_plane_pad |
3446 // |******************|
3447 // | |
3448 // | plane1 |
3449 // | |
3450 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003451
3452 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003453 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
3454 zout = min(DEPTH_GEMM3D - 1, zout);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003455
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003456 // Add offset due to the cross plane paddings
3457 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003458
3459 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3460 // multiply dst_stride_z by DEPTH_GEMM3D
3461 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003462#else // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +00003463 // Add offset for batched GEMM
3464 dst_addr += z * dst_stride_z;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003465#endif // defined(REINTERPRET_OUTPUT_AS_3D)
3466
3467 // Multiply by the weight of matrix-matrix product and store the result
3468#if defined(ALPHA)
3469 SCALE_BLOCK(4, float, c, ALPHA);
3470#endif // defined(ALPHA)
3471
3472 // Add beta*bias
3473#if defined(BETA)
3474 REPEAT_VAR_INIT_TO_CONST(4, uint, zero, 0);
3475
3476#if defined(BROADCAST_BIAS)
3477 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float));
3478
3479 LOAD_BLOCK(1, 4, float, bias, src2_addr, 0, src2_stride_y, zero);
3480
3481#ifndef UNIT_BETA
3482 SCALE_BLOCK(1, float, bias, BETA);
3483#endif // UNIT_BIAS
3484
3485 // c = c + bias[broadcasted]
3486 ADD_BLOCK_BROADCAST(4, c, bias0);
3487
3488#else // defined(BROADCAST_BIAS)
3489 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id(
3490 2) * src2_stride_z;
3491
3492 LOAD_BLOCK(4, 4, float, bias, src2_addr, 0, src2_stride_y, zero);
3493
3494#ifndef UNIT_BETA
3495 SCALE_BLOCK(4, float, bias, BETA);
3496#endif // UNIT_BIAS
3497
3498 // c = c + bias
3499 ADD_BLOCK(4, c, bias);
3500
3501#endif // defined(BROADCAST_BIAS)
3502#endif // defined(BETA)
3503
3504#if defined(ACTIVATION_TYPE)
3505 ACTIVATION_BLOCK(4, ACTIVATION_TYPE, float, c, A_VAL, B_VAL);
3506#endif // defined(ACTIVATION_TYPE)
Gian Marcoae2af742018-02-15 12:35:44 +00003507
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003508 // Store 4x4 block
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003509 vstore4(c0, 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
3510 vstore4(c1, 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
3511 vstore4(c2, 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
3512 vstore4(c3, 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003513}
3514
Georgios Pinitas84225582018-05-14 12:00:05 +01003515// Undefine local defines
3516#undef COLS_MTX_B
3517
Matthew Bentham6f31f8c2017-10-27 11:50:06 +01003518#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003519/** This OpenCL kernel computes the matrix multiplication between matrix A reshaped (src0) and matrix B reshaped (src1)
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003520 *
Gian Marco19835e52018-01-30 13:35:54 +00003521 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003522 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (e.g. -DMULT_TRANSPOSE1XW_WIDTH=2)
3523 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (e.g. -DMULT_INTERLEAVE4X4_HEIGHT=2)
3524 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
3525 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003526 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003527 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
3528 * The activation function is performed after the bias addition
3529 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003530 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
3531 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
3532 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
3533 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
3534 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003535 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
3536 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
3537 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3538 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
3539 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3540 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003541 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003542 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
3543 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3544 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
3545 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3546 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003547 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
3548 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
3549 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
3550 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
3551 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
3552 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003553 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003554 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00003555 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003556 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00003557 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003558 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003559 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
3560 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003561 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003562 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003563 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003564 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003565__kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0),
3566 IMAGE_DECLARATION(src1),
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003567#if defined(BETA)
3568 IMAGE_DECLARATION(src2),
3569#endif // defined(BETA)
Gian Marcoae2af742018-02-15 12:35:44 +00003570 IMAGE_DECLARATION(dst),
3571 uint src0_stride_z,
3572 uint src1_stride_z,
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003573#if defined(BETA)
3574 uint src2_stride_z,
3575#endif //defined(BETA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003576 uint dst_stride_z
3577#if defined(REINTERPRET_OUTPUT_AS_3D)
3578 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003579 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003580#endif // REINTERPRET_OUTPUT_AS_3D
3581 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003582{
Gian Marco36a0a462018-01-12 10:21:40 +00003583 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
3584 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +00003585 int z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003586
Gian Marco36a0a462018-01-12 10:21:40 +00003587 // Offset
3588 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
3589 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003590
Gian Marco36a0a462018-01-12 10:21:40 +00003591 // src_addr_a = address of matrix A
3592 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003593 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
3594 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
3595
3596#if defined(MATRIX_B_DEPTH)
3597 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
3598 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
3599#else // defined(MATRIX_B_DEPTH)
3600 src1_addr_in_bytes += z * src1_stride_z;
3601#endif // defined(MATRIX_B_DEPTH)
3602
3603 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
3604 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003605
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003606 // Compute end row address for matrix B
Gian Marco36a0a462018-01-12 10:21:40 +00003607 __global half *src_end_addr_b = src_addr_b + COLS_B;
3608
3609 src_addr_a += offset_row_a;
3610 src_addr_b += offset_row_b;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003611
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003612 // Reset accumulators
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003613 half8 c0 = 0.0f;
3614 half8 c1 = 0.0f;
3615 half8 c2 = 0.0f;
3616 half8 c3 = 0.0f;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003617
Gian Marco36a0a462018-01-12 10:21:40 +00003618 for(; src_addr_b <= (src_end_addr_b - (int)(16 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 16 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003619 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003620 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00003621 half4 a0 = vload4(0, src_addr_a);
3622 half8 b0 = vload8(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003623
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003624 c0 += (half8)a0.s0 * b0;
3625 c1 += (half8)a0.s1 * b0;
3626 c2 += (half8)a0.s2 * b0;
3627 c3 += (half8)a0.s3 * b0;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003628
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003629 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00003630 a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
3631 b0 = vload8(0, src_addr_b + 8 * MULT_TRANSPOSE1XW_WIDTH);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003632
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003633 c0 += (half8)a0.s0 * b0;
3634 c1 += (half8)a0.s1 * b0;
3635 c2 += (half8)a0.s2 * b0;
3636 c3 += (half8)a0.s3 * b0;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003637 }
3638
Gian Marco36a0a462018-01-12 10:21:40 +00003639 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003640 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003641 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00003642 half4 a0 = vload4(0, src_addr_a);
3643 half8 b0 = vload8(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003644
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003645 c0 += (half8)a0.s0 * b0;
3646 c1 += (half8)a0.s1 * b0;
3647 c2 += (half8)a0.s2 * b0;
3648 c3 += (half8)a0.s3 * b0;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003649 }
3650
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003651 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003652 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
3653
Gian Marcoae2af742018-02-15 12:35:44 +00003654 // Compute dst address
3655 __global uchar *dst_addr = offset(&dst, 0, 0);
3656
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003657 uint4 zout = 0;
3658
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003659#if defined(REINTERPRET_OUTPUT_AS_3D)
3660 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003661 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003662 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003663 // | |
3664 // | plane0 |
3665 // | |
3666 // |__________________|
3667 // |******************|
3668 // | cross_plane_pad |
3669 // |******************|
3670 // | |
3671 // | plane1 |
3672 // | |
3673 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003674
3675 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003676 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
3677 zout = min(DEPTH_GEMM3D - 1, zout);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003678
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003679 // Add offset due to the cross plane paddings
3680 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003681
3682 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3683 // multiply dst_stride_z by DEPTH_GEMM3D
3684 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003685#else // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +00003686 // Add offset for batched GEMM
3687 dst_addr += z * dst_stride_z;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003688#endif // defined(REINTERPRET_OUTPUT_AS_3D)
3689
3690 // Multiply by the weight of matrix-matrix product and store the result
3691#if defined(ALPHA)
3692 SCALE_BLOCK(4, half, c, ALPHA);
3693#endif // defined(ALPHA)
3694
3695 // Add beta*bias
3696#if defined(BETA)
3697 REPEAT_VAR_INIT_TO_CONST(4, uint, zero, 0);
3698
3699#if defined(BROADCAST_BIAS)
3700 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half));
3701
3702 LOAD_BLOCK(1, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
3703
3704#ifndef UNIT_BETA
3705 SCALE_BLOCK(1, half, bias, BETA);
3706#endif // UNIT_BIAS
3707
3708 // c = c + bias[broadcasted]
3709 ADD_BLOCK_BROADCAST(4, c, bias0);
3710
3711#else // defined(BROADCAST_BIAS)
3712
3713 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id(
3714 2) * src2_stride_z;
3715
3716 LOAD_BLOCK(4, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
3717
3718#ifndef UNIT_BETA
3719 SCALE_BLOCK(4, half, bias, BETA);
3720#endif // UNIT_BIAS
3721
3722 // c = c + bias
3723 ADD_BLOCK(4, c, bias);
3724
3725#endif // defined(BROADCAST_BIAS)
3726#endif // defined(BETA)
3727
3728#if defined(ACTIVATION_TYPE)
3729 ACTIVATION_BLOCK(4, ACTIVATION_TYPE, half, c, A_VAL, B_VAL);
3730#endif // defined(ACTIVATION_TYPE)
Gian Marcoae2af742018-02-15 12:35:44 +00003731
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003732 // Store 4x8 block
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003733 vstore8(c0, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
3734 vstore8(c1, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
3735 vstore8(c2, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
3736 vstore8(c3, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003737}
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003738
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003739/** This OpenCL kernel computes the matrix multiplication between matrix A reshaped (src0) and matrix B reshaped (src1) while accumulating the result in a 32 floating point variable.
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003740 *
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003741 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003742 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (e.g. -DMULT_TRANSPOSE1XW_WIDTH=2)
3743 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (e.g. -DMULT_INTERLEAVE4X4_HEIGHT=2)
3744 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
3745 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003746 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003747 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
3748 * The activation function is performed after the bias addition
3749 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003750 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
3751 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
3752 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
3753 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
3754 *
3755 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
3756 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
3757 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3758 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
3759 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3760 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
3761 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
3762 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
3763 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3764 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
3765 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3766 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003767 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
3768 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
3769 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
3770 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
3771 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
3772 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003773 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
3774 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
3775 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
3776 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
3777 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
3778 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
3779 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
3780 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003781 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003782 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
3783 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
3784 */
3785__kernel void gemm_mm_interleaved_transposed_f16_acc32(IMAGE_DECLARATION(src0),
3786 IMAGE_DECLARATION(src1),
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003787#if defined(BETA)
3788 IMAGE_DECLARATION(src2),
3789#endif // defined(BETA)
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003790 IMAGE_DECLARATION(dst),
3791 uint src0_stride_z,
3792 uint src1_stride_z,
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003793#if defined(BETA)
3794 uint src2_stride_z,
3795#endif //defined(BETA)
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003796 uint dst_stride_z
3797#if defined(REINTERPRET_OUTPUT_AS_3D)
3798 ,
3799 uint cross_plane_pad
3800#endif // REINTERPRET_OUTPUT_AS_3D
3801 )
3802{
3803 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
3804 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
3805 int z = get_global_id(2);
3806
3807 // Offset
3808 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
3809 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
3810
3811 // src_addr_a = address of matrix A
3812 // src_addr_b = address of matrix B
3813 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
3814 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
3815
3816#if defined(MATRIX_B_DEPTH)
3817 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
3818 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
3819#else // defined(MATRIX_B_DEPTH)
3820 src1_addr_in_bytes += z * src1_stride_z;
3821#endif // defined(MATRIX_B_DEPTH)
3822
3823 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
3824 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
3825
3826 // Compute end row address for matrix B
3827 __global half *src_end_addr_b = src_addr_b + COLS_B;
3828
3829 src_addr_a += offset_row_a;
3830 src_addr_b += offset_row_b;
3831
3832 // Reset accumulators
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003833 float8 c0 = 0.0f;
3834 float8 c1 = 0.0f;
3835 float8 c2 = 0.0f;
3836 float8 c3 = 0.0f;
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003837
3838 for(; src_addr_b <= (src_end_addr_b - (int)(16 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 16 * MULT_TRANSPOSE1XW_WIDTH)
3839 {
3840 // Load values from matrix A (interleaved) and matrix B (transposed)
3841 float4 a0 = convert_float4(vload4(0, src_addr_a));
3842 float8 b0 = convert_float8(vload8(0, src_addr_b));
3843
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003844 c0 += (float8)a0.s0 * b0;
3845 c1 += (float8)a0.s1 * b0;
3846 c2 += (float8)a0.s2 * b0;
3847 c3 += (float8)a0.s3 * b0;
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003848
3849 // Load values from matrix A (interleaved) and matrix B (transposed)
3850 a0 = convert_float4(vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT));
3851 b0 = convert_float8(vload8(0, src_addr_b + 8 * MULT_TRANSPOSE1XW_WIDTH));
3852
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003853 c0 += (float8)a0.s0 * b0;
3854 c1 += (float8)a0.s1 * b0;
3855 c2 += (float8)a0.s2 * b0;
3856 c3 += (float8)a0.s3 * b0;
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003857 }
3858
3859 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
3860 {
3861 // Load values from matrix A (interleaved) and matrix B (transposed)
3862 float4 a0 = convert_float4(vload4(0, src_addr_a));
3863 float8 b0 = convert_float8(vload8(0, src_addr_b));
3864
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003865 c0 += (float8)a0.s0 * b0;
3866 c1 += (float8)a0.s1 * b0;
3867 c2 += (float8)a0.s2 * b0;
3868 c3 += (float8)a0.s3 * b0;
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003869 }
3870
3871 // Compute destination address
3872 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
3873
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003874 // Compute dst address
3875 __global uchar *dst_addr = offset(&dst, 0, 0);
3876
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003877 uint4 zout = 0;
3878
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003879#if defined(REINTERPRET_OUTPUT_AS_3D)
3880 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
3881 // in order to take into account the presence of possible cross plane paddings
3882 //
3883 // | |
3884 // | plane0 |
3885 // | |
3886 // |__________________|
3887 // |******************|
3888 // | cross_plane_pad |
3889 // |******************|
3890 // | |
3891 // | plane1 |
3892 // | |
3893 // |__________________|
3894
3895 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003896 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
3897 zout = min(DEPTH_GEMM3D - 1, zout);
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003898
3899 // Add offset due to the cross plane paddings
3900 zout *= (cross_plane_pad * dst_stride_y);
3901
3902 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3903 // multiply dst_stride_z by DEPTH_GEMM3D
3904 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003905#else // defined(REINTERPRET_OUTPUT_AS_3D)
3906 // Add offset for batched GEMM
3907 dst_addr += z * dst_stride_z;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003908#endif // defined(REINTERPRET_OUTPUT_AS_3D)
3909
3910 // Multiply by the weight of matrix-matrix product and store the result
3911#if defined(ALPHA)
3912 SCALE_BLOCK(4, float, c, ALPHA);
3913#endif // defined(ALPHA)
3914
3915#if defined(BETA)
3916 REPEAT_VAR_INIT_TO_CONST(4, uint, zero, 0);
3917
3918#if defined(BROADCAST_BIAS)
3919 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half));
3920
3921 LOAD_BLOCK(1, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
3922
3923 float8 bias_f0 = convert_float8(bias0);
3924
3925#ifndef UNIT_BETA
3926 SCALE_BLOCK(1, float, bias_f, BETA);
3927#endif // UNIT_BIAS
3928
3929 // c = c + bias[broadcasted]
3930 ADD_BLOCK_BROADCAST(4, c, bias_f0);
3931
3932#else // defined(BROADCAST_BIAS)
3933 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id(
3934 2) * src2_stride_z;
3935
3936 LOAD_BLOCK(4, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
3937
3938 float8 bias_f0 = convert_float8(bias0);
3939 float8 bias_f1 = convert_float8(bias1);
3940 float8 bias_f2 = convert_float8(bias2);
3941 float8 bias_f3 = convert_float8(bias3);
3942
3943#ifndef UNIT_BETA
3944 SCALE_BLOCK(4, float, bias_f, BETA);
3945#endif // UNIT_BIAS
3946
3947 // c = c + bias
3948 ADD_BLOCK(4, c, bias_f);
3949
3950#endif // defined(BROADCAST_BIAS)
3951#endif // defined(BETA)
3952
3953 half8 c_h0 = convert_half8(c0);
3954 half8 c_h1 = convert_half8(c1);
3955 half8 c_h2 = convert_half8(c2);
3956 half8 c_h3 = convert_half8(c3);
3957
3958#if defined(ACTIVATION_TYPE)
3959 ACTIVATION_BLOCK(4, ACTIVATION_TYPE, half, c_h, A_VAL, B_VAL);
3960#endif // defined(ACTIVATION_TYPE)
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003961
3962 // Store 4x8 block
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003963 vstore8(c_h0, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
3964 vstore8(c_h1, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
3965 vstore8(c_h2, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
3966 vstore8(c_h3, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00003967}
3968
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003969/** This OpenCL kernel optimized for Bifrost architectures computes the matrix multiplication between matrix A reshaped (src0) and matrix B reshaped (src1)
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003970 *
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003971 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003972 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (e.g. -DMULT_TRANSPOSE1XW_WIDTH=2)
3973 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (e.g. -DMULT_INTERLEAVE4X4_HEIGHT=2)
3974 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
3975 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003976 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003977 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
3978 * The activation function is performed after the bias addition
3979 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003980 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
3981 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
3982 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
3983 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
3984 *
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003985 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
3986 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
3987 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3988 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
3989 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3990 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
3991 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
3992 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
3993 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3994 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
3995 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3996 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003997 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
3998 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
3999 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
4000 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
4001 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
4002 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004003 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
4004 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
4005 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
4006 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
4007 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
4008 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004009 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
4010 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
4011 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004012 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004013 */
4014__kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0),
4015 IMAGE_DECLARATION(src1),
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004016#if defined(BETA)
4017 IMAGE_DECLARATION(src2),
4018#endif // defined(BETA)
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004019 IMAGE_DECLARATION(dst),
4020 uint src0_stride_z,
4021 uint src1_stride_z,
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004022#if defined(BETA)
4023 uint src2_stride_z,
4024#endif //defined(BETA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004025 uint dst_stride_z
4026#if defined(REINTERPRET_OUTPUT_AS_3D)
4027 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004028 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004029#endif // REINTERPRET_OUTPUT_AS_3D
4030 )
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004031{
4032 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
4033 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
4034 int z = get_global_id(2);
4035
4036 // Offset
4037 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
4038 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
4039
4040 // src_addr_a = address of matrix A
4041 // src_addr_b = address of matrix B
4042 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
4043 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
4044
4045#if defined(MATRIX_B_DEPTH)
4046 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
4047 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
4048#else // defined(MATRIX_B_DEPTH)
4049 src1_addr_in_bytes += z * src1_stride_z;
4050#endif // defined(MATRIX_B_DEPTH)
4051
4052 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
4053 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
4054
4055 // Compute end row address for matrix B
4056 __global half *src_end_addr_b = src_addr_b + COLS_B;
4057
4058 src_addr_a += offset_row_a;
4059 src_addr_b += offset_row_b;
4060
4061 // Reset accumulators
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004062 half8 c0 = 0.0f;
4063 half8 c1 = 0.0f;
4064 half8 c2 = 0.0f;
4065 half8 c3 = 0.0f;
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004066
4067#define COLS_MTX_B (COLS_B / (8 * MULT_TRANSPOSE1XW_WIDTH))
4068
4069 int i = 0;
4070 for(; i <= (int)(COLS_MTX_B - 4); i += 4)
4071 {
4072#if MULT_INTERLEAVE4X4_HEIGHT == 1
4073 // Load values from matrix A (interleaved) and matrix B (transposed)
4074 half8 a0 = vload8(0, src_addr_a);
4075 half8 b0 = vload8(0, src_addr_b);
4076
4077 src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT;
4078 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
4079
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004080 c0 = fma((half8)a0.s0, b0, c0);
4081 c1 = fma((half8)a0.s1, b0, c1);
4082 c2 = fma((half8)a0.s2, b0, c2);
4083 c3 = fma((half8)a0.s3, b0, c3);
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004084
4085 // Load values from matrix B (transposed)
4086 b0 = vload8(0, src_addr_b);
4087
4088 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
4089
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004090 c0 = fma((half8)a0.s4, b0, c0);
4091 c1 = fma((half8)a0.s5, b0, c1);
4092 c2 = fma((half8)a0.s6, b0, c2);
4093 c3 = fma((half8)a0.s7, b0, c3);
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004094
4095 // Load values from matrix A (interleaved) and matrix B (transposed)
4096 a0 = vload8(0, src_addr_a);
4097 b0 = vload8(0, src_addr_b);
4098
4099 src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT;
4100 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
4101
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004102 c0 = fma((half8)a0.s0, b0, c0);
4103 c1 = fma((half8)a0.s1, b0, c1);
4104 c2 = fma((half8)a0.s2, b0, c2);
4105 c3 = fma((half8)a0.s3, b0, c3);
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004106
4107 // Load values from matrix B (transposed)
4108 b0 = vload8(0, src_addr_b);
4109
4110 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
4111
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004112 c0 = fma((half8)a0.s4, b0, c0);
4113 c1 = fma((half8)a0.s5, b0, c1);
4114 c2 = fma((half8)a0.s6, b0, c2);
4115 c3 = fma((half8)a0.s7, b0, c3);
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004116#else // MULT_INTERLEAVE4X4_HEIGHT == 1
4117 // Load values from matrix A (interleaved) and matrix B (transposed)
4118 half4 a0 = vload4(0, src_addr_a);
4119 half8 b0 = vload8(0, src_addr_b);
4120
4121 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
4122 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
4123
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004124 c0 = fma((half8)a0.s0, b0, c0);
4125 c1 = fma((half8)a0.s1, b0, c1);
4126 c2 = fma((half8)a0.s2, b0, c2);
4127 c3 = fma((half8)a0.s3, b0, c3);
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004128
4129 // Load values from matrix A (interleaved) and matrix B (transposed)
4130 a0 = vload4(0, src_addr_a);
4131 b0 = vload8(0, src_addr_b);
4132
4133 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
4134 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
4135
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004136 c0 = fma((half8)a0.s0, b0, c0);
4137 c1 = fma((half8)a0.s1, b0, c1);
4138 c2 = fma((half8)a0.s2, b0, c2);
4139 c3 = fma((half8)a0.s3, b0, c3);
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004140
4141 // Load values from matrix A (interleaved) and matrix B (transposed)
4142 a0 = vload4(0, src_addr_a);
4143 b0 = vload8(0, src_addr_b);
4144
4145 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
4146 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
4147
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004148 c0 = fma((half8)a0.s0, b0, c0);
4149 c1 = fma((half8)a0.s1, b0, c1);
4150 c2 = fma((half8)a0.s2, b0, c2);
4151 c3 = fma((half8)a0.s3, b0, c3);
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004152
4153 // Load values from matrix A (interleaved) and matrix B (transposed)
4154 a0 = vload4(0, src_addr_a);
4155 b0 = vload8(0, src_addr_b);
4156
4157 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
4158 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
4159
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004160 c0 = fma((half8)a0.s0, b0, c0);
4161 c1 = fma((half8)a0.s1, b0, c1);
4162 c2 = fma((half8)a0.s2, b0, c2);
4163 c3 = fma((half8)a0.s3, b0, c3);
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004164#endif // MULT_INTERLEAVE4X4_HEIGHT == 1
4165 }
4166
4167 for(; i < (int)(COLS_MTX_B); ++i)
4168 {
4169 // Load values from matrix A (interleaved) and matrix B (transposed)
4170 half4 a0 = vload4(0, src_addr_a);
4171 half8 b0 = vload8(0, src_addr_b);
4172
4173 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
4174 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
4175
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004176 c0 = fma((half8)a0.s0, b0, c0);
4177 c1 = fma((half8)a0.s1, b0, c1);
4178 c2 = fma((half8)a0.s2, b0, c2);
4179 c3 = fma((half8)a0.s3, b0, c3);
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004180 }
4181
4182 // Compute destination address
4183 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
4184
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004185 // Compute dst address
4186 __global uchar *dst_addr = offset(&dst, 0, 0);
4187
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004188 uint4 zout = 0;
4189
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004190#if defined(REINTERPRET_OUTPUT_AS_3D)
4191 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004192 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004193 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004194 // | |
4195 // | plane0 |
4196 // | |
4197 // |__________________|
4198 // |******************|
4199 // | cross_plane_pad |
4200 // |******************|
4201 // | |
4202 // | plane1 |
4203 // | |
4204 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004205
4206 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004207 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
4208 zout = min(DEPTH_GEMM3D - 1, zout);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004209
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004210 // Add offset due to the cross plane paddings
4211 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004212
4213 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4214 // multiply dst_stride_z by DEPTH_GEMM3D
4215 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004216#else // defined(REINTERPRET_OUTPUT_AS_3D)
4217 // Add offset for batched GEMM
4218 dst_addr += z * dst_stride_z;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004219#endif // defined(REINTERPRET_OUTPUT_AS_3D)
4220
4221 // Multiply by the weight of matrix-matrix product and store the result
4222#if defined(ALPHA)
4223 SCALE_BLOCK(4, half, c, ALPHA);
4224#endif // defined(ALPHA)
4225
4226 // Add beta*bias
4227#if defined(BETA)
4228 REPEAT_VAR_INIT_TO_CONST(4, uint, zero, 0);
4229
4230#if defined(BROADCAST_BIAS)
4231 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half));
4232
4233 LOAD_BLOCK(1, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
4234
4235#ifndef UNIT_BETA
4236 SCALE_BLOCK(1, half, bias, BETA);
4237#endif // UNIT_BIAS
4238
4239 // c = c + bias[broadcasted]
4240 ADD_BLOCK_BROADCAST(4, c, bias0);
4241
4242#else // defined(BROADCAST_BIAS)
4243 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id(
4244 2) * src2_stride_z;
4245
4246 LOAD_BLOCK(4, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
4247
4248#ifndef UNIT_BETA
4249 SCALE_BLOCK(4, half, bias, BETA);
4250#endif // UNIT_BIAS
4251
4252 // c = c + bias
4253 ADD_BLOCK(4, c, bias);
4254
4255#endif // defined(BROADCAST_BIAS)
4256#endif // defined(BETA)
4257
4258#if defined(ACTIVATION_TYPE)
4259 ACTIVATION_BLOCK(4, ACTIVATION_TYPE, half, c, A_VAL, B_VAL);
4260#endif // defined(ACTIVATION_TYPE)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004261
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004262 // Store 4x8 block
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004263 vstore8(c0, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
4264 vstore8(c1, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
4265 vstore8(c2, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
4266 vstore8(c3, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004267}
Georgios Pinitas84225582018-05-14 12:00:05 +01004268
4269// Undefine local defines
4270#undef COLS_MTX_B
4271
Matthew Bentham6f31f8c2017-10-27 11:50:06 +01004272#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004273
Gian Marco36a0a462018-01-12 10:21:40 +00004274#endif // defined(COLS_B) && defined(MULT_TRANSPOSE1XW_WIDTH) && defined(MULT_INTERLEAVE4X4_HEIGHT)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01004275
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004276#if defined(COLS_A) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && (NUM_ELEMS_PROCESSED_PER_THREAD_Y)
4277#if defined(DATA_TYPE)
4278#define VECTOR_TYPE VEC_DATA_TYPE(DATA_TYPE, NUM_ELEMS_PROCESSED_PER_THREAD_X)
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004279/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped.
4280 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004281 * @note This OpenCL kernel works with floating point data types (F16/F32)
4282 * @note The floating point data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
4283 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004284 * @note The number of matrix A columns and the optional alpha's value need to be passed at compile time using -DCOLS_A and -DALPHA
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004285 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
4286 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004287 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004288 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
4289 * The activation function is performed after the bias addition
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004290 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
4291 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004292 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
4293 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
4294 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
4295 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
4296 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004297 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004298 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
4299 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4300 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
4301 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4302 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01004303 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004304 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
4305 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4306 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
4307 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4308 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004309 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
4310 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
4311 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
4312 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
4313 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
4314 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01004315 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004316 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
4317 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
4318 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
4319 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
4320 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004321 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
4322 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004323 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004324 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004325 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
4326 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements for the output tensor (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004327 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004328__kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
4329 IMAGE_DECLARATION(src1),
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004330#if defined(BETA)
4331 IMAGE_DECLARATION(src2),
4332#endif // defined(BETA)
Gian Marcoae2af742018-02-15 12:35:44 +00004333 IMAGE_DECLARATION(dst),
4334 uint src0_stride_z,
4335 uint src1_stride_z,
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004336#if defined(BETA)
4337 uint src2_stride_z,
4338#endif //defined(BETA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004339 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004340#if defined(REINTERPRET_INPUT_AS_3D)
4341 ,
4342 uint src_cross_plane_pad
4343#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004344#if defined(REINTERPRET_OUTPUT_AS_3D)
4345 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004346 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004347#endif // REINTERPRET_OUTPUT_AS_3D
4348 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004349{
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004350 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004351
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004352 // Compute starting address for matrix A and Matrix B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004353 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004354
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004355 // Update address for the matrix A
4356 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004357
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004358 // Update address for the matrix B
4359 src_addr.s1 += idx * sizeof(DATA_TYPE);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004360
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004361#if defined(REINTERPRET_INPUT_AS_3D)
4362 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
4363 // in order to take into account the presence of possible cross plane paddings
4364 //
4365 // | |
4366 // | plane0 |
4367 // | |
4368 // |__________________|
4369 // |******************|
4370 // | cross_plane_pad |
4371 // |******************|
4372 // | |
4373 // | plane1 |
4374 // | |
4375 // |__________________|
4376
4377 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
4378 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
4379 zin = min(DEPTH_GEMM3D - 1, zin);
4380
4381 // Add offset due to the cross plane paddings
4382 zin *= (src_cross_plane_pad * src0_stride_y);
4383
4384 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4385 // multiply src0_stride_z by DEPTH_GEMM3D
4386 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
4387
4388#else // defined(REINTERPRET_INPUT_AS_3D)
4389
Gian Marcoae2af742018-02-15 12:35:44 +00004390 // Add offset for batched GEMM
4391 src_addr.s0 += get_global_id(2) * src0_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00004392
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004393#endif // defined(REINTERPRET_INPUT_AS_3D)
4394
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00004395#if defined(MATRIX_B_DEPTH)
4396 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
4397 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
4398#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00004399 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00004400#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00004401
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004402 int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(DATA_TYPE));
4403
4404 VECTOR_TYPE acc0 = 0.0f;
4405#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4406 VECTOR_TYPE acc1 = 0.0f;
4407#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4408#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4409 VECTOR_TYPE acc2 = 0.0f;
4410#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4411#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4412 VECTOR_TYPE acc3 = 0.0f;
4413#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4414
Georgios Pinitas96880cf2017-10-20 18:52:20 +01004415 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(DATA_TYPE)); src_addr += (int2)(2 * sizeof(DATA_TYPE), 2 * src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004416 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004417#if defined(REINTERPRET_INPUT_AS_3D)
4418 // Load values from matrix A
Usama Arif0681e3b2019-04-25 14:28:07 +01004419 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 2, DATA_TYPE, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s);
4420#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004421 // Load values from matrix A
4422 VEC_DATA_TYPE(DATA_TYPE, 2)
4423 a0 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
4424#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4425 VEC_DATA_TYPE(DATA_TYPE, 2)
4426 a1 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
4427#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4428#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4429 VEC_DATA_TYPE(DATA_TYPE, 2)
4430 a2 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
4431#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4432#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4433 VEC_DATA_TYPE(DATA_TYPE, 2)
4434 a3 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
4435#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004436#endif // defined(REINTERPRET_INPUT_AS_3D)
4437
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004438 // Load values from matrix B
4439 VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
4440 VECTOR_TYPE b1 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1 + src1_stride_y));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004441
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004442 // Accumulate
4443 acc0 += b0 * (VECTOR_TYPE)a0.s0;
4444 acc0 += b1 * (VECTOR_TYPE)a0.s1;
4445#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4446 acc1 += b0 * (VECTOR_TYPE)a1.s0;
4447 acc1 += b1 * (VECTOR_TYPE)a1.s1;
4448#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4449#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4450 acc2 += b0 * (VECTOR_TYPE)a2.s0;
4451 acc2 += b1 * (VECTOR_TYPE)a2.s1;
4452#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4453#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4454 acc3 += b0 * (VECTOR_TYPE)a3.s0;
4455 acc3 += b1 * (VECTOR_TYPE)a3.s1;
4456#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004457 }
4458
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004459 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(DATA_TYPE), src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004460 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004461#if defined(REINTERPRET_INPUT_AS_3D)
4462 // Load values from matrix A
4463 DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
4464#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4465 DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
4466#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4467#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4468 DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
4469#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4470#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4471 DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
4472#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4473#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004474 // Load values from matrix A
4475 DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
4476#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4477 DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
4478#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4479#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4480 DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
4481#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4482#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4483 DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
4484#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004485#endif // defined(REINTERPRET_INPUT_AS_3D)
4486
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004487 // Load values from matrix B
4488 VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004489
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004490 // Accumulate
4491 acc0 += b0 * (VECTOR_TYPE)a0;
4492#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4493 acc1 += b0 * (VECTOR_TYPE)a1;
4494#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4495#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4496 acc2 += b0 * (VECTOR_TYPE)a2;
4497#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4498#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4499 acc3 += b0 * (VECTOR_TYPE)a3;
4500#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004501 }
4502
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004503 int z = get_global_id(2);
4504
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004505 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004506 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
4507
Gian Marcoae2af742018-02-15 12:35:44 +00004508 // Compute dst address
4509 __global uchar *dst_addr = offset(&dst, 0, 0);
4510
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004511 uint4 zout = 0;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004512
4513#if defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004514
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004515 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004516 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004517 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004518 // | |
4519 // | plane0 |
4520 // | |
4521 // |__________________|
4522 // |******************|
4523 // | cross_plane_pad |
4524 // |******************|
4525 // | |
4526 // | plane1 |
4527 // | |
4528 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004529
4530 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004531 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
4532 zout = min(DEPTH_GEMM3D - 1, zout);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004533
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004534 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004535 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004536
4537 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4538 // multiply dst_stride_z by DEPTH_GEMM3D
4539 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004540#else // defined(REINTERPRET_OUTPUT_AS_3D)
4541 // Add offset for batched GEMM
4542 dst_addr += z * dst_stride_z;
4543#endif // defined(REINTERPRET_OUTPUT_AS_3D)
4544
4545 // Multiply by the weight of matrix-matrix product and store the result
4546#if defined(ALPHA)
4547 SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, DATA_TYPE, acc, ALPHA);
4548#endif // defined(ALPHA)
4549
4550 // Add beta*bias
4551#if defined(BETA)
4552 REPEAT_VAR_INIT_TO_CONST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, uint, zero, 0);
4553
4554#if defined(BROADCAST_BIAS)
4555 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)NUM_ELEMS_PROCESSED_PER_THREAD_X * sizeof(DATA_TYPE));
4556
4557 LOAD_BLOCK(1, NUM_ELEMS_PROCESSED_PER_THREAD_X, DATA_TYPE, bias, src2_addr, 0, src2_stride_y, zero);
4558
4559#ifndef UNIT_BETA
4560 SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
4561#endif // UNIT_BIAS
4562
4563 // c = c + bias[broadcasted]
4564 ADD_BLOCK_BROADCAST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias0);
4565
4566#else // defined(BROADCAST_BIAS)
4567 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)NUM_ELEMS_PROCESSED_PER_THREAD_X * sizeof(DATA_TYPE)) + (get_global_id(1) *
4568 (uint)NUM_ELEMS_PROCESSED_PER_THREAD_Y * src2_stride_y) + get_global_id(2) * src2_stride_z;
4569
4570 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, NUM_ELEMS_PROCESSED_PER_THREAD_X, DATA_TYPE, bias, src2_addr, 0, src2_stride_y, zero);
4571
4572#ifndef UNIT_BETA
4573 SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, DATA_TYPE, bias, BETA);
4574#endif // UNIT_BIAS
4575
4576 // c = c + bias
4577 ADD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias);
4578
4579#endif // defined(BROADCAST_BIAS)
4580#endif // defined(BETA)
4581
4582#if defined(ACTIVATION_TYPE)
4583 ACTIVATION_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, ACTIVATION_TYPE, DATA_TYPE, acc, A_VAL, B_VAL);
4584#endif // defined(ACTIVATION_TYPE)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004585
4586 // Store output block
Usama Arif0681e3b2019-04-25 14:28:07 +01004587 STORE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, NUM_ELEMS_PROCESSED_PER_THREAD_X, DATA_TYPE, acc, dst_addr, dst_stride_y, zout.s);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004588}
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004589#endif // defined(DATA_TYPE)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01004590
Michele Di Giorgiof6f08da2018-04-26 10:24:30 +01004591/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004592 *
4593 * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
4594 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
4595 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
4596 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
4597 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004598 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
4599 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004600 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004601 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
4602 * The activation function is performed after the bias addition
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004603 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
4604 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004605 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
4606 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
4607 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
4608 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
4609 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004610 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004611 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
4612 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4613 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
4614 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4615 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
4616 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
4617 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
4618 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4619 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
4620 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4621 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004622 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
4623 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
4624 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
4625 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
4626 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
4627 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004628 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
4629 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
4630 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
4631 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
4632 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
4633 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004634 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
4635 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004636 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004637 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004638 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
4639 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004640 */
4641__kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
4642 IMAGE_DECLARATION(src1),
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004643#if defined(BETA)
4644 IMAGE_DECLARATION(src2),
4645#endif // defined(BETA)
Gian Marcoae2af742018-02-15 12:35:44 +00004646 IMAGE_DECLARATION(dst),
4647 uint src0_stride_z,
4648 uint src1_stride_z,
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004649#if defined(BETA)
4650 uint src2_stride_z,
4651#endif //defined(BETA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004652 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004653#if defined(REINTERPRET_INPUT_AS_3D)
4654 ,
4655 uint src_cross_plane_pad
4656#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004657#if defined(REINTERPRET_OUTPUT_AS_3D)
4658 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004659 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004660#endif // REINTERPRET_OUTPUT_AS_3D
4661 )
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004662{
4663 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
4664
4665 // Compute starting address for matrix A and matrix B
4666 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
4667
4668 // Update address for matrix A
4669 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
4670
4671 // Update address for matrix B
4672 src_addr.s1 += idx * sizeof(float);
4673
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004674#if defined(REINTERPRET_INPUT_AS_3D)
4675 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
4676 // in order to take into account the presence of possible cross plane paddings
4677 //
4678 // | |
4679 // | plane0 |
4680 // | |
4681 // |__________________|
4682 // |******************|
4683 // | cross_plane_pad |
4684 // |******************|
4685 // | |
4686 // | plane1 |
4687 // | |
4688 // |__________________|
4689
4690 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
4691 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
4692 zin = min(DEPTH_GEMM3D - 1, zin);
4693
4694 // Add offset due to the cross plane paddings
4695 zin *= (src_cross_plane_pad * src0_stride_y);
4696
4697 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4698 // multiply src0_stride_z by DEPTH_GEMM3D
4699 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
4700
4701#else // defined(REINTERPRET_INPUT_AS_3D)
4702
Gian Marcoae2af742018-02-15 12:35:44 +00004703 // Add offset for batched GEMM
4704 src_addr.s0 += get_global_id(2) * src0_stride_z;
4705
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004706#endif // defined(REINTERPRET_INPUT_AS_3D)
4707
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00004708#if defined(MATRIX_B_DEPTH)
4709 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
4710 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
4711#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00004712 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00004713#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00004714
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004715 // Initialize accumulators
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004716 float4 acc0 = 0.0f;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004717
4718#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004719 float4 acc1 = 0.0f;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004720#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4721
4722#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004723 float4 acc2 = 0.0f;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004724#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4725
4726#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004727 float4 acc3 = 0.0f;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004728#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4729
4730 // A and B src indices get incremented at the same time.
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004731 int i = 0;
4732 for(; i <= ((int)COLS_A - 4); i += 4)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004733 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004734#if defined(REINTERPRET_INPUT_AS_3D)
4735 // Load values from matrix A and matrix B
Usama Arif0681e3b2019-04-25 14:28:07 +01004736 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 4, float, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s);
4737#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004738 // Load values from matrix A and matrix B
4739 float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004740#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004741 float4 a1 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004742#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4743#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004744 float4 a2 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004745#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4746#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004747 float4 a3 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004748#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004749#endif // defined(REINTERPRET_INPUT_AS_3D)
4750
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004751 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
4752 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004753
4754 // Multiply and accumulate
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004755 acc0.s0 = fma(a0.s0, b0.s0, acc0.s0);
4756 acc0.s1 = fma(a0.s0, b0.s1, acc0.s1);
4757 acc0.s2 = fma(a0.s0, b0.s2, acc0.s2);
4758 acc0.s3 = fma(a0.s0, b0.s3, acc0.s3);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004759
4760#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004761
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004762 acc1.s0 = fma(a1.s0, b0.s0, acc1.s0);
4763 acc1.s1 = fma(a1.s0, b0.s1, acc1.s1);
4764 acc1.s2 = fma(a1.s0, b0.s2, acc1.s2);
4765 acc1.s3 = fma(a1.s0, b0.s3, acc1.s3);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004766
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004767#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4768#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004769
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004770 acc2.s0 = fma(a2.s0, b0.s0, acc2.s0);
4771 acc2.s1 = fma(a2.s0, b0.s1, acc2.s1);
4772 acc2.s2 = fma(a2.s0, b0.s2, acc2.s2);
4773 acc2.s3 = fma(a2.s0, b0.s3, acc2.s3);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004774
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004775#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4776#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004777
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004778 acc3.s0 = fma(a3.s0, b0.s0, acc3.s0);
4779 acc3.s1 = fma(a3.s0, b0.s1, acc3.s1);
4780 acc3.s2 = fma(a3.s0, b0.s2, acc3.s2);
4781 acc3.s3 = fma(a3.s0, b0.s3, acc3.s3);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004782#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004783
4784 // Load values from matrix A and matrix B
4785 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
4786 src_addr.s1 += src1_stride_y;
4787
4788 // Multiply and accumulate
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004789 acc0.s0 = fma(a0.s1, b0.s0, acc0.s0);
4790 acc0.s1 = fma(a0.s1, b0.s1, acc0.s1);
4791 acc0.s2 = fma(a0.s1, b0.s2, acc0.s2);
4792 acc0.s3 = fma(a0.s1, b0.s3, acc0.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004793
4794#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4795
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004796 acc1.s0 = fma(a1.s1, b0.s0, acc1.s0);
4797 acc1.s1 = fma(a1.s1, b0.s1, acc1.s1);
4798 acc1.s2 = fma(a1.s1, b0.s2, acc1.s2);
4799 acc1.s3 = fma(a1.s1, b0.s3, acc1.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004800
4801#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4802#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4803
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004804 acc2.s0 = fma(a2.s1, b0.s0, acc2.s0);
4805 acc2.s1 = fma(a2.s1, b0.s1, acc2.s1);
4806 acc2.s2 = fma(a2.s1, b0.s2, acc2.s2);
4807 acc2.s3 = fma(a2.s1, b0.s3, acc2.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004808
4809#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4810#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4811
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004812 acc3.s0 = fma(a3.s1, b0.s0, acc3.s0);
4813 acc3.s1 = fma(a3.s1, b0.s1, acc3.s1);
4814 acc3.s2 = fma(a3.s1, b0.s2, acc3.s2);
4815 acc3.s3 = fma(a3.s1, b0.s3, acc3.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004816#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4817
4818 // Load values from matrix A and matrix B
4819 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
4820 src_addr.s1 += src1_stride_y;
4821
4822 // Multiply and accumulate
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004823 acc0.s0 = fma(a0.s2, b0.s0, acc0.s0);
4824 acc0.s1 = fma(a0.s2, b0.s1, acc0.s1);
4825 acc0.s2 = fma(a0.s2, b0.s2, acc0.s2);
4826 acc0.s3 = fma(a0.s2, b0.s3, acc0.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004827
4828#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4829
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004830 acc1.s0 = fma(a1.s2, b0.s0, acc1.s0);
4831 acc1.s1 = fma(a1.s2, b0.s1, acc1.s1);
4832 acc1.s2 = fma(a1.s2, b0.s2, acc1.s2);
4833 acc1.s3 = fma(a1.s2, b0.s3, acc1.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004834
4835#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4836#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4837
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004838 acc2.s0 = fma(a2.s2, b0.s0, acc2.s0);
4839 acc2.s1 = fma(a2.s2, b0.s1, acc2.s1);
4840 acc2.s2 = fma(a2.s2, b0.s2, acc2.s2);
4841 acc2.s3 = fma(a2.s2, b0.s3, acc2.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004842
4843#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4844#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4845
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004846 acc3.s0 = fma(a3.s2, b0.s0, acc3.s0);
4847 acc3.s1 = fma(a3.s2, b0.s1, acc3.s1);
4848 acc3.s2 = fma(a3.s2, b0.s2, acc3.s2);
4849 acc3.s3 = fma(a3.s2, b0.s3, acc3.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004850#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4851
4852 // Load values from matrix A and matrix B
4853 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
4854 src_addr.s1 += src1_stride_y;
4855
4856 // Multiply and accumulate
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004857 acc0.s0 = fma(a0.s3, b0.s0, acc0.s0);
4858 acc0.s1 = fma(a0.s3, b0.s1, acc0.s1);
4859 acc0.s2 = fma(a0.s3, b0.s2, acc0.s2);
4860 acc0.s3 = fma(a0.s3, b0.s3, acc0.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004861
4862#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4863
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004864 acc1.s0 = fma(a1.s3, b0.s0, acc1.s0);
4865 acc1.s1 = fma(a1.s3, b0.s1, acc1.s1);
4866 acc1.s2 = fma(a1.s3, b0.s2, acc1.s2);
4867 acc1.s3 = fma(a1.s3, b0.s3, acc1.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004868
4869#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4870#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4871
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004872 acc2.s0 = fma(a2.s3, b0.s0, acc2.s0);
4873 acc2.s1 = fma(a2.s3, b0.s1, acc2.s1);
4874 acc2.s2 = fma(a2.s3, b0.s2, acc2.s2);
4875 acc2.s3 = fma(a2.s3, b0.s3, acc2.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004876
4877#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4878#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4879
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004880 acc3.s0 = fma(a3.s3, b0.s0, acc3.s0);
4881 acc3.s1 = fma(a3.s3, b0.s1, acc3.s1);
4882 acc3.s2 = fma(a3.s3, b0.s2, acc3.s2);
4883 acc3.s3 = fma(a3.s3, b0.s3, acc3.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004884#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4885
4886 src_addr.s0 += 4 * sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004887 }
4888
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004889 for(; i < (int)COLS_A; ++i)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004890 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004891#if defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004892 // Load values from matrix A
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004893 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
4894#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4895 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
4896#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4897#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4898 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
4899#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4900#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4901 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
4902#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4903#else // defined(REINTERPRET_INPUT_AS_3D)
4904 // Load values from matrix A
4905 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004906#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4907 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
4908#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4909#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4910 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
4911#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4912#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
4913 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
4914#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004915#endif // defined(REINTERPRET_INPUT_AS_3D)
4916
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004917 // Load values from matrix B
4918 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004919 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004920
4921 // Multiply and accumulate
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004922 acc0.s0 = fma(a0, b0.s0, acc0.s0);
4923 acc0.s1 = fma(a0, b0.s1, acc0.s1);
4924 acc0.s2 = fma(a0, b0.s2, acc0.s2);
4925 acc0.s3 = fma(a0, b0.s3, acc0.s3);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004926#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004927 acc1.s0 = fma(a1, b0.s0, acc1.s0);
4928 acc1.s1 = fma(a1, b0.s1, acc1.s1);
4929 acc1.s2 = fma(a1, b0.s2, acc1.s2);
4930 acc1.s3 = fma(a1, b0.s3, acc1.s3);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004931#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
4932#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004933 acc2.s0 = fma(a2, b0.s0, acc2.s0);
4934 acc2.s1 = fma(a2, b0.s1, acc2.s1);
4935 acc2.s2 = fma(a2, b0.s2, acc2.s2);
4936 acc2.s3 = fma(a2, b0.s3, acc2.s3);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004937#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
4938#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004939 acc3.s0 = fma(a3, b0.s0, acc3.s0);
4940 acc3.s1 = fma(a3, b0.s1, acc3.s1);
4941 acc3.s2 = fma(a3, b0.s2, acc3.s2);
4942 acc3.s3 = fma(a3, b0.s3, acc3.s3);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004943#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004944
4945 src_addr.s0 += sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004946 }
4947
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004948 int z = get_global_id(2);
4949
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004950 // Compute destination address
4951 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
4952
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004953 // Compute dst address
4954 __global uchar *dst_addr = offset(&dst, 0, 0);
4955
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004956 uint4 zout = 0;
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004957
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004958#if defined(REINTERPRET_OUTPUT_AS_3D)
4959 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004960 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004961 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004962 // | |
4963 // | plane0 |
4964 // | |
4965 // |__________________|
4966 // |******************|
4967 // | cross_plane_pad |
4968 // |******************|
4969 // | |
4970 // | plane1 |
4971 // | |
4972 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004973
4974 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004975 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
4976 zout = min(DEPTH_GEMM3D - 1, zout);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004977
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004978 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004979 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004980
4981 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4982 // multiply dst_stride_z by DEPTH_GEMM3D
4983 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004984#else // defined(REINTERPRET_OUTPUT_AS_3D)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004985 // Add offset for batched GEMM
4986 dst_addr += z * dst_stride_z;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004987#endif // defined(REINTERPRET_OUTPUT_AS_3D)
4988
4989 // Multiply by the weight of matrix-matrix product and store the result
4990#if defined(ALPHA)
4991 SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, float, acc, ALPHA);
4992#endif // defined(ALPHA)
4993
4994 // Add beta*bias
4995#if defined(BETA)
4996 REPEAT_VAR_INIT_TO_CONST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, uint, zero, 0);
4997
4998#if defined(BROADCAST_BIAS)
4999 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float));
5000
5001 LOAD_BLOCK(1, 4, float, bias, src2_addr, 0, src2_stride_y, zero);
5002
5003#ifndef UNIT_BETA
5004 SCALE_BLOCK(1, float, bias, BETA);
5005#endif // UNIT_BIAS
5006
5007 // acc = acc + bias[broadcasted]
5008 ADD_BLOCK_BROADCAST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias0);
5009
5010#else // defined(BROADCAST_BIAS)
5011 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float)) + (get_global_id(1) *
5012 (uint)NUM_ELEMS_PROCESSED_PER_THREAD_Y * src2_stride_y) + get_global_id(2) * src2_stride_z;
5013
5014 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 4, float, bias, src2_addr, 0, src2_stride_y, zero);
5015
5016#ifndef UNIT_BETA
5017 SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, float, bias, BETA);
5018#endif // UNIT_BIAS
5019
5020 // acc = acc + bias
5021 ADD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias);
5022
5023#endif // defined(BROADCAST_BIAS)
5024#endif // defined(BETA)
5025
5026#if defined(ACTIVATION_TYPE)
5027 ACTIVATION_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, ACTIVATION_TYPE, float, acc, A_VAL, B_VAL);
5028#endif // defined(ACTIVATION_TYPE)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005029
5030 // Store the output block
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005031 vstore4(acc0, 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005032#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005033 vstore4(acc1, 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005034#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5035#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005036 vstore4(acc2, 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005037#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5038#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005039 vstore4(acc3, 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005040#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005041}
5042
5043/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
5044 *
5045 * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
5046 * This OpenCL kernel is optimized for Bifrost when the number of matrix B columns is less or equal to 1000.
5047 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
5048 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=2.
5049 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
5050 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha if alpha!=1.0f.
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005051 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
5052 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005053 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005054 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
5055 * The activation function is performed after the bias addition
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005056 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
5057 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005058 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
5059 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
5060 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
5061 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
5062 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005063 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005064 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
5065 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5066 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
5067 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
5068 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
5069 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
5070 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
5071 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5072 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
5073 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
5074 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005075 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
5076 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
5077 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
5078 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
5079 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
5080 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005081 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
5082 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
5083 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
5084 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
5085 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
5086 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005087 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
5088 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005089 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005090 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005091 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
5092 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005093 */
5094__kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
5095 IMAGE_DECLARATION(src1),
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005096#if defined(BETA)
5097 IMAGE_DECLARATION(src2),
5098#endif // defined(BETA)
Gian Marcoae2af742018-02-15 12:35:44 +00005099 IMAGE_DECLARATION(dst),
5100 uint src0_stride_z,
5101 uint src1_stride_z,
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005102#if defined(BETA)
5103 uint src2_stride_z,
5104#endif //defined(BETA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005105 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005106#if defined(REINTERPRET_INPUT_AS_3D)
5107 ,
5108 uint src_cross_plane_pad
5109#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005110#if defined(REINTERPRET_OUTPUT_AS_3D)
5111 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005112 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005113#endif // REINTERPRET_OUTPUT_AS_3D
5114 )
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005115{
5116 // Requires 2 NUM_ELEMS_PROCESSED_PER_THREAD_X, C vect2, A vect4, B (2 vload2) // to fix for NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5117 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
5118
5119 // Compute starting address for matrix A and Matrix B
5120 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
5121
5122 // Update address for the matrix A
5123 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
5124
5125 // Update address for the matrix B
5126 src_addr.s1 += idx * sizeof(float);
5127
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005128#if defined(REINTERPRET_INPUT_AS_3D)
5129 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
5130 // in order to take into account the presence of possible cross plane paddings
5131 //
5132 // | |
5133 // | plane0 |
5134 // | |
5135 // |__________________|
5136 // |******************|
5137 // | cross_plane_pad |
5138 // |******************|
5139 // | |
5140 // | plane1 |
5141 // | |
5142 // |__________________|
5143
5144 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
5145 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
5146 zin = min(DEPTH_GEMM3D - 1, zin);
5147
5148 // Add offset due to the cross plane paddings
5149 zin *= (src_cross_plane_pad * src0_stride_y);
5150
5151 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
5152 // multiply src0_stride_z by DEPTH_GEMM3D
5153 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
5154
5155#else // defined(REINTERPRET_INPUT_AS_3D)
5156
Gian Marcoae2af742018-02-15 12:35:44 +00005157 // Add offset for batched GEMM
5158 src_addr.s0 += get_global_id(2) * src0_stride_z;
5159
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005160#endif // defined(REINTERPRET_INPUT_AS_3D)
5161
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00005162#if defined(MATRIX_B_DEPTH)
5163 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
5164 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
5165#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00005166 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00005167#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00005168
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005169 // Initialize accumulators
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005170 float2 acc0 = 0.0f;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005171#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005172 float2 acc1 = 0.0f;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005173#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5174#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005175 float2 acc2 = 0.0f;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005176#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5177#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005178 float2 acc3 = 0.0f;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005179#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5180
5181 // A and B src indices get incremented at the same time.
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005182 int i = 0;
5183 for(; i <= ((int)COLS_A - 8); i += 8)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005184 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005185#if defined(REINTERPRET_INPUT_AS_3D)
5186 // Load values from matrix A
5187 float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + zin.s0));
5188#else // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005189 // Load values from matrix A
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005190 float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0));
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005191#endif // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005192
5193 // Load values from matrix B
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005194 float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
5195 src_addr.s1 += src1_stride_y;
5196 float2 b1 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
5197 src_addr.s1 += src1_stride_y;
5198 float2 b2 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
5199 src_addr.s1 += src1_stride_y;
5200 float2 b3 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
5201 src_addr.s1 += src1_stride_y;
5202 float2 b4 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
5203 src_addr.s1 += src1_stride_y;
5204 float2 b5 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
5205 src_addr.s1 += src1_stride_y;
5206 float2 b6 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
5207 src_addr.s1 += src1_stride_y;
5208 float2 b7 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
5209 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005210
5211 // Multiply and accumulate
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005212 acc0.s0 = fma(a0.s0, b0.s0, acc0.s0);
5213 acc0.s0 = fma(a0.s1, b1.s0, acc0.s0);
5214 acc0.s0 = fma(a0.s2, b2.s0, acc0.s0);
5215 acc0.s0 = fma(a0.s3, b3.s0, acc0.s0);
5216 acc0.s0 = fma(a0.s4, b4.s0, acc0.s0);
5217 acc0.s0 = fma(a0.s5, b5.s0, acc0.s0);
5218 acc0.s0 = fma(a0.s6, b6.s0, acc0.s0);
5219 acc0.s0 = fma(a0.s7, b7.s0, acc0.s0);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005220
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005221 acc0.s1 = fma(a0.s0, b0.s1, acc0.s1);
5222 acc0.s1 = fma(a0.s1, b1.s1, acc0.s1);
5223 acc0.s1 = fma(a0.s2, b2.s1, acc0.s1);
5224 acc0.s1 = fma(a0.s3, b3.s1, acc0.s1);
5225 acc0.s1 = fma(a0.s4, b4.s1, acc0.s1);
5226 acc0.s1 = fma(a0.s5, b5.s1, acc0.s1);
5227 acc0.s1 = fma(a0.s6, b6.s1, acc0.s1);
5228 acc0.s1 = fma(a0.s7, b7.s1, acc0.s1);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005229
5230#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005231#if defined(REINTERPRET_INPUT_AS_3D)
5232 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
5233#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005234 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005235#endif // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005236 acc1.s0 = fma(a0.s0, b0.s0, acc1.s0);
5237 acc1.s0 = fma(a0.s1, b1.s0, acc1.s0);
5238 acc1.s0 = fma(a0.s2, b2.s0, acc1.s0);
5239 acc1.s0 = fma(a0.s3, b3.s0, acc1.s0);
5240 acc1.s0 = fma(a0.s4, b4.s0, acc1.s0);
5241 acc1.s0 = fma(a0.s5, b5.s0, acc1.s0);
5242 acc1.s0 = fma(a0.s6, b6.s0, acc1.s0);
5243 acc1.s0 = fma(a0.s7, b7.s0, acc1.s0);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005244
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005245 acc1.s1 = fma(a0.s0, b0.s1, acc1.s1);
5246 acc1.s1 = fma(a0.s1, b1.s1, acc1.s1);
5247 acc1.s1 = fma(a0.s2, b2.s1, acc1.s1);
5248 acc1.s1 = fma(a0.s3, b3.s1, acc1.s1);
5249 acc1.s1 = fma(a0.s4, b4.s1, acc1.s1);
5250 acc1.s1 = fma(a0.s5, b5.s1, acc1.s1);
5251 acc1.s1 = fma(a0.s6, b6.s1, acc1.s1);
5252 acc1.s1 = fma(a0.s7, b7.s1, acc1.s1);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005253#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5254#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005255#if defined(REINTERPRET_INPUT_AS_3D)
5256 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
5257#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005258 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005259#endif // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005260 acc2.s0 = fma(a0.s0, b0.s0, acc2.s0);
5261 acc2.s0 = fma(a0.s1, b1.s0, acc2.s0);
5262 acc2.s0 = fma(a0.s2, b2.s0, acc2.s0);
5263 acc2.s0 = fma(a0.s3, b3.s0, acc2.s0);
5264 acc2.s0 = fma(a0.s4, b4.s0, acc2.s0);
5265 acc2.s0 = fma(a0.s5, b5.s0, acc2.s0);
5266 acc2.s0 = fma(a0.s6, b6.s0, acc2.s0);
5267 acc2.s0 = fma(a0.s7, b7.s0, acc2.s0);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005268
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005269 acc2.s1 = fma(a0.s0, b0.s1, acc2.s1);
5270 acc2.s1 = fma(a0.s1, b1.s1, acc2.s1);
5271 acc2.s1 = fma(a0.s2, b2.s1, acc2.s1);
5272 acc2.s1 = fma(a0.s3, b3.s1, acc2.s1);
5273 acc2.s1 = fma(a0.s4, b4.s1, acc2.s1);
5274 acc2.s1 = fma(a0.s5, b5.s1, acc2.s1);
5275 acc2.s1 = fma(a0.s6, b6.s1, acc2.s1);
5276 acc2.s1 = fma(a0.s7, b7.s1, acc2.s1);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005277#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5278#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005279#if defined(REINTERPRET_INPUT_AS_3D)
5280 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
5281#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005282 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005283#endif // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005284 acc3.s0 = fma(a0.s0, b0.s0, acc3.s0);
5285 acc3.s0 = fma(a0.s1, b1.s0, acc3.s0);
5286 acc3.s0 = fma(a0.s2, b2.s0, acc3.s0);
5287 acc3.s0 = fma(a0.s3, b3.s0, acc3.s0);
5288 acc3.s0 = fma(a0.s4, b4.s0, acc3.s0);
5289 acc3.s0 = fma(a0.s5, b5.s0, acc3.s0);
5290 acc3.s0 = fma(a0.s6, b6.s0, acc3.s0);
5291 acc3.s0 = fma(a0.s7, b7.s0, acc3.s0);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005292
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005293 acc3.s1 = fma(a0.s0, b0.s1, acc3.s1);
5294 acc3.s1 = fma(a0.s1, b1.s1, acc3.s1);
5295 acc3.s1 = fma(a0.s2, b2.s1, acc3.s1);
5296 acc3.s1 = fma(a0.s3, b3.s1, acc3.s1);
5297 acc3.s1 = fma(a0.s4, b4.s1, acc3.s1);
5298 acc3.s1 = fma(a0.s5, b5.s1, acc3.s1);
5299 acc3.s1 = fma(a0.s6, b6.s1, acc3.s1);
5300 acc3.s1 = fma(a0.s7, b7.s1, acc3.s1);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005301#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005302
5303 src_addr.s0 += sizeof(float) * 8;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005304 }
5305 // float size increment
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005306 for(; i < (int)COLS_A; ++i)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005307 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005308#if defined(REINTERPRET_INPUT_AS_3D)
5309 // Load values from matrix A
5310 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
5311#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5312 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
5313#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5314#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5315 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
5316#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5317#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5318 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
5319#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5320#else // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005321 // Load values from matrix A
5322 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
5323#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5324 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
5325#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5326#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5327 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
5328#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5329#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5330 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
5331#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005332#endif // defined(REINTERPRET_INPUT_AS_3D)
5333
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005334 // Load values from matrix B
5335 float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005336 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005337
5338 // Multiply and accumulate
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005339 acc0.s0 = fma(a0, b0.s0, acc0.s0);
5340 acc0.s1 = fma(a0, b0.s1, acc0.s1);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005341#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005342 acc1.s0 = fma(a1, b0.s0, acc1.s0);
5343 acc1.s1 = fma(a1, b0.s1, acc1.s1);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005344#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5345#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005346 acc2.s0 = fma(a2, b0.s0, acc2.s0);
5347 acc2.s1 = fma(a2, b0.s1, acc2.s1);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005348#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5349#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005350 acc3.s0 = fma(a3, b0.s0, acc3.s0);
5351 acc3.s1 = fma(a3, b0.s1, acc3.s1);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005352#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005353
5354 src_addr.s0 += sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005355 }
5356
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005357 int z = get_global_id(2);
5358
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005359 // Compute destination address
5360 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
5361
Gian Marcoae2af742018-02-15 12:35:44 +00005362 // Compute dst address
5363 __global uchar *dst_addr = offset(&dst, 0, 0);
5364
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005365 uint4 zout = 0;
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00005366
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005367#if defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005368
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005369 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01005370 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005371 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01005372 // | |
5373 // | plane0 |
5374 // | |
5375 // |__________________|
5376 // |******************|
5377 // | cross_plane_pad |
5378 // |******************|
5379 // | |
5380 // | plane1 |
5381 // | |
5382 // |__________________|
Gian Marcoae2af742018-02-15 12:35:44 +00005383
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005384 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005385 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
5386 zout = min(DEPTH_GEMM3D - 1, zout);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005387
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01005388 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005389 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005390
5391 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
5392 // multiply dst_stride_z by DEPTH_GEMM3D
5393 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005394#else // defined(REINTERPRET_OUTPUT_AS_3D)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005395 // Add offset for batched GEMM
5396 dst_addr += z * dst_stride_z;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005397#endif // defined(REINTERPRET_OUTPUT_AS_3D)
5398
5399 // Multiply by the weight of matrix-matrix product and store the result
5400#if defined(ALPHA)
5401 SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, float, acc, ALPHA);
5402#endif // defined(ALPHA)
5403
5404 // Add beta*bias
5405#if defined(BETA)
5406 REPEAT_VAR_INIT_TO_CONST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, uint, zero, 0);
5407
5408#if defined(BROADCAST_BIAS)
5409 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)2 * sizeof(float));
5410
5411 LOAD_BLOCK(1, 2, float, bias, src2_addr, 0, src2_stride_y, zero);
5412
5413#ifndef UNIT_BETA
5414 SCALE_BLOCK(1, float, bias, BETA);
5415#endif // UNIT_BIAS
5416
5417 // acc = acc + bias[broadcasted]
5418 ADD_BLOCK_BROADCAST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias0);
5419
5420#else // defined(BROADCAST_BIAS)
5421 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)2 * sizeof(float)) + (get_global_id(1) *
5422 (uint)NUM_ELEMS_PROCESSED_PER_THREAD_Y * src2_stride_y) + get_global_id(2) * src2_stride_z;
5423
5424 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 2, float, bias, src2_addr, 0, src2_stride_y, zero);
5425
5426#ifndef UNIT_BETA
5427 SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, float, bias, BETA);
5428#endif // UNIT_BIAS
5429
5430 // acc = acc + bias
5431 ADD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias);
5432
5433#endif // defined(BROADCAST_BIAS)
5434#endif // defined(BETA)
5435
5436#if defined(ACTIVATION_TYPE)
5437 ACTIVATION_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, ACTIVATION_TYPE, float, acc, A_VAL, B_VAL);
5438#endif // defined(ACTIVATION_TYPE)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005439
5440 // Store the output block
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005441 vstore2(acc0, 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005442#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005443 vstore2(acc1, 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005444#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5445#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005446 vstore2(acc2, 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005447#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5448#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005449 vstore2(acc3, 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005450#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005451}
5452
Vidhya Sudhan Loganathanbdff4912018-05-22 15:03:09 +01005453#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005454/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
5455 *
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005456 * @note This OpenCL kernel works with the 16-bit floating point data type (half) and accumulating the result in a 32 floating point variable.
5457 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
5458 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
5459 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
5460 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005461 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
5462 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005463 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005464 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
5465 * The activation function is performed after the bias addition
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005466 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
5467 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
5468 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
5469 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
5470 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
5471 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
5472 *
5473 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
5474 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
5475 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5476 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
5477 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
5478 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
5479 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
5480 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
5481 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5482 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
5483 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
5484 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005485 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
5486 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
5487 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
5488 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
5489 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
5490 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005491 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
5492 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
5493 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
5494 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
5495 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
5496 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
5497 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
5498 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005499 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005500 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
5501 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
5502 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
5503 */
5504__kernel void gemm_mm_floating_point_f16_bifrost_acc32(IMAGE_DECLARATION(src0),
5505 IMAGE_DECLARATION(src1),
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005506#if defined(BETA)
5507 IMAGE_DECLARATION(src2),
5508#endif // defined(BETA)
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005509 IMAGE_DECLARATION(dst),
5510 uint src0_stride_z,
5511 uint src1_stride_z,
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005512#if defined(BETA)
5513 uint src2_stride_z,
5514#endif //defined(BETA)
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005515 uint dst_stride_z
5516#if defined(REINTERPRET_INPUT_AS_3D)
5517 ,
5518 uint src_cross_plane_pad
5519#endif // REINTERPRET_INPUT_AS_3D
5520#if defined(REINTERPRET_OUTPUT_AS_3D)
5521 ,
5522 uint dst_cross_plane_pad
5523#endif // REINTERPRET_OUTPUT_AS_3D
5524 )
5525{
5526 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
5527
5528 // Compute starting address for matrix A and Matrix B
5529 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
5530
5531 // Update address for the matrix A
5532 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
5533
5534 // Update address for the matrix B
5535 src_addr.s1 += idx * sizeof(half);
5536
5537#if defined(REINTERPRET_INPUT_AS_3D)
5538 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
5539 // in order to take into account the presence of possible cross plane paddings
5540 //
5541 // | |
5542 // | plane0 |
5543 // | |
5544 // |__________________|
5545 // |******************|
5546 // | cross_plane_pad |
5547 // |******************|
5548 // | |
5549 // | plane1 |
5550 // | |
5551 // |__________________|
5552
5553 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
5554 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
5555 zin = min(DEPTH_GEMM3D - 1, zin);
5556
5557 // Add offset due to the cross plane paddings
5558 zin *= (src_cross_plane_pad * src0_stride_y);
5559
5560 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
5561 // multiply src0_stride_z by DEPTH_GEMM3D
5562 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
5563
5564#else // defined(REINTERPRET_INPUT_AS_3D)
5565
5566 // Add offset for batched GEMM
5567 src_addr.s0 += get_global_id(2) * src0_stride_z;
5568
5569#endif // defined(REINTERPRET_INPUT_AS_3D)
5570
5571#if defined(MATRIX_B_DEPTH)
5572 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
5573 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
5574#else // defined(MATRIX_B_DEPTH)
5575 src_addr.s1 += get_global_id(2) * src1_stride_z;
5576#endif // defined(MATRIX_B_DEPTH)
5577
5578 float8 acc0 = 0.0h;
5579#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5580 float8 acc1 = 0.0h;
5581#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5582#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5583 float8 acc2 = 0.0h;
5584#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5585#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5586 float8 acc3 = 0.0h;
5587#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5588
5589 int i = 0;
5590 for(; i <= ((int)COLS_A - 4); i += 4)
5591 {
5592#if defined(REINTERPRET_INPUT_AS_3D)
5593 // Load values from matrix A
Usama Arif0681e3b2019-04-25 14:28:07 +01005594 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 4, half, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s);
5595#else // defined(REINTERPRET_INPUT_AS_3D)
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005596 // Load values from matrix A
5597 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
5598#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5599 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
5600#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5601#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5602 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
5603#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5604#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5605 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
5606#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5607#endif // defined(REINTERPRET_INPUT_AS_3D)
5608
5609 // Load values from matrix B
5610 float8 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
5611 src_addr.s1 += src1_stride_y;
5612
5613 // Accumulate
5614 acc0 = fma(b0, (float8)a0.s0, acc0);
5615#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5616 acc1 = fma(b0, (float8)a1.s0, acc1);
5617#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5618#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5619 acc2 = fma(b0, (float8)a2.s0, acc2);
5620#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5621#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5622 acc3 = fma(b0, (float8)a3.s0, acc3);
5623#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5624
5625 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
5626 src_addr.s1 += src1_stride_y;
5627 acc0 = fma(b0, (float8)a0.s1, acc0);
5628#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5629 acc1 = fma(b0, (float8)a1.s1, acc1);
5630#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5631#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5632 acc2 = fma(b0, (float8)a2.s1, acc2);
5633#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5634#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5635 acc3 = fma(b0, (float8)a3.s1, acc3);
5636#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5637
5638 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
5639 src_addr.s1 += src1_stride_y;
5640 acc0 = fma(b0, (float8)a0.s2, acc0);
5641#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5642 acc1 = fma(b0, (float8)a1.s2, acc1);
5643#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5644#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5645 acc2 = fma(b0, (float8)a2.s2, acc2);
5646#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5647#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5648 acc3 = fma(b0, (float8)a3.s2, acc3);
5649#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5650
5651 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
5652 src_addr.s1 += src1_stride_y;
5653 acc0 = fma(b0, (float8)a0.s3, acc0);
5654#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5655 acc1 = fma(b0, (float8)a1.s3, acc1);
5656#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5657#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5658 acc2 = fma(b0, (float8)a2.s3, acc2);
5659#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5660#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5661 acc3 = fma(b0, (float8)a3.s3, acc3);
5662#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5663
5664 src_addr.s0 += 4 * sizeof(half);
5665 }
5666
5667 for(; i < (int)COLS_A; ++i)
5668 {
5669#if defined(REINTERPRET_INPUT_AS_3D)
5670 // Load values from matrix A
5671 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
5672#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5673 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
5674#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5675#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5676 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
5677#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5678#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5679 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
5680#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5681#else // defined(REINTERPRET_INPUT_AS_3D)
5682 // Load values from matrix A
5683 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
5684#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5685 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
5686#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5687#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5688 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
5689#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5690#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5691 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
5692#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5693#endif // defined(REINTERPRET_INPUT_AS_3D)
5694
5695 // Load values from matrix B
5696 float8 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
5697
5698 src_addr += (int2)(sizeof(half), src1_stride_y);
5699
5700 // Accumulate
5701 acc0 = fma(b0, (float8)a0, acc0); // b0 * (half8)a0;
5702#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5703 acc1 = fma(b0, (float8)a1, acc1); // b0 * (half8)a1;
5704#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5705#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5706 acc2 = fma(b0, (float8)a2, acc2); // b0 * (half8)a2;
5707#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5708#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5709 acc3 = fma(b0, (float8)a3, acc3); // b0 * (half8)a3;
5710#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5711 }
5712
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005713 int z = get_global_id(2);
5714
5715 // Compute destination address
5716 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
5717
5718 // Compute dst address
5719 __global uchar *dst_addr = offset(&dst, 0, 0);
5720
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005721 uint4 zout = 0;
5722
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005723#if defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005724
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005725 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
5726 // in order to take into account the presence of possible cross plane paddings
5727 //
5728 // | |
5729 // | plane0 |
5730 // | |
5731 // |__________________|
5732 // |******************|
5733 // | cross_plane_pad |
5734 // |******************|
5735 // | |
5736 // | plane1 |
5737 // | |
5738 // |__________________|
5739
5740 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005741 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
5742 zout = min(DEPTH_GEMM3D - 1, zout);
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005743
5744 // Add offset due to the cross plane paddings
5745 zout *= (dst_cross_plane_pad * dst_stride_y);
5746
5747 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
5748 // multiply dst_stride_z by DEPTH_GEMM3D
5749 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005750#else // defined(REINTERPRET_OUTPUT_AS_3D)
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005751 // Add offset for batched GEMM
5752 dst_addr += z * dst_stride_z;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005753#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005754
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005755 // Multiply by the weight of matrix-matrix product and store the result
5756#if defined(ALPHA)
5757 SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, float, acc, ALPHA);
5758#endif // defined(ALPHA)
5759
5760#if defined(BETA)
5761 REPEAT_VAR_INIT_TO_CONST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, uint, zero, 0);
5762
5763#if defined(BROADCAST_BIAS)
5764 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half));
5765
5766 LOAD_BLOCK(1, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
5767
5768 float8 bias_f0 = convert_float8(bias0);
5769
5770#ifndef UNIT_BETA
5771 SCALE_BLOCK(1, float, bias_f, BETA);
5772#endif // UNIT_BIAS
5773
5774 // acc = acc + bias[broadcasted]
5775 ADD_BLOCK_BROADCAST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias_f0);
5776
5777#else // defined(BROADCAST_BIAS)
5778 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (get_global_id(1) *
5779 (uint)NUM_ELEMS_PROCESSED_PER_THREAD_Y * src2_stride_y) + get_global_id(2) * src2_stride_z;
5780
5781 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
5782
5783 float8 bias_f0 = convert_float8(bias0);
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005784#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005785 float8 bias_f1 = convert_float8(bias1);
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005786#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5787#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005788 float8 bias_f2 = convert_float8(bias2);
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005789#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5790#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005791 float8 bias_f3 = convert_float8(bias3);
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005792#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005793
5794#ifndef UNIT_BETA
5795 SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, float, bias_f, BETA);
5796#endif // UNIT_BIAS
5797
5798 // acc = acc + bias
5799 ADD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias_f);
5800
5801#endif // defined(BROADCAST_BIAS)
5802#endif // defined(BETA)
5803
5804 half8 acc_h0 = convert_half8(acc0);
5805#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5806 half8 acc_h1 = convert_half8(acc1);
5807#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5808#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5809 half8 acc_h2 = convert_half8(acc2);
5810#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5811#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5812 half8 acc_h3 = convert_half8(acc3);
5813#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5814
5815#if defined(ACTIVATION_TYPE)
5816 ACTIVATION_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, ACTIVATION_TYPE, half, acc_h, A_VAL, B_VAL);
5817#endif // defined(ACTIVATION_TYPE)
5818
5819 // Store the output block
5820 STORE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 8, half, acc_h, dst_addr, dst_stride_y, zout.s);
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00005821}
5822
5823/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
5824 *
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005825 * @note This OpenCL kernel works with the 16-bit floating point data type (half) and uses the fma units.
5826 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
5827 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
5828 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
5829 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005830 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
5831 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005832 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005833 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
5834 * The activation function is performed after the bias addition
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005835 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
5836 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005837 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
5838 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
5839 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
5840 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
5841 *
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005842 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
5843 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
5844 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5845 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
5846 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
5847 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
5848 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
5849 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
5850 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5851 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
5852 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
5853 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005854 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
5855 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
5856 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
5857 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
5858 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
5859 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005860 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
5861 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
5862 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
5863 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
5864 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
5865 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005866 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
5867 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005868 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005869 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005870 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
5871 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005872 */
5873__kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0),
5874 IMAGE_DECLARATION(src1),
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005875#if defined(BETA)
5876 IMAGE_DECLARATION(src2),
5877#endif // defined(BETA)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005878 IMAGE_DECLARATION(dst),
5879 uint src0_stride_z,
5880 uint src1_stride_z,
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005881#if defined(BETA)
5882 uint src2_stride_z,
5883#endif //defined(BETA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005884 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005885#if defined(REINTERPRET_INPUT_AS_3D)
5886 ,
5887 uint src_cross_plane_pad
5888#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005889#if defined(REINTERPRET_OUTPUT_AS_3D)
5890 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005891 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005892#endif // REINTERPRET_OUTPUT_AS_3D
5893 )
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005894{
5895 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
5896
5897 // Compute starting address for matrix A and Matrix B
5898 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
5899
5900 // Update address for the matrix A
5901 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
5902
5903 // Update address for the matrix B
5904 src_addr.s1 += idx * sizeof(half);
5905
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005906#if defined(REINTERPRET_INPUT_AS_3D)
5907 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
5908 // in order to take into account the presence of possible cross plane paddings
5909 //
5910 // | |
5911 // | plane0 |
5912 // | |
5913 // |__________________|
5914 // |******************|
5915 // | cross_plane_pad |
5916 // |******************|
5917 // | |
5918 // | plane1 |
5919 // | |
5920 // |__________________|
5921
5922 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
5923 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
5924 zin = min(DEPTH_GEMM3D - 1, zin);
5925
5926 // Add offset due to the cross plane paddings
5927 zin *= (src_cross_plane_pad * src0_stride_y);
5928
5929 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
5930 // multiply src0_stride_z by DEPTH_GEMM3D
5931 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
5932
5933#else // defined(REINTERPRET_INPUT_AS_3D)
5934
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005935 // Add offset for batched GEMM
5936 src_addr.s0 += get_global_id(2) * src0_stride_z;
5937
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005938#endif // defined(REINTERPRET_INPUT_AS_3D)
5939
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005940#if defined(MATRIX_B_DEPTH)
5941 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
5942 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
5943#else // defined(MATRIX_B_DEPTH)
5944 src_addr.s1 += get_global_id(2) * src1_stride_z;
5945#endif // defined(MATRIX_B_DEPTH)
5946
5947 half8 acc0 = 0.0h;
5948#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5949 half8 acc1 = 0.0h;
5950#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5951#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5952 half8 acc2 = 0.0h;
5953#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5954#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5955 half8 acc3 = 0.0h;
5956#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5957
5958 int i = 0;
5959 for(; i <= ((int)COLS_A - 4); i += 4)
5960 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005961#if defined(REINTERPRET_INPUT_AS_3D)
5962 // Load values from matrix A
Usama Arif0681e3b2019-04-25 14:28:07 +01005963 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 4, half, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s);
5964#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005965 // Load values from matrix A
5966 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
5967#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5968 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
5969#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5970#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5971 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
5972#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5973#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5974 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
5975#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005976#endif // defined(REINTERPRET_INPUT_AS_3D)
5977
Gian Marco Iodicefd683112018-04-17 09:52:44 +01005978 // Load values from matrix B
5979 half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
5980 src_addr.s1 += src1_stride_y;
5981
5982 // Accumulate
5983 acc0 = fma(b0, (half8)a0.s0, acc0);
5984#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5985 acc1 = fma(b0, (half8)a1.s0, acc1);
5986#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5987#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5988 acc2 = fma(b0, (half8)a2.s0, acc2);
5989#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5990#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5991 acc3 = fma(b0, (half8)a3.s0, acc3);
5992#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5993
5994 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
5995 src_addr.s1 += src1_stride_y;
5996 acc0 = fma(b0, (half8)a0.s1, acc0);
5997#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5998 acc1 = fma(b0, (half8)a1.s1, acc1);
5999#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6000#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6001 acc2 = fma(b0, (half8)a2.s1, acc2);
6002#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6003#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6004 acc3 = fma(b0, (half8)a3.s1, acc3);
6005#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6006
6007 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
6008 src_addr.s1 += src1_stride_y;
6009 acc0 = fma(b0, (half8)a0.s2, acc0);
6010#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6011 acc1 = fma(b0, (half8)a1.s2, acc1);
6012#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6013#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6014 acc2 = fma(b0, (half8)a2.s2, acc2);
6015#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6016#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6017 acc3 = fma(b0, (half8)a3.s2, acc3);
6018#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6019
6020 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
6021 src_addr.s1 += src1_stride_y;
6022 acc0 = fma(b0, (half8)a0.s3, acc0);
6023#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6024 acc1 = fma(b0, (half8)a1.s3, acc1);
6025#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6026#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6027 acc2 = fma(b0, (half8)a2.s3, acc2);
6028#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6029#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6030 acc3 = fma(b0, (half8)a3.s3, acc3);
6031#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6032
6033 src_addr.s0 += 4 * sizeof(half);
6034 }
6035
6036 for(; i < (int)COLS_A; ++i)
6037 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01006038#if defined(REINTERPRET_INPUT_AS_3D)
6039 // Load values from matrix A
6040 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
6041#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6042 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
6043#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6044#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6045 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
6046#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6047#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6048 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
6049#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6050#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006051 // Load values from matrix A
6052 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
6053#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6054 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
6055#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6056#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6057 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
6058#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6059#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6060 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
6061#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01006062#endif // defined(REINTERPRET_INPUT_AS_3D)
6063
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006064 // Load values from matrix B
6065 half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
6066
6067 src_addr += (int2)(sizeof(half), src1_stride_y);
6068
6069 // Accumulate
6070 acc0 = fma(b0, (half8)a0, acc0); // b0 * (half8)a0;
6071#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6072 acc1 = fma(b0, (half8)a1, acc1); // b0 * (half8)a1;
6073#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6074#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6075 acc2 = fma(b0, (half8)a2, acc2); // b0 * (half8)a2;
6076#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6077#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6078 acc3 = fma(b0, (half8)a3, acc3); // b0 * (half8)a3;
6079#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6080 }
6081
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006082 int z = get_global_id(2);
6083
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006084 // Compute destination address
6085 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
6086
6087 // Compute dst address
6088 __global uchar *dst_addr = offset(&dst, 0, 0);
6089
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006090 uint4 zout = 0;
6091
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006092#if defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006093
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006094 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01006095 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006096 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01006097 // | |
6098 // | plane0 |
6099 // | |
6100 // |__________________|
6101 // |******************|
6102 // | cross_plane_pad |
6103 // |******************|
6104 // | |
6105 // | plane1 |
6106 // | |
6107 // |__________________|
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006108
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006109 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006110 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
6111 zout = min(DEPTH_GEMM3D - 1, zout);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006112
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01006113 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01006114 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006115
6116 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
6117 // multiply dst_stride_z by DEPTH_GEMM3D
6118 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006119#else // defined(REINTERPRET_OUTPUT_AS_3D)
6120 // Add offset for batched GEMM
6121 dst_addr += z * dst_stride_z;
6122#endif // defined(REINTERPRET_OUTPUT_AS_3D)
6123
6124 // Multiply by the weight of matrix-matrix product and store the result
6125#if defined(ALPHA)
6126 SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, half, acc, ALPHA);
6127#endif // defined(ALPHA)
6128
6129 // Add beta*bias
6130#if defined(BETA)
6131 REPEAT_VAR_INIT_TO_CONST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, uint, zero, 0);
6132
6133#if defined(BROADCAST_BIAS)
6134 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half));
6135
6136 LOAD_BLOCK(1, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
6137
6138#ifndef UNIT_BETA
6139 SCALE_BLOCK(1, half, bias, BETA);
6140#endif // UNIT_BIAS
6141
6142 // acc = acc + bias[broadcasted]
6143 ADD_BLOCK_BROADCAST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias0);
6144
6145#else // defined(BROADCAST_BIAS)
6146 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (get_global_id(1) *
6147 (uint)NUM_ELEMS_PROCESSED_PER_THREAD_Y * src2_stride_y) + get_global_id(2) * src2_stride_z;
6148
6149 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
6150
6151#ifndef UNIT_BETA
6152 SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, half, bias, BETA);
6153#endif // UNIT_BIAS
6154
6155 // acc = acc + bias
6156 ADD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias);
6157
6158#endif // defined(BROADCAST_BIAS)
6159#endif // defined(BETA)
6160
6161#if defined(ACTIVATION_TYPE)
6162 ACTIVATION_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, ACTIVATION_TYPE, half, acc, A_VAL, B_VAL);
6163#endif // defined(ACTIVATION_TYPE)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006164
6165 // Store the output block
Usama Arif0681e3b2019-04-25 14:28:07 +01006166 STORE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 8, half, acc, dst_addr, dst_stride_y, zout.s);
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006167}
Vidhya Sudhan Loganathanbdff4912018-05-22 15:03:09 +01006168#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006169
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01006170#endif // defined(COLS_A) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && (NUM_ELEMS_PROCESSED_PER_THREAD_Y)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006171
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006172#if defined(BETA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006173/** This OpenCL kernel performs the in-place matrix addition between 2 matrices taking into account that the second matrix might be weighted by a scalar value beta:
6174 *
Gian Marco19835e52018-01-30 13:35:54 +00006175 * @note The beta's value need to be passed at compile time using -DBETA
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006176 *
6177 * @param[in] src_ptr Pointer to the source matrix. Supported data types: F32
6178 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
6179 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
6180 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
6181 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006182 * @param[in] src_stride_z Stride of the destination tensor in Z dimension (in bytes)
6183 * @param[in] src_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006184 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01006185 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006186 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
6187 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
6188 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
6189 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006190 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
6191 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006192 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
6193 */
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006194__kernel void gemm_ma_f32(TENSOR3D_DECLARATION(src),
6195 TENSOR3D_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006196{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006197 // Compute source and destination addresses
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006198 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
6199 Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006200
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006201 // Load values from A x B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006202 float4 alpha_ab = vload4(0, (__global float *)dst.ptr);
6203
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006204 // Load values from Matrix C
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006205 float4 c = vload4(0, (__global float *)src.ptr);
6206
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006207 // Computes alpha * axb + beta * c
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006208 float4 out = alpha_ab + (float4)BETA * c;
6209
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006210 // Store final result in axb matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006211 vstore4(out, 0, (__global float *)dst.ptr);
6212}
6213
Vidhya Sudhan Loganathan76c85642018-05-25 13:53:02 +01006214#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006215/** This OpenCL kernel performs the in-place matrix addition between 2 matrices taking into account that the second matrix might be weighted by a scalar value beta:
6216 *
Gian Marco19835e52018-01-30 13:35:54 +00006217 * @note The beta's value need to be passed at compile time using -DBETA
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01006218 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006219 * @param[in] src_ptr Pointer to the source matrix. Supported data types: F16
6220 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
6221 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
6222 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
6223 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006224 * @param[in] src_stride_z Stride of the destination tensor in Z dimension (in bytes)
6225 * @param[in] src_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006226 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01006227 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006228 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
6229 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
6230 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
6231 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006232 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
6233 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006234 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
6235 */
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006236__kernel void gemm_ma_f16(TENSOR3D_DECLARATION(src),
6237 TENSOR3D_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006238{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006239 // Compute source and destination addresses
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006240 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
6241 Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006242
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006243 // Load values from A x B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006244 half8 alpha_ab = vload8(0, (__global half *)dst.ptr);
6245
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006246 // Load values from Matrix C
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006247 half8 c = vload8(0, (__global half *)src.ptr);
6248
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006249 // Computes alpha * axb + beta * c
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006250 half8 out = alpha_ab + (half8)BETA * c;
6251
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006252 // Store final result in axb matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006253 vstore8(out, 0, (__global half *)dst.ptr);
6254}
Vidhya Sudhan Loganathan76c85642018-05-25 13:53:02 +01006255#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006256#endif // defined(BETA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006257
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006258#if defined(WIDTH_VECTOR_A)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006259/** This OpenCL kernel computes the vector by matrix multiplication between each row of A (src0) and matrix B (src1) used for locally connected layer
6260 *
Gian Marco19835e52018-01-30 13:35:54 +00006261 * @note The width of A need to be passed at compile time using -DWIDTH_VECTOR_A
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006262 *
Gian Marco19835e52018-01-30 13:35:54 +00006263 * @note The input A and matrix B must not be reshaped
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006264 *
6265 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
6266 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
6267 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
6268 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
6269 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
6270 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01006271 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006272 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
6273 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
6274 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
6275 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
6276 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
6277 * @param[in] src1_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
6278 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01006279 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006280 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
6281 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
6282 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
6283 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
6284 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
6285 */
6286__kernel void gemm_lc_vm_f32(IMAGE_DECLARATION(src0),
6287 TENSOR3D_DECLARATION(src1),
6288 IMAGE_DECLARATION(dst))
6289{
6290 int idx = get_global_id(0) * 4;
6291 int idy = get_global_id(1);
6292
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006293 // Compute the address for the vector A and matrix B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006294 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes + src0_stride_y * idy, src1_offset_first_element_in_bytes + src1_stride_z * idy));
6295 src_addr.s1 += idx * sizeof(float);
6296
6297 int end_row_vec_a = src_addr.s0 + (WIDTH_VECTOR_A * sizeof(float));
6298
6299 float4 acc = 0.0f;
6300
Georgios Pinitas96880cf2017-10-20 18:52:20 +01006301 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(float)); src_addr += (int2)(2 * sizeof(float), 2 * src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006302 {
6303 float2 a0 = vload2(0, (__global float *)(src0_ptr + src_addr.s0));
6304 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
6305 float4 b1 = vload4(0, (__global float *)(src1_ptr + src_addr.s1 + src1_stride_y));
6306
6307 acc += b0 * (float4)a0.s0;
6308 acc += b1 * (float4)a0.s1;
6309 }
6310
6311 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(float), src1_stride_y))
6312 {
6313 float a0 = *((__global float *)(src0_ptr + src_addr.s0));
6314 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
6315
6316 acc += b0 * (float4)a0;
6317 }
6318
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006319 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006320 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
6321
6322 vstore4(acc, 0, (__global float *)(offset(&dst, 0, 0)));
6323}
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006324#endif // defined(WIDTH_VECTOR_A)
6325
6326/** This kernel accumulates each row with the biases vector.
6327 *
6328 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=short.
6329 * @note The vector size must be passed at compile time using -DVECTOR_SIZE e.g. -DVECTOR_SIZE=16.
6330 *
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01006331 * @param[in, out] accum_ptr Pointer to the accumulate tensor. Supported data type: U8/S8/U16/S16/F16/U32/S32/F32
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006332 * @param[in] accum_stride_x Stride of the accmulate tensor in X dimension (in bytes)
6333 * @param[in] accum_step_x accum_stride_x * number of elements along X processed per workitem(in bytes)
6334 * @param[in] accum_stride_y Stride of the accumlulate tensor in Y dimension (in bytes)
6335 * @param[in] accum_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
6336 * @param[in] accum_offset_first_element_in_bytes The offset of the first element in the accumulate tensor
6337 * @param[in] biases_ptr Pointer to the biases vector. Same as @p accum_ptr
6338 * @param[in] biases_stride_x Stride of the destination tensor in X dimension (in bytes)
6339 * @param[in] biases_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
6340 * @param[in] biases_offset_first_element_in_bytes The offset of the first element in the destination tensor
6341 */
6342#if defined(DATA_TYPE) && defined(VECTOR_SIZE)
6343__kernel void gemm_accumulate_biases(
6344 IMAGE_DECLARATION(accum),
6345 VECTOR_DECLARATION(biases))
6346{
6347 Image accum = CONVERT_TO_IMAGE_STRUCT(accum);
6348 Vector biases = CONVERT_TO_VECTOR_STRUCT(biases);
6349
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006350 // Vector size, e.g. number of vector elements.
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006351 VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
6352 accum_value = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)accum.ptr);
6353 VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
6354 biases_value = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)biases.ptr);
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01006355 accum_value = biases_value + accum_value;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006356 // Store result in the accumulate buffer
6357 VSTORE(VECTOR_SIZE)
6358 (accum_value, 0, (__global DATA_TYPE *)accum.ptr);
6359}
6360#endif // defined(DATA_TYPE) && defined(VECTOR_SIZE)